Skip to content

Commit

Permalink
Enable Streaming for OpenAI-compatible models
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Sep 18, 2024
1 parent 51a46bc commit 2a01bb5
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 50 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

### Fixed

## [0.56.0]

### Updated
- Enabled Streaming for OpenAI-compatible APIs (eg, DeepSeek Coder)
- If streaming to stdout, also print a newline at the end of streaming (to separate multiple outputs).

### Fixed
- Relaxed the type-assertions in `StreamCallback` to allow for more flexibility.

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.56.0-DEV"
version = "0.56.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
105 changes: 58 additions & 47 deletions src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,13 @@ end

"""
OpenAI.create_chat(schema::CustomOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String="http://localhost:8080",
kwargs...)
api_key::AbstractString,
model::AbstractString,
conversation;
http_kwargs::NamedTuple = NamedTuple(),
streamcallback::Any = nothing,
url::String = "http://localhost:8080",
kwargs...)
Dispatch to the OpenAI.create_chat function, for any OpenAI-compatible API.
Expand All @@ -124,12 +126,27 @@ function OpenAI.create_chat(schema::CustomOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
http_kwargs::NamedTuple = NamedTuple(),
streamcallback::Any = nothing,
url::String = "http://localhost:8080",
kwargs...)
# Build the corresponding provider object
# Create chat will automatically pass our data to endpoint `/chat/completions`
provider = CustomProvider(; api_key, base_url = url)
OpenAI.create_chat(provider, model, conversation; kwargs...)
if !isnothing(streamcallback)
## Take over from OpenAI.jl
url = OpenAI.build_url(provider, "chat/completions")
headers = OpenAI.auth_header(provider, api_key)
streamcallback, new_kwargs = configure_callback!(
streamcallback, schema; kwargs...)
input = OpenAI.build_params((; messages = conversation, model, new_kwargs...))
## Use the streaming callback
resp = streamed_request!(streamcallback, url, headers, input; http_kwargs...)
OpenAI.OpenAIResponse(resp.status, JSON3.read(resp.body))
else
## Use OpenAI.jl default
OpenAI.create_chat(provider, model, conversation; http_kwargs, kwargs...)
end
end

"""
Expand Down Expand Up @@ -170,95 +187,89 @@ function OpenAI.create_chat(schema::MistralOpenAISchema,
conversation;
url::String = "https://api.mistral.ai/v1",
kwargs...)
# Build the corresponding provider object
# try to override provided api_key because the default is OpenAI key
provider = CustomProvider(;
api_key = isempty(MISTRALAI_API_KEY) ? api_key : MISTRALAI_API_KEY,
base_url = url)
OpenAI.create_chat(provider, model, conversation; kwargs...)
api_key = isempty(MISTRALAI_API_KEY) ? api_key : MISTRALAI_API_KEY
OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...)
end
function OpenAI.create_chat(schema::FireworksOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String = "https://api.fireworks.ai/inference/v1",
kwargs...)
# Build the corresponding provider object
# try to override provided api_key because the default is OpenAI key
provider = CustomProvider(;
api_key = isempty(FIREWORKS_API_KEY) ? api_key : FIREWORKS_API_KEY,
base_url = url)
OpenAI.create_chat(provider, model, conversation; kwargs...)
api_key = isempty(FIREWORKS_API_KEY) ? api_key : FIREWORKS_API_KEY
OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...)
end
function OpenAI.create_chat(schema::TogetherOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String = "https://api.together.xyz/v1",
kwargs...)
# Build the corresponding provider object
# try to override provided api_key because the default is OpenAI key
provider = CustomProvider(;
api_key = isempty(TOGETHER_API_KEY) ? api_key : TOGETHER_API_KEY,
base_url = url)
OpenAI.create_chat(provider, model, conversation; kwargs...)
api_key = isempty(TOGETHER_API_KEY) ? api_key : TOGETHER_API_KEY
OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...)
end
function OpenAI.create_chat(schema::GroqOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String = "https://api.groq.com/openai/v1",
kwargs...)
# Build the corresponding provider object
# try to override provided api_key because the default is OpenAI key
provider = CustomProvider(;
api_key = isempty(GROQ_API_KEY) ? api_key : GROQ_API_KEY,
base_url = url)
OpenAI.create_chat(provider, model, conversation; kwargs...)
api_key = isempty(GROQ_API_KEY) ? api_key : GROQ_API_KEY
OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...)
end
function OpenAI.create_chat(schema::DeepSeekOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String = "https://api.deepseek.com/v1",
kwargs...)
# Build the corresponding provider object
# try to override provided api_key because the default is OpenAI key
provider = CustomProvider(;
api_key = isempty(DEEPSEEK_API_KEY) ? api_key : DEEPSEEK_API_KEY,
base_url = url)
OpenAI.create_chat(provider, model, conversation; kwargs...)
api_key = isempty(DEEPSEEK_API_KEY) ? api_key : DEEPSEEK_API_KEY
OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...)
end
function OpenAI.create_chat(schema::OpenRouterOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String = "https://openrouter.ai/api/v1",
kwargs...)
# Build the corresponding provider object
# try to override provided api_key because the default is OpenAI key
provider = CustomProvider(;
api_key = isempty(OPENROUTER_API_KEY) ? api_key : OPENROUTER_API_KEY,
base_url = url)
OpenAI.create_chat(provider, model, conversation; kwargs...)
api_key = isempty(OPENROUTER_API_KEY) ? api_key : OPENROUTER_API_KEY
OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...)
end
function OpenAI.create_chat(schema::DatabricksOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
http_kwargs::NamedTuple = NamedTuple(),
streamcallback::Any = nothing,
url::String = "https://<workspace_host>.databricks.com",
kwargs...)
# Build the corresponding provider object
provider = CustomProvider(;
api_key = isempty(DATABRICKS_API_KEY) ? api_key : DATABRICKS_API_KEY,
base_url = isempty(DATABRICKS_HOST) ? url : DATABRICKS_HOST)
# Override standard OpenAI request endpoint
OpenAI.openai_request("serving-endpoints/$model/invocations",
provider;
method = "POST",
model,
messages = conversation,
kwargs...)
if !isnothing(streamcallback)
throw(ArgumentError("Streaming is not supported for Databricks models yet!"))
## Take over from OpenAI.jl
# url = OpenAI.build_url(provider, "serving-endpoints/$model/invocations")
# headers = OpenAI.auth_header(provider, api_key)
# streamcallback, new_kwargs = configure_callback!(
# streamcallback, schema; kwargs...)
# input = OpenAI.build_params((; messages = conversation, model, new_kwargs...))
# ## Use the streaming callback
# resp = streamed_request!(streamcallback, url, headers, input; http_kwargs...)
# OpenAI.OpenAIResponse(resp.status, JSON3.read(resp.body))
else
# Override standard OpenAI request endpoint
OpenAI.openai_request("serving-endpoints/$model/invocations",
provider;
method = "POST",
model,
messages = conversation,
http_kwargs,
kwargs...)
end
end

