Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
zakiali committed Jan 10, 2025
1 parent fb6f1d6 commit f299217
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
21 changes: 15 additions & 6 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ impl ProtocolFormatter {
fn format_moderation_error(error: &ModerationError) -> String {
let error_part = match error {
ModerationError::ContentFlagged { categories, .. } => {
format!("Content was flagged in the following categories: {}", categories)
format!(
"Content was flagged in the following categories: {}",
categories
)
}
};
format!("3:{}\n", error_part)
Expand Down Expand Up @@ -209,7 +212,8 @@ async fn stream_message(
}
Err(err) => {
// Send an error message first
tx.send(ProtocolFormatter::format_error(&err.to_string())).await?;
tx.send(ProtocolFormatter::format_error(&err.to_string()))
.await?;
// Then send an empty tool response to maintain the protocol
let result = vec![Content::text("Error occurred").with_priority(0.0)];
tx.send(ProtocolFormatter::format_tool_response(
Expand Down Expand Up @@ -238,7 +242,8 @@ async fn stream_message(
Err(err) => {
println!("Error: {}", err);
// Send error message for invalid tool call
tx.send(ProtocolFormatter::format_error(&err.to_string())).await?;
tx.send(ProtocolFormatter::format_error(&err.to_string()))
.await?;
// Send a placeholder tool call to maintain protocol
tx.send(ProtocolFormatter::format_tool_call(
&request.id,
Expand Down Expand Up @@ -302,10 +307,14 @@ async fn handler(
tracing::error!("Failed to start reply stream: {}", e);
// Check if it's a moderation error
if let Some(moderation_error) = e.downcast_ref::<ModerationError>() {
let _ = tx.send(ProtocolFormatter::format_moderation_error(moderation_error)).await;
let _ = tx
.send(ProtocolFormatter::format_moderation_error(moderation_error))
.await;
} else {
// Send a generic error message
let _ = tx.send(ProtocolFormatter::format_error(&e.to_string())).await;
let _ = tx
.send(ProtocolFormatter::format_error(&e.to_string()))
.await;
}
// Send a finish message with error as the reason
let _ = tx.send(ProtocolFormatter::format_finish("error")).await;
Expand Down Expand Up @@ -616,4 +625,4 @@ mod tests {
assert_eq!(response.status(), StatusCode::OK);
}
}
}
}
6 changes: 3 additions & 3 deletions crates/goose/src/providers/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
use tokio::select;
use tokio::sync::RwLock;
use thiserror::Error;

use super::configs::ModelConfig;
use crate::message::{Message, MessageContent};
Expand All @@ -19,7 +19,7 @@ pub enum ModerationError {
ContentFlagged {
categories: String,
category_scores: Option<serde_json::Value>,
}
},
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -653,4 +653,4 @@ mod tests {
*count
);
}
}
}
2 changes: 1 addition & 1 deletion crates/goose/src/providers/databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ impl Moderation for DatabricksProvider {
);

let auth_header = self.ensure_auth_header().await?;
let payload = json!({
let payload = json!({
"messages": [
{
"role": "user",
Expand Down

0 comments on commit f299217

Please sign in to comment.