Skip to content

Commit

Permalink
Fix tool dict (#230)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Nov 6, 2024
1 parent bcc7c81 commit 31b1c27
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 13 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.62.1]

### Fixed
- Fixed a bug in `tool_call_signature` where hidden fields were not hidden early enough and would fail if a Dict argument was provided. It used to do the processing after, but Dicts cannot be processed, so we're now masking the fields upfront.

## [0.62.0]

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.62.0"
version = "0.62.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
46 changes: 35 additions & 11 deletions src/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,17 @@ function extract_docstring(m::Method; max_description_length::Int = 100)
return extract_docstring(get_function(m); max_description_length)
end

function to_json_schema(orig_type; max_description_length::Int = 100)
@inline function is_hidden_field(field_name::AbstractString,
hidden_fields::AbstractVector{<:Union{AbstractString, Regex}})
any(x -> occursin(x, field_name), hidden_fields)
end
@inline function is_hidden_field(field_name::Symbol,
hidden_fields::AbstractVector{<:Union{AbstractString, Regex}})
is_hidden_field(string(field_name), hidden_fields)
end

function to_json_schema(orig_type; max_description_length::Int = 100,
hidden_fields::AbstractVector{<:Union{AbstractString, Regex}} = String[])
schema = Dict{String, Any}()
type = remove_null_types(orig_type)
if isstructtype(type)
Expand All @@ -190,9 +200,12 @@ function to_json_schema(orig_type; max_description_length::Int = 100)
## extract the field names and types
required_types = String[]
for (field_name, field_type) in zip(fieldnames(type), fieldtypes(type))
if is_hidden_field(field_name, hidden_fields)
continue
end
schema["properties"][string(field_name)] = to_json_schema(
remove_null_types(field_type);
max_description_length)
max_description_length, hidden_fields)
## Hack: no null type (Nothing, Missing) implies it it is a required field
is_required_field(field_type) && push!(required_types, string(field_name))
end
Expand All @@ -205,23 +218,29 @@ function to_json_schema(orig_type; max_description_length::Int = 100)
end
return schema
end
function to_json_schema(type::Type{<:AbstractString}; max_description_length::Int = 100)
function to_json_schema(type::Type{<:AbstractString}; max_description_length::Int = 100,
hidden_fields::AbstractVector{<:Union{AbstractString, Regex}} = String[])
Dict{String, Any}("type" => to_json_type(type))
end
function to_json_schema(type::Type{T};
max_description_length::Int = 100) where {T <:
Union{AbstractSet, Tuple, AbstractArray}}
max_description_length::Int = 100,
hidden_fields::AbstractVector{<:Union{AbstractString, Regex}} = String[]) where {T <:
Union{
AbstractSet, Tuple, AbstractArray}}
element_type = eltype(type)
return Dict{String, Any}("type" => "array",
"items" => to_json_schema(remove_null_types(element_type)))
"items" => to_json_schema(remove_null_types(element_type);
max_description_length, hidden_fields))
end
function to_json_schema(type::Type{<:Enum}; max_description_length::Int = 100)
function to_json_schema(type::Type{<:Enum}; max_description_length::Int = 100,
hidden_fields::AbstractVector{<:Union{AbstractString, Regex}} = String[])
enum_options = Base.Enums.namemap(type) |> values .|> string
return Dict{String, Any}("type" => "string",
"enum" => enum_options)
end
## Dispatch for method of a function -- grabs only arguments!! Not kwargs!!
function to_json_schema(m::Method; max_description_length::Int = 100)
function to_json_schema(m::Method; max_description_length::Int = 100,
hidden_fields::AbstractVector{<:Union{AbstractString, Regex}} = String[])
## Warning: We cannot extract keyword arguments from the method signature
kwargs = Base.kwarg_decl(m)
!isempty(kwargs) &&
Expand All @@ -233,9 +252,12 @@ function to_json_schema(m::Method; max_description_length::Int = 100)
## extract the field names and types
required_types = String[]
for (field_name, field_type) in zip(get_arg_names(m), get_arg_types(m))
if is_hidden_field(field_name, hidden_fields)
continue
end
schema["properties"][string(field_name)] = to_json_schema(
remove_null_types(field_type);
max_description_length)
max_description_length, hidden_fields)
## Hack: no null type (Nothing, Missing) implies it it is a required field
is_required_field(field_type) && push!(required_types, string(field_name))
end
Expand All @@ -245,7 +267,8 @@ function to_json_schema(m::Method; max_description_length::Int = 100)
!isempty(docs) && (schema["description"] = docs)
return schema
end
function to_json_schema(type::Type{<:AbstractDict}; max_description_length::Int = 100)
function to_json_schema(type::Type{<:AbstractDict}; max_description_length::Int = 100,
hidden_fields::AbstractVector{<:Union{AbstractString, Regex}} = String[])
throw(ArgumentError("Dicts are not supported yet as we cannot analyze their keys/values on a type-level. Use a nested Struct instead!"))
end

