Skip to content

Commit

Permalink
refactor: use async-openai-alt and upgrade to 0.26.1
Browse files Browse the repository at this point in the history
Signed-off-by: Wei Zhang <[email protected]>
  • Loading branch information
zwpaper committed Jan 7, 2025
1 parent 88991ee commit 9d7c3c0
Show file tree
Hide file tree
Showing 16 changed files with 73 additions and 63 deletions.
36 changes: 18 additions & 18 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ mime_guess = "2.0.4"
assert_matches = "1.5"
insta = "1.34.0"
logkit = "0.3"
async-openai = "0.20"
async-openai-alt = "0.26.1"
tracing-test = "0.2"
clap = "4.3.0"
ratelimit = "0.10"
Expand Down
2 changes: 1 addition & 1 deletion crates/http-api-bindings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ serde_json = { workspace = true }
tabby-common = { path = "../tabby-common" }
tabby-inference = { path = "../tabby-inference" }
ollama-api-bindings = { path = "../ollama-api-bindings" }
async-openai.workspace = true
async-openai-alt.workspace = true
tokio.workspace = true
tracing.workspace = true
leaky-bucket = "1.1.2"
Expand Down
4 changes: 2 additions & 2 deletions crates/http-api-bindings/src/chat/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::sync::Arc;

use async_openai::config::OpenAIConfig;
use async_openai_alt::config::OpenAIConfig;
use tabby_common::config::HttpModelConfig;
use tabby_inference::{ChatCompletionStream, ExtendedOpenAIConfig};

Expand Down Expand Up @@ -34,7 +34,7 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
let config = builder.build().expect("Failed to build config");

let engine = Box::new(
async_openai::Client::with_config(config)
async_openai_alt::Client::with_config(config)
.with_http_client(create_reqwest_client(api_endpoint)),
);

