From fb6f1d6e6cebec800af9a946da9f6b5cf01ed01e Mon Sep 17 00:00:00 2001 From: Zaki Ali Date: Fri, 10 Jan 2025 09:20:43 -0800 Subject: [PATCH] Add ErrorPart and Moderation Error struct --- crates/goose-server/src/routes/reply.rs | 96 +++++++++++++++++++------ crates/goose/src/providers/base.rs | 40 ++++++----- 2 files changed, 99 insertions(+), 37 deletions(-) diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index 141a584f..be9cc870 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -9,9 +9,10 @@ use axum::{ use bytes::Bytes; use futures::{stream::StreamExt, Stream}; use goose::message::{Message, MessageContent}; +use goose::providers::base::{Moderation, ModerationError, ModerationResult}; use mcp_core::{content::Content, role::Role}; use serde::Deserialize; -use serde_json::{json, Value}; +use serde_json::{error, json, Value}; use std::{ convert::Infallible, pin::Pin, @@ -159,6 +160,20 @@ impl ProtocolFormatter { format!("a:{}\n", response) } + fn format_error(error: &str) -> String { + // Error messages start with "3:" in the new protocol. + format!("3:{}\n", error) + } + + fn format_moderation_error(error: &ModerationError) -> String { + let error_part = match error { + ModerationError::ContentFlagged { categories, .. } => { + format!("Content was flagged in the following categories: {}", categories) + } + }; + format!("3:{}\n", error_part) + } + fn format_finish(reason: &str) -> String { // Finish messages start with "d:" let finish = json!({ @@ -193,8 +208,10 @@ async fn stream_message( .await?; } Err(err) => { - let result = - vec![Content::text(format!("Error {}", err)).with_priority(0.0)]; + // Send an error message first + tx.send(ProtocolFormatter::format_error(&err.to_string())).await?; + // Then send an empty tool response to maintain the protocol + let result = vec![Content::text("Error occurred").with_priority(0.0)]; tx.send(ProtocolFormatter::format_tool_response( &response.id, &result, @@ -209,22 +226,27 @@ async fn stream_message( for content in message.content { match content { MessageContent::ToolRequest(request) => { - if let Ok(tool_call) = request.tool_call { - tx.send(ProtocolFormatter::format_tool_call( - &request.id, - &tool_call.name, - &tool_call.arguments, - )) - .await?; - } else { - // if the llm generates an invalid object tool call, we still have - // to include it in the history. It always comes with a response indicating the error - tx.send(ProtocolFormatter::format_tool_call( - &request.id, - "invalid name", - &json!({}), - )) - .await?; + match request.tool_call { + Ok(tool_call) => { + tx.send(ProtocolFormatter::format_tool_call( + &request.id, + &tool_call.name, + &tool_call.arguments, + )) + .await?; + } + Err(err) => { + println!("Error: {}", err); + // Send error message for invalid tool call + tx.send(ProtocolFormatter::format_error(&err.to_string())).await?; + // Send a placeholder tool call to maintain protocol + tx.send(ProtocolFormatter::format_tool_call( + &request.id, + "invalid_tool", + &json!({"error": err.to_string()}), + )) + .await?; + } } } MessageContent::Text(text) => { @@ -278,6 +300,13 @@ async fn handler( Ok(stream) => stream, Err(e) => { tracing::error!("Failed to start reply stream: {}", e); + // Check if it's a moderation error + if let Some(moderation_error) = e.downcast_ref::() { + let _ = tx.send(ProtocolFormatter::format_moderation_error(moderation_error)).await; + } else { + // Send a generic error message + let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await; + } // Send a finish message with error as the reason let _ = tx.send(ProtocolFormatter::format_finish("error")).await; return; @@ -291,11 +320,18 @@ async fn handler( Ok(Some(Ok(message))) => { if let Err(e) = stream_message(message, &tx).await { tracing::error!("Error sending message through channel: {}", e); + let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await; break; } } Ok(Some(Err(e))) => { tracing::error!("Error processing message: {}", e); + // Check if it's a moderation error + if let Some(moderation_error) = e.downcast_ref::() { + let _ = tx.send(ProtocolFormatter::format_moderation_error(moderation_error)).await; + } else { + let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await; + } break; } Ok(None) => { @@ -503,6 +539,26 @@ mod tests { assert!(formatted.starts_with("a:")); assert!(formatted.contains("\"toolCallId\":\"123\"")); + // Test error formatting + let formatted = ProtocolFormatter::format_error("Test error"); + assert!(formatted.starts_with("3:")); + assert!(formatted.contains("\"message\":\"Test error\"")); + assert!(formatted.contains("\"code\":\"server_error\"")); + + // Test moderation error formatting + let moderation_error = ModerationError::ContentFlagged { + categories: "hate, violence".to_string(), + category_scores: Some(json!({ + "hate": 0.9, + "violence": 0.8 + })), + }; + let formatted = ProtocolFormatter::format_moderation_error(&moderation_error); + assert!(formatted.starts_with("3:")); + assert!(formatted.contains("\"code\":\"content_flagged\"")); + assert!(formatted.contains("\"categories\":\"hate, violence\"")); + assert!(formatted.contains("\"categoryScores\"")); + // Test finish formatting let formatted = ProtocolFormatter::format_finish("stop"); assert!(formatted.starts_with("d:")); @@ -560,4 +616,4 @@ mod tests { assert_eq!(response.status(), StatusCode::OK); } } -} +} \ No newline at end of file diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index e00ac933..fa31a136 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -6,12 +6,22 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::select; use tokio::sync::RwLock; +use thiserror::Error; use super::configs::ModelConfig; use crate::message::{Message, MessageContent}; use mcp_core::role::Role; use mcp_core::tool::Tool; +#[derive(Error, Debug)] +pub enum ModerationError { + #[error("Content was flagged for moderation in categories: {categories}")] + ContentFlagged { + categories: String, + category_scores: Option, + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ProviderUsage { pub model: String, @@ -197,10 +207,10 @@ pub trait Provider: Send + Sync + Moderation { let categories = result.categories .unwrap_or_else(|| vec!["unknown".to_string()]) .join(", "); - return Err(anyhow::anyhow!( - "Content was flagged for moderation in categories: {}", - categories - )); + return Err(ModerationError::ContentFlagged { + categories, + category_scores: result.category_scores, + }.into()); } // Moderation passed, wait for completion @@ -215,10 +225,10 @@ pub trait Provider: Send + Sync + Moderation { let categories = moderation_result.categories .unwrap_or_else(|| vec!["unknown".to_string()]) .join(", "); - return Err(anyhow::anyhow!( - "Content was flagged for moderation in categories: {}", - categories - )); + return Err(ModerationError::ContentFlagged { + categories, + category_scores: moderation_result.category_scores, + }.into()); } Ok(completion_result) @@ -338,10 +348,8 @@ mod tests { let result = provider.complete("system", &[test_message], &[]).await; assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Content was flagged")); + let err = result.unwrap_err(); + assert!(err.downcast_ref::().is_some()); } #[tokio::test] @@ -407,10 +415,8 @@ mod tests { let result = provider.complete("system", &[test_message], &[]).await; assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Content was flagged")); + let err = result.unwrap_err(); + assert!(err.downcast_ref::().is_some()); } #[tokio::test] @@ -647,4 +653,4 @@ mod tests { *count ); } -} +} \ No newline at end of file