Skip to content

Commit

Permalink
Add experimental prompt cache support for Anthropic
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Aug 16, 2024
1 parent a2cde30 commit 8e880f8
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 39 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.50.0]

### Breaking Changes
- `AIMessage` and `DataMessage` now have a new field `extras` to hold any API-specific metadata in a simple dictionary. Change is backward-compatible (defaults to `nothing`).

### Added
- Added EXPERIMENTAL support for Anthropic's new prompt cache (see ?`aigenerate` and look for `cache` kwarg). Note that COST estimate will be wrong (ignores the caching discount for now).
- Added a new `extras` field to `AIMessage` and `DataMessage` to hold any API-specific metadata in a simple dictionary (eg, used for reporting on the cache hit/miss).

## [0.49.0]

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.49.0"
version = "0.50.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
116 changes: 95 additions & 21 deletions src/llm_anthropic.jl

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion src/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Returned by `aigenerate`, `aiclassify`, and `aiscan` functions.
- `elapsed::Float64`: The time taken to generate the response in seconds.
- `cost::Union{Nothing, Float64}`: The cost of the API call (calculated with information from `MODEL_REGISTRY`).
- `log_prob::Union{Nothing, Float64}`: The log probability of the response.
- `extras::Union{Nothing, Dict{Symbol, Any}}`: A dictionary for additional metadata that is not part of the key message fields. Try to limit to a small number of items and singletons to be serializable.
- `finish_reason::Union{Nothing, String}`: The reason the response was finished.
- `run_id::Union{Nothing, Int}`: The unique ID of the run.
- `sample_id::Union{Nothing, Int}`: The unique ID of the sample (if multiple samples are generated, they will all have the same `run_id`).
Expand All @@ -89,6 +90,7 @@ Base.@kwdef struct AIMessage{T <: Union{AbstractString, Nothing}} <: AbstractCha
elapsed::Float64 = -1.0
cost::Union{Nothing, Float64} = nothing
log_prob::Union{Nothing, Float64} = nothing
extras::Union{Nothing, Dict{Symbol, Any}} = nothing
finish_reason::Union{Nothing, String} = nothing
run_id::Union{Nothing, Int} = Int(rand(Int16))
sample_id::Union{Nothing, Int} = nothing
Expand All @@ -108,6 +110,7 @@ Returned by `aiextract`, and `aiextract` functions.
- `elapsed::Float64`: The time taken to generate the response in seconds.
- `cost::Union{Nothing, Float64}`: The cost of the API call (calculated with information from `MODEL_REGISTRY`).
- `log_prob::Union{Nothing, Float64}`: The log probability of the response.
- `extras::Union{Nothing, Dict{Symbol, Any}}`: A dictionary for additional metadata that is not part of the key message fields. Try to limit to a small number of items and singletons to be serializable.
- `finish_reason::Union{Nothing, String}`: The reason the response was finished.
- `run_id::Union{Nothing, Int}`: The unique ID of the run.
- `sample_id::Union{Nothing, Int}`: The unique ID of the sample (if multiple samples are generated, they will all have the same `run_id`).
Expand All @@ -119,6 +122,7 @@ Base.@kwdef struct DataMessage{T <: Any} <: AbstractDataMessage
elapsed::Float64 = -1.0
cost::Union{Nothing, Float64} = nothing
log_prob::Union{Nothing, Float64} = nothing
extras::Union{Nothing, Dict{Symbol, Any}} = nothing
finish_reason::Union{Nothing, String} = nothing
run_id::Union{Nothing, Int} = Int(rand(Int16))
sample_id::Union{Nothing, Int} = nothing
Expand All @@ -137,7 +141,7 @@ function (MSG::Type{<:AbstractChatMessage})(msg::AbstractTracerMessage{<:Abstrac
MSG(; msg.content)
end

## It checks types so it should be defined for all inputs
## It checks types so it should be defined for all inputs
isusermessage(m::Any) = m isa UserMessage
issystemmessage(m::Any) = m isa SystemMessage
isdatamessage(m::Any) = m isa DataMessage
Expand Down
7 changes: 6 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,12 @@ function _report_stats(msg,
model::String)
cost = call_cost(msg, model)
cost_str = iszero(cost) ? "" : " @ Cost: \$$(round(cost; digits=4))"
return "Tokens: $(sum(msg.tokens))$(cost_str) in $(round(msg.elapsed;digits=1)) seconds"
metadata_str = if !isnothing(msg.extras) && !isempty(msg.extras)
" (Metadata: $(join([string(k, " => ", v) for (k, v) in msg.extras if v isa Number && !iszero(v)], ", ")))"
else
""
end
return "Tokens: $(sum(msg.tokens))$(cost_str) in $(round(msg.elapsed;digits=1)) seconds$(metadata_str)"
end
## dispatch for array -> take last message
function _report_stats(msg::AbstractVector,
Expand Down
17 changes: 17 additions & 0 deletions test.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[
{
"content": "Sure, here's an example of how you can define a similarity retrieval function for Euclidean distance in Julia:\n\n```julia\nusing PromptingTools.Experimental.RAGTools\n\nstruct EuclideanSimilarity <: AbstractSimilarityFinder end\n\nfunction find_closest(finder::EuclideanSimilarity, embeddings::AbstractMatrix{<:Real}, query_embedding::AbstractVector{<:Real}; top_k::Integer = 5, minimum_similarity::Real = 0.0)\n dists = mapslices(v -> norm(v .- query_embedding), embeddings, dims=1)\n positions = sortperm(dists)[1:min(top_k, size(embeddings, 2))]\n scores = -dists[positions]\n mask = scores .>= minimum_similarity\n return CandidateChunks(positions[mask], scores[mask])\nend\n\nfunction find_closest(finder::EuclideanSimilarity, index::ChunkIndex, query_embedding::AbstractVector{<:Real}; top_k::Integer = 5, minimum_similarity::Real = 0.0)\n return find_closest(finder, index.embeddings, query_embedding; top_k=top_k, minimum_similarity=minimum_similarity)\nend\n\nfunction find_closest(finder::EuclideanSimilarity, index::MultiIndex, query_embedding::AbstractVector{<:Real}; top_k::Integer = 5, minimum_similarity::Real = 0.0)\n results = [find_closest(finder, idx, query_embedding; top_k=top_k, minimum_similarity=minimum_similarity) for idx in index.indexes]\n return MultiCandidateChunks(\n [r.index_id for r in results],\n [r.positions for r in results],\n [r.scores for r in results]\n )\nend\n```\n\nHere's a breakdown of the code:\n\n1. `EuclideanSimilarity <: AbstractSimilarityFinder`: This defines a new type `EuclideanSimilarity` that is a subtype of `AbstractSimilarityFinder`. This type will be used to represent the Euclidean distance similarity finder.\n\n2. `find_closest(finder::EuclideanSimilarity, embeddings::AbstractMatrix{<:Real}, query_embedding::AbstractVector{<:Real}; top_k::Integer = 5, minimum_similarity::Real = 0.0)`: This function implements the `find_closest` method for the `EuclideanSimilarity` type. It takes an embedding matrix, a query embedding vector, and optional parameters `top_k` and `minimum_similarity`. It calculates the Euclidean distances between the query embedding and each embedding in the matrix, sorts the positions by the distances, and returns a `CandidateChunks` object containing the top `top_k` positions and their corresponding scores (negative distances).\n\n3. `find_closest(finder::EuclideanSimilarity, index::ChunkIndex, query_embedding::AbstractVector{<:Real}; top_k::Integer = 5, minimum_similarity::Real = 0.0)`: This method implements the `find_closest` function for a `ChunkIndex` object, which simply delegates the call to the previous `find_closest` method using the `index.embeddings` matrix.\n\n4. `find_closest(finder::EuclideanSimilarity, index::MultiIndex, query_embedding::AbstractVector{<:Real}; top_k::Integer = 5, minimum_similarity::Real = 0.0)`: This method implements the `find_closest` function for a `MultiIndex` object. It calls the `find_closest` method for each sub-index in the `MultiIndex` and collects the results into a `MultiCandidateChunks` object.\n\nWith this implementation, you can now use the `EuclideanSimilarity` type and the `find_closest` methods in your retrieval pipeline, just like the other similarity finders provided by the `PromptingTools.Experimental.RAGTools` module.",
"status": 200,
"tokens": [
4,
969
],
"elapsed": 10.802840083,
"cost": 0.00121225,
"log_prob": null,
"finish_reason": "end_turn",
"run_id": 4668,
"sample_id": null,
"_type": "aimessage"
}
]
160 changes: 145 additions & 15 deletions test/llm_anthropic.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using PromptingTools: TestEchoAnthropicSchema, render, AnthropicSchema
using PromptingTools: AIMessage, SystemMessage, AbstractMessage
using PromptingTools: UserMessage, UserMessageWithImages, DataMessage
using PromptingTools: call_cost, anthropic_api, function_call_signature
using PromptingTools: call_cost, anthropic_api, function_call_signature,
anthropic_extra_headers

@testset "render-Anthropic" begin
schema = AnthropicSchema()
Expand All @@ -11,7 +12,8 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature
UserMessage("Hello, my name is {{name}}")
]
expected_output = (; system = "Act as a helpful AI assistant",
conversation = [Dict("role" => "user", "content" => "Hello, my name is John")])
conversation = [Dict("role" => "user",
"content" => [Dict("type" => "text", "text" => "Hello, my name is John")])])
conversation = render(schema, messages; name = "John")
@test conversation == expected_output
# Test with dry_run=true on ai* functions
Expand All @@ -26,7 +28,8 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature
]
expected_output = (; system = "Act as a helpful AI assistant",
conversation = [Dict(
"role" => "assistant", "content" => "Hello, my name is {{name}}")])
"role" => "assistant",
"content" => [Dict("type" => "text", "text" => "Hello, my name is {{name}}")])])
conversation = render(schema, messages; name = "John")
# AIMessage does not replace handlebar variables
@test conversation == expected_output
Expand All @@ -37,7 +40,8 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature
]
conversation = render(schema, messages)
expected_output = (; system = "Act as a helpful AI assistant",
conversation = [Dict("role" => "user", "content" => "User message")])
conversation = [Dict("role" => "user",
"content" => [Dict("type" => "text", "text" => "User message")])])
@test conversation == expected_output

