Skip to content

Commit

Permalink
Implement Chat Completions API (foundation-model-stack#240)
Browse files Browse the repository at this point in the history
* Implement Chat Completions API

Signed-off-by: declark1 <[email protected]>

* Conditionally enable chat completions endpoint

Signed-off-by: declark1 <[email protected]>

* Update chat completions to chat completions detection, rename items for alignment

Signed-off-by: declark1 <[email protected]>

---------

Signed-off-by: declark1 <[email protected]>
  • Loading branch information
declark1 authored Nov 5, 2024
1 parent f2010e1 commit c15b2a2
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 19 deletions.
12 changes: 6 additions & 6 deletions src/clients/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ impl OpenAiClient {
#[instrument(skip_all, fields(request.model))]
pub async fn chat_completions(
&self,
request: ChatCompletionRequest,
request: ChatCompletionsRequest,
headers: HeaderMap,
) -> Result<ChatCompletionResponse, Error> {
) -> Result<ChatCompletionsResponse, Error> {
let url = self.client.base_url().join("/v1/chat/completions").unwrap();
let headers = with_traceparent_header(headers);
let stream = request.stream.unwrap_or_default();
Expand Down Expand Up @@ -94,7 +94,7 @@ impl OpenAiClient {
}
}
});
Ok(ChatCompletionResponse::Streaming(rx))
Ok(ChatCompletionsResponse::Streaming(rx))
} else {
let response = self
.client
Expand Down Expand Up @@ -136,19 +136,19 @@ impl Client for OpenAiClient {
}

#[derive(Debug)]
pub enum ChatCompletionResponse {
pub enum ChatCompletionsResponse {
Unary(ChatCompletion),
Streaming(mpsc::Receiver<Result<sse::Event, Infallible>>),
}

impl From<ChatCompletion> for ChatCompletionResponse {
impl From<ChatCompletion> for ChatCompletionsResponse {
fn from(value: ChatCompletion) -> Self {
Self::Unary(value)
}
}

#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub struct ChatCompletionsRequest {
/// A list of messages comprising the conversation so far.
pub messages: Vec<Message>,
/// ID of the model to use.
Expand Down
23 changes: 22 additions & 1 deletion src/orchestrator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

pub mod errors;
pub use errors::Error;
pub mod chat_completions_detection;
pub mod streaming;
pub mod unary;

Expand All @@ -35,7 +36,7 @@ use crate::{
text_context_doc::ContextType, TextChatDetectorClient, TextContextDocDetectorClient,
TextGenerationDetectorClient,
},
openai::OpenAiClient,
openai::{ChatCompletionsRequest, OpenAiClient},
ClientMap, GenerationClient, NlpClient, TextContentsDetectorClient, TgisClient,
},
config::{DetectorType, GenerationProvider, OrchestratorConfig},
Expand Down Expand Up @@ -469,6 +470,26 @@ impl StreamingClassificationWithGenTask {
}
}

#[derive(Debug)]
pub struct ChatCompletionsDetectionTask {
/// Unique identifier of request trace
pub trace_id: TraceId,
/// Chat completion request
pub request: ChatCompletionsRequest,
// Headermap
pub headers: HeaderMap,
}

impl ChatCompletionsDetectionTask {
pub fn new(trace_id: TraceId, request: ChatCompletionsRequest, headers: HeaderMap) -> Self {
Self {
trace_id,
request,
headers,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
20 changes: 20 additions & 0 deletions src/orchestrator/chat_completions_detection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use tracing::{info, instrument};

use super::{ChatCompletionsDetectionTask, Error, Orchestrator};
use crate::clients::openai::{ChatCompletionsResponse, OpenAiClient};

impl Orchestrator {
#[instrument(skip_all, fields(trace_id = ?task.trace_id, headers = ?task.headers))]
pub async fn handle_chat_completions_detection(
&self,
task: ChatCompletionsDetectionTask,
) -> Result<ChatCompletionsResponse, Error> {
info!("handling chat completions detection task");
let client = self
.ctx
.clients
.get_as::<OpenAiClient>("chat_generation")
.expect("chat_generation client not found");
Ok(client.chat_completions(task.request, task.headers).await?)
}
}
62 changes: 50 additions & 12 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,20 @@ use opentelemetry::trace::TraceContextExt;
use rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig};
use tokio::{net::TcpListener, signal};
use tokio_rustls::TlsAcceptor;
use tokio_stream::wrappers::ReceiverStream;
use tower_http::trace::TraceLayer;
use tower_service::Service;
use tracing::{debug, error, info, instrument, warn, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use webpki::types::{CertificateDer, PrivateKeyDer};

use crate::{
clients::openai::{ChatCompletionsRequest, ChatCompletionsResponse},
models::{self, InfoParams, InfoResponse},
orchestrator::{
self, ChatDetectionTask, ClassificationWithGenTask, ContextDocsDetectionTask,
DetectionOnGenerationTask, GenerationWithDetectionTask, Orchestrator,
StreamingClassificationWithGenTask, TextContentDetectionTask,
self, ChatCompletionsDetectionTask, ChatDetectionTask, ClassificationWithGenTask,
ContextDocsDetectionTask, DetectionOnGenerationTask, GenerationWithDetectionTask,
Orchestrator, StreamingClassificationWithGenTask, TextContentDetectionTask,
},
tracing_utils,
};
Expand Down Expand Up @@ -160,7 +162,7 @@ pub async fn run(
}

// (2b) Add main guardrails server routes
let app = Router::new()
let mut router = Router::new()
.route(
&format!("{}/classification-with-text-generation", API_PREFIX),
post(classification_with_gen),
Expand Down Expand Up @@ -191,16 +193,25 @@ pub async fn run(
.route(
&format!("{}/detection/generated", TEXT_API_PREFIX),
post(detect_generated),
)
.with_state(shared_state)
.layer(
TraceLayer::new_for_http()
.make_span_with(tracing_utils::incoming_request_span)
.on_request(tracing_utils::on_incoming_request)
.on_response(tracing_utils::on_outgoing_response)
.on_eos(tracing_utils::on_outgoing_eos),
);

// If chat generation is configured, enable the chat completions detection endpoint.
if shared_state.orchestrator.config().chat_generation.is_some() {
info!("Enabling chat completions detection endpoint");
router = router.route(
"/api/v2/chat/completions-detection",
post(chat_completions_detection),
);
}

let app = router.with_state(shared_state).layer(
TraceLayer::new_for_http()
.make_span_with(tracing_utils::incoming_request_span)
.on_request(tracing_utils::on_incoming_request)
.on_response(tracing_utils::on_outgoing_response)
.on_eos(tracing_utils::on_outgoing_eos),
);

// (2c) Generate main guardrails server handle based on whether TLS is needed
let listener: TcpListener = TcpListener::bind(&http_addr)
.await
Expand Down Expand Up @@ -488,6 +499,33 @@ async fn detect_generated(
}
}

#[instrument(skip_all)]
async fn chat_completions_detection(
State(state): State<Arc<ServerState>>,
headers: HeaderMap,
WithRejection(Json(request), _): WithRejection<Json<ChatCompletionsRequest>, Error>,
) -> Result<impl IntoResponse, Error> {
let trace_id = Span::current().context().span().span_context().trace_id();
info!(?trace_id, "handling request");
let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers);
let task = ChatCompletionsDetectionTask::new(trace_id, request, headers);
match state
.orchestrator
.handle_chat_completions_detection(task)
.await
{
Ok(response) => match response {
ChatCompletionsResponse::Unary(response) => Ok(Json(response).into_response()),
ChatCompletionsResponse::Streaming(response_rx) => {
let response_stream = ReceiverStream::new(response_rx);
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
Ok(sse.into_response())
}
},
Err(error) => Err(error.into()),
}
}

/// Shutdown signal handler
async fn shutdown_signal() {
let ctrl_c = async {
Expand Down

0 comments on commit c15b2a2

Please sign in to comment.