diff --git a/CHANGELOG.md b/CHANGELOG.md index 433a9357e..8566ddb05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.50.0] + +### Breaking Changes +- `AIMessage` and `DataMessage` now have a new field `extras` to hold any API-specific metadata in a simple dictionary. Change is backward-compatible (defaults to `nothing`). + +### Added +- Added EXPERIMENTAL support for Anthropic's new prompt cache (see ?`aigenerate` and look for `cache` kwarg). Note that COST estimate will be wrong (ignores the caching discount for now). +- Added a new `extras` field to `AIMessage` and `DataMessage` to hold any API-specific metadata in a simple dictionary (eg, used for reporting on the cache hit/miss). + ## [0.49.0] ### Added diff --git a/Project.toml b/Project.toml index d199ad268..372b15dbb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.49.0" +version = "0.50.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/llm_anthropic.jl b/src/llm_anthropic.jl index e2368ddba..ac893d651 100644 --- a/src/llm_anthropic.jl +++ b/src/llm_anthropic.jl @@ -9,6 +9,7 @@ messages::Vector{<:AbstractMessage}; conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], tools::Vector{<:Dict{String, <:Any}} = Dict{String, Any}[], + cache::Union{Nothing, Symbol} = nothing, kwargs...) Builds a history of the conversation to provide the prompt to the API. All unspecified kwargs are passed as replacements such that `{{key}}=>value` in the template. @@ -16,14 +17,17 @@ Builds a history of the conversation to provide the prompt to the API. All unspe # Keyword Arguments - `conversation`: Past conversation to be included in the beginning of the prompt (for continued conversations). - `tools`: A list of tools to be used in the conversation. Added to the end of the system prompt to enforce its use. +- `cache`: A symbol representing the caching strategy to be used. Currently only `nothing` (no caching), `:system`, `:tools`,`:last` and `:all` are supported. """ function render(schema::AbstractAnthropicSchema, messages::Vector{<:AbstractMessage}; conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], tools::Vector{<:Dict{String, <:Any}} = Dict{String, Any}[], + cache::Union{Nothing, Symbol} = nothing, kwargs...) ## @assert count(issystemmessage, messages)<=1 "AbstractAnthropicSchema only supports at most 1 System message" + @assert (isnothing(cache)||cache in [:system, :tools, :last, :all]) "Currently only `:system`, `:tools`, `:last`, `:all` are supported for Anthropic Prompt Caching" system = nothing @@ -39,7 +43,8 @@ function render(schema::AbstractAnthropicSchema, elseif msg isa UserMessage || msg isa AIMessage content = msg.content push!(conversation, - Dict("role" => role4render(schema, msg), "content" => content)) + Dict("role" => role4render(schema, msg), + "content" => [Dict{String, Any}("type" => "text", "text" => content)])) elseif msg isa UserMessageWithImages error("AbstractAnthropicSchema does not yet support UserMessageWithImages. Please use OpenAISchema instead.") end @@ -58,25 +63,52 @@ function render(schema::AbstractAnthropicSchema, end end + ## Apply cache for last message + is_valid_conversation = length(conversation) > 0 && + haskey(conversation[end], "content") && + length(conversation[end]["content"]) > 0 + if is_valid_conversation && (cache == :last || cache == :all) + conversation[end]["content"][end]["cache_control"] = Dict("type" => "ephemeral") + end + if !isnothing(system) && (cache == :system || cache == :all) + ## Apply cache for system message + system = [Dict("type" => "text", "text" => system, + "cache_control" => Dict("type" => "ephemeral"))] + end + ## Sense check @assert !isempty(conversation) "AbstractAnthropicSchema requires at least 1 User message, ie, no `prompt` provided!" return (; system, conversation) end +function anthropic_extra_headers(; has_tools = false, has_cache = false) + extra_headers = ["anthropic-version" => "2023-06-01"] + if has_tools + push!(extra_headers, "anthropic-beta" => "tools-2024-04-04") + end + if has_cache + push!(extra_headers, "anthropic-beta" => "prompt-caching-2024-07-31") + end + return extra_headers +end + ## Model-calling """ - anthropic_api(prompt_schema::AbstractAnthropicSchema, - messages::Vector{<:AbstractMessage} = AbstractMessage[]; - prompt::Union{AbstractString, Nothing} = nothing; - system::Union{Nothing, AbstractString} = nothing, - endpoint::String = "generate", - model::String = "llama2", http_kwargs::NamedTuple = NamedTuple(), + anthropic_api( + prompt_schema::AbstractAnthropicSchema, + messages::Vector{<:AbstractDict{String, <:Any}} = Vector{Dict{String, Any}}(); + api_key::AbstractString = ANTHROPIC_API_KEY, + system::Union{Nothing, AbstractString, AbstractVector{<:AbstractDict}} = nothing, + endpoint::String = "messages", + max_tokens::Int = 2048, + model::String = "claude-3-haiku-20240307", http_kwargs::NamedTuple = NamedTuple(), stream::Bool = false, - url::String = "localhost", port::Int = 11434, + url::String = "https://api.anthropic.com/v1", + cache::Union{Nothing, Symbol} = nothing, kwargs...) -Simple wrapper for a call to Ollama API. +Simple wrapper for a call to Anthropic API. # Keyword Arguments - `prompt_schema`: Defines which prompt template should be applied. @@ -88,32 +120,35 @@ Simple wrapper for a call to Ollama API. - `http_kwargs::NamedTuple`: Additional keyword arguments for the HTTP request. Defaults to empty `NamedTuple`. - `stream`: A boolean indicating whether to stream the response. Defaults to `false`. - `url`: The URL of the Ollama API. Defaults to "localhost". +- `cache`: A symbol representing the caching strategy to be used. Currently only `nothing` (no caching), `:system`, `:tools`,`:last` and `:all` are supported. - `kwargs`: Prompt variables to be used to fill the prompt/template """ function anthropic_api( prompt_schema::AbstractAnthropicSchema, messages::Vector{<:AbstractDict{String, <:Any}} = Vector{Dict{String, Any}}(); api_key::AbstractString = ANTHROPIC_API_KEY, - system::Union{Nothing, AbstractString} = nothing, + system::Union{Nothing, AbstractString, AbstractVector{<:AbstractDict}} = nothing, endpoint::String = "messages", max_tokens::Int = 2048, model::String = "claude-3-haiku-20240307", http_kwargs::NamedTuple = NamedTuple(), stream::Bool = false, url::String = "https://api.anthropic.com/v1", + cache::Union{Nothing, Symbol} = nothing, kwargs...) @assert endpoint in ["messages"] "Only 'messages' endpoint is supported." - ## + ## body = Dict("model" => model, "max_tokens" => max_tokens, "stream" => stream, "messages" => messages, kwargs...) ## provide system message if !isnothing(system) body["system"] = system end - ## + ## Build the headers + extra_headers = anthropic_extra_headers(; + has_tools = haskey(kwargs, :tools), has_cache = !isnothing(cache)) headers = auth_header( api_key; bearer = false, x_api_key = true, - extra_headers = ["anthropic-version" => "2023-06-01", - "anthropic-beta" => "tools-2024-04-04"]) + extra_headers) api_url = string(url, "/", endpoint) resp = HTTP.post(api_url, headers, JSON3.write(body); http_kwargs...) body = JSON3.read(resp.body) @@ -123,8 +158,9 @@ end function anthropic_api(prompt_schema::TestEchoAnthropicSchema, messages::Vector{<:AbstractDict{String, <:Any}} = Vector{Dict{String, Any}}(); api_key::AbstractString = ANTHROPIC_API_KEY, - system::Union{Nothing, AbstractString} = nothing, + system::Union{Nothing, AbstractString, AbstractVector{<:AbstractDict}} = nothing, endpoint::String = "messages", + cache::Union{Nothing, Symbol} = nothing, model::String = "claude-3-haiku-20240307", kwargs...) prompt_schema.model_id = model prompt_schema.inputs = (; system, messages) @@ -138,6 +174,7 @@ end return_all::Bool = false, dry_run::Bool = false, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], http_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), + cache::Union{Nothing, Symbol} = nothing, kwargs...) Generate an AI response based on a given prompt using the Anthropic API. @@ -154,8 +191,15 @@ Generate an AI response based on a given prompt using the Anthropic API. - `http_kwargs::NamedTuple`: Additional keyword arguments for the HTTP request. Defaults to empty `NamedTuple`. - `api_kwargs::NamedTuple`: Additional keyword arguments for the Ollama API. Defaults to an empty `NamedTuple`. - `max_tokens::Int`: The maximum number of tokens to generate. Defaults to 2048, because it's a required parameter for the API. +- `cache`: A symbol indicating whether to use caching for the prompt. Supported values are `nothing` (no caching), `:system`, `:tools`, `:last` and `:all`. Note that COST estimate will be wrong (ignores the caching). + - `:system`: Caches the system message + - `:tools`: Caches the tool definitions (and everything before them) + - `:last`: Caches the last message in the conversation (and everything before it) + - `:all`: Cache trigger points are inserted in all of the above places (ie, higher likelyhood of cache hit, but also slightly higher cost) - `kwargs`: Prompt variables to be used to fill the prompt/template +Note: At the moment, the cache is only allowed for prompt segments over 1024 tokens (in some cases, over 2048 tokens). You'll get an error if you try to cache short prompts. + # Returns - `msg`: An `AIMessage` object representing the generated AI message, including the content, status, tokens, and elapsed time. Use `msg.content` to access the extracted string. @@ -218,26 +262,36 @@ function aigenerate( return_all::Bool = false, dry_run::Bool = false, conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[], http_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(), + cache::Union{Nothing, Symbol} = nothing, kwargs...) ## global MODEL_ALIASES + @assert (isnothing(cache)||cache in [:system, :tools, :last, :all]) "Currently only `:system`, `:tools`, `:last` and `:all` are supported for Anthropic Prompt Caching" ## Find the unique ID for the model alias provided model_id = get(MODEL_ALIASES, model, model) - conv_rendered = render(prompt_schema, prompt; conversation, kwargs...) + conv_rendered = render(prompt_schema, prompt; conversation, cache, kwargs...) if !dry_run time = @elapsed resp = anthropic_api( prompt_schema, conv_rendered.conversation; api_key, - conv_rendered.system, endpoint = "messages", model = model_id, http_kwargs, + conv_rendered.system, endpoint = "messages", model = model_id, http_kwargs, cache, api_kwargs...) tokens_prompt = get(resp.response[:usage], :input_tokens, 0) tokens_completion = get(resp.response[:usage], :output_tokens, 0) content = mapreduce(x -> get(x, :text, ""), *, resp.response[:content]) |> strip + ## Build metadata + extras = Dict{Symbol, Any}() + haskey(resp.response[:usage], :cache_creation_input_tokens) && + (extras[:cache_creation_input_tokens] = resp.response[:usage][:cache_creation_input_tokens]) + haskey(resp.response[:usage], :cache_read_input_tokens) && + (extras[:cache_read_input_tokens] = resp.response[:usage][:cache_read_input_tokens]) + ## Build the message msg = AIMessage(; content, status = Int(resp.status), cost = call_cost(tokens_prompt, tokens_completion, model_id), finish_reason = get(resp.response, :stop_reason, nothing), tokens = (tokens_prompt, tokens_completion), + extras, elapsed = time) ## Reporting verbose && @info _report_stats(msg, model_id) @@ -267,6 +321,7 @@ end http_kwargs::NamedTuple = (retry_non_idempotent = true, retries = 5, readtimeout = 120), api_kwargs::NamedTuple = NamedTuple(), + cache::Union{Nothing, Symbol} = nothing, kwargs...) Extract required information (defined by a struct **`return_type`**) from the provided prompt by leveraging Anthropic's function calling mode. @@ -291,8 +346,15 @@ It's effectively a light wrapper around `aigenerate` call, which requires additi - `conversation`: An optional vector of `AbstractMessage` objects representing the conversation history. If not provided, it is initialized as an empty vector. - `http_kwargs`: A named tuple of HTTP keyword arguments. - `api_kwargs`: A named tuple of API keyword arguments. +- `cache`: A symbol indicating whether to use caching for the prompt. Supported values are `nothing` (no caching), `:system`, `:tools`, `:last` and `:all`. Note that COST estimate will be wrong (ignores the caching). + - `:system`: Caches the system message + - `:tools`: Caches the tool definitions (and everything before them) + - `:last`: Caches the last message in the conversation (and everything before it) + - `:all`: Cache trigger points are inserted in all of the above places (ie, higher likelyhood of cache hit, but also slightly higher cost) - `kwargs`: Prompt variables to be used to fill the prompt/template +Note: At the moment, the cache is only allowed for prompt segments over 1024 tokens (in some cases, over 2048 tokens). You'll get an error if you try to cache short prompts. + # Returns If `return_all=false` (default): - `msg`: An `DataMessage` object representing the extracted data, including the content, status, tokens, and elapsed time. @@ -367,7 +429,7 @@ return_type = MaybeExtract{MyMeasurement} # If LLM extraction fails, it will return a Dict with `error` and `message` fields instead of the result! msg = aiextract("Extract measurements from the text: I am giraffe"; model="claudeo", return_type) msg.content -# Output: MaybeExtract{MyMeasurement}(nothing, true, "I'm sorry, but your input of \"I am giraffe\" does not contain any information about a person's age, height or weight measurements that I can extract. To use this tool, please provide a statement that includes at least the person's age, and optionally their height in inches and weight in pounds. Without that information, I am unable to extract the requested measurements.") +# Output: MaybeExtract{MyMeasurement}(nothing, true, "I'm sorry, but your input of "I am giraffe" does not contain any information about a person's age, height or weight measurements that I can extract. To use this tool, please provide a statement that includes at least the person's age, and optionally their height in inches and weight in pounds. Without that information, I am unable to extract the requested measurements.") ``` That way, you can handle the error gracefully and get a reason why extraction failed (in `msg.content.message`). @@ -393,9 +455,11 @@ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMP http_kwargs::NamedTuple = (retry_non_idempotent = true, retries = 5, readtimeout = 120), api_kwargs::NamedTuple = NamedTuple(), + cache::Union{Nothing, Symbol} = nothing, kwargs...) ## global MODEL_ALIASES + @assert (isnothing(cache)||cache in [:system, :tools, :last, :all]) "Currently only `:system`, `:tools`, `:last` and `:all` are supported for Anthropic Prompt Caching" ## Find the unique ID for the model alias provided model_id = get(MODEL_ALIASES, model, model) @@ -404,17 +468,20 @@ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMP sig = function_call_signature(return_type; max_description_length = 100) tools = [Dict("name" => sig["name"], "description" => get(sig, "description", ""), "input_schema" => sig["parameters"])] + ## update tools to use caching + (cache == :tools || cache == :all) && + (tools[end]["cache_control"] = Dict("type" => "ephemeral")) ## Add the function call stopping sequence to the api_kwargs api_kwargs = merge(api_kwargs, (; tools)) ## We provide the tool description to the rendering engine - conv_rendered = render(prompt_schema, prompt; tools, conversation, kwargs...) + conv_rendered = render(prompt_schema, prompt; tools, conversation, cache, kwargs...) if !dry_run time = @elapsed resp = anthropic_api( prompt_schema, conv_rendered.conversation; api_key, - conv_rendered.system, endpoint = "messages", model = model_id, http_kwargs, + conv_rendered.system, endpoint = "messages", model = model_id, cache, http_kwargs, api_kwargs...) tokens_prompt = get(resp.response[:usage], :input_tokens, 0) tokens_completion = get(resp.response[:usage], :output_tokens, 0) @@ -436,13 +503,20 @@ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMP @warn "No tool_use found in the response. Returning the raw text instead." mapreduce(x -> get(x, :text, ""), *, resp.response[:content]) |> strip end + ## Build metadata + extras = Dict{Symbol, Any}() + haskey(resp.response[:usage], :cache_creation_input_tokens) && + (extras[:cache_creation_input_tokens] = resp.response[:usage][:cache_creation_input_tokens]) + haskey(resp.response[:usage], :cache_read_input_tokens) && + (extras[:cache_read_input_tokens] = resp.response[:usage][:cache_read_input_tokens]) ## Build data message msg = DataMessage(; content, status = Int(resp.status), cost = call_cost(tokens_prompt, tokens_completion, model_id), finish_reason, tokens = (tokens_prompt, tokens_completion), - elapsed = time) + elapsed = time, + extras) ## Reporting verbose && @info _report_stats(msg, model_id) diff --git a/src/messages.jl b/src/messages.jl index 718c906d4..7a991b5f9 100644 --- a/src/messages.jl +++ b/src/messages.jl @@ -78,6 +78,7 @@ Returned by `aigenerate`, `aiclassify`, and `aiscan` functions. - `elapsed::Float64`: The time taken to generate the response in seconds. - `cost::Union{Nothing, Float64}`: The cost of the API call (calculated with information from `MODEL_REGISTRY`). - `log_prob::Union{Nothing, Float64}`: The log probability of the response. +- `extras::Union{Nothing, Dict{Symbol, Any}}`: A dictionary for additional metadata that is not part of the key message fields. Try to limit to a small number of items and singletons to be serializable. - `finish_reason::Union{Nothing, String}`: The reason the response was finished. - `run_id::Union{Nothing, Int}`: The unique ID of the run. - `sample_id::Union{Nothing, Int}`: The unique ID of the sample (if multiple samples are generated, they will all have the same `run_id`). @@ -89,6 +90,7 @@ Base.@kwdef struct AIMessage{T <: Union{AbstractString, Nothing}} <: AbstractCha elapsed::Float64 = -1.0 cost::Union{Nothing, Float64} = nothing log_prob::Union{Nothing, Float64} = nothing + extras::Union{Nothing, Dict{Symbol, Any}} = nothing finish_reason::Union{Nothing, String} = nothing run_id::Union{Nothing, Int} = Int(rand(Int16)) sample_id::Union{Nothing, Int} = nothing @@ -108,6 +110,7 @@ Returned by `aiextract`, and `aiextract` functions. - `elapsed::Float64`: The time taken to generate the response in seconds. - `cost::Union{Nothing, Float64}`: The cost of the API call (calculated with information from `MODEL_REGISTRY`). - `log_prob::Union{Nothing, Float64}`: The log probability of the response. +- `extras::Union{Nothing, Dict{Symbol, Any}}`: A dictionary for additional metadata that is not part of the key message fields. Try to limit to a small number of items and singletons to be serializable. - `finish_reason::Union{Nothing, String}`: The reason the response was finished. - `run_id::Union{Nothing, Int}`: The unique ID of the run. - `sample_id::Union{Nothing, Int}`: The unique ID of the sample (if multiple samples are generated, they will all have the same `run_id`). @@ -119,6 +122,7 @@ Base.@kwdef struct DataMessage{T <: Any} <: AbstractDataMessage elapsed::Float64 = -1.0 cost::Union{Nothing, Float64} = nothing log_prob::Union{Nothing, Float64} = nothing + extras::Union{Nothing, Dict{Symbol, Any}} = nothing finish_reason::Union{Nothing, String} = nothing run_id::Union{Nothing, Int} = Int(rand(Int16)) sample_id::Union{Nothing, Int} = nothing @@ -137,7 +141,7 @@ function (MSG::Type{<:AbstractChatMessage})(msg::AbstractTracerMessage{<:Abstrac MSG(; msg.content) end -## It checks types so it should be defined for all inputs +## It checks types so it should be defined for all inputs isusermessage(m::Any) = m isa UserMessage issystemmessage(m::Any) = m isa SystemMessage isdatamessage(m::Any) = m isa DataMessage diff --git a/src/utils.jl b/src/utils.jl index 9d93c7b34..36acc9547 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -468,7 +468,12 @@ function _report_stats(msg, model::String) cost = call_cost(msg, model) cost_str = iszero(cost) ? "" : " @ Cost: \$$(round(cost; digits=4))" - return "Tokens: $(sum(msg.tokens))$(cost_str) in $(round(msg.elapsed;digits=1)) seconds" + metadata_str = if !isnothing(msg.extras) && !isempty(msg.extras) + " (Metadata: $(join([string(k, " => ", v) for (k, v) in msg.extras if v isa Number && !iszero(v)], ", ")))" + else + "" + end + return "Tokens: $(sum(msg.tokens))$(cost_str) in $(round(msg.elapsed;digits=1)) seconds$(metadata_str)" end ## dispatch for array -> take last message function _report_stats(msg::AbstractVector, diff --git a/test.json b/test.json new file mode 100644 index 000000000..25c8cadbd --- /dev/null +++ b/test.json @@ -0,0 +1,17 @@ +[ + { + "content": "Sure, here's an example of how you can define a similarity retrieval function for Euclidean distance in Julia:\n\n```julia\nusing PromptingTools.Experimental.RAGTools\n\nstruct EuclideanSimilarity <: AbstractSimilarityFinder end\n\nfunction find_closest(finder::EuclideanSimilarity, embeddings::AbstractMatrix{<:Real}, query_embedding::AbstractVector{<:Real}; top_k::Integer = 5, minimum_similarity::Real = 0.0)\n dists = mapslices(v -> norm(v .- query_embedding), embeddings, dims=1)\n positions = sortperm(dists)[1:min(top_k, size(embeddings, 2))]\n scores = -dists[positions]\n mask = scores .>= minimum_similarity\n return CandidateChunks(positions[mask], scores[mask])\nend\n\nfunction find_closest(finder::EuclideanSimilarity, index::ChunkIndex, query_embedding::AbstractVector{<:Real}; top_k::Integer = 5, minimum_similarity::Real = 0.0)\n return find_closest(finder, index.embeddings, query_embedding; top_k=top_k, minimum_similarity=minimum_similarity)\nend\n\nfunction find_closest(finder::EuclideanSimilarity, index::MultiIndex, query_embedding::AbstractVector{<:Real}; top_k::Integer = 5, minimum_similarity::Real = 0.0)\n results = [find_closest(finder, idx, query_embedding; top_k=top_k, minimum_similarity=minimum_similarity) for idx in index.indexes]\n return MultiCandidateChunks(\n [r.index_id for r in results],\n [r.positions for r in results],\n [r.scores for r in results]\n )\nend\n```\n\nHere's a breakdown of the code:\n\n1. `EuclideanSimilarity <: AbstractSimilarityFinder`: This defines a new type `EuclideanSimilarity` that is a subtype of `AbstractSimilarityFinder`. This type will be used to represent the Euclidean distance similarity finder.\n\n2. `find_closest(finder::EuclideanSimilarity, embeddings::AbstractMatrix{<:Real}, query_embedding::AbstractVector{<:Real}; top_k::Integer = 5, minimum_similarity::Real = 0.0)`: This function implements the `find_closest` method for the `EuclideanSimilarity` type. It takes an embedding matrix, a query embedding vector, and optional parameters `top_k` and `minimum_similarity`. It calculates the Euclidean distances between the query embedding and each embedding in the matrix, sorts the positions by the distances, and returns a `CandidateChunks` object containing the top `top_k` positions and their corresponding scores (negative distances).\n\n3. `find_closest(finder::EuclideanSimilarity, index::ChunkIndex, query_embedding::AbstractVector{<:Real}; top_k::Integer = 5, minimum_similarity::Real = 0.0)`: This method implements the `find_closest` function for a `ChunkIndex` object, which simply delegates the call to the previous `find_closest` method using the `index.embeddings` matrix.\n\n4. `find_closest(finder::EuclideanSimilarity, index::MultiIndex, query_embedding::AbstractVector{<:Real}; top_k::Integer = 5, minimum_similarity::Real = 0.0)`: This method implements the `find_closest` function for a `MultiIndex` object. It calls the `find_closest` method for each sub-index in the `MultiIndex` and collects the results into a `MultiCandidateChunks` object.\n\nWith this implementation, you can now use the `EuclideanSimilarity` type and the `find_closest` methods in your retrieval pipeline, just like the other similarity finders provided by the `PromptingTools.Experimental.RAGTools` module.", + "status": 200, + "tokens": [ + 4, + 969 + ], + "elapsed": 10.802840083, + "cost": 0.00121225, + "log_prob": null, + "finish_reason": "end_turn", + "run_id": 4668, + "sample_id": null, + "_type": "aimessage" + } +] \ No newline at end of file diff --git a/test/llm_anthropic.jl b/test/llm_anthropic.jl index 2c7fdfb10..5ccb12a80 100644 --- a/test/llm_anthropic.jl +++ b/test/llm_anthropic.jl @@ -1,7 +1,8 @@ using PromptingTools: TestEchoAnthropicSchema, render, AnthropicSchema using PromptingTools: AIMessage, SystemMessage, AbstractMessage using PromptingTools: UserMessage, UserMessageWithImages, DataMessage -using PromptingTools: call_cost, anthropic_api, function_call_signature +using PromptingTools: call_cost, anthropic_api, function_call_signature, + anthropic_extra_headers @testset "render-Anthropic" begin schema = AnthropicSchema() @@ -11,7 +12,8 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature UserMessage("Hello, my name is {{name}}") ] expected_output = (; system = "Act as a helpful AI assistant", - conversation = [Dict("role" => "user", "content" => "Hello, my name is John")]) + conversation = [Dict("role" => "user", + "content" => [Dict("type" => "text", "text" => "Hello, my name is John")])]) conversation = render(schema, messages; name = "John") @test conversation == expected_output # Test with dry_run=true on ai* functions @@ -26,7 +28,8 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature ] expected_output = (; system = "Act as a helpful AI assistant", conversation = [Dict( - "role" => "assistant", "content" => "Hello, my name is {{name}}")]) + "role" => "assistant", + "content" => [Dict("type" => "text", "text" => "Hello, my name is {{name}}")])]) conversation = render(schema, messages; name = "John") # AIMessage does not replace handlebar variables @test conversation == expected_output @@ -37,7 +40,8 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature ] conversation = render(schema, messages) expected_output = (; system = "Act as a helpful AI assistant", - conversation = [Dict("role" => "user", "content" => "User message")]) + conversation = [Dict("role" => "user", + "content" => [Dict("type" => "text", "text" => "User message")])]) @test conversation == expected_output # Given a schema and a vector of messages, it should return a conversation dictionary with the correct roles and contents for each message. @@ -49,10 +53,15 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature ] expected_output = (; system = "Act as a helpful AI assistant", conversation = [ - Dict("role" => "user", "content" => "Hello"), - Dict("role" => "assistant", "content" => "Hi there"), - Dict("role" => "user", "content" => "How are you?"), - Dict("role" => "assistant", "content" => "I'm doing well, thank you!") + Dict( + "role" => "user", "content" => [Dict("type" => "text", "text" => "Hello")]), + Dict("role" => "assistant", + "content" => [Dict("type" => "text", "text" => "Hi there")]), + Dict("role" => "user", + "content" => [Dict("type" => "text", "text" => "How are you?")]), + Dict("role" => "assistant", + "content" => [Dict( + "type" => "text", "text" => "I'm doing well, thank you!")]) ]) conversation = render(schema, messages) @test conversation == expected_output @@ -65,8 +74,10 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature ] expected_output = (; system = "This is a system message", conversation = [ - Dict("role" => "user", "content" => "Hello"), - Dict("role" => "assistant", "content" => "Hi there") + Dict( + "role" => "user", "content" => [Dict("type" => "text", "text" => "Hello")]), + Dict("role" => "assistant", + "content" => [Dict("type" => "text", "text" => "Hi there")]) ]) conversation = render(schema, messages) @test conversation == expected_output @@ -83,8 +94,10 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature ] expected_output = (; system = "Act as a helpful AI assistant", conversation = [ - Dict("role" => "user", "content" => "Hello"), - Dict("role" => "assistant", "content" => "Hi there") + Dict( + "role" => "user", "content" => [Dict("type" => "text", "text" => "Hello")]), + Dict("role" => "assistant", + "content" => [Dict("type" => "text", "text" => "Hi there")]) ]) conversation = render(schema, messages) @test conversation == expected_output @@ -117,6 +130,56 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature "input_schema" => "")] @test_logs (:warn, r"Multiple tools provided") match_mode=:any render( schema, messages; tools) + + ## Cache variables + messages = [ + SystemMessage("Act as a helpful AI assistant"), + UserMessage("Hello, my name is {{name}}") + ] + conversation = render(schema, messages; name = "John", cache = :system) + expected_output = (; + system = Dict{String, Any}[Dict("cache_control" => Dict("type" => "ephemeral"), + "text" => "Act as a helpful AI assistant", "type" => "text")], + conversation = [Dict("role" => "user", + "content" => [Dict("type" => "text", "text" => "Hello, my name is John")])]) + @test conversation == expected_output + + conversation = render(schema, messages; name = "John", cache = :last) + expected_output = (; + system = "Act as a helpful AI assistant", + conversation = [Dict("role" => "user", + "content" => [Dict("type" => "text", "text" => "Hello, my name is John", + "cache_control" => Dict("type" => "ephemeral"))])]) + @test conversation == expected_output + + conversation = render(schema, messages; name = "John", cache = :all) + expected_output = (; + system = Dict{String, Any}[Dict("cache_control" => Dict("type" => "ephemeral"), + "text" => "Act as a helpful AI assistant", "type" => "text")], + conversation = [Dict("role" => "user", + "content" => [Dict("type" => "text", "text" => "Hello, my name is John", + "cache_control" => Dict("type" => "ephemeral"))])]) + @test conversation == expected_output +end + +@testset "anthropic_extra_headers" begin + @test anthropic_extra_headers() == ["anthropic-version" => "2023-06-01"] + + @test anthropic_extra_headers(has_tools = true) == [ + "anthropic-version" => "2023-06-01", + "anthropic-beta" => "tools-2024-04-04" + ] + + @test anthropic_extra_headers(has_cache = true) == [ + "anthropic-version" => "2023-06-01", + "anthropic-beta" => "prompt-caching-2024-07-31" + ] + + @test anthropic_extra_headers(has_tools = true, has_cache = true) == [ + "anthropic-version" => "2023-06-01", + "anthropic-beta" => "tools-2024-04-04", + "anthropic-beta" => "prompt-caching-2024-07-31" + ] end @testset "anthropic_api" begin @@ -153,10 +216,12 @@ end tokens = (2, 1), finish_reason = "stop", cost = msg.cost, + extras = Dict{Symbol, Any}(), elapsed = msg.elapsed) @test msg == expected_output @test schema1.inputs.system == "Act as a helpful AI assistant" - @test schema1.inputs.messages == [Dict("role" => "user", "content" => "Hello World")] + @test schema1.inputs.messages == [Dict( + "role" => "user", "content" => [Dict("type" => "text", "text" => "Hello World")])] @test schema1.model_id == "claude-3-opus-20240229" # Test different input combinations and different prompts @@ -170,11 +235,47 @@ end tokens = (2, 1), finish_reason = "stop", cost = msg.cost, + extras = Dict{Symbol, Any}(), elapsed = msg.elapsed) @test msg == expected_output @test schema2.inputs.system == "Act as a helpful AI assistant" - @test schema2.inputs.messages == [Dict("role" => "user", "content" => "Hello World")] + @test schema2.inputs.messages == [Dict( + "role" => "user", "content" => [Dict("type" => "text", "text" => "Hello World")])] @test schema2.model_id == "claude-3-5-sonnet-20240620" + + # With caching + response3 = Dict( + :content => [ + Dict(:text => "Hello!")], + :stop_reason => "stop", + :usage => Dict(:input_tokens => 2, :output_tokens => 1, + :cache_creation_input_tokens => 1, :cache_read_input_tokens => 0)) + + schema3 = TestEchoAnthropicSchema(; response = response3, status = 200) + msg = aigenerate(schema3, UserMessage("Hello {{name}}"), + model = "claudes", http_kwargs = (; verbose = 3), api_kwargs = (; temperature = 0), + cache = :all, + name = "World") + expected_output = AIMessage(; + content = "Hello!" |> strip, + status = 200, + tokens = (2, 1), + finish_reason = "stop", + cost = msg.cost, + extras = Dict{Symbol, Any}( + :cache_read_input_tokens => 0, :cache_creation_input_tokens => 1), + elapsed = msg.elapsed) + @test msg == expected_output + @test schema3.inputs.system == [Dict("cache_control" => Dict("type" => "ephemeral"), + "text" => "Act as a helpful AI assistant", "type" => "text")] + @test schema3.inputs.messages == [Dict("role" => "user", + "content" => Dict{String, Any}[Dict("cache_control" => Dict("type" => "ephemeral"), + "text" => "Hello World", "type" => "text")])] + @test schema3.model_id == "claude-3-5-sonnet-20240620" + + ## Bad cache + @test_throws AssertionError aigenerate( + schema3, UserMessage("Hello {{name}}"); model = "claudeo", cache = :bad) end @testset "aiextract-Anthropic" begin @@ -197,12 +298,15 @@ end tokens = (2, 1), finish_reason = "tool_use", cost = msg.cost, + extras = Dict{Symbol, Any}(), elapsed = msg.elapsed) @test msg == expected_output @test schema1.inputs.system == "Act as a helpful AI assistant\n\nUse the Fruit_extractor tool in your response." @test schema1.inputs.messages == - [Dict("role" => "user", "content" => "Hello World! Banana")] + [Dict("role" => "user", + "content" => Dict{String, Any}[Dict( + "text" => "Hello World! Banana", "type" => "text")])] @test schema1.model_id == "claude-3-opus-20240229" # Test badly formatted response @@ -225,6 +329,32 @@ end schema3 = TestEchoAnthropicSchema(; response, status = 200) msg = aiextract(schema3, "Hello World! Banana"; model = "claudeo", return_type = Fruit) @test msg.content == "No tools for you!" + + # With Cache + response4 = Dict( + :content => [ + Dict(:type => "tool_use", :input => Dict("name" => "banana"))], + :stop_reason => "tool_use", + :usage => Dict(:input_tokens => 2, :output_tokens => 1, + :cache_creation_input_tokens => 1, :cache_read_input_tokens => 0)) + schema4 = TestEchoAnthropicSchema(; response = response4, status = 200) + msg = aiextract( + schema4, "Hello World! Banana"; model = "claudeo", return_type = Fruit, cache = :all) + expected_output = DataMessage(; + content = Fruit("banana"), + status = 200, + tokens = (2, 1), + finish_reason = "tool_use", + cost = msg.cost, + extras = Dict{Symbol, Any}( + :cache_read_input_tokens => 0, :cache_creation_input_tokens => 1), + elapsed = msg.elapsed) + @test msg == expected_output + + # Bad cache + @test_throws AssertionError aiextract( + schema4, "Hello World! Banana"; model = "claudeo", + return_type = Fruit, cache = :bad) end @testset "not implemented ai* functions" begin diff --git a/test/utils.jl b/test/utils.jl index 93ff42899..ec64a87fc 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -236,6 +236,13 @@ end msg = AIMessage(; content = "", tokens = (1000, 5000), elapsed = 5.0) expected_output = "Tokens: 6000 @ Cost: \$0.008 in 5.0 seconds" @test _report_stats(msg, "gpt-3.5-turbo") == expected_output + + # Add extra metadata + msg = AIMessage(; content = "", tokens = (1000, 5000), elapsed = 5.0, + extras = Dict{Symbol, Any}( + :cache_read_input_tokens => 100, :cache_creation_input_tokens => 200)) + expected_output = "Tokens: 6000 @ Cost: \$0.008 in 5.0 seconds (Metadata: cache_read_input_tokens => 100, cache_creation_input_tokens => 200)" + @test _report_stats(msg, "gpt-3.5-turbo") == expected_output end @testset "_string_to_vector" begin