# Given a schema and a vector of messages, it should return a conversation dictionary with the correct roles and contents for each message.
Expand All @@ -49,10 +53,15 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature
]
expected_output = (; system = "Act as a helpful AI assistant",
conversation = [
Dict("role" => "user", "content" => "Hello"),
Dict("role" => "assistant", "content" => "Hi there"),
Dict("role" => "user", "content" => "How are you?"),
Dict("role" => "assistant", "content" => "I'm doing well, thank you!")
Dict(
"role" => "user", "content" => [Dict("type" => "text", "text" => "Hello")]),
Dict("role" => "assistant",
"content" => [Dict("type" => "text", "text" => "Hi there")]),
Dict("role" => "user",
"content" => [Dict("type" => "text", "text" => "How are you?")]),
Dict("role" => "assistant",
"content" => [Dict(
"type" => "text", "text" => "I'm doing well, thank you!")])
])
conversation = render(schema, messages)
@test conversation == expected_output
Expand All @@ -65,8 +74,10 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature
]
expected_output = (; system = "This is a system message",
conversation = [
Dict("role" => "user", "content" => "Hello"),
Dict("role" => "assistant", "content" => "Hi there")
Dict(
"role" => "user", "content" => [Dict("type" => "text", "text" => "Hello")]),
Dict("role" => "assistant",
"content" => [Dict("type" => "text", "text" => "Hi there")])
])
conversation = render(schema, messages)
@test conversation == expected_output
Expand All @@ -83,8 +94,10 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature
]
expected_output = (; system = "Act as a helpful AI assistant",
conversation = [
Dict("role" => "user", "content" => "Hello"),
Dict("role" => "assistant", "content" => "Hi there")
Dict(
"role" => "user", "content" => [Dict("type" => "text", "text" => "Hello")]),
Dict("role" => "assistant",
"content" => [Dict("type" => "text", "text" => "Hi there")])
])
conversation = render(schema, messages)
@test conversation == expected_output
Expand Down Expand Up @@ -117,6 +130,56 @@ using PromptingTools: call_cost, anthropic_api, function_call_signature
"input_schema" => "")]
@test_logs (:warn, r"Multiple tools provided") match_mode=:any render(
schema, messages; tools)