Expand Down
2 changes: 1 addition & 1 deletion crates/http-api-bindings/src/rate_limit.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use async_openai::{
use async_openai_alt::{
error::OpenAIError,
types::{
ChatCompletionResponseStream, CreateChatCompletionRequest, CreateChatCompletionResponse,
Expand Down
2 changes: 1 addition & 1 deletion crates/llama-cpp-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ anyhow.workspace = true
which = "6"
serde.workspace = true
serdeconv.workspace = true
async-openai.workspace = true
async-openai-alt.workspace = true

[build-dependencies]
cmake = "0.1"
Expand Down
10 changes: 5 additions & 5 deletions crates/llama-cpp-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod supervisor;
use std::{path::PathBuf, sync::Arc};

use anyhow::Result;
use async_openai::error::OpenAIError;
use async_openai_alt::error::OpenAIError;
use async_trait::async_trait;
use futures::stream::BoxStream;
use serde::Deserialize;
Expand Down Expand Up @@ -161,15 +161,15 @@ impl ChatCompletionServer {
impl ChatCompletionStream for ChatCompletionServer {
async fn chat(
&self,
request: async_openai::types::CreateChatCompletionRequest,
) -> Result<async_openai::types::CreateChatCompletionResponse, OpenAIError> {
request: async_openai_alt::types::CreateChatCompletionRequest,
) -> Result<async_openai_alt::types::CreateChatCompletionResponse, OpenAIError> {
self.chat_completion.chat(request).await
}

async fn chat_stream(
&self,
request: async_openai::types::CreateChatCompletionRequest,
) -> Result<async_openai::types::ChatCompletionResponseStream, OpenAIError> {
request: async_openai_alt::types::CreateChatCompletionRequest,
) -> Result<async_openai_alt::types::ChatCompletionResponseStream, OpenAIError> {
self.chat_completion.chat_stream(request).await
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ derive_builder.workspace = true
futures = { workspace = true }
tabby-common = { path = "../tabby-common" }
trie-rs = "0.1.1"
async-openai.workspace = true
async-openai-alt.workspace = true
secrecy = "0.8"
reqwest.workspace = true
tracing.workspace = true
6 changes: 3 additions & 3 deletions crates/tabby-inference/src/chat.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use async_openai::{
use async_openai_alt::{
config::OpenAIConfig,
error::OpenAIError,
types::{
Expand Down Expand Up @@ -85,7 +85,7 @@ impl ExtendedOpenAIConfig {
}
}

impl async_openai::config::Config for ExtendedOpenAIConfig {
impl async_openai_alt::config::Config for ExtendedOpenAIConfig {
fn headers(&self) -> reqwest::header::HeaderMap {
self.base.headers()
}
Expand All @@ -108,7 +108,7 @@ impl async_openai::config::Config for ExtendedOpenAIConfig {
}

#[async_trait]
impl ChatCompletionStream for async_openai::Client<ExtendedOpenAIConfig> {
impl ChatCompletionStream for async_openai_alt::Client<ExtendedOpenAIConfig> {
async fn chat(
&self,
request: CreateChatCompletionRequest,
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ axum-prometheus = "0.6"
uuid.workspace = true
color-eyre = { version = "0.6.3" }
reqwest.workspace = true
async-openai.workspace = true
async-openai-alt.workspace = true
spinners = "4.1.1"
regex.workspace = true

Expand Down
4 changes: 2 additions & 2 deletions crates/tabby/src/routes/chat.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::sync::Arc;

use async_openai::error::OpenAIError;
use async_openai_alt::error::OpenAIError;
use axum::{
extract::State,
response::sse::{Event, KeepAlive, Sse},
Expand Down Expand Up @@ -36,7 +36,7 @@ pub async fn chat_completions_utoipa(_request: Json<serde_json::Value>) -> Statu
pub async fn chat_completions(
State(state): State<Arc<dyn ChatCompletionStream>>,
TypedHeader(MaybeUser(user)): TypedHeader<MaybeUser>,
Json(mut request): Json<async_openai::types::CreateChatCompletionRequest>,
Json(mut request): Json<async_openai_alt::types::CreateChatCompletionRequest>,
) -> Result<Sse<impl Stream<Item = Result<Event, anyhow::Error>>>, StatusCode> {
if let Some(user) = user {
request.user.replace(user);
Expand Down
2 changes: 1 addition & 1 deletion ee/tabby-schema/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ schema-language = ["juniper/schema-language"]

[dependencies]
anyhow.workspace = true
async-openai.workspace = true
async-openai-alt.workspace = true
async-trait.workspace = true
axum = { workspace = true }
base64 = "0.22.0"
Expand Down
2 changes: 1 addition & 1 deletion ee/tabby-schema/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub mod worker;
use std::{sync::Arc, time::Instant};

use access_policy::{AccessPolicyService, SourceIdAccessPolicy};
use async_openai::{
use async_openai_alt::{
error::OpenAIError,
types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
Expand Down
2 changes: 1 addition & 1 deletion ee/tabby-webserver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ strum.workspace = true
cron = "0.12.1"
async-stream.workspace = true
logkit.workspace = true
async-openai.workspace = true
async-openai-alt.workspace = true
ratelimit.workspace = true
cached.workspace = true

Expand Down
54 changes: 32 additions & 22 deletions ee/tabby-webserver/src/service/answer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ use std::{
};

use anyhow::anyhow;
use async_openai::{
use async_openai_alt::{
error::OpenAIError,
types::{
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent,
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, Role,
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageArgs, ChatCompletionRequestUserMessageContent,
CreateChatCompletionRequestArgs, Role,
},
};
use async_stream::stream;
Expand Down Expand Up @@ -438,8 +441,9 @@ fn convert_messages_to_chat_completion_request(
if !config.system_prompt.is_empty() {
output.push(ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: config.system_prompt.clone(),
role: Role::System,
content: ChatCompletionRequestSystemMessageContent::Text(
config.system_prompt.clone(),
),
name: None,
},
));
Expand All @@ -452,36 +456,42 @@ fn convert_messages_to_chat_completion_request(
thread::Role::User => Role::User,
};

let content = if role == Role::User {
let message: ChatCompletionRequestMessage = if role == Role::User {
if i % 2 != 0 {
bail!("User message must be followed by assistant message");
}

let y = &messages[i + 1];

build_user_prompt(&x.content, &y.attachment, None)
let content = build_user_prompt(&x.content, &y.attachment, None);
ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(
helper.rewrite_tag(&content),
),
..Default::default()
})
} else {
x.content.clone()
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
content: Some(ChatCompletionRequestAssistantMessageContent::Text(
x.content.clone(),
)),
..Default::default()
})
};

output.push(ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: helper.rewrite_tag(&content),
role,
name: None,
},
));
output.push(message);
}

output.push(ChatCompletionRequestMessage::System(
ChatCompletionRequestSystemMessage {
content: helper.rewrite_tag(&build_user_prompt(
&messages[messages.len() - 1].content,
attachment,
user_attachment_input,
output.push(ChatCompletionRequestMessage::User(
ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(helper.rewrite_tag(
&build_user_prompt(
&messages[messages.len() - 1].content,
attachment,
user_attachment_input,
),
)),
role: Role::User,
name: None,
..Default::default()
},
));

Expand Down
4 changes: 2 additions & 2 deletions ee/tabby-webserver/src/service/answer/testutils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::sync::Arc;

use async_openai::{
use async_openai_alt::{
error::OpenAIError,
types::{
ChatChoice, ChatChoiceStream, ChatCompletionResponseMessage, ChatCompletionResponseStream,
Expand Down Expand Up @@ -44,7 +44,7 @@ impl ChatCompletionStream for FakeChatCompletionStream {
_request: CreateChatCompletionRequest,
) -> Result<CreateChatCompletionResponse, OpenAIError> {
if self.return_error {
return Err(OpenAIError::ApiError(async_openai::error::ApiError {
return Err(OpenAIError::ApiError(async_openai_alt::error::ApiError {
message: "error".to_string(),
code: None,
param: None,
Expand Down

0 comments on commit 9d7c3c0

Please sign in to comment.