358 lines
11 KiB
Rust
358 lines
11 KiB
Rust
//! MCP protocol server implementation
|
|
//!
|
|
//! Implements the Model Context Protocol for stdio-based communication
|
|
//! with AI agents and tool orchestrators.
|
|
|
|
use crate::tools::McpGateTools;
|
|
use crate::types::*;
|
|
use cognitum_gate_tilezero::{GateThresholds, TileZero};
|
|
use std::sync::Arc;
|
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
|
use tokio::sync::RwLock;
|
|
use tracing::{debug, error, info, warn};
|
|
|
|
/// MCP Gate Server
|
|
pub struct McpGateServer {
|
|
/// Tools handler
|
|
tools: McpGateTools,
|
|
/// Server info
|
|
server_info: ServerInfo,
|
|
}
|
|
|
|
/// Server information
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct ServerInfo {
|
|
/// Server name
|
|
pub name: String,
|
|
/// Server version
|
|
pub version: String,
|
|
/// Protocol version
|
|
pub protocol_version: String,
|
|
}
|
|
|
|
impl Default for ServerInfo {
|
|
fn default() -> Self {
|
|
Self {
|
|
name: "mcp-gate".to_string(),
|
|
version: env!("CARGO_PKG_VERSION").to_string(),
|
|
protocol_version: "2024-11-05".to_string(),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Server capabilities
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct ServerCapabilities {
|
|
/// Tool capabilities
|
|
pub tools: ToolCapabilities,
|
|
}
|
|
|
|
/// Tool capabilities
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct ToolCapabilities {
|
|
/// Whether tool listing changes are supported
|
|
#[serde(rename = "listChanged")]
|
|
pub list_changed: bool,
|
|
}
|
|
|
|
impl Default for ServerCapabilities {
|
|
fn default() -> Self {
|
|
Self {
|
|
tools: ToolCapabilities {
|
|
list_changed: false,
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
impl McpGateServer {
|
|
/// Create a new server with default configuration
|
|
pub fn new() -> Self {
|
|
let thresholds = GateThresholds::default();
|
|
let tilezero = Arc::new(RwLock::new(TileZero::new(thresholds)));
|
|
Self {
|
|
tools: McpGateTools::new(tilezero),
|
|
server_info: ServerInfo::default(),
|
|
}
|
|
}
|
|
|
|
/// Create a new server with custom thresholds
|
|
pub fn with_thresholds(thresholds: GateThresholds) -> Self {
|
|
let tilezero = Arc::new(RwLock::new(TileZero::new(thresholds)));
|
|
Self {
|
|
tools: McpGateTools::new(tilezero),
|
|
server_info: ServerInfo::default(),
|
|
}
|
|
}
|
|
|
|
/// Create a new server with a shared TileZero instance
|
|
pub fn with_tilezero(tilezero: Arc<RwLock<TileZero>>) -> Self {
|
|
Self {
|
|
tools: McpGateTools::new(tilezero),
|
|
server_info: ServerInfo::default(),
|
|
}
|
|
}
|
|
|
|
/// Run the server on stdio
|
|
pub async fn run_stdio(&self) -> Result<(), std::io::Error> {
|
|
info!("Starting MCP Gate server on stdio");
|
|
|
|
let stdin = tokio::io::stdin();
|
|
let mut stdout = tokio::io::stdout();
|
|
let reader = BufReader::new(stdin);
|
|
let mut lines = reader.lines();
|
|
|
|
while let Ok(Some(line)) = lines.next_line().await {
|
|
if line.trim().is_empty() {
|
|
continue;
|
|
}
|
|
|
|
debug!("Received: {}", line);
|
|
|
|
let response = self.handle_message(&line).await;
|
|
|
|
if let Some(resp) = response {
|
|
let resp_json = serde_json::to_string(&resp).unwrap_or_default();
|
|
debug!("Sending: {}", resp_json);
|
|
stdout.write_all(resp_json.as_bytes()).await?;
|
|
stdout.write_all(b"\n").await?;
|
|
stdout.flush().await?;
|
|
}
|
|
}
|
|
|
|
info!("MCP Gate server shutting down");
|
|
Ok(())
|
|
}
|
|
|
|
/// Handle a single message
|
|
async fn handle_message(&self, message: &str) -> Option<JsonRpcResponse> {
|
|
let request: JsonRpcRequest = match serde_json::from_str(message) {
|
|
Ok(req) => req,
|
|
Err(e) => {
|
|
error!("Failed to parse request: {}", e);
|
|
return Some(JsonRpcResponse::error(
|
|
serde_json::Value::Null,
|
|
-32700,
|
|
format!("Parse error: {}", e),
|
|
));
|
|
}
|
|
};
|
|
|
|
let result = self.handle_request(&request).await;
|
|
Some(result)
|
|
}
|
|
|
|
/// Handle a JSON-RPC request
|
|
async fn handle_request(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
|
|
match request.method.as_str() {
|
|
"initialize" => self.handle_initialize(request),
|
|
"initialized" => {
|
|
// Notification, no response needed
|
|
JsonRpcResponse::success(request.id.clone(), serde_json::json!({}))
|
|
}
|
|
"tools/list" => self.handle_tools_list(request),
|
|
"tools/call" => self.handle_tools_call(request).await,
|
|
"shutdown" => {
|
|
info!("Received shutdown request");
|
|
JsonRpcResponse::success(request.id.clone(), serde_json::json!({}))
|
|
}
|
|
_ => {
|
|
warn!("Unknown method: {}", request.method);
|
|
JsonRpcResponse::error(
|
|
request.id.clone(),
|
|
-32601,
|
|
format!("Method not found: {}", request.method),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle initialize request
|
|
fn handle_initialize(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
|
|
info!("Handling initialize request");
|
|
|
|
let result = serde_json::json!({
|
|
"protocolVersion": self.server_info.protocol_version,
|
|
"capabilities": ServerCapabilities::default(),
|
|
"serverInfo": {
|
|
"name": self.server_info.name,
|
|
"version": self.server_info.version
|
|
}
|
|
});
|
|
|
|
JsonRpcResponse::success(request.id.clone(), result)
|
|
}
|
|
|
|
/// Handle tools/list request
|
|
fn handle_tools_list(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
|
|
info!("Handling tools/list request");
|
|
|
|
let tools = McpGateTools::list_tools();
|
|
let result = serde_json::json!({
|
|
"tools": tools
|
|
});
|
|
|
|
JsonRpcResponse::success(request.id.clone(), result)
|
|
}
|
|
|
|
/// Handle tools/call request
|
|
async fn handle_tools_call(&self, request: &JsonRpcRequest) -> JsonRpcResponse {
|
|
info!("Handling tools/call request");
|
|
|
|
// Parse the tool call from params
|
|
let tool_call: McpToolCall = match serde_json::from_value(request.params.clone()) {
|
|
Ok(tc) => tc,
|
|
Err(e) => {
|
|
return JsonRpcResponse::error(
|
|
request.id.clone(),
|
|
-32602,
|
|
format!("Invalid params: {}", e),
|
|
);
|
|
}
|
|
};
|
|
|
|
// Call the tool
|
|
match self.tools.call_tool(tool_call).await {
|
|
Ok(result) => {
|
|
let response_content = match result {
|
|
McpToolResult::Success { content } => serde_json::json!({
|
|
"content": [{
|
|
"type": "text",
|
|
"text": serde_json::to_string_pretty(&content).unwrap_or_default()
|
|
}]
|
|
}),
|
|
McpToolResult::Error { error } => serde_json::json!({
|
|
"content": [{
|
|
"type": "text",
|
|
"text": error
|
|
}],
|
|
"isError": true
|
|
}),
|
|
};
|
|
JsonRpcResponse::success(request.id.clone(), response_content)
|
|
}
|
|
Err(e) => JsonRpcResponse::error(request.id.clone(), e.code(), e.to_string()),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for McpGateServer {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
/// Configuration for the MCP Gate server
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct McpGateConfig {
|
|
/// Gate thresholds
|
|
#[serde(default)]
|
|
pub thresholds: GateThresholds,
|
|
/// Log level
|
|
#[serde(default = "default_log_level")]
|
|
pub log_level: String,
|
|
}
|
|
|
|
fn default_log_level() -> String {
|
|
"info".to_string()
|
|
}
|
|
|
|
impl Default for McpGateConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
thresholds: GateThresholds::default(),
|
|
log_level: default_log_level(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_server_info_default() {
|
|
let info = ServerInfo::default();
|
|
assert_eq!(info.name, "mcp-gate");
|
|
assert_eq!(info.protocol_version, "2024-11-05");
|
|
}
|
|
|
|
#[test]
|
|
fn test_server_capabilities_default() {
|
|
let caps = ServerCapabilities::default();
|
|
assert!(!caps.tools.list_changed);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_handle_initialize() {
|
|
let server = McpGateServer::new();
|
|
let request = JsonRpcRequest {
|
|
jsonrpc: "2.0".to_string(),
|
|
id: serde_json::json!(1),
|
|
method: "initialize".to_string(),
|
|
params: serde_json::json!({}),
|
|
};
|
|
|
|
let response = server.handle_request(&request).await;
|
|
assert!(response.result.is_some());
|
|
assert!(response.error.is_none());
|
|
|
|
let result = response.result.unwrap();
|
|
assert_eq!(result["protocolVersion"], "2024-11-05");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_handle_tools_list() {
|
|
let server = McpGateServer::new();
|
|
let request = JsonRpcRequest {
|
|
jsonrpc: "2.0".to_string(),
|
|
id: serde_json::json!(1),
|
|
method: "tools/list".to_string(),
|
|
params: serde_json::json!({}),
|
|
};
|
|
|
|
let response = server.handle_request(&request).await;
|
|
assert!(response.result.is_some());
|
|
|
|
let result = response.result.unwrap();
|
|
let tools = result["tools"].as_array().unwrap();
|
|
assert_eq!(tools.len(), 3);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_handle_tools_call() {
|
|
let server = McpGateServer::new();
|
|
let request = JsonRpcRequest {
|
|
jsonrpc: "2.0".to_string(),
|
|
id: serde_json::json!(1),
|
|
method: "tools/call".to_string(),
|
|
params: serde_json::json!({
|
|
"name": "permit_action",
|
|
"arguments": {
|
|
"action_id": "test-1",
|
|
"action_type": "config_change"
|
|
}
|
|
}),
|
|
};
|
|
|
|
let response = server.handle_request(&request).await;
|
|
assert!(response.result.is_some());
|
|
assert!(response.error.is_none());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_handle_unknown_method() {
|
|
let server = McpGateServer::new();
|
|
let request = JsonRpcRequest {
|
|
jsonrpc: "2.0".to_string(),
|
|
id: serde_json::json!(1),
|
|
method: "unknown/method".to_string(),
|
|
params: serde_json::json!({}),
|
|
};
|
|
|
|
let response = server.handle_request(&request).await;
|
|
assert!(response.error.is_some());
|
|
assert_eq!(response.error.unwrap().code, -32601);
|
|
}
|
|
}
|