Skip to content

Commit

Permalink
feat: openrouter provider (#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelneale authored Jan 6, 2025
1 parent 44671d5 commit 667187e
Show file tree
Hide file tree
Showing 8 changed files with 339 additions and 26 deletions.
37 changes: 37 additions & 0 deletions crates/goose-server/src/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@ pub enum ProviderSettings {
#[serde(default)]
estimate_factor: Option<f32>,
},
OpenRouter {
#[serde(default = "default_openrouter_host")]
host: String,
api_key: String,
#[serde(default = "default_model")]
model: String,
#[serde(default)]
temperature: Option<f32>,
#[serde(default)]
max_tokens: Option<i32>,
#[serde(default)]
context_limit: Option<usize>,
#[serde(default)]
estimate_factor: Option<f32>,
},
Databricks {
#[serde(default = "default_databricks_host")]
host: String,
Expand Down Expand Up @@ -139,6 +154,7 @@ impl ProviderSettings {
ProviderSettings::Google { .. } => ProviderType::Google,
ProviderSettings::Groq { .. } => ProviderType::Groq,
ProviderSettings::Anthropic { .. } => ProviderType::Anthropic,
ProviderSettings::OpenRouter { .. } => ProviderType::OpenRouter,
}
}

Expand All @@ -162,6 +178,23 @@ impl ProviderSettings {
.with_context_limit(context_limit)
.with_estimate_factor(estimate_factor),
}),
ProviderSettings::OpenRouter {
host,
api_key,
model,
temperature,
max_tokens,
context_limit,
estimate_factor,
} => ProviderConfig::OpenRouter(OpenAiProviderConfig {
host,
api_key,
model: ModelConfig::new(model)
.with_temperature(temperature)
.with_max_tokens(max_tokens)
.with_context_limit(context_limit)
.with_estimate_factor(estimate_factor),
}),
ProviderSettings::Databricks {
host,
model,
Expand Down Expand Up @@ -317,6 +350,10 @@ fn default_port() -> u16 {
3000
}

pub fn default_openrouter_host() -> String {
"https://openrouter.ai".to_string()
}

fn default_model() -> String {
OPEN_AI_DEFAULT_MODEL.to_string()
}
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub mod utils;

pub mod google;
pub mod groq;
pub mod openrouter;

#[cfg(test)]
pub mod mock;
Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/providers/configs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub enum ProviderConfig {
Anthropic(AnthropicProviderConfig),
Google(GoogleProviderConfig),
Groq(GroqProviderConfig),
OpenRouter(OpenAiProviderConfig),
}

/// Configuration for model-specific settings and limits
Expand Down
6 changes: 5 additions & 1 deletion crates/goose/src/providers/factory.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{
anthropic::AnthropicProvider, base::Provider, configs::ProviderConfig,
databricks::DatabricksProvider, google::GoogleProvider, groq::GroqProvider,
ollama::OllamaProvider, openai::OpenAiProvider,
ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider,
};
use anyhow::Result;
use strum_macros::EnumIter;
Expand All @@ -14,6 +14,7 @@ pub enum ProviderType {
Anthropic,
Google,
Groq,
OpenRouter,
}

pub fn get_provider(config: ProviderConfig) -> Result<Box<dyn Provider + Send + Sync>> {
Expand All @@ -28,5 +29,8 @@ pub fn get_provider(config: ProviderConfig) -> Result<Box<dyn Provider + Send +
}
ProviderConfig::Google(google_config) => Ok(Box::new(GoogleProvider::new(google_config)?)),
ProviderConfig::Groq(groq_config) => Ok(Box::new(GroqProvider::new(groq_config)?)),
ProviderConfig::OpenRouter(openrouter_config) => {
Ok(Box::new(OpenRouterProvider::new(openrouter_config)?))
}
}
}
13 changes: 13 additions & 0 deletions crates/goose/src/providers/model_pricing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,24 @@ lazy_static::lazy_static! {
input_token_price: dec!(15.00),
output_token_price: dec!(75.00),
});
// OpenRouter Models
m.insert("anthropic/claude-3-sonnet".to_string(), Pricing {
input_token_price: dec!(3.00),
output_token_price: dec!(15.00),
});
m.insert("claude-3-sonnet".to_string(), Pricing {
input_token_price: dec!(3.00),
output_token_price: dec!(15.00),
});
// OpenAI
m.insert("gpt-4o".to_string(), Pricing {
input_token_price: dec!(2.50),
output_token_price: dec!(10.00),
});
m.insert("gpt-4".to_string(), Pricing {
input_token_price: dec!(2.50),
output_token_price: dec!(10.00),
});
m.insert("gpt-4o-2024-11-20".to_string(), Pricing {
input_token_price: dec!(2.50),
output_token_price: dec!(10.00),
Expand Down
196 changes: 196 additions & 0 deletions crates/goose/src/providers/openrouter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use reqwest::Client;
use serde_json::Value;
use std::time::Duration;

use super::base::ProviderUsage;
use super::base::{Provider, Usage};
use super::configs::OpenAiProviderConfig;
use super::configs::{ModelConfig, ProviderModelConfig};
use super::model_pricing::cost;
use super::model_pricing::model_pricing_for;
use super::utils::{get_model, handle_response};
use crate::message::Message;
use crate::providers::openai_utils::{
check_openai_context_length_error, create_openai_request_payload_with_concat_response_content,
get_openai_usage, openai_response_to_message,
};
use mcp_core::tool::Tool;

pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-3.5-sonnet";

pub struct OpenRouterProvider {
client: Client,
config: OpenAiProviderConfig,
}

impl OpenRouterProvider {
pub fn new(config: OpenAiProviderConfig) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(600)) // 10 minutes timeout
.build()?;

Ok(Self { client, config })
}

async fn post(&self, payload: Value) -> Result<Value> {
let url = format!(
"{}/api/v1/chat/completions",
self.config.host.trim_end_matches('/')
);

let response = self
.client
.post(&url)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("HTTP-Referer", "https://github.com/block/goose")
.header("X-Title", "Goose")
.json(&payload)
.send()
.await?;

handle_response(payload, response).await?
}
}

