diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fead9587..9142cb84b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +### Fixed + +## [0.56.0] + +### Updated +- Enabled Streaming for OpenAI-compatible APIs (eg, DeepSeek Coder) +- If streaming to stdout, also print a newline at the end of streaming (to separate multiple outputs). + ### Fixed - Relaxed the type-assertions in `StreamCallback` to allow for more flexibility. diff --git a/Project.toml b/Project.toml index 38c34f215..71f67c195 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.56.0-DEV" +version = "0.56.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/llm_openai.jl b/src/llm_openai.jl index 1b11fc063..629f64b92 100644 --- a/src/llm_openai.jl +++ b/src/llm_openai.jl @@ -108,11 +108,13 @@ end """ OpenAI.create_chat(schema::CustomOpenAISchema, - api_key::AbstractString, - model::AbstractString, - conversation; - url::String="http://localhost:8080", - kwargs...) + api_key::AbstractString, + model::AbstractString, + conversation; + http_kwargs::NamedTuple = NamedTuple(), + streamcallback::Any = nothing, + url::String = "http://localhost:8080", + kwargs...) Dispatch to the OpenAI.create_chat function, for any OpenAI-compatible API. @@ -124,12 +126,27 @@ function OpenAI.create_chat(schema::CustomOpenAISchema, api_key::AbstractString, model::AbstractString, conversation; + http_kwargs::NamedTuple = NamedTuple(), + streamcallback::Any = nothing, url::String = "http://localhost:8080", kwargs...) # Build the corresponding provider object # Create chat will automatically pass our data to endpoint `/chat/completions` provider = CustomProvider(; api_key, base_url = url) - OpenAI.create_chat(provider, model, conversation; kwargs...) + if !isnothing(streamcallback) + ## Take over from OpenAI.jl + url = OpenAI.build_url(provider, "chat/completions") + headers = OpenAI.auth_header(provider, api_key) + streamcallback, new_kwargs = configure_callback!( + streamcallback, schema; kwargs...) + input = OpenAI.build_params((; messages = conversation, model, new_kwargs...)) + ## Use the streaming callback + resp = streamed_request!(streamcallback, url, headers, input; http_kwargs...) + OpenAI.OpenAIResponse(resp.status, JSON3.read(resp.body)) + else + ## Use OpenAI.jl default + OpenAI.create_chat(provider, model, conversation; http_kwargs, kwargs...) + end end """ @@ -170,12 +187,9 @@ function OpenAI.create_chat(schema::MistralOpenAISchema, conversation; url::String = "https://api.mistral.ai/v1", kwargs...) - # Build the corresponding provider object # try to override provided api_key because the default is OpenAI key - provider = CustomProvider(; - api_key = isempty(MISTRALAI_API_KEY) ? api_key : MISTRALAI_API_KEY, - base_url = url) - OpenAI.create_chat(provider, model, conversation; kwargs...) + api_key = isempty(MISTRALAI_API_KEY) ? api_key : MISTRALAI_API_KEY + OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...) end function OpenAI.create_chat(schema::FireworksOpenAISchema, api_key::AbstractString, @@ -183,12 +197,9 @@ function OpenAI.create_chat(schema::FireworksOpenAISchema, conversation; url::String = "https://api.fireworks.ai/inference/v1", kwargs...) - # Build the corresponding provider object # try to override provided api_key because the default is OpenAI key - provider = CustomProvider(; - api_key = isempty(FIREWORKS_API_KEY) ? api_key : FIREWORKS_API_KEY, - base_url = url) - OpenAI.create_chat(provider, model, conversation; kwargs...) + api_key = isempty(FIREWORKS_API_KEY) ? api_key : FIREWORKS_API_KEY + OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...) end function OpenAI.create_chat(schema::TogetherOpenAISchema, api_key::AbstractString, @@ -196,12 +207,8 @@ function OpenAI.create_chat(schema::TogetherOpenAISchema, conversation; url::String = "https://api.together.xyz/v1", kwargs...) - # Build the corresponding provider object - # try to override provided api_key because the default is OpenAI key - provider = CustomProvider(; - api_key = isempty(TOGETHER_API_KEY) ? api_key : TOGETHER_API_KEY, - base_url = url) - OpenAI.create_chat(provider, model, conversation; kwargs...) + api_key = isempty(TOGETHER_API_KEY) ? api_key : TOGETHER_API_KEY + OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...) end function OpenAI.create_chat(schema::GroqOpenAISchema, api_key::AbstractString, @@ -209,12 +216,8 @@ function OpenAI.create_chat(schema::GroqOpenAISchema, conversation; url::String = "https://api.groq.com/openai/v1", kwargs...) - # Build the corresponding provider object - # try to override provided api_key because the default is OpenAI key - provider = CustomProvider(; - api_key = isempty(GROQ_API_KEY) ? api_key : GROQ_API_KEY, - base_url = url) - OpenAI.create_chat(provider, model, conversation; kwargs...) + api_key = isempty(GROQ_API_KEY) ? api_key : GROQ_API_KEY + OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...) end function OpenAI.create_chat(schema::DeepSeekOpenAISchema, api_key::AbstractString, @@ -222,12 +225,8 @@ function OpenAI.create_chat(schema::DeepSeekOpenAISchema, conversation; url::String = "https://api.deepseek.com/v1", kwargs...) - # Build the corresponding provider object - # try to override provided api_key because the default is OpenAI key - provider = CustomProvider(; - api_key = isempty(DEEPSEEK_API_KEY) ? api_key : DEEPSEEK_API_KEY, - base_url = url) - OpenAI.create_chat(provider, model, conversation; kwargs...) + api_key = isempty(DEEPSEEK_API_KEY) ? api_key : DEEPSEEK_API_KEY + OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...) end function OpenAI.create_chat(schema::OpenRouterOpenAISchema, api_key::AbstractString, @@ -235,30 +234,42 @@ function OpenAI.create_chat(schema::OpenRouterOpenAISchema, conversation; url::String = "https://openrouter.ai/api/v1", kwargs...) - # Build the corresponding provider object - # try to override provided api_key because the default is OpenAI key - provider = CustomProvider(; - api_key = isempty(OPENROUTER_API_KEY) ? api_key : OPENROUTER_API_KEY, - base_url = url) - OpenAI.create_chat(provider, model, conversation; kwargs...) + api_key = isempty(OPENROUTER_API_KEY) ? api_key : OPENROUTER_API_KEY + OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...) end function OpenAI.create_chat(schema::DatabricksOpenAISchema, api_key::AbstractString, model::AbstractString, conversation; + http_kwargs::NamedTuple = NamedTuple(), + streamcallback::Any = nothing, url::String = "https://.databricks.com", kwargs...) # Build the corresponding provider object provider = CustomProvider(; api_key = isempty(DATABRICKS_API_KEY) ? api_key : DATABRICKS_API_KEY, base_url = isempty(DATABRICKS_HOST) ? url : DATABRICKS_HOST) - # Override standard OpenAI request endpoint - OpenAI.openai_request("serving-endpoints/$model/invocations", - provider; - method = "POST", - model, - messages = conversation, - kwargs...) + if !isnothing(streamcallback) + throw(ArgumentError("Streaming is not supported for Databricks models yet!")) + ## Take over from OpenAI.jl + # url = OpenAI.build_url(provider, "serving-endpoints/$model/invocations") + # headers = OpenAI.auth_header(provider, api_key) + # streamcallback, new_kwargs = configure_callback!( + # streamcallback, schema; kwargs...) + # input = OpenAI.build_params((; messages = conversation, model, new_kwargs...)) + # ## Use the streaming callback + # resp = streamed_request!(streamcallback, url, headers, input; http_kwargs...) + # OpenAI.OpenAIResponse(resp.status, JSON3.read(resp.body)) + else + # Override standard OpenAI request endpoint + OpenAI.openai_request("serving-endpoints/$model/invocations", + provider; + method = "POST", + model, + messages = conversation, + http_kwargs, + kwargs...) + end end # Extend OpenAI create_embeddings to allow for testing diff --git a/src/streaming.jl b/src/streaming.jl index d6d5bb8ff..fe8c5ea06 100644 --- a/src/streaming.jl +++ b/src/streaming.jl @@ -122,7 +122,8 @@ function configure_callback!(cb::T, schema::AbstractPromptSchema; api_kwargs...) where {T <: AbstractStreamCallback} ## Check if we are in passthrough mode or if we should configure the callback if isnothing(cb.flavor) - if schema isa OpenAISchema + ## Enable streaming for all OpenAI-compatible APIs + if schema isa AbstractOpenAISchema api_kwargs = (; api_kwargs..., stream = true, stream_options = (; include_usage = true)) flavor = OpenAIStream() @@ -287,7 +288,6 @@ Print the content to the IO output stream `out`. """ @inline function print_content(out::IO, text::AbstractString; kwargs...) print(out, text) - # flush(stdout) end """ print_content(out::Channel, text::AbstractString; kwargs...) @@ -471,6 +471,8 @@ function streamed_request!(cb::AbstractStreamCallback, url, headers, input; kwar end HTTP.closeread(stream) end + ## For estetic reasons, if printing to stdout, we send a newline and flush + cb.out == stdout && (println(); flush(stdout)) body = build_response_body(cb.flavor, cb; verbose, cb.kwargs...) resp.body = JSON3.write(body)