Skip to content

Commit

Permalink
Update chat_completions to stream ChatCompletionChunk messages (#268)
Browse files Browse the repository at this point in the history
Signed-off-by: declark1 <[email protected]>
  • Loading branch information
declark1 authored Jan 9, 2025
1 parent b333ba3 commit 412bb2e
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 47 deletions.
104 changes: 61 additions & 43 deletions src/clients/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
*/

use std::{collections::HashMap, convert::Infallible};
use std::collections::HashMap;

use async_trait::async_trait;
use axum::response::sse;
use eventsource_stream::Eventsource;
use futures::StreamExt;
use http_body_util::BodyExt;
Expand Down Expand Up @@ -76,7 +75,6 @@ impl OpenAiClient {
let (tx, rx) = mpsc::channel(32);
let mut event_stream = self
.inner()
.clone()
.post(url, headers, request)
.await?
.0
Expand All @@ -86,15 +84,31 @@ impl OpenAiClient {
tokio::spawn(async move {
while let Some(result) = event_stream.next().await {
match result {
Ok(event) => {
let event = sse::Event::default().data(event.data);
let _ = tx.send(Ok(event)).await;
Ok(event) if event.data == "[DONE]" => {
// Send None to signal that the stream completed
let _ = tx.send(Ok(None)).await;
break;
}
Ok(event) => match serde_json::from_str::<ChatCompletionChunk>(&event.data)
{
Ok(chunk) => {
let _ = tx.send(Ok(Some(chunk))).await;
}
Err(e) => {
let error = Error::Http {
code: StatusCode::INTERNAL_SERVER_ERROR,
message: format!("deserialization error: {e}"),
};
let _ = tx.send(Err(error)).await;
}
},
Err(error) => {
// We received an error from the event stream, send an error event
let event =
sse::Event::default().event("error").data(error.to_string());
let _ = tx.send(Ok(event)).await;
// We received an error from the event stream, send error message
let error = Error::Http {
code: StatusCode::INTERNAL_SERVER_ERROR,
message: error.to_string(),
};
let _ = tx.send(Err(error)).await;
}
}
}
Expand Down Expand Up @@ -144,7 +158,7 @@ impl HttpClientExt for OpenAiClient {
#[derive(Debug)]
pub enum ChatCompletionsResponse {
Unary(ChatCompletion),
Streaming(mpsc::Receiver<Result<sse::Event, Infallible>>),
Streaming(mpsc::Receiver<Result<Option<ChatCompletionChunk>, Error>>),
}

impl From<ChatCompletion> for ChatCompletionsResponse {
Expand Down Expand Up @@ -475,56 +489,59 @@ pub struct Function {
pub struct ChatCompletion {
/// A unique identifier for the chat completion.
pub id: String,
/// A list of chat completion choices. Can be more than one if n is greater than 1.
pub choices: Vec<ChatCompletionChoice>,
/// The object type, which is always `chat.completion`.
pub object: String,
/// The Unix timestamp (in seconds) of when the chat completion was created.
pub created: i64,
/// The model used for the chat completion.
pub model: String,
/// A list of chat completion choices. Can be more than one if n is greater than 1.
pub choices: Vec<ChatCompletionChoice>,
/// Usage statistics for the completion request.
pub usage: Usage,
/// This fingerprint represents the backend configuration that the model runs with.
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
/// The service tier used for processing the request.
/// This field is only included if the `service_tier` parameter is specified in the request.
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
/// This fingerprint represents the backend configuration that the model runs with.
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
/// The object type, which is always `chat.completion`.
pub object: String,
/// Usage statistics for the completion request.
pub usage: Usage,
}

/// A chat completion choice.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChoice {
/// The reason the model stopped generating tokens.
pub finish_reason: String,
/// The index of the choice in the list of choices.
pub index: usize,
/// A chat completion message generated by the model.
pub message: ChatCompletionMessage,
/// Log probability information for the choice.
pub logprobs: Option<ChatCompletionLogprobs>,
/// The reason the model stopped generating tokens.
pub finish_reason: String,
}

/// A chat completion message generated by the model.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionMessage {
/// The role of the author of this message.
pub role: String,
/// The contents of the message.
pub content: Option<String>,
/// The tool calls generated by the model, such as function calls.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ToolCall>,
/// The refusal message generated by the model.
#[serde(skip_serializing_if = "Option::is_none")]
pub refusal: Option<String>,
pub tool_calls: Vec<ToolCall>,
/// The role of the author of this message.
pub role: String,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionLogprobs {
/// A list of message content tokens with log probability information.
pub content: Option<Vec<ChatCompletionLogprob>>,
/// A list of message refusal tokens with log probability information.
#[serde(skip_serializing_if = "Option::is_none")]
pub refusal: Option<Vec<ChatCompletionLogprob>>,
}

Expand Down Expand Up @@ -555,68 +572,69 @@ pub struct ChatCompletionTopLogprob {
pub struct ChatCompletionChunk {
/// A unique identifier for the chat completion. Each chunk has the same ID.
pub id: String,
/// A list of chat completion choices.
pub choices: Vec<ChatCompletionChunkChoice>,
/// The object type, which is always `chat.completion.chunk`.
pub object: String,
/// The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp.
pub created: i64,
/// The model to generate the completion.
pub model: String,
/// This fingerprint represents the backend configuration that the model runs with.
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
/// A list of chat completion choices.
pub choices: Vec<ChatCompletionChunkChoice>,
/// The service tier used for processing the request.
/// This field is only included if the service_tier parameter is specified in the request.
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
/// This fingerprint represents the backend configuration that the model runs with.
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
/// The object type, which is always `chat.completion.chunk`.
pub object: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionChunkChoice {
/// The index of the choice in the list of choices.
pub index: u32,
/// A chat completion delta generated by streamed model responses.
pub delta: ChatCompletionDelta,
/// Log probability information for the choice.
pub logprobs: Option<ChatCompletionLogprobs>,
/// The reason the model stopped generating tokens.
pub finish_reason: Option<String>,
/// The index of the choice in the list of choices.
pub index: u32,
}

/// A chat completion delta generated by streamed model responses.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionDelta {
/// The role of the author of this message.
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
/// The contents of the message.
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
/// The refusal message generated by the model.
#[serde(skip_serializing_if = "Option::is_none")]
pub refusal: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
/// The tool calls generated by the model, such as function calls.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<ToolCall>,
/// The role of the author of this message.
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
}

/// Usage statistics for a completion.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
/// Number of tokens in the generated completion.
pub completion_tokens: u32,
/// Number of tokens in the prompt.
pub prompt_tokens: u32,
/// Total number of tokens used in the request (prompt + completion).
pub total_tokens: u32,
/// Breakdown of tokens used in a completion.
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_token_details: Option<CompletionTokenDetails>,
/// Number of tokens in the generated completion.
pub completion_tokens: u32,
/// Breakdown of tokens used in the prompt.
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_token_details: Option<PromptTokenDetails>,
/// Breakdown of tokens used in a completion.
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_token_details: Option<CompletionTokenDetails>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down
29 changes: 25 additions & 4 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ use axum::{
Json, Router,
};
use axum_extra::extract::WithRejection;
use futures::{stream, Stream, StreamExt};
use futures::{
stream::{self, BoxStream},
Stream, StreamExt,
};
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
use opentelemetry::trace::TraceContextExt;
Expand Down Expand Up @@ -505,6 +508,7 @@ async fn chat_completions_detection(
headers: HeaderMap,
WithRejection(Json(request), _): WithRejection<Json<ChatCompletionsRequest>, Error>,
) -> Result<impl IntoResponse, Error> {
use ChatCompletionsResponse::*;
let trace_id = Span::current().context().span().span_context().trace_id();
info!(?trace_id, "handling request");
let headers = filter_headers(&state.orchestrator.config().passthrough_headers, headers);
Expand All @@ -515,10 +519,27 @@ async fn chat_completions_detection(
.await
{
Ok(response) => match response {
ChatCompletionsResponse::Unary(response) => Ok(Json(response).into_response()),
ChatCompletionsResponse::Streaming(response_rx) => {
Unary(response) => Ok(Json(response).into_response()),
Streaming(response_rx) => {
let response_stream = ReceiverStream::new(response_rx);
let sse = Sse::new(response_stream).keep_alive(KeepAlive::default());
// Convert response stream to a stream of SSE events
let event_stream: BoxStream<Result<Event, Infallible>> = response_stream
.map(|message| match message {
Ok(Some(chunk)) => Ok(Event::default().json_data(chunk).unwrap()),
Ok(None) => {
// The stream completed, send [DONE] message
Ok(Event::default().data("[DONE]"))
}
Err(error) => {
let error: Error = orchestrator::Error::from(error).into();
Ok(Event::default()
.event("error")
.json_data(error.to_json())
.unwrap())
}
})
.boxed();
let sse = Sse::new(event_stream).keep_alive(KeepAlive::default());
Ok(sse.into_response())
}
},
Expand Down

0 comments on commit 412bb2e

Please sign in to comment.