Skip to content

Commit

Permalink
Update Callbacks compat + ToolRef definition (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Nov 3, 2024
1 parent 158f409 commit 8b6f951
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 16 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.61.0]

### Added
- Added a new `extras` field to `ToolRef` to enable additional parameters in the tool signature (eg, `display_width_px`, `display_height_px` for the `:computer` tool).

### Updated
- Updated the compat bounds for `StreamCallbacks` to enable both v0.4 and v0.5 (Fixes Julia 1.9 compatibility).
- Updated the return type of `tool_call_signature` to `Dict{String, AbstractTool}` to enable better interoperability with different tool types.

## [0.60.0]

### Added
Expand Down
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.60.0"
version = "0.61.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -55,9 +55,9 @@ REPL = "<0.0.1, 1"
Random = "<0.0.1, 1"
SparseArrays = "<0.0.1, 1"
Statistics = "<0.0.1, 1"
StreamCallbacks = "0.4"
StreamCallbacks = "0.4, 0.5"
Test = "<0.0.1, 1"
julia = "1.9,1.10"
julia = "1.9, 1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand Down
14 changes: 8 additions & 6 deletions src/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,12 @@ It can be rendered with a `render` method and a prompt schema.
# Arguments
- `ref::Symbol`: The symbolic name of the tool.
- `callable::Any`: The callable object of the tool, eg, a type or a function.
- `extras::Dict{String, Any}`: Additional parameters to be included in the tool signature.
# Examples
```julia
# Define a tool with a symbolic name and a callable object
tool = ToolRef(:computer, println)
tool = ToolRef(;ref=:computer, callable=println)
# Show the rendered tool signature
PT.render(PT.AnthropicSchema(), tool)
Expand All @@ -70,6 +71,7 @@ PT.render(PT.AnthropicSchema(), tool)
Base.@kwdef struct ToolRef <: AbstractTool
ref::Symbol
callable::Any = identity
extras::Dict{String, Any} = Dict()
end
Base.show(io::IO, t::ToolRef) = print(io, "ToolRef($(t.ref))")

Expand Down Expand Up @@ -444,7 +446,7 @@ Note: Fairly experimental, but works for combination of structs, arrays, strings
- `hidden_fields::AbstractVector{<:Union{AbstractString, Regex}}`: A list of fields to hide from the LLM (eg, `["ctx_user_id"]` or `r"ctx"`).
# Returns
- `Dict{String, Tool}`: A dictionary representing the function call signature schema.
- `Dict{String, AbstractTool}`: A dictionary representing the function call signature schema.
# Tips
- You can improve the quality of the extraction by writing a helpful docstring for your struct (or any nested struct). It will be provided as a description.
Expand All @@ -464,7 +466,7 @@ struct MyMeasurement
end
tool_map = tool_call_signature(MyMeasurement)
#
# Dict{String, PromptingTools.Tool}("MyMeasurement" => PromptingTools.Tool
# Dict{String, PromptingTools.AbstractTool}("MyMeasurement" => PromptingTools.Tool
# name: String "MyMeasurement"
# parameters: Dict{String, Any}
# description: Nothing nothing
Expand Down Expand Up @@ -563,7 +565,7 @@ function tool_call_signature(
description = haskey(schema, "description") ? schema["description"] : nothing,
strict = haskey(schema, "strict") ? schema["strict"] : nothing,
callable = call_type)
return Dict(schema["name"] => tool)
return Dict{String, AbstractTool}(schema["name"] => tool)
end

## Only thing you can change is the "strict" setting
Expand All @@ -583,7 +585,7 @@ function tool_call_signature(
end
function tool_call_signature(
tool::ToolRef; kwargs...)
Dict(string(tool.ref) => tool)
return Dict{String, AbstractTool}(string(tool.ref) => tool)
end

## Add support for function signatures
Expand All @@ -594,7 +596,7 @@ end
function tool_call_signature(
tools::Vector{<:T}; kwargs...) where {T <:
Union{Type, Function, Method, AbstractTool}}
tool_map = Dict{String, Tool}()
tool_map = Dict{String, AbstractTool}()
for tool in tools
temp_map = tool_call_signature(tool; kwargs...)
for (name, tool) in temp_map
Expand Down
26 changes: 23 additions & 3 deletions src/llm_anthropic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,13 @@ function render(schema::AbstractAnthropicSchema,
tool::ToolRef;
kwargs...)
## WARNING: We ignore the tool name here, because the names are strict
(; extras) = tool
rendered = if tool.ref == :computer
Dict(
"type" => "computer_20241022",
"name" => "computer",
"display_width_px" => 1024,
"display_height_px" => 768,
"display_number" => 1
"display_width_px" => get(extras, "display_width_px", 1024),
"display_height_px" => get(extras, "display_height_px", 768)
)
elseif tool.ref == :str_replace_editor
Dict(
Expand All @@ -151,6 +151,9 @@ function render(schema::AbstractAnthropicSchema,
else
throw(ArgumentError("Unknown tool reference: $(tool.ref)"))
end
if !isempty(extras)
merge!(rendered, extras)
end
return rendered
end

Expand Down Expand Up @@ -836,6 +839,23 @@ conv = aitools(
# UserMessage("And in New York?")
# AIToolRequest("-"; Tool Requests: 1)
```
Using the the new Computer Use beta feature:
```julia
# Define tools (and associated functions to call)
tool_map = Dict("bash" => PT.ToolRef(; ref=:bash, callable=bash_tool),
"computer" => PT.ToolRef(; ref=:computer, callable=computer_tool,
extras=Dict("display_width_px" => 1920, "display_height_px" => 1080)),
"str_replace_editor" => PT.ToolRef(; ref=:str_replace_editor, callable=edit_tool))
msg = aitools(prompt; tools=collect(values(tool_map)), model="claude", betas=[:computer_use])
PT.pprint(msg)
# --------------------
# AI Tool Request
# --------------------
# Tool Request: computer, args: Dict{Symbol, Any}(:action => "screenshot")
```
"""
function aitools(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_TYPE;
tools::Union{Type, Function, Method, AbstractTool, Vector} = Tool[],
Expand Down
4 changes: 2 additions & 2 deletions test/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ end
@test isabstracttool(my_test_function) == false

## ToolRef
tool = ToolRef(:computer, println)
tool = ToolRef(; ref = :computer, callable = println)
@test tool isa ToolRef
@test tool.ref == :computer
@test tool.callable == println
Expand Down Expand Up @@ -757,7 +757,7 @@ end
@test tool2.parameters["properties"]["weight"]["type"] == "number"

## ToolRef
tool = ToolRef(:computer, println)
tool = ToolRef(; ref = :computer, callable = println)
tool_map = tool_call_signature(tool)
@test tool_map == Dict("computer" => tool)
end
Expand Down
10 changes: 10 additions & 0 deletions test/llm_anthropic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,16 @@ end
@test rendered["name"] == "computer"
@test rendered["display_width_px"] == 1024
@test rendered["display_height_px"] == 768
@test !haskey(rendered, "display_number")

computer_tool2 = ToolRef(ref = :computer,
extras = Dict("display_width_px" => 1920,
"display_height_px" => 1080, "display_number" => 1))
rendered = render(schema, computer_tool2)
@test rendered["type"] == "computer_20241022"
@test rendered["name"] == "computer"
@test rendered["display_width_px"] == 1920
@test rendered["display_height_px"] == 1080
@test rendered["display_number"] == 1

# Test text editor tool rendering
Expand Down
4 changes: 2 additions & 2 deletions test/llm_shared.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using PromptingTools: render, NoSchema, AbstractPromptSchema
using PromptingTools: render, NoSchema, AbstractPromptSchema, OpenAISchema
using PromptingTools: AIMessage, SystemMessage, AbstractMessage, AbstractChatMessage
using PromptingTools: UserMessage, UserMessageWithImages, DataMessage, AIToolRequest,
ToolMessage, ToolRef
Expand Down Expand Up @@ -222,7 +222,7 @@ using PromptingTools: finalize_outputs, role4render

## ToolRef
schema = NoSchema()
tool = ToolRef(:computer, println)
tool = ToolRef(; ref = :computer)
@test_throws ArgumentError render(schema, tool)
end

Expand Down

0 comments on commit 8b6f951

Please sign in to comment.