From 667187e791ecde08ed97a435fcdf8a33d7bed075 Mon Sep 17 00:00:00 2001 From: Michael Neale Date: Tue, 7 Jan 2025 08:59:00 +1100 Subject: [PATCH] feat: openrouter provider (#538) --- crates/goose-server/src/configuration.rs | 37 ++++ crates/goose/src/providers.rs | 1 + crates/goose/src/providers/configs.rs | 1 + crates/goose/src/providers/factory.rs | 6 +- crates/goose/src/providers/model_pricing.rs | 13 ++ crates/goose/src/providers/openrouter.rs | 196 ++++++++++++++++++++ ui/desktop/src/components/ApiKeyWarning.tsx | 109 ++++++++--- ui/desktop/src/main.ts | 2 +- 8 files changed, 339 insertions(+), 26 deletions(-) create mode 100644 crates/goose/src/providers/openrouter.rs diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index 870b0fba0..8e8be4f89 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -51,6 +51,21 @@ pub enum ProviderSettings { #[serde(default)] estimate_factor: Option, }, + OpenRouter { + #[serde(default = "default_openrouter_host")] + host: String, + api_key: String, + #[serde(default = "default_model")] + model: String, + #[serde(default)] + temperature: Option, + #[serde(default)] + max_tokens: Option, + #[serde(default)] + context_limit: Option, + #[serde(default)] + estimate_factor: Option, + }, Databricks { #[serde(default = "default_databricks_host")] host: String, @@ -139,6 +154,7 @@ impl ProviderSettings { ProviderSettings::Google { .. } => ProviderType::Google, ProviderSettings::Groq { .. } => ProviderType::Groq, ProviderSettings::Anthropic { .. } => ProviderType::Anthropic, + ProviderSettings::OpenRouter { .. } => ProviderType::OpenRouter, } } @@ -162,6 +178,23 @@ impl ProviderSettings { .with_context_limit(context_limit) .with_estimate_factor(estimate_factor), }), + ProviderSettings::OpenRouter { + host, + api_key, + model, + temperature, + max_tokens, + context_limit, + estimate_factor, + } => ProviderConfig::OpenRouter(OpenAiProviderConfig { + host, + api_key, + model: ModelConfig::new(model) + .with_temperature(temperature) + .with_max_tokens(max_tokens) + .with_context_limit(context_limit) + .with_estimate_factor(estimate_factor), + }), ProviderSettings::Databricks { host, model, @@ -317,6 +350,10 @@ fn default_port() -> u16 { 3000 } +pub fn default_openrouter_host() -> String { + "https://openrouter.ai".to_string() +} + fn default_model() -> String { OPEN_AI_DEFAULT_MODEL.to_string() } diff --git a/crates/goose/src/providers.rs b/crates/goose/src/providers.rs index 6d564eeb8..aa34f537b 100644 --- a/crates/goose/src/providers.rs +++ b/crates/goose/src/providers.rs @@ -12,6 +12,7 @@ pub mod utils; pub mod google; pub mod groq; +pub mod openrouter; #[cfg(test)] pub mod mock; diff --git a/crates/goose/src/providers/configs.rs b/crates/goose/src/providers/configs.rs index 94f6d585d..0e613a2f5 100644 --- a/crates/goose/src/providers/configs.rs +++ b/crates/goose/src/providers/configs.rs @@ -15,6 +15,7 @@ pub enum ProviderConfig { Anthropic(AnthropicProviderConfig), Google(GoogleProviderConfig), Groq(GroqProviderConfig), + OpenRouter(OpenAiProviderConfig), } /// Configuration for model-specific settings and limits diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index 58ad7513b..4c7f0a6a9 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -1,7 +1,7 @@ use super::{ anthropic::AnthropicProvider, base::Provider, configs::ProviderConfig, databricks::DatabricksProvider, google::GoogleProvider, groq::GroqProvider, - ollama::OllamaProvider, openai::OpenAiProvider, + ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider, }; use anyhow::Result; use strum_macros::EnumIter; @@ -14,6 +14,7 @@ pub enum ProviderType { Anthropic, Google, Groq, + OpenRouter, } pub fn get_provider(config: ProviderConfig) -> Result> { @@ -28,5 +29,8 @@ pub fn get_provider(config: ProviderConfig) -> Result Ok(Box::new(GoogleProvider::new(google_config)?)), ProviderConfig::Groq(groq_config) => Ok(Box::new(GroqProvider::new(groq_config)?)), + ProviderConfig::OpenRouter(openrouter_config) => { + Ok(Box::new(OpenRouterProvider::new(openrouter_config)?)) + } } } diff --git a/crates/goose/src/providers/model_pricing.rs b/crates/goose/src/providers/model_pricing.rs index 0c3a15f7d..fc5b074a5 100644 --- a/crates/goose/src/providers/model_pricing.rs +++ b/crates/goose/src/providers/model_pricing.rs @@ -61,11 +61,24 @@ lazy_static::lazy_static! { input_token_price: dec!(15.00), output_token_price: dec!(75.00), }); + // OpenRouter Models + m.insert("anthropic/claude-3-sonnet".to_string(), Pricing { + input_token_price: dec!(3.00), + output_token_price: dec!(15.00), + }); + m.insert("claude-3-sonnet".to_string(), Pricing { + input_token_price: dec!(3.00), + output_token_price: dec!(15.00), + }); // OpenAI m.insert("gpt-4o".to_string(), Pricing { input_token_price: dec!(2.50), output_token_price: dec!(10.00), }); + m.insert("gpt-4".to_string(), Pricing { + input_token_price: dec!(2.50), + output_token_price: dec!(10.00), + }); m.insert("gpt-4o-2024-11-20".to_string(), Pricing { input_token_price: dec!(2.50), output_token_price: dec!(10.00), diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs new file mode 100644 index 000000000..e33885e32 --- /dev/null +++ b/crates/goose/src/providers/openrouter.rs @@ -0,0 +1,196 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use reqwest::Client; +use serde_json::Value; +use std::time::Duration; + +use super::base::ProviderUsage; +use super::base::{Provider, Usage}; +use super::configs::OpenAiProviderConfig; +use super::configs::{ModelConfig, ProviderModelConfig}; +use super::model_pricing::cost; +use super::model_pricing::model_pricing_for; +use super::utils::{get_model, handle_response}; +use crate::message::Message; +use crate::providers::openai_utils::{ + check_openai_context_length_error, create_openai_request_payload_with_concat_response_content, + get_openai_usage, openai_response_to_message, +}; +use mcp_core::tool::Tool; + +pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-3.5-sonnet"; + +pub struct OpenRouterProvider { + client: Client, + config: OpenAiProviderConfig, +} + +impl OpenRouterProvider { + pub fn new(config: OpenAiProviderConfig) -> Result { + let client = Client::builder() + .timeout(Duration::from_secs(600)) // 10 minutes timeout + .build()?; + + Ok(Self { client, config }) + } + + async fn post(&self, payload: Value) -> Result { + let url = format!( + "{}/api/v1/chat/completions", + self.config.host.trim_end_matches('/') + ); + + let response = self + .client + .post(&url) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", self.config.api_key)) + .header("HTTP-Referer", "https://github.com/block/goose") + .header("X-Title", "Goose") + .json(&payload) + .send() + .await?; + + handle_response(payload, response).await? + } +} + +#[async_trait] +impl Provider for OpenRouterProvider { + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } + + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage)> { + // Create the base payload + let payload = create_openai_request_payload_with_concat_response_content( + &self.config.model, + system, + messages, + tools, + )?; + + // Make request + let response = self.post(payload).await?; + + // Raise specific error if context length is exceeded + if let Some(error) = response.get("error") { + if let Some(err) = check_openai_context_length_error(error) { + return Err(err.into()); + } + return Err(anyhow!("OpenRouter API error: {}", error)); + } + + // Parse response + let message = openai_response_to_message(response.clone())?; + let usage = self.get_usage(&response)?; + let model = get_model(&response); + let cost = cost(&usage, &model_pricing_for(&model)); + + Ok((message, ProviderUsage::new(model, usage, cost))) + } + + fn get_usage(&self, data: &Value) -> Result { + get_openai_usage(data) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::message::MessageContent; + use crate::providers::configs::ModelConfig; + use crate::providers::mock_server::{ + create_mock_open_ai_response, create_mock_open_ai_response_with_tools, create_test_tool, + get_expected_function_call_arguments, setup_mock_server, TEST_INPUT_TOKENS, + TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS, + }; + use rust_decimal_macros::dec; + use wiremock::MockServer; + + async fn _setup_mock_response(response_body: Value) -> (MockServer, OpenRouterProvider) { + let mock_server = setup_mock_server("/api/v1/chat/completions", response_body).await; + + // Create the OpenRouterProvider with the mock server's URL as the host + let config = OpenAiProviderConfig { + host: mock_server.uri(), + api_key: "test_api_key".to_string(), + model: ModelConfig::new("gpt-3.5-turbo".to_string()).with_temperature(Some(0.7)), + }; + + let provider = OpenRouterProvider::new(config).unwrap(); + (mock_server, provider) + } + + #[tokio::test] + async fn test_complete_basic() -> Result<()> { + let model_name = "gpt-4"; + // Mock response for normal completion + let response_body = + create_mock_open_ai_response(model_name, "Hello! How can I assist you today?"); + + let (_, provider) = _setup_mock_response(response_body).await; + + // Prepare input messages + let messages = vec![Message::user().with_text("Hello?")]; + + // Call the complete method + let (message, usage) = provider + .complete("You are a helpful assistant.", &messages, &[]) + .await?; + + // Assert the response + if let MessageContent::Text(text) = &message.content[0] { + assert_eq!(text.text, "Hello! How can I assist you today?"); + } else { + panic!("Expected Text content"); + } + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); + assert_eq!(usage.model, model_name); + assert_eq!(usage.cost, Some(dec!(0.00018))); + + Ok(()) + } + + #[tokio::test] + async fn test_complete_tool_request() -> Result<()> { + // Mock response for tool calling + let response_body = create_mock_open_ai_response_with_tools("gpt-4"); + + let (_, provider) = _setup_mock_response(response_body).await; + + // Input messages + let messages = vec![Message::user().with_text("What's the weather in San Francisco?")]; + + // Call the complete method + let (message, usage) = provider + .complete( + "You are a helpful assistant.", + &messages, + &[create_test_tool()], + ) + .await?; + + // Assert the response + if let MessageContent::ToolRequest(tool_request) = &message.content[0] { + let tool_call = tool_request.tool_call.as_ref().unwrap(); + assert_eq!(tool_call.name, TEST_TOOL_FUNCTION_NAME); + assert_eq!(tool_call.arguments, get_expected_function_call_arguments()); + } else { + panic!("Expected ToolCall content"); + } + + assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS)); + assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS)); + assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS)); + + Ok(()) + } +} diff --git a/ui/desktop/src/components/ApiKeyWarning.tsx b/ui/desktop/src/components/ApiKeyWarning.tsx index 5ff3a438b..ffe05b01d 100644 --- a/ui/desktop/src/components/ApiKeyWarning.tsx +++ b/ui/desktop/src/components/ApiKeyWarning.tsx @@ -1,42 +1,103 @@ import React from 'react'; import { Card } from './ui/card'; import { Bird } from './ui/icons'; +import { ChevronDown } from 'lucide-react'; interface ApiKeyWarningProps { className?: string; } +interface CollapsibleProps { + title: string; + children: React.ReactNode; + defaultOpen?: boolean; +} + +function Collapsible({ title, children, defaultOpen = false }: CollapsibleProps) { + const [isOpen, setIsOpen] = React.useState(defaultOpen); + + return ( +
+ + {isOpen && ( +
+ {children} +
+ )} +
+ ); +} + +const OPENAI_CONFIG = `export GOOSE_PROVIDER__TYPE=openai +export GOOSE_PROVIDER__HOST=https://api.openai.com +export GOOSE_PROVIDER__MODEL=gpt-4 +export GOOSE_PROVIDER__API_KEY=your_api_key_here`; + +const ANTHROPIC_CONFIG = `export GOOSE_PROVIDER__TYPE=anthropic +export GOOSE_PROVIDER__HOST=https://api.anthropic.com +export GOOSE_PROVIDER__MODEL=claude-3-sonnet +export GOOSE_PROVIDER__API_KEY=your_api_key_here`; + +const DATABRICKS_CONFIG = `export GOOSE_PROVIDER__TYPE=databricks +export GOOSE_PROVIDER__HOST=your_databricks_host +export GOOSE_PROVIDER__MODEL=claude-3-sonnet-2`; + +const OPENROUTER_CONFIG = `export GOOSE_PROVIDER__TYPE=openrouter +export GOOSE_PROVIDER__HOST=https://openrouter.ai +export GOOSE_PROVIDER__MODEL=anthropic/claude-3-sonnet +export GOOSE_PROVIDER__API_KEY=your_api_key_here`; + export function ApiKeyWarning({ className }: ApiKeyWarningProps) { return ( - +
-
+

API Key Required

-
- To use Goose, you need to set some combination of the following env variables -
-
- # OpenAI -
-
- export GOOSE_PROVIDER__TYPE=openai
- GOOSE_PROVIDER__HOST=https://api.openai.com
- GOOSE_PROVIDER__MODEL=gpt-4o
- GOOSE_PROVIDER__API_KEY=...
-
-
- # Databricks + Claude -
-
- export GOOSE_PROVIDER__TYPE=databricks
- export GOOSE_PROVIDER__HOST=...
- export GOOSE_PROVIDER__MODEL="claude-3-5-sonnet-2"
-
-
- Please export these and restart the application. +

+ To use Goose, you need to set environment variables for one of the following providers: +

+ +
+ +
+              {OPENAI_CONFIG}
+            
+
+ + +
+              {ANTHROPIC_CONFIG}
+            
+
+ + +
+              {DATABRICKS_CONFIG}
+            
+
+ + +
+              {OPENROUTER_CONFIG}
+            
+
+ +

+ After setting these variables, restart Goose for the changes to take effect. +

); diff --git a/ui/desktop/src/main.ts b/ui/desktop/src/main.ts index 058132746..d15b3ccc1 100644 --- a/ui/desktop/src/main.ts +++ b/ui/desktop/src/main.ts @@ -39,7 +39,7 @@ const checkApiCredentials = () => { //{env-macro-start}// const apiKeyProvidersValid = - ['openai', 'anthropic', 'google', 'groq'].includes(process.env.GOOSE_PROVIDER__TYPE) && + ['openai', 'anthropic', 'google', 'groq', 'openrouter'].includes(process.env.GOOSE_PROVIDER__TYPE) && process.env.GOOSE_PROVIDER__HOST && process.env.GOOSE_PROVIDER__MODEL && process.env.GOOSE_PROVIDER__API_KEY;