From 31b1c27964b054bfd1f3dd85c608226d35a782d9 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Wed, 6 Nov 2024 08:00:25 -0500 Subject: [PATCH] Fix tool dict (#230) --- CHANGELOG.md | 5 +++++ Project.toml | 2 +- src/extraction.jl | 46 +++++++++++++++++++++++++++++++++++----------- test/extraction.jl | 46 +++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 86 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 94519551b..1118085b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.62.1] + +### Fixed +- Fixed a bug in `tool_call_signature` where hidden fields were not hidden early enough and would fail if a Dict argument was provided. It used to do the processing after, but Dicts cannot be processed, so we're now masking the fields upfront. + ## [0.62.0] ### Added diff --git a/Project.toml b/Project.toml index 096f7d153..c2d232ab7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.62.0" +version = "0.62.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/extraction.jl b/src/extraction.jl index a3d144de1..780aaf121 100644 --- a/src/extraction.jl +++ b/src/extraction.jl @@ -181,7 +181,17 @@ function extract_docstring(m::Method; max_description_length::Int = 100) return extract_docstring(get_function(m); max_description_length) end -function to_json_schema(orig_type; max_description_length::Int = 100) +@inline function is_hidden_field(field_name::AbstractString, + hidden_fields::AbstractVector{<:Union{AbstractString, Regex}}) + any(x -> occursin(x, field_name), hidden_fields) +end +@inline function is_hidden_field(field_name::Symbol, + hidden_fields::AbstractVector{<:Union{AbstractString, Regex}}) + is_hidden_field(string(field_name), hidden_fields) +end + +function to_json_schema(orig_type; max_description_length::Int = 100, + hidden_fields::AbstractVector{<:Union{AbstractString, Regex}} = String[]) schema = Dict{String, Any}() type = remove_null_types(orig_type) if isstructtype(type) @@ -190,9 +200,12 @@ function to_json_schema(orig_type; max_description_length::Int = 100) ## extract the field names and types required_types = String[] for (field_name, field_type) in zip(fieldnames(type), fieldtypes(type)) + if is_hidden_field(field_name, hidden_fields) + continue + end schema["properties"][string(field_name)] = to_json_schema( remove_null_types(field_type); - max_description_length) + max_description_length, hidden_fields) ## Hack: no null type (Nothing, Missing) implies it it is a required field is_required_field(field_type) && push!(required_types, string(field_name)) end @@ -205,23 +218,29 @@ function to_json_schema(orig_type; max_description_length::Int = 100) end return schema end -function to_json_schema(type::Type{<:AbstractString}; max_description_length::Int = 100) +function to_json_schema(type::Type{<:AbstractString}; max_description_length::Int = 100, + hidden_fields::AbstractVector{<:Union{AbstractString, Regex}} = String[]) Dict{String, Any}("type" => to_json_type(type)) end function to_json_schema(type::Type{T}; - max_description_length::Int = 100) where {T <: - Union{AbstractSet, Tuple, AbstractArray}} + max_description_length::Int = 100, + hidden_fields::AbstractVector{<:Union{AbstractString, Regex}} = String[]) where {T <: + Union{ + AbstractSet, Tuple, AbstractArray}} element_type = eltype(type) return Dict{String, Any}("type" => "array", - "items" => to_json_schema(remove_null_types(element_type))) + "items" => to_json_schema(remove_null_types(element_type); + max_description_length, hidden_fields)) end -function to_json_schema(type::Type{<:Enum}; max_description_length::Int = 100) +function to_json_schema(type::Type{<:Enum}; max_description_length::Int = 100, + hidden_fields::AbstractVector{<:Union{AbstractString, Regex}} = String[]) enum_options = Base.Enums.namemap(type) |> values .|> string return Dict{String, Any}("type" => "string", "enum" => enum_options) end ## Dispatch for method of a function -- grabs only arguments!! Not kwargs!! -function to_json_schema(m::Method; max_description_length::Int = 100) +function to_json_schema(m::Method; max_description_length::Int = 100, + hidden_fields::AbstractVector{<:Union{AbstractString, Regex}} = String[]) ## Warning: We cannot extract keyword arguments from the method signature kwargs = Base.kwarg_decl(m) !isempty(kwargs) && @@ -233,9 +252,12 @@ function to_json_schema(m::Method; max_description_length::Int = 100) ## extract the field names and types required_types = String[] for (field_name, field_type) in zip(get_arg_names(m), get_arg_types(m)) + if is_hidden_field(field_name, hidden_fields) + continue + end schema["properties"][string(field_name)] = to_json_schema( remove_null_types(field_type); - max_description_length) + max_description_length, hidden_fields) ## Hack: no null type (Nothing, Missing) implies it it is a required field is_required_field(field_type) && push!(required_types, string(field_name)) end @@ -245,7 +267,8 @@ function to_json_schema(m::Method; max_description_length::Int = 100) !isempty(docs) && (schema["description"] = docs) return schema end -function to_json_schema(type::Type{<:AbstractDict}; max_description_length::Int = 100) +function to_json_schema(type::Type{<:AbstractDict}; max_description_length::Int = 100, + hidden_fields::AbstractVector{<:Union{AbstractString, Regex}} = String[]) throw(ArgumentError("Dicts are not supported yet as we cannot analyze their keys/values on a type-level. Use a nested Struct instead!")) end @@ -535,7 +558,8 @@ function tool_call_signature( name end schema = Dict{String, Any}("name" => name, - "parameters" => to_json_schema(type_or_method; max_description_length)) + "parameters" => to_json_schema(type_or_method; max_description_length, + hidden_fields)) ## docstrings docs = isnothing(docs) ? extract_docstring(type_or_method; max_description_length) : docs diff --git a/test/extraction.jl b/test/extraction.jl index cffd8d441..3f8617558 100644 --- a/test/extraction.jl +++ b/test/extraction.jl @@ -6,7 +6,7 @@ using PromptingTools: Tool, isabstracttool, execute_tool, parse_tool, get_arg_na get_arg_types, get_method, get_function, remove_field!, tool_call_signature, ToolRef using PromptingTools: AbstractToolError, ToolNotFoundError, ToolExecutionError, - ToolGenericError + ToolGenericError, is_hidden_field # TODO: check more edge cases like empty structs @@ -358,6 +358,36 @@ end @test schema_measurement["description"] == "Represents person's age, height, and weight\n" end + +@testset "is_hidden_field" begin + # Test basic string matching + @test is_hidden_field("context", ["context"]) == true + @test is_hidden_field("data", ["context"]) == false + + # Test regex matching + @test is_hidden_field("my_context", [r"context$"]) == true + @test is_hidden_field("context_var", [r"^context"]) == true + @test is_hidden_field("mydata", [r"context"]) == false + + # Test multiple patterns + @test is_hidden_field("context", ["data", "context", "temp"]) == true + @test is_hidden_field("context", [r"^data", r"temp$", r"context"]) == true + + # Test mixed string and regex patterns + @test is_hidden_field( + "my_context", Union{AbstractString, Regex}["data", r"context$"]) == true + @test is_hidden_field( + "context_var", Union{AbstractString, Regex}[r"^context", "temp"]) == true + + # Test empty patterns list + @test is_hidden_field("context", String[]) == false + @test is_hidden_field("context", Regex[]) == false + + # Test with Symbol input + @test is_hidden_field(:context, ["context"]) == true + @test is_hidden_field(:my_context, [r"context$"]) == true + @test is_hidden_field(:data, ["context"]) == false +end @testset "set_properties_strict!" begin # Test 1: Basic functionality params = Dict( @@ -769,6 +799,20 @@ end tool = ToolRef(; ref = :computer, callable = println) tool_map = tool_call_signature(tool) @test tool_map == Dict("computer" => tool) + + ## accepting dictionary when it's hidden // it would fail otherwise + tool_map = tool_call_signature(context_test_function2; hidden_fields = ["context"]) + @test tool_map isa Dict + + @test_throws ArgumentError tool_call_signature(context_test_function2) + + # for struct + mutable struct MyStruct1234 + context::Dict{String, Any} + end + @test_throws ArgumentError tool_call_signature(MyStruct1234) + tool_map = tool_call_signature(MyStruct1234; hidden_fields = ["context"]) + @test tool_map isa Dict end @testset "parse_tool" begin