Expand Down Expand Up @@ -535,7 +558,8 @@ function tool_call_signature(
name
end
schema = Dict{String, Any}("name" => name,
"parameters" => to_json_schema(type_or_method; max_description_length))
"parameters" => to_json_schema(type_or_method; max_description_length,
hidden_fields))
## docstrings
docs = isnothing(docs) ? extract_docstring(type_or_method; max_description_length) :
docs
Expand Down
46 changes: 45 additions & 1 deletion test/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using PromptingTools: Tool, isabstracttool, execute_tool, parse_tool, get_arg_na
get_arg_types, get_method, get_function, remove_field!,
tool_call_signature, ToolRef
using PromptingTools: AbstractToolError, ToolNotFoundError, ToolExecutionError,
ToolGenericError
ToolGenericError, is_hidden_field

# TODO: check more edge cases like empty structs

Expand Down Expand Up @@ -358,6 +358,36 @@ end
@test schema_measurement["description"] ==
"Represents person's age, height, and weight\n"
end

@testset "is_hidden_field" begin
# Test basic string matching
@test is_hidden_field("context", ["context"]) == true
@test is_hidden_field("data", ["context"]) == false

# Test regex matching
@test is_hidden_field("my_context", [r"context$"]) == true
@test is_hidden_field("context_var", [r"^context"]) == true
@test is_hidden_field("mydata", [r"context"]) == false

# Test multiple patterns
@test is_hidden_field("context", ["data", "context", "temp"]) == true
@test is_hidden_field("context", [r"^data", r"temp$", r"context"]) == true

# Test mixed string and regex patterns
@test is_hidden_field(
"my_context", Union{AbstractString, Regex}["data", r"context$"]) == true
@test is_hidden_field(
"context_var", Union{AbstractString, Regex}[r"^context", "temp"]) == true

# Test empty patterns list
@test is_hidden_field("context", String[]) == false
@test is_hidden_field("context", Regex[]) == false

# Test with Symbol input
@test is_hidden_field(:context, ["context"]) == true
@test is_hidden_field(:my_context, [r"context$"]) == true
@test is_hidden_field(:data, ["context"]) == false
end
@testset "set_properties_strict!" begin
# Test 1: Basic functionality
params = Dict(
Expand Down Expand Up @@ -769,6 +799,20 @@ end
tool = ToolRef(; ref = :computer, callable = println)
tool_map = tool_call_signature(tool)
@test tool_map == Dict("computer" => tool)

## accepting dictionary when it's hidden // it would fail otherwise
tool_map = tool_call_signature(context_test_function2; hidden_fields = ["context"])
@test tool_map isa Dict

@test_throws ArgumentError tool_call_signature(context_test_function2)

# for struct
mutable struct MyStruct1234
context::Dict{String, Any}
end
@test_throws ArgumentError tool_call_signature(MyStruct1234)
tool_map = tool_call_signature(MyStruct1234; hidden_fields = ["context"])
@test tool_map isa Dict
end

@testset "parse_tool" begin
Expand Down

0 comments on commit 31b1c27

Please sign in to comment.