Skip to content

Commit

Permalink
Add ErrorPart and Moderation Error struct
Browse files Browse the repository at this point in the history
  • Loading branch information
zakiali committed Jan 10, 2025
1 parent e99b3d4 commit fb6f1d6
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 37 deletions.
96 changes: 76 additions & 20 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ use axum::{
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use goose::message::{Message, MessageContent};
use goose::providers::base::{Moderation, ModerationError, ModerationResult};
use mcp_core::{content::Content, role::Role};
use serde::Deserialize;
use serde_json::{json, Value};
use serde_json::{error, json, Value};
use std::{
convert::Infallible,
pin::Pin,
Expand Down Expand Up @@ -159,6 +160,20 @@ impl ProtocolFormatter {
format!("a:{}\n", response)
}

fn format_error(error: &str) -> String {
// Error messages start with "3:" in the new protocol.
format!("3:{}\n", error)
}

fn format_moderation_error(error: &ModerationError) -> String {
let error_part = match error {
ModerationError::ContentFlagged { categories, .. } => {
format!("Content was flagged in the following categories: {}", categories)
}
};
format!("3:{}\n", error_part)
}

fn format_finish(reason: &str) -> String {
// Finish messages start with "d:"
let finish = json!({
Expand Down Expand Up @@ -193,8 +208,10 @@ async fn stream_message(
.await?;
}
Err(err) => {
let result =
vec![Content::text(format!("Error {}", err)).with_priority(0.0)];
// Send an error message first
tx.send(ProtocolFormatter::format_error(&err.to_string())).await?;
// Then send an empty tool response to maintain the protocol
let result = vec![Content::text("Error occurred").with_priority(0.0)];
tx.send(ProtocolFormatter::format_tool_response(
&response.id,
&result,
Expand All @@ -209,22 +226,27 @@ async fn stream_message(
for content in message.content {
match content {
MessageContent::ToolRequest(request) => {
if let Ok(tool_call) = request.tool_call {
tx.send(ProtocolFormatter::format_tool_call(
&request.id,
&tool_call.name,
&tool_call.arguments,
))
.await?;
} else {
// if the llm generates an invalid object tool call, we still have
// to include it in the history. It always comes with a response indicating the error
tx.send(ProtocolFormatter::format_tool_call(
&request.id,
"invalid name",
&json!({}),
))
.await?;
match request.tool_call {
Ok(tool_call) => {
tx.send(ProtocolFormatter::format_tool_call(
&request.id,
&tool_call.name,
&tool_call.arguments,
))
.await?;
}
Err(err) => {
println!("Error: {}", err);
// Send error message for invalid tool call
tx.send(ProtocolFormatter::format_error(&err.to_string())).await?;
// Send a placeholder tool call to maintain protocol
tx.send(ProtocolFormatter::format_tool_call(
&request.id,
"invalid_tool",
&json!({"error": err.to_string()}),
))
.await?;
}
}
}
MessageContent::Text(text) => {
Expand Down Expand Up @@ -278,6 +300,13 @@ async fn handler(
Ok(stream) => stream,
Err(e) => {
tracing::error!("Failed to start reply stream: {}", e);
// Check if it's a moderation error
if let Some(moderation_error) = e.downcast_ref::<ModerationError>() {
let _ = tx.send(ProtocolFormatter::format_moderation_error(moderation_error)).await;
} else {
// Send a generic error message
let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await;
}
// Send a finish message with error as the reason
let _ = tx.send(ProtocolFormatter::format_finish("error")).await;
return;
Expand All @@ -291,11 +320,18 @@ async fn handler(
Ok(Some(Ok(message))) => {
if let Err(e) = stream_message(message, &tx).await {
tracing::error!("Error sending message through channel: {}", e);
let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await;
break;
}
}
Ok(Some(Err(e))) => {
tracing::error!("Error processing message: {}", e);
// Check if it's a moderation error
if let Some(moderation_error) = e.downcast_ref::<ModerationError>() {
let _ = tx.send(ProtocolFormatter::format_moderation_error(moderation_error)).await;
} else {
let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await;
}
break;
}
Ok(None) => {
Expand Down Expand Up @@ -503,6 +539,26 @@ mod tests {
assert!(formatted.starts_with("a:"));
assert!(formatted.contains("\"toolCallId\":\"123\""));

// Test error formatting
let formatted = ProtocolFormatter::format_error("Test error");
assert!(formatted.starts_with("3:"));
assert!(formatted.contains("\"message\":\"Test error\""));
assert!(formatted.contains("\"code\":\"server_error\""));

// Test moderation error formatting
let moderation_error = ModerationError::ContentFlagged {
categories: "hate, violence".to_string(),
category_scores: Some(json!({
"hate": 0.9,
"violence": 0.8
})),
};
let formatted = ProtocolFormatter::format_moderation_error(&moderation_error);
assert!(formatted.starts_with("3:"));
assert!(formatted.contains("\"code\":\"content_flagged\""));
assert!(formatted.contains("\"categories\":\"hate, violence\""));
assert!(formatted.contains("\"categoryScores\""));

// Test finish formatting
let formatted = ProtocolFormatter::format_finish("stop");
assert!(formatted.starts_with("d:"));
Expand Down Expand Up @@ -560,4 +616,4 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK);
}
}
}
}
40 changes: 23 additions & 17 deletions crates/goose/src/providers/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,22 @@ use std::collections::HashMap;
use std::sync::Arc;
use tokio::select;
use tokio::sync::RwLock;
use thiserror::Error;

use super::configs::ModelConfig;
use crate::message::{Message, MessageContent};
use mcp_core::role::Role;
use mcp_core::tool::Tool;

#[derive(Error, Debug)]
pub enum ModerationError {
#[error("Content was flagged for moderation in categories: {categories}")]
ContentFlagged {
categories: String,
category_scores: Option<serde_json::Value>,
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderUsage {
pub model: String,
Expand Down Expand Up @@ -197,10 +207,10 @@ pub trait Provider: Send + Sync + Moderation {
let categories = result.categories
.unwrap_or_else(|| vec!["unknown".to_string()])
.join(", ");
return Err(anyhow::anyhow!(
"Content was flagged for moderation in categories: {}",
categories
));
return Err(ModerationError::ContentFlagged {
categories,
category_scores: result.category_scores,
}.into());
}

// Moderation passed, wait for completion
Expand All @@ -215,10 +225,10 @@ pub trait Provider: Send + Sync + Moderation {
let categories = moderation_result.categories
.unwrap_or_else(|| vec!["unknown".to_string()])
.join(", ");
return Err(anyhow::anyhow!(
"Content was flagged for moderation in categories: {}",
categories
));
return Err(ModerationError::ContentFlagged {
categories,
category_scores: moderation_result.category_scores,
}.into());
}

Ok(completion_result)
Expand Down Expand Up @@ -338,10 +348,8 @@ mod tests {
let result = provider.complete("system", &[test_message], &[]).await;

assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Content was flagged"));
let err = result.unwrap_err();
assert!(err.downcast_ref::<ModerationError>().is_some());
}

#[tokio::test]
Expand Down Expand Up @@ -407,10 +415,8 @@ mod tests {
let result = provider.complete("system", &[test_message], &[]).await;

assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Content was flagged"));
let err = result.unwrap_err();
assert!(err.downcast_ref::<ModerationError>().is_some());
}

#[tokio::test]
Expand Down Expand Up @@ -647,4 +653,4 @@ mod tests {
*count
);
}
}
}

0 comments on commit fb6f1d6

Please sign in to comment.