## Cache variables
messages = [
SystemMessage("Act as a helpful AI assistant"),
UserMessage("Hello, my name is {{name}}")
]
conversation = render(schema, messages; name = "John", cache = :system)
expected_output = (;
system = Dict{String, Any}[Dict("cache_control" => Dict("type" => "ephemeral"),
"text" => "Act as a helpful AI assistant", "type" => "text")],
conversation = [Dict("role" => "user",
"content" => [Dict("type" => "text", "text" => "Hello, my name is John")])])
@test conversation == expected_output

conversation = render(schema, messages; name = "John", cache = :last)
expected_output = (;
system = "Act as a helpful AI assistant",
conversation = [Dict("role" => "user",
"content" => [Dict("type" => "text", "text" => "Hello, my name is John",
"cache_control" => Dict("type" => "ephemeral"))])])
@test conversation == expected_output

conversation = render(schema, messages; name = "John", cache = :all)
expected_output = (;
system = Dict{String, Any}[Dict("cache_control" => Dict("type" => "ephemeral"),
"text" => "Act as a helpful AI assistant", "type" => "text")],
conversation = [Dict("role" => "user",
"content" => [Dict("type" => "text", "text" => "Hello, my name is John",
"cache_control" => Dict("type" => "ephemeral"))])])
@test conversation == expected_output
end

