Skip to content

Commit

Permalink
Anthropic computer use (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Nov 3, 2024
1 parent 4c3516f commit 158f409
Show file tree
Hide file tree
Showing 11 changed files with 283 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for Ollama streaming with schema `OllamaSchema` (see `?StreamCallback` for more information). Schema `OllamaManaged` is NOT supported (it's legacy and will be removed in the future).
- Moved the implementation of streaming callbacks to a new `StreamCallbacks` package.
- Added new error types for tool execution to enable better error handling and reporting (see `?AbstractToolError`).
- Added support for Anthropic's new pre-trained tools via `ToolRef` (see `?ToolRef`), to enable the feature, use the `:computer_use` beta header (eg, `aitools(..., betas = [:computer_use])`).

### Fixed
- Fixed a bug in `call_cost` where the cost was not calculated if any non-AIMessages were provided in the conversation.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.60.0-DEV"
version = "0.60.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
3 changes: 2 additions & 1 deletion src/PromptingTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ const RESERVED_KWARGS = [
:no_system_message,
:aiprefill,
:name_user,
:name_assistant
:name_assistant,
:betas
]

# export replace_words, recursive_splitter, split_by_length, call_cost, auth_header # for debugging only
Expand Down
29 changes: 29 additions & 0 deletions src/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,31 @@ Base.@kwdef struct Tool <: AbstractTool
end
Base.show(io::IO, t::AbstractTool) = dump(io, t; maxdepth = 1)

"""
ToolRef(ref::Symbol, callable::Any)
Represents a reference to a tool with a symbolic name and a callable object (to call during tool execution).
It can be rendered with a `render` method and a prompt schema.
# Arguments
- `ref::Symbol`: The symbolic name of the tool.
- `callable::Any`: The callable object of the tool, eg, a type or a function.
# Examples
```julia
# Define a tool with a symbolic name and a callable object
tool = ToolRef(:computer, println)
# Show the rendered tool signature
PT.render(PT.AnthropicSchema(), tool)
```
"""
Base.@kwdef struct ToolRef <: AbstractTool
ref::Symbol
callable::Any = identity
end
Base.show(io::IO, t::ToolRef) = print(io, "ToolRef($(t.ref))")

### Useful Error Types
"""
AbstractToolError
Expand Down Expand Up @@ -556,6 +581,10 @@ function tool_call_signature(
end
return Dict(tool.name => tool)
end
function tool_call_signature(
tool::ToolRef; kwargs...)
Dict(string(tool.ref) => tool)
end

## Add support for function signatures
function tool_call_signature(f::Function; kwargs...)
Expand Down
112 changes: 99 additions & 13 deletions src/llm_anthropic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,35 +102,107 @@ Renders the tool signatures into the Anthropic format.
function render(schema::AbstractAnthropicSchema,
tools::Vector{<:AbstractTool};
kwargs...)
tools = [Dict(:name => tool.name,
:description => isnothing(tool.description) ? "" : tool.description,
:input_schema => tool.parameters) for tool in tools]
return tools
[render(schema, tool; kwargs...) for tool in tools]
end
function render(schema::AbstractAnthropicSchema,
tool::AbstractTool;
kwargs...)
return Dict(
:name => tool.name,
:description => isnothing(tool.description) ? "" : tool.description,
:input_schema => tool.parameters
)
end

"""
anthropic_extra_headers
render(schema::AbstractAnthropicSchema,
tool::ToolRef;
kwargs...)
Renders the tool reference into the Anthropic format.
Available tools:
- `:computer`: A tool for using the computer.
- `:str_replace_editor`: A tool for replacing text in a string.
- `:bash`: A tool for running bash commands.
"""
function render(schema::AbstractAnthropicSchema,
tool::ToolRef;
kwargs...)
## WARNING: We ignore the tool name here, because the names are strict
rendered = if tool.ref == :computer
Dict(
"type" => "computer_20241022",
"name" => "computer",
"display_width_px" => 1024,
"display_height_px" => 768,
"display_number" => 1
)
elseif tool.ref == :str_replace_editor
Dict(
"type" => "text_editor_20241022",
"name" => "str_replace_editor"
)
elseif tool.ref == :bash
Dict(
"type" => "bash_20241022",
"name" => "bash"
)
else
throw(ArgumentError("Unknown tool reference: $(tool.ref)"))
end
return rendered
end

"""
BETA_HEADERS_ANTHROPIC
A vector of symbols representing the beta features to be used.
Allowed:
- `:tools`: Enables tools in the conversation.
- `:cache`: Enables prompt caching.
- `:long_output`: Enables long outputs (up to 8K tokens) with Anthropic's Sonnet 3.5.
- `:computer_use`: Enables the use of the computer tool.
"""
const BETA_HEADERS_ANTHROPIC = [:tools, :cache, :long_output, :computer_use]

"""
anthropic_extra_headers(;
has_tools = false, has_cache = false, has_long_output = false,
betas::Union{Nothing, Vector{Symbol}} = nothing)
Adds API version and beta headers to the request.
# Kwargs / Beta headers
- `has_tools`: Enables tools in the conversation.
- `has_cache`: Enables prompt caching.
- `has_long_output`: Enables long outputs (up to 8K tokens) with Anthropic's Sonnet 3.5.
- `betas`: A vector of symbols representing the beta features to be used. Currently only `:computer_use`, `:long_output`, `:tools` and `:cache` are supported.
Refer to `BETA_HEADERS_ANTHROPIC` for the allowed beta features.
"""
function anthropic_extra_headers(;
has_tools = false, has_cache = false, has_long_output = false)
has_tools = false, has_cache = false, has_long_output = false,
betas::Union{Nothing, Vector{Symbol}} = nothing)
global BETA_HEADERS_ANTHROPIC
betas_parsed = isnothing(betas) ? Symbol[] : betas
@assert all(b -> b in BETA_HEADERS_ANTHROPIC, betas_parsed) "Unknown beta feature: $(setdiff(betas_parsed, BETA_HEADERS_ANTHROPIC))"
##
extra_headers = ["anthropic-version" => "2023-06-01"]
beta_headers = String[]
if has_tools
if has_tools || :tools in betas_parsed
push!(beta_headers, "tools-2024-04-04")
end
if has_cache
if has_cache || :cache in betas_parsed
push!(beta_headers, "prompt-caching-2024-07-31")
end
if has_long_output
if has_long_output || :long_output in betas_parsed
push!(beta_headers, "max-tokens-3-5-sonnet-2024-07-15")
end
if :computer_use in betas_parsed
push!(beta_headers, "computer-use-2024-10-22")
end
if !isempty(beta_headers)
extra_headers = [extra_headers..., "anthropic-beta" => join(beta_headers, ",")]
end
Expand All @@ -150,6 +222,7 @@ end
stream::Bool = false,
url::String = "https://api.anthropic.com/v1",
cache::Union{Nothing, Symbol} = nothing,
betas::Union{Nothing, Vector{Symbol}} = nothing,
kwargs...)
Simple wrapper for a call to Anthropic API.
Expand All @@ -165,6 +238,7 @@ Simple wrapper for a call to Anthropic API.
- `stream`: A boolean indicating whether to stream the response. Defaults to `false`.
- `url`: The URL of the Ollama API. Defaults to "localhost".
- `cache`: A symbol representing the caching strategy to be used. Currently only `nothing` (no caching), `:system`, `:tools`,`:last` and `:all` are supported.
- `betas`: A vector of symbols representing the beta features to be used. Currently only `:tools` and `:cache` are supported.
- `kwargs`: Prompt variables to be used to fill the prompt/template
"""
function anthropic_api(
Expand All @@ -179,6 +253,7 @@ function anthropic_api(
streamcallback::Any = nothing,
url::String = "https://api.anthropic.com/v1",
cache::Union{Nothing, Symbol} = nothing,
betas::Union{Nothing, Vector{Symbol}} = nothing,
kwargs...)
@assert endpoint in ["messages"] "Only 'messages' endpoint is supported."
##
Expand All @@ -191,7 +266,8 @@ function anthropic_api(
## Build the headers
extra_headers = anthropic_extra_headers(;
has_tools = haskey(kwargs, :tools), has_cache = !isnothing(cache),
has_long_output = (max_tokens > 4096 && model in ["claude-3-5-sonnet-20240620"]))
has_long_output = (max_tokens > 4096 && model in ["claude-3-5-sonnet-20240620"]),
betas = betas)
headers = auth_header(
api_key; bearer = false, x_api_key = true,
extra_headers)
Expand Down Expand Up @@ -234,6 +310,7 @@ end
aiprefill::Union{Nothing, AbstractString} = nothing,
http_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(),
cache::Union{Nothing, Symbol} = nothing,
betas::Union{Nothing, Vector{Symbol}} = nothing,
kwargs...)
Generate an AI response based on a given prompt using the Anthropic API.
Expand All @@ -259,6 +336,7 @@ Generate an AI response based on a given prompt using the Anthropic API.
- `:tools`: Caches the tool definitions (and everything before them)
- `:last`: Caches the last message in the conversation (and everything before it)
- `:all`: Cache trigger points are inserted in all of the above places (ie, higher likelyhood of cache hit, but also slightly higher cost)
- `betas::Union{Nothing, Vector{Symbol}}`: A vector of symbols representing the beta features to be used. See `?anthropic_extra_headers` for details.
- `kwargs`: Prompt variables to be used to fill the prompt/template
Note: At the moment, the cache is only allowed for prompt segments over 1024 tokens (in some cases, over 2048 tokens). You'll get an error if you try to cache short prompts.
Expand Down Expand Up @@ -351,6 +429,7 @@ function aigenerate(
aiprefill::Union{Nothing, AbstractString} = nothing,
http_kwargs::NamedTuple = NamedTuple(), api_kwargs::NamedTuple = NamedTuple(),
cache::Union{Nothing, Symbol} = nothing,
betas::Union{Nothing, Vector{Symbol}} = nothing,
kwargs...)
##
global MODEL_ALIASES
Expand All @@ -364,7 +443,8 @@ function aigenerate(
if !dry_run
time = @elapsed resp = anthropic_api(
prompt_schema, conv_rendered.conversation; api_key,
conv_rendered.system, endpoint = "messages", model = model_id, streamcallback, http_kwargs, cache,
conv_rendered.system, endpoint = "messages", model = model_id,
streamcallback, http_kwargs, cache, betas,
api_kwargs...)
tokens_prompt = get(resp.response[:usage], :input_tokens, 0)
tokens_completion = get(resp.response[:usage], :output_tokens, 0)
Expand Down Expand Up @@ -420,6 +500,7 @@ end
retries = 5,
readtimeout = 120), api_kwargs::NamedTuple = NamedTuple(),
cache::Union{Nothing, Symbol} = nothing,
betas::Union{Nothing, Vector{Symbol}} = nothing,
kwargs...)
Extract required information (defined by a struct **`return_type`**) from the provided prompt by leveraging Anthropic's function calling mode.
Expand Down Expand Up @@ -452,6 +533,7 @@ It's effectively a light wrapper around `aigenerate` call, which requires additi
- `:tools`: Caches the tool definitions (and everything before them)
- `:last`: Caches the last message in the conversation (and everything before it)
- `:all`: Cache trigger points are inserted in all of the above places (ie, higher likelyhood of cache hit, but also slightly higher cost)
- `betas::Union{Nothing, Vector{Symbol}}`: A vector of symbols representing the beta features to be used. See `?anthropic_extra_headers` for details.
- `kwargs`: Prompt variables to be used to fill the prompt/template
Note: At the moment, the cache is only allowed for prompt segments over 1024 tokens (in some cases, over 2048 tokens). You'll get an error if you try to cache short prompts.
Expand Down Expand Up @@ -580,6 +662,7 @@ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMP
retries = 5,
readtimeout = 120), api_kwargs::NamedTuple = (; tool_choice = nothing),
cache::Union{Nothing, Symbol} = nothing,
betas::Union{Nothing, Vector{Symbol}} = nothing,
kwargs...)
##
global MODEL_ALIASES
Expand Down Expand Up @@ -622,7 +705,7 @@ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMP
if !dry_run
time = @elapsed resp = anthropic_api(
prompt_schema, conv_rendered.conversation; api_key,
conv_rendered.system, endpoint = "messages", model = model_id, cache, http_kwargs,
conv_rendered.system, endpoint = "messages", model = model_id, cache, http_kwargs, betas,
api_kwargs...)
tokens_prompt = get(resp.response[:usage], :input_tokens, 0)
tokens_completion = get(resp.response[:usage], :output_tokens, 0)
Expand Down Expand Up @@ -681,6 +764,7 @@ end
conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[],
no_system_message::Bool = false,
cache::Union{Nothing, Symbol} = nothing,
betas::Union{Nothing, Vector{Symbol}} = nothing,
http_kwargs::NamedTuple = (retry_non_idempotent = true,
retries = 5,
readtimeout = 120), api_kwargs::NamedTuple = (;
Expand All @@ -706,6 +790,7 @@ Differences to `aiextract`: Can provide infinitely many tools (including Functio
- `conversation`: An optional vector of `AbstractMessage` objects representing the conversation history.
- `no_system_message::Bool = false`: Whether to exclude the system message from the conversation history.
- `cache::Union{Nothing, Symbol} = nothing`: Whether to cache the prompt. Defaults to `nothing`.
- `betas::Union{Nothing, Vector{Symbol}} = nothing`: A vector of symbols representing the beta features to be used. See `?anthropic_extra_headers` for details.
- `http_kwargs`: A named tuple of HTTP keyword arguments.
- `api_kwargs`: A named tuple of API keyword arguments. Several important arguments are highlighted below:
- `tool_choice`: The choice of tool mode. Can be "auto", "exact", or can depend on the provided.. Defaults to `nothing`, which translates to "auto".
Expand Down Expand Up @@ -761,6 +846,7 @@ function aitools(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_
conversation::AbstractVector{<:AbstractMessage} = AbstractMessage[],
no_system_message::Bool = false,
cache::Union{Nothing, Symbol} = nothing,
betas::Union{Nothing, Vector{Symbol}} = nothing,
http_kwargs::NamedTuple = (retry_non_idempotent = true,
retries = 5,
readtimeout = 120), api_kwargs::NamedTuple = (;
Expand Down Expand Up @@ -800,7 +886,7 @@ function aitools(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_
if !dry_run
time = @elapsed resp = anthropic_api(
prompt_schema, conv_rendered.conversation; api_key,
conv_rendered.system, endpoint = "messages", model = model_id, cache, http_kwargs,
conv_rendered.system, endpoint = "messages", model = model_id, cache, http_kwargs, betas,
api_kwargs...)
tokens_prompt = get(resp.response[:usage], :input_tokens, 0)
tokens_completion = get(resp.response[:usage], :output_tokens, 0)
Expand Down
40 changes: 24 additions & 16 deletions src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,31 @@ function render(schema::AbstractOpenAISchema,
tools::Vector{<:AbstractTool};
json_mode::Union{Nothing, Bool} = nothing,
kwargs...)
output = Dict{Symbol, Any}[]
for tool in tools
rendered = Dict(:type => "function",
:function => Dict(
:parameters => tool.parameters, :name => tool.name))
## Add strict flag
tool.strict == true && (rendered[:function][:strict] = tool.strict)
if json_mode == true
rendered[:function][:schema] = pop!(rendered[:function], :parameters)
else
## Add description if not in JSON mode
!isnothing(tool.description) &&
(rendered[:function][:description] = tool.description)
end
push!(output, rendered)
[render(schema, tool; json_mode, kwargs...) for tool in tools]
end
function render(schema::AbstractOpenAISchema,
tool::AbstractTool;
json_mode::Union{Nothing, Bool} = nothing,
kwargs...)
rendered = Dict(:type => "function",
:function => Dict(
:parameters => tool.parameters, :name => tool.name))
## Add strict flag
tool.strict == true && (rendered[:function][:strict] = tool.strict)
if json_mode == true
rendered[:function][:schema] = pop!(rendered[:function], :parameters)
else
## Add description if not in JSON mode
!isnothing(tool.description) &&
(rendered[:function][:description] = tool.description)
end
return output
return rendered
end
function render(schema::AbstractOpenAISchema,
tool::ToolRef;
json_mode::Union{Nothing, Bool} = nothing,
kwargs...)
throw(ArgumentError("Function `render` is not implemented for the provided schema ($(typeof(schema))) and $(typeof(tool))."))
end

"""
Expand Down
6 changes: 6 additions & 0 deletions src/llm_shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ function render(schema::AbstractPromptSchema,
kwargs...)
render(schema, collect(values(tools)); kwargs...)
end
# For ToolRef
function render(schema::AbstractPromptSchema,
tool::AbstractTool;
kwargs...)
throw(ArgumentError("Function `render` is not implemented for the provided schema ($(typeof(schema))) and $(typeof(tool))."))
end

"""
finalize_outputs(prompt::ALLOWED_PROMPT_TYPE, conv_rendered::Any,
Expand Down
Loading

0 comments on commit 158f409

Please sign in to comment.