diff --git a/src/clients/openai.rs b/src/clients/openai.rs index 7abdac21..a626d111 100644 --- a/src/clients/openai.rs +++ b/src/clients/openai.rs @@ -58,9 +58,9 @@ impl OpenAiClient { #[instrument(skip_all, fields(request.model))] pub async fn chat_completions( &self, - request: ChatCompletionRequest, + request: ChatCompletionsRequest, headers: HeaderMap, - ) -> Result { + ) -> Result { let url = self.client.base_url().join("/v1/chat/completions").unwrap(); let headers = with_traceparent_header(headers); let stream = request.stream.unwrap_or_default(); @@ -94,7 +94,7 @@ impl OpenAiClient { } } }); - Ok(ChatCompletionResponse::Streaming(rx)) + Ok(ChatCompletionsResponse::Streaming(rx)) } else { let response = self .client @@ -136,19 +136,19 @@ impl Client for OpenAiClient { } #[derive(Debug)] -pub enum ChatCompletionResponse { +pub enum ChatCompletionsResponse { Unary(ChatCompletion), Streaming(mpsc::Receiver>), } -impl From for ChatCompletionResponse { +impl From for ChatCompletionsResponse { fn from(value: ChatCompletion) -> Self { Self::Unary(value) } } #[derive(Debug, Default, Clone, Serialize, Deserialize)] -pub struct ChatCompletionRequest { +pub struct ChatCompletionsRequest { /// A list of messages comprising the conversation so far. pub messages: Vec, /// ID of the model to use. diff --git a/src/orchestrator.rs b/src/orchestrator.rs index 4fbf6ecf..a7fa465d 100644 --- a/src/orchestrator.rs +++ b/src/orchestrator.rs @@ -17,6 +17,7 @@ pub mod errors; pub use errors::Error; +pub mod chat_completions_detection; pub mod streaming; pub mod unary; @@ -35,7 +36,7 @@ use crate::{ text_context_doc::ContextType, TextChatDetectorClient, TextContextDocDetectorClient, TextGenerationDetectorClient, }, - openai::OpenAiClient, + openai::{ChatCompletionsRequest, OpenAiClient}, ClientMap, GenerationClient, NlpClient, TextContentsDetectorClient, TgisClient, }, config::{DetectorType, GenerationProvider, OrchestratorConfig}, @@ -469,6 +470,26 @@ impl StreamingClassificationWithGenTask { } } +#[derive(Debug)] +pub struct ChatCompletionsDetectionTask { + /// Unique identifier of request trace + pub trace_id: TraceId, + /// Chat completion request + pub request: ChatCompletionsRequest, + // Headermap + pub headers: HeaderMap, +} + +impl ChatCompletionsDetectionTask { + pub fn new(trace_id: TraceId, request: ChatCompletionsRequest, headers: HeaderMap) -> Self { + Self { + trace_id, + request, + headers, + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/orchestrator/chat_completions_detection.rs b/src/orchestrator/chat_completions_detection.rs new file mode 100644 index 00000000..dd06fab0 --- /dev/null +++ b/src/orchestrator/chat_completions_detection.rs @@ -0,0 +1,20 @@ +use tracing::{info, instrument}; + +use super::{ChatCompletionsDetectionTask, Error, Orchestrator}; +use crate::clients::openai::{ChatCompletionsResponse, OpenAiClient}; + +impl Orchestrator { + #[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))] + pub async fn handle_chat_completions_detection( + &self, + task: ChatCompletionsDetectionTask, + ) -> Result { + info!("handling chat completions detection task"); + let client = self + .ctx + .clients + .get_as::("chat_generation") + .expect("chat_generation client not found"); + Ok(client.chat_completions(task.request, task.headers).await?) + } +} diff --git a/src/server.rs b/src/server.rs index 7907dc86..e7bd12d6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -44,6 +44,7 @@ use opentelemetry::trace::TraceContextExt; use rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig}; use tokio::{net::TcpListener, signal}; use tokio_rustls::TlsAcceptor; +use tokio_stream::wrappers::ReceiverStream; use tower_http::trace::TraceLayer; use tower_service::Service; use tracing::{debug, error, info, instrument, warn, Span}; @@ -51,11 +52,12 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; use webpki::types::{CertificateDer, PrivateKeyDer}; use crate::{ + clients::openai::{ChatCompletionsRequest, ChatCompletionsResponse}, models::{self, InfoParams, InfoResponse}, orchestrator::{ - self, ChatDetectionTask, ClassificationWithGenTask, ContextDocsDetectionTask, - DetectionOnGenerationTask, GenerationWithDetectionTask, Orchestrator, - StreamingClassificationWithGenTask, TextContentDetectionTask, + self, ChatCompletionsDetectionTask, ChatDetectionTask, ClassificationWithGenTask, + ContextDocsDetectionTask, DetectionOnGenerationTask, GenerationWithDetectionTask, + Orchestrator, StreamingClassificationWithGenTask, TextContentDetectionTask, }, tracing_utils, }; @@ -160,7 +162,7 @@ pub async fn run( } // (2b) Add main guardrails server routes - let app = Router::new() + let mut router = Router::new() .route( &format!("{}/classification-with-text-generation", API_PREFIX), post(classification_with_gen), @@ -191,16 +193,25 @@ pub async fn run( .route( &format!("{}/detection/generated", TEXT_API_PREFIX), post(detect_generated), - ) - .with_state(shared_state) - .layer( - TraceLayer::new_for_http() - .make_span_with(tracing_utils::incoming_request_span) - .on_request(tracing_utils::on_incoming_request) - .on_response(tracing_utils::on_outgoing_response) - .on_eos(tracing_utils::on_outgoing_eos), ); + // If chat generation is configured, enable the chat completions detection endpoint. + if shared_state.orchestrator.config().chat_generation.is_some() { + info!("Enabling chat completions detection endpoint"); + router = router.route( + "/api/v2/chat/completions-detection", + post(chat_completions_detection), + ); + } + + let app = router.with_state(shared_state).layer( + TraceLayer::new_for_http() + .make_span_with(tracing_utils::incoming_request_span) + .on_request(tracing_utils::on_incoming_request) + .on_response(tracing_utils::on_outgoing_response) + .on_eos(tracing_utils::on_outgoing_eos), + ); + // (2c) Generate main guardrails server handle based on whether TLS is needed let listener: TcpListener = TcpListener::bind(&http_addr) .await @@ -488,6 +499,33 @@ async fn detect_generated( } } +#[instrument(skip_all)] +async fn chat_completions_detection( + State(state): State>, + headers: HeaderMap, + WithRejection(Json(request), _): WithRejection, Error>, +) -> Result { + let trace_id = Span::current().context().span().span_context().trace_id(); + info!(?trace_id, "handling request"); + let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers); + let task = ChatCompletionsDetectionTask::new(trace_id, request, headers); + match state + .orchestrator + .handle_chat_completions_detection(task) + .await + { + Ok(response) => match response { + ChatCompletionsResponse::Unary(response) => Ok(Json(response).into_response()), + ChatCompletionsResponse::Streaming(response_rx) => { + let response_stream = ReceiverStream::new(response_rx); + let sse = Sse::new(response_stream).keep_alive(KeepAlive::default()); + Ok(sse.into_response()) + } + }, + Err(error) => Err(error.into()), + } +} + /// Shutdown signal handler async fn shutdown_signal() { let ctrl_c = async {