@testset "anthropic_extra_headers" begin
@test anthropic_extra_headers() == ["anthropic-version" => "2023-06-01"]

@test anthropic_extra_headers(has_tools = true) == [
"anthropic-version" => "2023-06-01",
"anthropic-beta" => "tools-2024-04-04"
]

@test anthropic_extra_headers(has_cache = true) == [
"anthropic-version" => "2023-06-01",
"anthropic-beta" => "prompt-caching-2024-07-31"
]

@test anthropic_extra_headers(has_tools = true, has_cache = true) == [
"anthropic-version" => "2023-06-01",
"anthropic-beta" => "tools-2024-04-04",
"anthropic-beta" => "prompt-caching-2024-07-31"
]
end

@testset "anthropic_api" begin
Expand Down Expand Up @@ -153,10 +216,12 @@ end
tokens = (2, 1),
finish_reason = "stop",
cost = msg.cost,
extras = Dict{Symbol, Any}(),
elapsed = msg.elapsed)
@test msg == expected_output
@test schema1.inputs.system == "Act as a helpful AI assistant"
@test schema1.inputs.messages == [Dict("role" => "user", "content" => "Hello World")]
@test schema1.inputs.messages == [Dict(
"role" => "user", "content" => [Dict("type" => "text", "text" => "Hello World")])]
@test schema1.model_id == "claude-3-opus-20240229"

# Test different input combinations and different prompts
Expand All @@ -170,11 +235,47 @@ end
tokens = (2, 1),
finish_reason = "stop",
cost = msg.cost,
extras = Dict{Symbol, Any}(),
elapsed = msg.elapsed)
@test msg == expected_output
@test schema2.inputs.system == "Act as a helpful AI assistant"
@test schema2.inputs.messages == [Dict("role" => "user", "content" => "Hello World")]
@test schema2.inputs.messages == [Dict(
"role" => "user", "content" => [Dict("type" => "text", "text" => "Hello World")])]
@test schema2.model_id == "claude-3-5-sonnet-20240620"

