Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tool execution with kwargs #228

Merged
merged 1 commit into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Added a new `extras` field to `ToolRef` to enable additional parameters in the tool signature (eg, `display_width_px`, `display_height_px` for the `:computer` tool).
- Added a new kwarg `unused_as_kwargs` to `execute_tool` to enable passing unused args as kwargs (see `?execute_tool` for more information). Helps with using kwarg-based functions.

### Updated
- Updated the compat bounds for `StreamCallbacks` to enable both v0.4 and v0.5 (Fixes Julia 1.9 compatibility).
Expand Down
44 changes: 28 additions & 16 deletions src/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -721,13 +721,13 @@ end
### Processing utilities

"""
parse_tool(datatype::Type, blob::AbstractString)
parse_tool(datatype::Type, blob::AbstractString; kwargs...)

Parse the JSON blob into the specified datatype in try-catch mode.

If parsing fails, it tries to return the untyped JSON blob in a dictionary.
"""
function parse_tool(datatype::Type, blob::AbstractString)
function parse_tool(datatype::Type, blob::AbstractString; kwargs...)
try
return if blob == "{}"
## If empty, return empty datatype
Expand All @@ -743,16 +743,19 @@ function parse_tool(datatype::Type, blob::AbstractString)
end

## Utility for Anthropic - it returns a parsed dict and we need text for deserialization into an object
function parse_tool(datatype::Type, blob::AbstractDict)
isempty(blob) ? datatype() : parse_tool(datatype, JSON3.write(blob))
function parse_tool(datatype::Type, blob::AbstractDict; kwargs...)
isempty(blob) ? datatype() : parse_tool(datatype, JSON3.write(blob); kwargs...)
end
function parse_tool(tool::AbstractTool, input::Union{AbstractString, AbstractDict})
return parse_tool(tool.callable, input)
function parse_tool(
tool::AbstractTool, input::Union{AbstractString, AbstractDict}; kwargs...)
return parse_tool(tool.callable, input; kwargs...)
end

"""
execute_tool(f::Function, args::AbstractDict{Symbol, <:Any},
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}())
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}();
throw_on_error::Bool = true, unused_as_kwargs::Bool = false,
kwargs...)

Executes a function with the provided arguments.

Expand All @@ -767,6 +770,9 @@ Dictionary is un-ordered, so we need to sort the arguments first and then pass t
- `f::Function`: The function to execute.
- `args::AbstractDict{Symbol, <:Any}`: The arguments to pass to the function.
- `context::AbstractDict{Symbol, <:Any}`: Optional context to pass to the function, it will prioritized to get the argument values from.
- `throw_on_error::Bool`: Whether to throw an error if the tool execution fails. Defaults to `true`.
- `unused_as_kwargs::Bool`: Whether to pass unused arguments as keyword arguments. Defaults to `false`. Function must support keyword arguments!
- `kwargs...`: Additional keyword arguments to pass to the function.

# Example
```julia
Expand All @@ -787,9 +793,11 @@ PT.execute_tool(tool_map, PT.tool_calls(msg)[1])
"""
function execute_tool(f::Function, args::AbstractDict{Symbol, <:Any},
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}();
throw_on_error::Bool = true)
throw_on_error::Bool = true, unused_as_kwargs::Bool = false,
kwargs...)
args_sorted = []
for arg in get_arg_names(f)
arg_names = get_arg_names(f)
for arg in arg_names
if arg == :context
push!(args_sorted, context)
elseif haskey(context, arg)
Expand All @@ -798,30 +806,34 @@ function execute_tool(f::Function, args::AbstractDict{Symbol, <:Any},
push!(args_sorted, args[arg])
end
end
if unused_as_kwargs
unused_args = setdiff(keys(args), arg_names)
kwargs = merge(NamedTuple(kwargs), (; [arg => args[arg] for arg in unused_args]...))
end

result = try
f(args_sorted...)
f(args_sorted...; kwargs...)
catch e
ToolExecutionError("Tool execution of `$(f)` failed", e)
end
throw_on_error && result isa AbstractToolError && throw(result)
return result
end
function execute_tool(tool::AbstractTool, args::AbstractDict{Symbol, <:Any},
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}())
return execute_tool(tool.callable, args, context)
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}(); kwargs...)
return execute_tool(tool.callable, args, context; kwargs...)
end
function execute_tool(tool::AbstractTool, msg::ToolMessage,
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}())
return execute_tool(tool.callable, msg.args, context)
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}(); kwargs...)
return execute_tool(tool.callable, msg.args, context; kwargs...)
end
function execute_tool(tool_map::AbstractDict{String, <:AbstractTool}, msg::ToolMessage,
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}())
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}(); kwargs...)
if !haskey(tool_map, msg.name)
throw(ToolNotFoundError("Tool `$(msg.name)` not found"))
end
tool = tool_map[msg.name]
return execute_tool(tool, msg, context)
return execute_tool(tool, msg, context; kwargs...)
end

"""
Expand Down
28 changes: 28 additions & 0 deletions test/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ function context_test_function2(x::Int, y::String, context::Dict)
return "Context test: $x, $y, $(context)"
end

# Test function that accepts kwargs
function kwarg_test_function(x::Int; y::Int = 0, z::Int = 0, kwargs...)
return x + y + z
end
# Test with function that has no kwargs
function no_kwarg_function(x::Int)
return x
end

@testset "ToolErrors" begin
e = ToolNotFoundError("Tool `xyz` not found")
@test e isa AbstractToolError
Expand Down Expand Up @@ -860,4 +869,23 @@ end
@test_throws ToolNotFoundError execute_tool(tool_map,
ToolMessage(;
tool_call_id = "1", name = "wrong_tool_name", raw = "", args = args))

# Test passing kwargs directly
args = Dict(:x => 1)
@test execute_tool(kwarg_test_function, args; y = 2, z = 3) == 6 # 1 + 2 + 3

# Test unused args passed as kwargs when unused_as_kwargs=true
args = Dict(:x => 1, :y => 2, :z => 3, :extra => 4)
@test execute_tool(kwarg_test_function, args; unused_as_kwargs = true) == 6 # 1 + 2 + 3

# Test that extra args are ignored when unused_as_kwargs=false
args = Dict(:x => 1, :y => 2, :z => 3, :extra => 4)
@test execute_tool(kwarg_test_function, args; unused_as_kwargs = false) == 1

# Test that args override kwargs when unused_as_kwargs=true
args = Dict(:x => 1, :y => 2, :z => 3)
@test execute_tool(kwarg_test_function, args; unused_as_kwargs = true, y = 5) == 6 # args

args = Dict(:x => 1, :extra => 2)
@test execute_tool(no_kwarg_function, args; unused_as_kwargs = false) == 1
end
Loading