From 1ef3ba37afb369b6d06279c9a1de3834dc407781 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Mon, 9 Sep 2024 12:07:06 +0300 Subject: [PATCH] Update Anthropic kwargs + docs --- CHANGELOG.md | 10 ++++ Project.toml | 2 +- src/llm_anthropic.jl | 60 +++++++++++++++++++++--- src/llm_openai.jl | 1 + src/streaming.jl | 10 +++- test/llm_anthropic.jl | 105 +++++++++++++++++++++++++++++++++++++++++- 6 files changed, 177 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9773cdc4c..b7a68f31f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added ### Fixed +## [0.53.0] + +### Added +- Added beta headers to enable long outputs (up to 8K tokens) with Anthropic's Sonnet 3.5 (see `?anthropic_extra_headers`). +- Added a kwarg to prefill (`aiprefill`) AI responses with Anthropic's models to improve steerability (see `?aigenerate`). + +### Updated +- Documentation of `aigenerate` to make it clear that if `streamcallback` is provide WITH `flavor` set, there is no automatic configuration and the user must provide the correct `api_kwargs`. +- Grouped Anthropic's beta headers as a comma-separated string as per the latest API specification. + ## [0.52.0] diff --git a/Project.toml b/Project.toml index 716f0d997..6a8db0fe9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.52.0" +version = "0.53.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/llm_anthropic.jl b/src/llm_anthropic.jl index b2ed2686a..c3e6fc8d2 100644 --- a/src/llm_anthropic.jl +++ b/src/llm_anthropic.jl @@ -7,6 +7,7 @@ """ render(schema::AbstractAnthropicSchema, messages::Vector{<:AbstractMessage}; + aiprefill::Union{Nothing, AbstractString} = nothing, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], tools::Vector{<:Dict{String, <:Any}} = Dict{String, Any}[], cache::Union{Nothing, Symbol} = nothing, @@ -15,12 +16,14 @@ Builds a history of the conversation to provide the prompt to the API. All unspecified kwargs are passed as replacements such that `{{key}}=>value` in the template. # Keyword Arguments +- `aiprefill`: A string to be used as a prefill for the AI response. This steer the AI response in a certain direction (and potentially save output tokens). - `conversation`: Past conversation to be included in the beginning of the prompt (for continued conversations). - `tools`: A list of tools to be used in the conversation. Added to the end of the system prompt to enforce its use. - `cache`: A symbol representing the caching strategy to be used. Currently only `nothing` (no caching), `:system`, `:tools`,`:last` and `:all` are supported. """ function render(schema::AbstractAnthropicSchema, messages::Vector{<:AbstractMessage}; + aiprefill::Union{Nothing, AbstractString} = nothing, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], tools::Vector{<:Dict{String, <:Any}} = Dict{String, Any}[], cache::Union{Nothing, Symbol} = nothing, @@ -79,16 +82,41 @@ function render(schema::AbstractAnthropicSchema, ## Sense check @assert !isempty(conversation) "AbstractAnthropicSchema requires at least 1 User message, ie, no `prompt` provided!" + ## Apply prefilling of responses + if !isnothing(aiprefill) && !isempty(aiprefill) + aimsg = AIMessage(aiprefill) + push!(conversation, + Dict("role" => role4render(schema, aimsg), + "content" => [Dict{String, Any}("type" => "text", "text" => aiprefill)])) + end return (; system, conversation) end -function anthropic_extra_headers(; has_tools = false, has_cache = false) +""" + anthropic_extra_headers + +Adds API version and beta headers to the request. + +# Kwargs / Beta headers +- `has_tools`: Enables tools in the conversation. +- `has_cache`: Enables prompt caching. +- `has_long_output`: Enables long outputs (up to 8K tokens) with Anthropic's Sonnet 3.5. +""" +function anthropic_extra_headers(; + has_tools = false, has_cache = false, has_long_output = false) extra_headers = ["anthropic-version" => "2023-06-01"] + beta_headers = String[] if has_tools - push!(extra_headers, "anthropic-beta" => "tools-2024-04-04") + push!(beta_headers, "tools-2024-04-04") end if has_cache - push!(extra_headers, "anthropic-beta" => "prompt-caching-2024-07-31") + push!(beta_headers, "prompt-caching-2024-07-31") + end + if has_long_output + push!(beta_headers, "max-tokens-3-5-sonnet-2024-07-15") + end + if !isempty(beta_headers) + extra_headers = [extra_headers..., "anthropic-beta" => join(beta_headers, ",")] end return extra_headers end @@ -146,7 +174,8 @@ function anthropic_api( end ## Build the headers extra_headers = anthropic_extra_headers(; - has_tools = haskey(kwargs, :tools), has_cache = !isnothing(cache)) + has_tools = haskey(kwargs, :tools), has_cache = !isnothing(cache), + has_long_output = (max_tokens > 4096 && model in ["claude-3-5-sonnet-20240620"])) headers = auth_header( api_key; bearer = false, x_api_key = true, extra_headers) @@ -174,7 +203,7 @@ function anthropic_api(prompt_schema::TestEchoAnthropicSchema, cache::Union{Nothing, Symbol} = nothing, model::String = "claude-3-haiku-20240307", kwargs...) prompt_schema.model_id = model - prompt_schema.inputs = (; system, messages) + prompt_schema.inputs = (; system, messages = copy(messages)) return prompt_schema end @@ -185,6 +214,7 @@ end return_all::Bool = false, dry_run::Bool = false, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], streamcallback::Any = nothing, + aiprefill::Union{Nothing, AbstractString} = nothing, http_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), cache::Union{Nothing, Symbol} = nothing, kwargs...) @@ -201,6 +231,8 @@ Generate an AI response based on a given prompt using the Anthropic API. - `dry_run::Bool=false`: If `true`, skips sending the messages to the model (for debugging, often used with `return_all=true`). - `conversation::AbstractVector{<:AbstractMessage}=[]`: Not allowed for this schema. Provided only for compatibility. - `streamcallback::Any`: A callback function to handle streaming responses. Can be simply `stdout` or `StreamCallback` object. See `?StreamCallback` for details. + Note: We configure the `StreamCallback` (and necessary `api_kwargs`) for you, unless you specify the `flavor`. See `?configure_callback!` for details. +- `aiprefill::Union{Nothing, AbstractString}`: A string to be used as a prefill for the AI response. This steer the AI response in a certain direction (and potentially save output tokens). It MUST NOT end with a trailing with space. Useful for JSON formatting. - `http_kwargs::NamedTuple`: Additional keyword arguments for the HTTP request. Defaults to empty `NamedTuple`. - `api_kwargs::NamedTuple`: Additional keyword arguments for the Ollama API. Defaults to an empty `NamedTuple`. - `max_tokens::Int`: The maximum number of tokens to generate. Defaults to 2048, because it's a required parameter for the API. @@ -281,6 +313,13 @@ msg = aigenerate("Count from 1 to 10."; streamcallback, model="claudeh") ``` Note: Streaming support is only for Anthropic models and it doesn't yet support tool calling and a few other features (logprobs, refusals, etc.) + +You can also provide a prefill for the AI response to steer the response in a certain direction (eg, formatting, style): +```julia +msg = aigenerate("Sum up 1 to 100."; aiprefill = "I'd be happy to answer in one number without any additional text. The answer is:", model="claudeh") +``` +Note: It MUST NOT end with a trailing with space. You'll get an API error if you do. + """ function aigenerate( prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_TYPE; @@ -290,17 +329,20 @@ function aigenerate( return_all::Bool = false, dry_run::Bool = false, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], streamcallback::Any = nothing, + aiprefill::Union{Nothing, AbstractString} = nothing, http_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), cache::Union{Nothing, Symbol} = nothing, kwargs...) ## global MODEL_ALIASES @assert (isnothing(cache)||cache in [:system, :tools, :last, :all]) "Currently only `:system`, `:tools`, `:last` and `:all` are supported for Anthropic Prompt Caching" + @assert (isnothing(aiprefill)||!isempty(strip(aiprefill))) "`aiprefill` must not be empty`" ## Find the unique ID for the model alias provided model_id = get(MODEL_ALIASES, model, model) - conv_rendered = render(prompt_schema, prompt; conversation, cache, kwargs...) + conv_rendered = render(prompt_schema, prompt; aiprefill, conversation, cache, kwargs...) if !dry_run + @info conv_rendered.conversation time = @elapsed resp = anthropic_api( prompt_schema, conv_rendered.conversation; api_key, conv_rendered.system, endpoint = "messages", model = model_id, streamcallback, http_kwargs, cache, @@ -308,6 +350,12 @@ function aigenerate( tokens_prompt = get(resp.response[:usage], :input_tokens, 0) tokens_completion = get(resp.response[:usage], :output_tokens, 0) content = mapreduce(x -> get(x, :text, ""), *, resp.response[:content]) |> strip + ## add aiprefill to the content + if !isnothing(aiprefill) && !isempty(aiprefill) + content = aiprefill * content + ## remove the prefill from the end of the conversation + pop!(conv_rendered.conversation) + end ## Build metadata extras = Dict{Symbol, Any}() haskey(resp.response[:usage], :cache_creation_input_tokens) && diff --git a/src/llm_openai.jl b/src/llm_openai.jl index f72419563..77f705303 100644 --- a/src/llm_openai.jl +++ b/src/llm_openai.jl @@ -460,6 +460,7 @@ Generate an AI response based on a given prompt using the OpenAI API. - `dry_run::Bool=false`: If `true`, skips sending the messages to the model (for debugging, often used with `return_all=true`). - `conversation`: An optional vector of `AbstractMessage` objects representing the conversation history. If not provided, it is initialized as an empty vector. - `streamcallback`: A callback function to handle streaming responses. Can be simply `stdout` or a `StreamCallback` object. See `?StreamCallback` for details. + Note: We configure the `StreamCallback` (and necessary `api_kwargs`) for you, unless you specify the `flavor`. See `?configure_callback!` for details. - `http_kwargs`: A named tuple of HTTP keyword arguments. - `api_kwargs`: A named tuple of API keyword arguments. Useful parameters include: - `temperature`: A float representing the temperature for sampling (ie, the amount of "creativity"). Often defaults to `0.7`. diff --git a/src/streaming.jl b/src/streaming.jl index bdd81939d..d84797eaf 100644 --- a/src/streaming.jl +++ b/src/streaming.jl @@ -85,6 +85,9 @@ msg = aigenerate("Count from 1 to 100."; streamcallback) streamcallback = PT.StreamCallback(; verbose=true, throw_on_error=true) msg = aigenerate("Count from 1 to 10."; streamcallback) ``` + +Note: If you provide a `StreamCallback` object to `aigenerate`, we will configure it and necessary `api_kwargs` via `configure_callback!` unless you specify the `flavor` field. +If you provide a `StreamCallback` with a specific `flavor`, we leave all configuration to the user (eg, you need to provide the correct `api_kwargs`). """ @kwdef mutable struct StreamCallback{ T1 <: Any, T2 <: Union{AbstractStreamFlavor, Nothing}} <: @@ -111,8 +114,11 @@ Base.length(cb::AbstractStreamCallback) = length(cb.chunks) api_kwargs...) Configures the callback `cb` for streaming with a given prompt schema. + If no `cb.flavor` is provided, adjusts the `flavor` and the provided `api_kwargs` as necessary. +Eg, for most schemas, we add kwargs like `stream = true` to the `api_kwargs`. +If `cb.flavor` is provided, both `callback` and `api_kwargs` are left unchanged! You need to configure them yourself! """ function configure_callback!(cb::T, schema::AbstractPromptSchema; api_kwargs...) where {T <: StreamCallback} @@ -216,7 +222,7 @@ Returns a list of `StreamChunk` and the next spillover (if message was incomplet try JSON3.read(data) catch e - verbose && @warn "Cannot parse JSON: $raw_chunk" + verbose && @warn "Cannot parse JSON: $data" nothing end else @@ -242,7 +248,7 @@ Returns a list of `StreamChunk` and the next spillover (if message was incomplet try JSON3.read(data) catch e - verbose && @warn "Cannot parse JSON: $raw_chunk" + verbose && @warn "Cannot parse JSON: $data" nothing end else diff --git a/test/llm_anthropic.jl b/test/llm_anthropic.jl index 5ccb12a80..c7bb6b1d4 100644 --- a/test/llm_anthropic.jl +++ b/test/llm_anthropic.jl @@ -160,6 +160,53 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature, "content" => [Dict("type" => "text", "text" => "Hello, my name is John", "cache_control" => Dict("type" => "ephemeral"))])]) @test conversation == expected_output + + # Test aiprefill functionality + messages = [ + SystemMessage("Act as a helpful AI assistant"), + UserMessage("Hello, what's your name?") + ] + + # Test with aiprefill + conversation = render(schema, messages; aiprefill = "My name is Claude") + expected_output = (; + system = "Act as a helpful AI assistant", + conversation = [ + Dict("role" => "user", + "content" => [Dict("type" => "text", "text" => "Hello, what's your name?")]), + Dict("role" => "assistant", + "content" => [Dict("type" => "text", "text" => "My name is Claude")]) + ]) + @test conversation == expected_output + + # Test without aiprefill + conversation_without_prefill = render(schema, messages) + expected_output_without_prefill = (; + system = "Act as a helpful AI assistant", + conversation = [ + Dict("role" => "user", + "content" => [Dict("type" => "text", "text" => "Hello, what's your name?")]) + ]) + @test conversation_without_prefill == expected_output_without_prefill + + # Test with empty aiprefill + conversation_empty_prefill = render(schema, messages; aiprefill = "") + @test conversation_empty_prefill == expected_output_without_prefill + + # Test aiprefill with cache + conversation_with_cache = render( + schema, messages; aiprefill = "My name is Claude", cache = :all) + expected_output_with_cache = (; + system = Dict{String, Any}[Dict("cache_control" => Dict("type" => "ephemeral"), + "text" => "Act as a helpful AI assistant", "type" => "text")], + conversation = [ + Dict("role" => "user", + "content" => [Dict("type" => "text", "text" => "Hello, what's your name?", + "cache_control" => Dict("type" => "ephemeral"))]), + Dict("role" => "assistant", + "content" => [Dict("type" => "text", "text" => "My name is Claude")]) + ]) + @test conversation_with_cache == expected_output_with_cache end @testset "anthropic_extra_headers" begin @@ -177,8 +224,12 @@ end @test anthropic_extra_headers(has_tools = true, has_cache = true) == [ "anthropic-version" => "2023-06-01", - "anthropic-beta" => "tools-2024-04-04", - "anthropic-beta" => "prompt-caching-2024-07-31" + "anthropic-beta" => "tools-2024-04-04,prompt-caching-2024-07-31" + ] + @test anthropic_extra_headers( + has_tools = true, has_cache = true, has_long_output = true) == [ + "anthropic-version" => "2023-06-01", + "anthropic-beta" => "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15" ] end @@ -243,6 +294,41 @@ end "role" => "user", "content" => [Dict("type" => "text", "text" => "Hello World")])] @test schema2.model_id == "claude-3-5-sonnet-20240620" + # Test aiprefill functionality + schema2 = TestEchoAnthropicSchema(; + response = Dict( + :content => [Dict(:text => "The answer is 42")], + :stop_reason => "stop", + :usage => Dict(:input_tokens => 5, :output_tokens => 4)), + status = 200) + + aiprefill = "The answer to the ultimate question of life, the universe, and everything is:" + msg = aigenerate(schema2, UserMessage("What is the answer to everything?"), + model = "claudes", http_kwargs = (; verbose = 3), api_kwargs = (; temperature = 0), + aiprefill = aiprefill) + + expected_output = AIMessage(; + content = aiprefill * "The answer is 42" |> strip, + status = 200, + tokens = (5, 4), + finish_reason = "stop", + cost = msg.cost, + run_id = msg.run_id, + sample_id = msg.sample_id, + extras = Dict{Symbol, Any}(), + elapsed = msg.elapsed) + + @test msg.content == expected_output.content + @test schema2.inputs.system == "Act as a helpful AI assistant" + @test schema2.inputs.messages == [ + Dict("role" => "user", + "content" => [Dict( + "type" => "text", "text" => "What is the answer to everything?")]), + Dict("role" => "assistant", + "content" => [Dict("type" => "text", "text" => aiprefill)]) + ] + @test schema2.model_id == "claude-3-5-sonnet-20240620" + # With caching response3 = Dict( :content => [ @@ -276,6 +362,21 @@ end ## Bad cache @test_throws AssertionError aigenerate( schema3, UserMessage("Hello {{name}}"); model = "claudeo", cache = :bad) + + # Test error throw if aiprefill is empty string + @test_throws AssertionError aigenerate( + AnthropicSchema(), + "Hello World"; + model = "claudeh", + aiprefill = "" + ) + + @test_throws AssertionError aigenerate( + AnthropicSchema(), + "Hello World"; + model = "claudeh", + aiprefill = " " # Only whitespace + ) end @testset "aiextract-Anthropic" begin