Skip to content

Commit

Permalink
update goose agent
Browse files Browse the repository at this point in the history
  • Loading branch information
salman1993 committed Jan 10, 2025
1 parent f6b874c commit 646d61d
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use chrono::{DateTime, TimeZone, Utc};
use mcp_client::McpService;
use rust_decimal_macros::dec;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::LazyLock;
use std::time::Duration;
use tokio::sync::Mutex;
use tracing::{debug, instrument};

use super::system::{SystemConfig, SystemError, SystemInfo, SystemResult};
use crate::prompt_template::load_prompt_file;
use crate::providers::base::{Provider, ProviderUsage};
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient};
use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait};
use mcp_client::transport::{SseTransport, StdioTransport, Transport};
use mcp_core::{Content, Tool, ToolCall, ToolError, ToolResult};

Expand All @@ -20,7 +22,7 @@ static DEFAULT_TIMESTAMP: LazyLock<DateTime<Utc>> =

/// Manages MCP clients and their interactions
pub struct Capabilities {
clients: HashMap<String, Arc<Mutex<McpClient>>>,
clients: HashMap<String, Arc<Mutex<Box<dyn McpClientTrait>>>>,
instructions: HashMap<String, String>,
provider: Box<dyn Provider>,
provider_usage: Mutex<Vec<ProviderUsage>>,
Expand Down Expand Up @@ -87,14 +89,18 @@ impl Capabilities {
/// Add a new MCP system based on the provided client type
// TODO IMPORTANT need to ensure this times out if the system command is broken!
pub async fn add_system(&mut self, config: SystemConfig) -> SystemResult<()> {
let mut client: McpClient = match config {
let mut client: Box<dyn McpClientTrait> = match config {
SystemConfig::Sse { ref uri } => {
let transport = SseTransport::new(uri);
McpClient::new(transport.start().await?)
let handle = transport.start().await?;
let service = McpService::with_timeout(handle, Duration::from_secs(10));
Box::new(McpClient::new(service))
}
SystemConfig::Stdio { ref cmd, ref args } => {
let transport = StdioTransport::new(cmd, args.to_vec());
McpClient::new(transport.start().await?)
let handle = transport.start().await?;
let service = McpService::with_timeout(handle, Duration::from_secs(10));
Box::new(McpClient::new(service))
}
};

Expand Down Expand Up @@ -258,7 +264,10 @@ impl Capabilities {
}

/// Find and return a reference to the appropriate client for a tool call
fn get_client_for_tool(&self, prefixed_name: &str) -> Option<Arc<Mutex<McpClient>>> {
fn get_client_for_tool(
&self,
prefixed_name: &str,
) -> Option<Arc<Mutex<Box<dyn McpClientTrait>>>> {
prefixed_name
.split_once("__")
.and_then(|(client_name, _)| self.clients.get(client_name))
Expand Down

0 comments on commit 646d61d

Please sign in to comment.