# With caching
response3 = Dict(
:content => [
Dict(:text => "Hello!")],
:stop_reason => "stop",
:usage => Dict(:input_tokens => 2, :output_tokens => 1,
:cache_creation_input_tokens => 1, :cache_read_input_tokens => 0))

schema3 = TestEchoAnthropicSchema(; response = response3, status = 200)
msg = aigenerate(schema3, UserMessage("Hello {{name}}"),
model = "claudes", http_kwargs = (; verbose = 3), api_kwargs = (; temperature = 0),
cache = :all,
name = "World")
expected_output = AIMessage(;
content = "Hello!" |> strip,
status = 200,
tokens = (2, 1),
finish_reason = "stop",
cost = msg.cost,
extras = Dict{Symbol, Any}(
:cache_read_input_tokens => 0, :cache_creation_input_tokens => 1),
elapsed = msg.elapsed)
@test msg == expected_output
@test schema3.inputs.system == [Dict("cache_control" => Dict("type" => "ephemeral"),
"text" => "Act as a helpful AI assistant", "type" => "text")]
@test schema3.inputs.messages == [Dict("role" => "user",
"content" => Dict{String, Any}[Dict("cache_control" => Dict("type" => "ephemeral"),
"text" => "Hello World", "type" => "text")])]
@test schema3.model_id == "claude-3-5-sonnet-20240620"

## Bad cache
@test_throws AssertionError aigenerate(
schema3, UserMessage("Hello {{name}}"); model = "claudeo", cache = :bad)
end

@testset "aiextract-Anthropic" begin
Expand All @@ -197,12 +298,15 @@ end
tokens = (2, 1),
finish_reason = "tool_use",
cost = msg.cost,
extras = Dict{Symbol, Any}(),
elapsed = msg.elapsed)
@test msg == expected_output
@test schema1.inputs.system ==
"Act as a helpful AI assistant\n\nUse the Fruit_extractor tool in your response."
@test schema1.inputs.messages ==
[Dict("role" => "user", "content" => "Hello World! Banana")]
[Dict("role" => "user",
"content" => Dict{String, Any}[Dict(
"text" => "Hello World! Banana", "type" => "text")])]
@test schema1.model_id == "claude-3-opus-20240229"

# Test badly formatted response
Expand All @@ -225,6 +329,32 @@ end
schema3 = TestEchoAnthropicSchema(; response, status = 200)
msg = aiextract(schema3, "Hello World! Banana"; model = "claudeo", return_type = Fruit)
@test msg.content == "No tools for you!"

# With Cache
response4 = Dict(
:content => [
Dict(:type => "tool_use", :input => Dict("name" => "banana"))],
:stop_reason => "tool_use",
:usage => Dict(:input_tokens => 2, :output_tokens => 1,
:cache_creation_input_tokens => 1, :cache_read_input_tokens => 0))
schema4 = TestEchoAnthropicSchema(; response = response4, status = 200)
msg = aiextract(
schema4, "Hello World! Banana"; model = "claudeo", return_type = Fruit, cache = :all)
expected_output = DataMessage(;
content = Fruit("banana"),
status = 200,
tokens = (2, 1),
finish_reason = "tool_use",
cost = msg.cost,
extras = Dict{Symbol, Any}(
:cache_read_input_tokens => 0, :cache_creation_input_tokens => 1),
elapsed = msg.elapsed)
@test msg == expected_output

# Bad cache
@test_throws AssertionError aiextract(
schema4, "Hello World! Banana"; model = "claudeo",
return_type = Fruit, cache = :bad)
end

@testset "not implemented ai* functions" begin
Expand Down
Loading

0 comments on commit 8e880f8

Please sign in to comment.