diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index 4007896a..ba527bd7 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -88,12 +88,19 @@ impl Capabilities { // TODO IMPORTANT need to ensure this times out if the system command is broken! pub async fn add_system(&mut self, config: SystemConfig) -> SystemResult<()> { let mut client: McpClient = match config { - SystemConfig::Sse { ref uri } => { - let transport = SseTransport::new(uri); + SystemConfig::Sse { + ref uri, + ref secrets, + } => { + let transport = SseTransport::new(uri, secrets.get_env()); McpClient::new(transport.start().await?) } - SystemConfig::Stdio { ref cmd, ref args } => { - let transport = StdioTransport::new(cmd, args.to_vec()); + SystemConfig::Stdio { + ref cmd, + ref args, + ref secrets, + } => { + let transport = StdioTransport::new(cmd, args.to_vec(), secrets.get_env()); McpClient::new(transport.start().await?) } }; diff --git a/crates/goose/src/agents/system.rs b/crates/goose/src/agents/system.rs index 022064a6..2ff0077b 100644 --- a/crates/goose/src/agents/system.rs +++ b/crates/goose/src/agents/system.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use mcp_client::client::Error as ClientError; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -15,25 +17,63 @@ pub enum SystemError { pub type SystemResult = Result; +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +pub struct Secrets { + /// A map of environment variables to set, e.g. API_KEY -> some_secret + #[serde(default)] + #[serde(flatten)] + map: HashMap, +} + +impl Secrets { + pub fn new(map: HashMap) -> Self { + Self { map } + } + + pub fn default() -> Self { + Self::new(HashMap::new()) + } + + pub fn get_env(&self) -> HashMap { + self.map + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect() + } +} + /// Represents the different types of MCP systems that can be added to the manager #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(tag = "type")] pub enum SystemConfig { /// Server-sent events client with a URI endpoint - Sse { uri: String }, + Sse { + uri: String, + #[serde(default)] + secrets: Secrets, + }, /// Standard I/O client with command and arguments - Stdio { cmd: String, args: Vec }, + Stdio { + cmd: String, + args: Vec, + #[serde(default)] + secrets: Secrets, + }, } impl SystemConfig { pub fn sse>(uri: S) -> Self { - Self::Sse { uri: uri.into() } + Self::Sse { + uri: uri.into(), + secrets: Secrets::default(), + } } pub fn stdio>(cmd: S) -> Self { Self::Stdio { cmd: cmd.into(), args: vec![], + secrets: Secrets::default(), } } @@ -43,8 +83,9 @@ impl SystemConfig { S: Into, { match self { - Self::Stdio { cmd, .. } => Self::Stdio { + Self::Stdio { cmd, secrets, .. } => Self::Stdio { cmd, + secrets, args: args.into_iter().map(Into::into).collect(), }, other => other, @@ -55,8 +96,8 @@ impl SystemConfig { impl std::fmt::Display for SystemConfig { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - SystemConfig::Sse { uri } => write!(f, "SSE({})", uri), - SystemConfig::Stdio { cmd, args } => write!(f, "Stdio({} {})", cmd, args.join(" ")), + SystemConfig::Sse { uri, .. } => write!(f, "SSE({})", uri), + SystemConfig::Stdio { cmd, args, .. } => write!(f, "Stdio({} {})", cmd, args.join(" ")), } } } diff --git a/crates/mcp-client/examples/clients.rs b/crates/mcp-client/examples/clients.rs index 0f208085..dec3beef 100644 --- a/crates/mcp-client/examples/clients.rs +++ b/crates/mcp-client/examples/clients.rs @@ -4,8 +4,8 @@ use mcp_client::{ }; use rand::Rng; use rand::SeedableRng; -use std::sync::Arc; use std::time::Duration; +use std::{collections::HashMap, sync::Arc}; use tracing_subscriber::EnvFilter; #[tokio::main] @@ -122,7 +122,7 @@ async fn create_stdio_client( _name: &str, _version: &str, ) -> Result> { - let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()]); + let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); Ok(McpClient::new(transport.start().await?)) } @@ -130,6 +130,6 @@ async fn create_sse_client( _name: &str, _version: &str, ) -> Result> { - let transport = SseTransport::new("http://localhost:8000/sse"); + let transport = SseTransport::new("http://localhost:8000/sse", HashMap::new()); Ok(McpClient::new(transport.start().await?)) } diff --git a/crates/mcp-client/examples/sse.rs b/crates/mcp-client/examples/sse.rs index f6c79563..1e2db772 100644 --- a/crates/mcp-client/examples/sse.rs +++ b/crates/mcp-client/examples/sse.rs @@ -1,6 +1,7 @@ use anyhow::Result; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient}; use mcp_client::transport::{SseTransport, Transport}; +use std::collections::HashMap; use std::time::Duration; use tracing_subscriber::EnvFilter; @@ -16,7 +17,7 @@ async fn main() -> Result<()> { .init(); // Create the base transport - let transport = SseTransport::new("http://localhost:8000/sse"); + let transport = SseTransport::new("http://localhost:8000/sse", HashMap::new()); // Start transport let handle = transport.start().await?; diff --git a/crates/mcp-client/examples/stdio.rs b/crates/mcp-client/examples/stdio.rs index 52d7ee14..0053e4f8 100644 --- a/crates/mcp-client/examples/stdio.rs +++ b/crates/mcp-client/examples/stdio.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use anyhow::Result; use mcp_client::client::{ClientCapabilities, ClientInfo, Error as ClientError, McpClient}; use mcp_client::transport::{StdioTransport, Transport}; @@ -15,7 +17,7 @@ async fn main() -> Result<(), ClientError> { .init(); // 1) Create the transport - let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()]); + let transport = StdioTransport::new("uvx", vec!["mcp-server-git".to_string()], HashMap::new()); // 2) Start the transport to get a handle let transport_handle = transport.start().await?; diff --git a/crates/mcp-client/examples/stdio_integration.rs b/crates/mcp-client/examples/stdio_integration.rs index 4b2da6ce..49990cb7 100644 --- a/crates/mcp-client/examples/stdio_integration.rs +++ b/crates/mcp-client/examples/stdio_integration.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + // This example shows how to use the mcp-client crate to interact with a server that has a simple counter tool. // The server is started by running `cargo run -p mcp-server` in the root of the mcp-server crate. use anyhow::Result; @@ -23,6 +25,7 @@ async fn main() -> Result<(), ClientError> { .into_iter() .map(|s| s.to_string()) .collect(), + HashMap::new(), ); // Start the transport to get a handle diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index 74e4ab84..70f40a58 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -4,6 +4,7 @@ use eventsource_client::{Client, SSE}; use futures::TryStreamExt; use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest}; use reqwest::Client as HttpClient; +use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; use tokio::time::{timeout, Duration}; @@ -205,13 +206,15 @@ impl SseActor { #[derive(Clone)] pub struct SseTransport { sse_url: String, + env: HashMap, } /// The SSE transport spawns an `SseActor` on `start()`. impl SseTransport { - pub fn new>(sse_url: S) -> Self { + pub fn new>(sse_url: S, env: HashMap) -> Self { Self { sse_url: sse_url.into(), + env: env, } } @@ -238,6 +241,11 @@ impl SseTransport { #[async_trait] impl Transport for SseTransport { async fn start(&self) -> Result { + // Set environment variables + for (key, value) in &self.env { + std::env::set_var(key, value); + } + // Create a channel for outgoing TransportMessages let (tx, rx) = mpsc::channel(32); diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index fb77a2b1..c470b7f4 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::sync::Arc; use tokio::process::{Child, ChildStdin, ChildStdout, Command}; @@ -103,18 +104,25 @@ impl StdioActor { pub struct StdioTransport { command: String, args: Vec, + env: HashMap, } impl StdioTransport { - pub fn new>(command: S, args: Vec) -> Self { + pub fn new>( + command: S, + args: Vec, + env: HashMap, + ) -> Self { Self { command: command.into(), args, + env: env, } } async fn spawn_process(&self) -> Result<(Child, ChildStdin, ChildStdout), Error> { let mut process = Command::new(&self.command) + .envs(&self.env) .args(&self.args) .stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped())