diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d75e209d..c51c070da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added a new `extras` field to `ToolRef` to enable additional parameters in the tool signature (eg, `display_width_px`, `display_height_px` for the `:computer` tool). +- Added a new kwarg `unused_as_kwargs` to `execute_tool` to enable passing unused args as kwargs (see `?execute_tool` for more information). Helps with using kwarg-based functions. ### Updated - Updated the compat bounds for `StreamCallbacks` to enable both v0.4 and v0.5 (Fixes Julia 1.9 compatibility). diff --git a/src/extraction.jl b/src/extraction.jl index 8f155bc83..a3d144de1 100644 --- a/src/extraction.jl +++ b/src/extraction.jl @@ -721,13 +721,13 @@ end ### Processing utilities """ - parse_tool(datatype::Type, blob::AbstractString) + parse_tool(datatype::Type, blob::AbstractString; kwargs...) Parse the JSON blob into the specified datatype in try-catch mode. If parsing fails, it tries to return the untyped JSON blob in a dictionary. """ -function parse_tool(datatype::Type, blob::AbstractString) +function parse_tool(datatype::Type, blob::AbstractString; kwargs...) try return if blob == "{}" ## If empty, return empty datatype @@ -743,16 +743,19 @@ function parse_tool(datatype::Type, blob::AbstractString) end ## Utility for Anthropic - it returns a parsed dict and we need text for deserialization into an object -function parse_tool(datatype::Type, blob::AbstractDict) - isempty(blob) ? datatype() : parse_tool(datatype, JSON3.write(blob)) +function parse_tool(datatype::Type, blob::AbstractDict; kwargs...) + isempty(blob) ? datatype() : parse_tool(datatype, JSON3.write(blob); kwargs...) end -function parse_tool(tool::AbstractTool, input::Union{AbstractString, AbstractDict}) - return parse_tool(tool.callable, input) +function parse_tool( + tool::AbstractTool, input::Union{AbstractString, AbstractDict}; kwargs...) + return parse_tool(tool.callable, input; kwargs...) end """ execute_tool(f::Function, args::AbstractDict{Symbol, <:Any}, - context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}()) + context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}(); + throw_on_error::Bool = true, unused_as_kwargs::Bool = false, + kwargs...) Executes a function with the provided arguments. @@ -767,6 +770,9 @@ Dictionary is un-ordered, so we need to sort the arguments first and then pass t - `f::Function`: The function to execute. - `args::AbstractDict{Symbol, <:Any}`: The arguments to pass to the function. - `context::AbstractDict{Symbol, <:Any}`: Optional context to pass to the function, it will prioritized to get the argument values from. +- `throw_on_error::Bool`: Whether to throw an error if the tool execution fails. Defaults to `true`. +- `unused_as_kwargs::Bool`: Whether to pass unused arguments as keyword arguments. Defaults to `false`. Function must support keyword arguments! +- `kwargs...`: Additional keyword arguments to pass to the function. # Example ```julia @@ -787,9 +793,11 @@ PT.execute_tool(tool_map, PT.tool_calls(msg)[1]) """ function execute_tool(f::Function, args::AbstractDict{Symbol, <:Any}, context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}(); - throw_on_error::Bool = true) + throw_on_error::Bool = true, unused_as_kwargs::Bool = false, + kwargs...) args_sorted = [] - for arg in get_arg_names(f) + arg_names = get_arg_names(f) + for arg in arg_names if arg == :context push!(args_sorted, context) elseif haskey(context, arg) @@ -798,9 +806,13 @@ function execute_tool(f::Function, args::AbstractDict{Symbol, <:Any}, push!(args_sorted, args[arg]) end end + if unused_as_kwargs + unused_args = setdiff(keys(args), arg_names) + kwargs = merge(NamedTuple(kwargs), (; [arg => args[arg] for arg in unused_args]...)) + end result = try - f(args_sorted...) + f(args_sorted...; kwargs...) catch e ToolExecutionError("Tool execution of `$(f)` failed", e) end @@ -808,20 +820,20 @@ function execute_tool(f::Function, args::AbstractDict{Symbol, <:Any}, return result end function execute_tool(tool::AbstractTool, args::AbstractDict{Symbol, <:Any}, - context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}()) - return execute_tool(tool.callable, args, context) + context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}(); kwargs...) + return execute_tool(tool.callable, args, context; kwargs...) end function execute_tool(tool::AbstractTool, msg::ToolMessage, - context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}()) - return execute_tool(tool.callable, msg.args, context) + context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}(); kwargs...) + return execute_tool(tool.callable, msg.args, context; kwargs...) end function execute_tool(tool_map::AbstractDict{String, <:AbstractTool}, msg::ToolMessage, - context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}()) + context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}(); kwargs...) if !haskey(tool_map, msg.name) throw(ToolNotFoundError("Tool `$(msg.name)` not found")) end tool = tool_map[msg.name] - return execute_tool(tool, msg, context) + return execute_tool(tool, msg, context; kwargs...) end """ diff --git a/test/extraction.jl b/test/extraction.jl index 2c6590478..cffd8d441 100644 --- a/test/extraction.jl +++ b/test/extraction.jl @@ -21,6 +21,15 @@ function context_test_function2(x::Int, y::String, context::Dict) return "Context test: $x, $y, $(context)" end +# Test function that accepts kwargs +function kwarg_test_function(x::Int; y::Int = 0, z::Int = 0, kwargs...) + return x + y + z +end +# Test with function that has no kwargs +function no_kwarg_function(x::Int) + return x +end + @testset "ToolErrors" begin e = ToolNotFoundError("Tool `xyz` not found") @test e isa AbstractToolError @@ -860,4 +869,23 @@ end @test_throws ToolNotFoundError execute_tool(tool_map, ToolMessage(; tool_call_id = "1", name = "wrong_tool_name", raw = "", args = args)) + + # Test passing kwargs directly + args = Dict(:x => 1) + @test execute_tool(kwarg_test_function, args; y = 2, z = 3) == 6 # 1 + 2 + 3 + + # Test unused args passed as kwargs when unused_as_kwargs=true + args = Dict(:x => 1, :y => 2, :z => 3, :extra => 4) + @test execute_tool(kwarg_test_function, args; unused_as_kwargs = true) == 6 # 1 + 2 + 3 + + # Test that extra args are ignored when unused_as_kwargs=false + args = Dict(:x => 1, :y => 2, :z => 3, :extra => 4) + @test execute_tool(kwarg_test_function, args; unused_as_kwargs = false) == 1 + + # Test that args override kwargs when unused_as_kwargs=true + args = Dict(:x => 1, :y => 2, :z => 3) + @test execute_tool(kwarg_test_function, args; unused_as_kwargs = true, y = 5) == 6 # args + + args = Dict(:x => 1, :extra => 2) + @test execute_tool(no_kwarg_function, args; unused_as_kwargs = false) == 1 end