diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index cea2b0db..da6cb402 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -33,7 +33,7 @@ impl Capabilities { /// Add a new MCP system based on the provided client type // TODO IMPORTANT need to ensure this times out if the system command is broken! pub async fn add_system(&mut self, config: SystemConfig) -> SystemResult<()> { - let client: McpClient = match config { + let mut client: McpClient = match config { SystemConfig::Sse { ref uri } => { let transport = SseTransport::new(uri); McpClient::new(transport.start().await?) diff --git a/crates/goose/src/agents/default.rs b/crates/goose/src/agents/default.rs index b032c56c..9b6916a3 100644 --- a/crates/goose/src/agents/default.rs +++ b/crates/goose/src/agents/default.rs @@ -58,7 +58,6 @@ impl DefaultAgent { &resources, Some(model_name), ); - let mut status_content: Vec = Vec::new(); if approx_count > target_limit { @@ -217,6 +216,7 @@ impl Agent for DefaultAgent { } // Update conversation history for the start of the reply + let resources = capabilities.get_resources().await?; let mut messages = self .prepare_inference( &system_prompt, @@ -224,8 +224,12 @@ impl Agent for DefaultAgent { messages, &Vec::new(), estimated_limit, - &capabilities.provider().get_model_config().model_name, - &capabilities.get_resources().await?, + &capabilities + .provider() + .get_model_config() + .model_name + .clone(), + &resources, ) .await?; diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs index bec3232a..f6c79563 100644 --- a/crates/mcp-client/examples/sse.rs +++ b/crates/mcp-client/examples/sse.rs @@ -22,7 +22,7 @@ async fn main() -> Result<()> { let handle = transport.start().await?; // Create client - let client = McpClient::new(handle); + let mut client = McpClient::new(handle); println!("Client created\n"); // Initialize diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs index 1876f096..52d7ee14 100644 --- a/crates/mcp-client/examples/stdio.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -21,7 +21,7 @@ async fn main() -> Result<(), ClientError> { let transport_handle = transport.start().await?; // 3) Create the client - let client = McpClient::new(transport_handle); + let mut client = McpClient::new(transport_handle); // Initialize let server_info = client @@ -45,5 +45,9 @@ async fn main() -> Result<(), ClientError> { .await?; println!("Tool result: {tool_result:?}\n"); + // List resources + let resources = client.list_resources().await?; + println!("Available resources: {resources:?}\n"); + Ok(()) } diff --git a/crates/mcp-client/examples/stdio_integration.rs b/crates/mcp-client/examples/stdio_integration.rs index 615b2f46..4b2da6ce 100644 --- a/crates/mcp-client/examples/stdio_integration.rs +++ b/crates/mcp-client/examples/stdio_integration.rs @@ -29,7 +29,7 @@ async fn main() -> Result<(), ClientError> { let transport_handle = transport.start().await.unwrap(); // Create client - let client = McpClient::new(transport_handle); + let mut client = McpClient::new(transport_handle); // Initialize let server_info = client diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index ff02738f..81e64610 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -1,16 +1,16 @@ use std::sync::atomic::{AtomicU64, Ordering}; +use crate::transport::TransportHandle; use mcp_core::protocol::{ CallToolResult, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, ReadResourceResult, + ServerCapabilities, METHOD_NOT_FOUND, }; use serde::{Deserialize, Serialize}; use serde_json::Value; use thiserror::Error; use tokio::sync::Mutex; -use tower::{Service, ServiceExt}; - -use crate::transport::TransportHandle; // for Service::ready() +use tower::{Service, ServiceExt}; // for Service::ready() /// Error type for MCP client operations. #[derive(Debug, Error)] @@ -27,6 +27,9 @@ pub enum Error { #[error("Unexpected response from server")] UnexpectedResponse, + #[error("Not initialized")] + NotInitialized, + #[error("Timeout or service not ready")] NotReady, } @@ -55,6 +58,7 @@ pub struct InitializeParams { pub struct McpClient { service: Mutex, next_id: AtomicU64, + server_capabilities: Option, } impl McpClient { @@ -63,6 +67,7 @@ impl McpClient { Self { service: Mutex::new(transport_handle), next_id: AtomicU64::new(1), + server_capabilities: None, // set during initialization } } @@ -135,7 +140,7 @@ impl McpClient { } pub async fn initialize( - &self, + &mut self, info: ClientInfo, capabilities: ClientCapabilities, ) -> Result { @@ -151,24 +156,80 @@ impl McpClient { self.send_notification("notifications/initialized", serde_json::json!({})) .await?; + self.server_capabilities = Some(result.capabilities.clone()); + Ok(result) } + fn completed_initialization(&self) -> bool { + self.server_capabilities.is_some() + } + pub async fn list_resources(&self) -> Result { + if !self.completed_initialization() { + return Err(Error::NotInitialized); + } + // If resources is not supported, return an empty list + if self + .server_capabilities + .as_ref() + .unwrap() + .resources + .is_none() + { + return Ok(ListResourcesResult { resources: vec![] }); + } + self.send_request("resources/list", serde_json::json!({})) .await } pub async fn read_resource(&self, uri: &str) -> Result { + if !self.completed_initialization() { + return Err(Error::NotInitialized); + } + // If resources is not supported, return an error + if self + .server_capabilities + .as_ref() + .unwrap() + .resources + .is_none() + { + return Err(Error::RpcError { + code: METHOD_NOT_FOUND, + message: "Server does not support 'resources' capability".to_string(), + }); + } + let params = serde_json::json!({ "uri": uri }); self.send_request("resources/read", params).await } pub async fn list_tools(&self) -> Result { + if !self.completed_initialization() { + return Err(Error::NotInitialized); + } + // If tools is not supported, return an empty list + if self.server_capabilities.as_ref().unwrap().tools.is_none() { + return Ok(ListToolsResult { tools: vec![] }); + } + self.send_request("tools/list", serde_json::json!({})).await } pub async fn call_tool(&self, name: &str, arguments: Value) -> Result { + if !self.completed_initialization() { + return Err(Error::NotInitialized); + } + // If tools is not supported, return an error + if self.server_capabilities.as_ref().unwrap().tools.is_none() { + return Err(Error::RpcError { + code: METHOD_NOT_FOUND, + message: "Server does not support 'tools' capability".to_string(), + }); + } + let params = serde_json::json!({ "name": name, "arguments": arguments }); self.send_request("tools/call", params).await }