#[async_trait]
impl Provider for OpenRouterProvider {
fn get_model_config(&self) -> &ModelConfig {
self.config.model_config()
}

async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage)> {
// Create the base payload
let payload = create_openai_request_payload_with_concat_response_content(
&self.config.model,
system,
messages,
tools,
)?;

// Make request
let response = self.post(payload).await?;

// Raise specific error if context length is exceeded
if let Some(error) = response.get("error") {
if let Some(err) = check_openai_context_length_error(error) {
return Err(err.into());
}
return Err(anyhow!("OpenRouter API error: {}", error));
}

// Parse response
let message = openai_response_to_message(response.clone())?;
let usage = self.get_usage(&response)?;
let model = get_model(&response);
let cost = cost(&usage, &model_pricing_for(&model));

Ok((message, ProviderUsage::new(model, usage, cost)))
}

fn get_usage(&self, data: &Value) -> Result<Usage> {
get_openai_usage(data)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::message::MessageContent;
use crate::providers::configs::ModelConfig;
use crate::providers::mock_server::{
create_mock_open_ai_response, create_mock_open_ai_response_with_tools, create_test_tool,
get_expected_function_call_arguments, setup_mock_server, TEST_INPUT_TOKENS,
TEST_OUTPUT_TOKENS, TEST_TOOL_FUNCTION_NAME, TEST_TOTAL_TOKENS,
};
use rust_decimal_macros::dec;
use wiremock::MockServer;

async fn _setup_mock_response(response_body: Value) -> (MockServer, OpenRouterProvider) {
let mock_server = setup_mock_server("/api/v1/chat/completions", response_body).await;

// Create the OpenRouterProvider with the mock server's URL as the host
let config = OpenAiProviderConfig {
host: mock_server.uri(),
api_key: "test_api_key".to_string(),
model: ModelConfig::new("gpt-3.5-turbo".to_string()).with_temperature(Some(0.7)),
};

let provider = OpenRouterProvider::new(config).unwrap();
(mock_server, provider)
}

#[tokio::test]
async fn test_complete_basic() -> Result<()> {
let model_name = "gpt-4";
// Mock response for normal completion
let response_body =
create_mock_open_ai_response(model_name, "Hello! How can I assist you today?");

let (_, provider) = _setup_mock_response(response_body).await;

// Prepare input messages
let messages = vec![Message::user().with_text("Hello?")];

// Call the complete method
let (message, usage) = provider
.complete("You are a helpful assistant.", &messages, &[])
.await?;

// Assert the response
if let MessageContent::Text(text) = &message.content[0] {
assert_eq!(text.text, "Hello! How can I assist you today?");
} else {
panic!("Expected Text content");
}
assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS));
assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS));
assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS));
assert_eq!(usage.model, model_name);
assert_eq!(usage.cost, Some(dec!(0.00018)));

Ok(())
}

#[tokio::test]
async fn test_complete_tool_request() -> Result<()> {
// Mock response for tool calling
let response_body = create_mock_open_ai_response_with_tools("gpt-4");

let (_, provider) = _setup_mock_response(response_body).await;

// Input messages
let messages = vec![Message::user().with_text("What's the weather in San Francisco?")];

// Call the complete method
let (message, usage) = provider
.complete(
"You are a helpful assistant.",
&messages,
&[create_test_tool()],
)
.await?;

// Assert the response
if let MessageContent::ToolRequest(tool_request) = &message.content[0] {
let tool_call = tool_request.tool_call.as_ref().unwrap();
assert_eq!(tool_call.name, TEST_TOOL_FUNCTION_NAME);
assert_eq!(tool_call.arguments, get_expected_function_call_arguments());
} else {
panic!("Expected ToolCall content");
}

assert_eq!(usage.usage.input_tokens, Some(TEST_INPUT_TOKENS));
assert_eq!(usage.usage.output_tokens, Some(TEST_OUTPUT_TOKENS));
assert_eq!(usage.usage.total_tokens, Some(TEST_TOTAL_TOKENS));

Ok(())
}
}
Loading

0 comments on commit 667187e

Please sign in to comment.