# Extend OpenAI create_embeddings to allow for testing
Expand Down
6 changes: 4 additions & 2 deletions src/streaming.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ function configure_callback!(cb::T, schema::AbstractPromptSchema;
api_kwargs...) where {T <: AbstractStreamCallback}
## Check if we are in passthrough mode or if we should configure the callback
if isnothing(cb.flavor)
if schema isa OpenAISchema
## Enable streaming for all OpenAI-compatible APIs
if schema isa AbstractOpenAISchema
api_kwargs = (;
api_kwargs..., stream = true, stream_options = (; include_usage = true))
flavor = OpenAIStream()
Expand Down Expand Up @@ -287,7 +288,6 @@ Print the content to the IO output stream `out`.
"""
@inline function print_content(out::IO, text::AbstractString; kwargs...)
print(out, text)
# flush(stdout)
end
"""
print_content(out::Channel, text::AbstractString; kwargs...)
Expand Down Expand Up @@ -471,6 +471,8 @@ function streamed_request!(cb::AbstractStreamCallback, url, headers, input; kwar
end
HTTP.closeread(stream)
end
## For estetic reasons, if printing to stdout, we send a newline and flush
cb.out == stdout && (println(); flush(stdout))

body = build_response_body(cb.flavor, cb; verbose, cb.kwargs...)
resp.body = JSON3.write(body)
Expand Down

0 comments on commit 2a01bb5

Please sign in to comment.