Skip to content

Commit

Permalink
Update tool execution
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Nov 3, 2024
1 parent 8b6f951 commit b2bcb94
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Added a new `extras` field to `ToolRef` to enable additional parameters in the tool signature (eg, `display_width_px`, `display_height_px` for the `:computer` tool).
- Added a new kwarg `unused_as_kwargs` to `execute_tool` to enable passing unused args as kwargs (see `?execute_tool` for more information). Helps with using kwarg-based functions.

### Updated
- Updated the compat bounds for `StreamCallbacks` to enable both v0.4 and v0.5 (Fixes Julia 1.9 compatibility).
Expand Down
44 changes: 28 additions & 16 deletions src/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -721,13 +721,13 @@ end
### Processing utilities

"""
parse_tool(datatype::Type, blob::AbstractString)
parse_tool(datatype::Type, blob::AbstractString; kwargs...)
Parse the JSON blob into the specified datatype in try-catch mode.
If parsing fails, it tries to return the untyped JSON blob in a dictionary.
"""
function parse_tool(datatype::Type, blob::AbstractString)
function parse_tool(datatype::Type, blob::AbstractString; kwargs...)
try
return if blob == "{}"
## If empty, return empty datatype
Expand All @@ -743,16 +743,19 @@ function parse_tool(datatype::Type, blob::AbstractString)
end

## Utility for Anthropic - it returns a parsed dict and we need text for deserialization into an object
function parse_tool(datatype::Type, blob::AbstractDict)
isempty(blob) ? datatype() : parse_tool(datatype, JSON3.write(blob))
function parse_tool(datatype::Type, blob::AbstractDict; kwargs...)
isempty(blob) ? datatype() : parse_tool(datatype, JSON3.write(blob); kwargs...)
end
function parse_tool(tool::AbstractTool, input::Union{AbstractString, AbstractDict})
return parse_tool(tool.callable, input)
function parse_tool(
tool::AbstractTool, input::Union{AbstractString, AbstractDict}; kwargs...)
return parse_tool(tool.callable, input; kwargs...)
end

"""
execute_tool(f::Function, args::AbstractDict{Symbol, <:Any},
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}())
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}();
throw_on_error::Bool = true, unused_as_kwargs::Bool = false,
kwargs...)
Executes a function with the provided arguments.
Expand All @@ -767,6 +770,9 @@ Dictionary is un-ordered, so we need to sort the arguments first and then pass t
- `f::Function`: The function to execute.
- `args::AbstractDict{Symbol, <:Any}`: The arguments to pass to the function.
- `context::AbstractDict{Symbol, <:Any}`: Optional context to pass to the function, it will prioritized to get the argument values from.
- `throw_on_error::Bool`: Whether to throw an error if the tool execution fails. Defaults to `true`.
- `unused_as_kwargs::Bool`: Whether to pass unused arguments as keyword arguments. Defaults to `false`. Function must support keyword arguments!
- `kwargs...`: Additional keyword arguments to pass to the function.
# Example
```julia
Expand All @@ -787,9 +793,11 @@ PT.execute_tool(tool_map, PT.tool_calls(msg)[1])
"""
function execute_tool(f::Function, args::AbstractDict{Symbol, <:Any},
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}();
throw_on_error::Bool = true)
throw_on_error::Bool = true, unused_as_kwargs::Bool = false,
kwargs...)
args_sorted = []
for arg in get_arg_names(f)
arg_names = get_arg_names(f)
for arg in arg_names
if arg == :context
push!(args_sorted, context)
elseif haskey(context, arg)
Expand All @@ -798,30 +806,34 @@ function execute_tool(f::Function, args::AbstractDict{Symbol, <:Any},
push!(args_sorted, args[arg])
end
end
if unused_as_kwargs
unused_args = setdiff(keys(args), arg_names)
kwargs = merge(NamedTuple(kwargs), (; [arg => args[arg] for arg in unused_args]...))
end

result = try
f(args_sorted...)
f(args_sorted...; kwargs...)
catch e
ToolExecutionError("Tool execution of `$(f)` failed", e)
end
throw_on_error && result isa AbstractToolError && throw(result)
return result
end
function execute_tool(tool::AbstractTool, args::AbstractDict{Symbol, <:Any},
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}())
return execute_tool(tool.callable, args, context)
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}(); kwargs...)
return execute_tool(tool.callable, args, context; kwargs...)
end
function execute_tool(tool::AbstractTool, msg::ToolMessage,
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}())
return execute_tool(tool.callable, msg.args, context)
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}(); kwargs...)
return execute_tool(tool.callable, msg.args, context; kwargs...)
end
function execute_tool(tool_map::AbstractDict{String, <:AbstractTool}, msg::ToolMessage,
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}())
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}(); kwargs...)
if !haskey(tool_map, msg.name)
throw(ToolNotFoundError("Tool `$(msg.name)` not found"))
end
tool = tool_map[msg.name]
return execute_tool(tool, msg, context)
return execute_tool(tool, msg, context; kwargs...)
end

"""
Expand Down
28 changes: 28 additions & 0 deletions test/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ function context_test_function2(x::Int, y::String, context::Dict)
return "Context test: $x, $y, $(context)"
end

# Test function that accepts kwargs
function kwarg_test_function(x::Int; y::Int = 0, z::Int = 0, kwargs...)
return x + y + z
end
# Test with function that has no kwargs
function no_kwarg_function(x::Int)
return x
end

@testset "ToolErrors" begin
e = ToolNotFoundError("Tool `xyz` not found")
@test e isa AbstractToolError
Expand Down Expand Up @@ -860,4 +869,23 @@ end
@test_throws ToolNotFoundError execute_tool(tool_map,
ToolMessage(;
tool_call_id = "1", name = "wrong_tool_name", raw = "", args = args))

# Test passing kwargs directly
args = Dict(:x => 1)
@test execute_tool(kwarg_test_function, args; y = 2, z = 3) == 6 # 1 + 2 + 3

# Test unused args passed as kwargs when unused_as_kwargs=true
args = Dict(:x => 1, :y => 2, :z => 3, :extra => 4)
@test execute_tool(kwarg_test_function, args; unused_as_kwargs = true) == 6 # 1 + 2 + 3

# Test that extra args are ignored when unused_as_kwargs=false
args = Dict(:x => 1, :y => 2, :z => 3, :extra => 4)
@test execute_tool(kwarg_test_function, args; unused_as_kwargs = false) == 1

# Test that args override kwargs when unused_as_kwargs=true
args = Dict(:x => 1, :y => 2, :z => 3)
@test execute_tool(kwarg_test_function, args; unused_as_kwargs = true, y = 5) == 6 # args

args = Dict(:x => 1, :extra => 2)
@test execute_tool(no_kwarg_function, args; unused_as_kwargs = false) == 1
end

0 comments on commit b2bcb94

Please sign in to comment.