Skip to content

Commit

Permalink
Fix to-json-type (#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Dec 5, 2024
1 parent f73e683 commit b95f6c9
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 10 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.65.1]

### Fixed
- Removed unnecessary printing to `stdout` during precompilation in `precompile.jl`.
- Fixed a "bug-waiting-to-happen" in tool use. `to_json_type` now enforces users to provide concrete types, because abstract types can lead to errors during JSON3 deserialization.
- Flowed through a bug fix in `StreamCallback` where the usage information was being included in the response even when `usage=nothing`. Lower bound of `StreamCallbacks` was bumped to `0.5.1`.

## [0.65.0]

### Breaking
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.65.0"
version = "0.65.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -55,7 +55,7 @@ REPL = "<0.0.1, 1"
Random = "<0.0.1, 1"
SparseArrays = "<0.0.1, 1"
Statistics = "<0.0.1, 1"
StreamCallbacks = "0.4, 0.5"
StreamCallbacks = "0.5.1"
Test = "<0.0.1, 1"
julia = "1.9, 1.10"

Expand Down
21 changes: 17 additions & 4 deletions src/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,24 @@ end
# 1) OpenAI / JSON format
######################

to_json_type(s::Type{<:AbstractString}) = "string"
to_json_type(n::Type{<:Real}) = "number"
to_json_type(n::Type{<:Integer}) = "integer"
"Check if a type is concrete."
function is_concrete_type(s::Type)
isconcretetype(s) ||
throw(ArgumentError("Cannot convert abstract type $s to JSON type. You must provide concrete types!"))
end

function is_not_union_type(s::Type)
!isa(s, Union) ||
throw(ArgumentError("Cannot convert $s to JSON type. The only supported union types are Union{..., Nothing}. Please pick a concrete type (`::String` is generic if you cannot pick)!"))
end

to_json_type(s::Type{<:AbstractString}) = (is_concrete_type(s); "string")
to_json_type(n::Type{<:Real}) = (is_concrete_type(n); "number")
to_json_type(n::Type{<:Integer}) = (is_concrete_type(n); "integer")
to_json_type(b::Type{Bool}) = "boolean"
to_json_type(t::Type{<:Union{Missing, Nothing}}) = "null"
to_json_type(t::Type{<:Any}) = "string" # object?
to_json_type(t::Type{<:Any}) = (is_not_union_type(t); is_concrete_type(t); "string") # object?
to_json_type(t::Type{Any}) = "string" # Allow explicit Any as it can be deserialized by JSON3

has_null_type(T::Type{Missing}) = true
has_null_type(T::Type{Nothing}) = true
Expand Down Expand Up @@ -457,6 +469,7 @@ end
Extract the argument names, types and docstrings from a struct to create the function call signature in JSON schema.
You must provide a Struct type (not an instance of it) with some fields.
The types must be CONCRETE, it helps with correct conversion to JSON schema and then conversion back to the struct.
Note: Fairly experimental, but works for combination of structs, arrays, strings and singletons.
Expand Down
2 changes: 2 additions & 0 deletions src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,8 @@ This is a perfect solution for extracting structured information from text (eg,
It's effectively a light wrapper around `aigenerate` call, which requires additional keyword argument `return_type` to be provided
and will enforce the model outputs to adhere to it.
!!! Note: The types must be CONCRETE, it helps with correct conversion to JSON schema and then conversion back to the struct.
# Arguments
- `prompt_schema`: An optional object to specify which prompt template should be applied (Default to `PROMPT_SCHEMA = OpenAISchema`)
- `prompt`: Can be a string representing the prompt for the AI conversation, a `UserMessage`, a vector of `AbstractMessage` or an `AITemplate`
Expand Down
2 changes: 1 addition & 1 deletion src/precompilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ messages = [
_ = render(OpenAISchema(), messages)

## Utilities
pprint(messages)
pprint(devnull, messages)
last_output(messages)
last_message(messages)

Expand Down
132 changes: 130 additions & 2 deletions test/extraction.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using PromptingTools: MaybeExtract, extract_docstring, ItemsExtract, ToolMessage
using PromptingTools: has_null_type, is_required_field, remove_null_types, to_json_schema
using PromptingTools: tool_call_signature, set_properties_strict!,
using PromptingTools: tool_call_signature, set_properties_strict!, is_concrete_type,
to_json_type, is_not_union_type,
update_field_descriptions!, generate_struct
using PromptingTools: Tool, isabstracttool, execute_tool, parse_tool, get_arg_names,
get_arg_types, get_method, get_function, remove_field!,
Expand Down Expand Up @@ -99,6 +100,46 @@ end
@test occursin("computer", output)
end

@testset "is_concrete_type" begin
@test is_concrete_type(Int) == true
@test_throws ArgumentError is_concrete_type(AbstractString)
end

@testset "is_not_union_type" begin
@test_throws ArgumentError is_not_union_type(Union{Int, String})
@test is_not_union_type(Int) == true
end

@testset "to_json_type" begin
# Test string types
@test to_json_type(String) == "string"
@test to_json_type(SubString{String}) == "string"

# Test number types
@test to_json_type(Float64) == "number"
@test to_json_type(Float32) == "number"
@test to_json_type(Int64) == "integer"
@test to_json_type(Int32) == "integer"
@test to_json_type(UInt8) == "integer"

# Test boolean type
@test to_json_type(Bool) == "boolean"

# Test null types
@test to_json_type(Nothing) == "null"
@test to_json_type(Missing) == "null"

# Test concrete Any types
struct CustomType end
@test to_json_type(CustomType) == "string"

# Test error cases for abstract types
@test_throws ArgumentError to_json_type(AbstractString)
@test_throws ArgumentError to_json_type(Real)
@test_throws ArgumentError to_json_type(Integer)
@test_throws ArgumentError to_json_type(AbstractArray)
end

@testset "has_null_type" begin
@test has_null_type(Number) == false
@test has_null_type(Nothing) == true
Expand Down Expand Up @@ -236,6 +277,10 @@ end
@test schema["type"] == "array"
@test schema["items"]["type"] == "number"

schema = to_json_schema(Tuple{Int64, String})
@test schema["type"] == "array"
@test schema["items"]["type"] == "string"

## Special types
@enum TemperatureUnits celsius fahrenheit
schema = to_json_schema(TemperatureUnits)
Expand Down Expand Up @@ -273,7 +318,8 @@ end

## Fallback to string (for tough unions)
@test to_json_schema(Any) == Dict("type" => "string")
@test to_json_schema(Union{Int, String, Real}) == Dict("type" => "string")
## We force user to be explicit about the type, so it fails with a clear error
@test_throws ArgumentError to_json_schema(Union{Int, String, Float64})

## Disallowed types
@test_throws ArgumentError to_json_schema(Dict{String, Int})
Expand Down Expand Up @@ -309,6 +355,88 @@ end
@test schema["required"] == ["x", "y"]
@test haskey(schema, "description")
@test schema["description"] == "This is a test function.\n"

## Round trip on complicated types
struct MockTypeTest4
# Test concrete types
int_field::Int64
float_field::Float64
string_field::String
bool_field::Bool

# Test unions with Nothing/Missing
nullable_int::Union{Int64, Nothing}
missing_float::Union{Float64, Missing}

# No type
no_type_field::Any

# Test nested types
array_field::Vector{Float64}
tuple_field::Tuple{Int64, String}

## Not supported
# dict_field::Dict{String, Int64}
# union_field::Union{Int64, String}
end

# Test basic schema structure for MockTypeTest4
schema = to_json_schema(MockTypeTest4)
@test schema isa Dict{String, Any}
@test schema["type"] == "object"
@test haskey(schema, "properties")
@test haskey(schema, "required")

# Test concrete type fields
props = schema["properties"]
@test props["int_field"]["type"] == "integer"
@test props["float_field"]["type"] == "number"
@test props["string_field"]["type"] == "string"
@test props["bool_field"]["type"] == "boolean"

# Test nullable/missing fields
@test props["nullable_int"]["type"] == "integer"

@test props["missing_float"]["type"] == "number"

# Test Any field
@test props["no_type_field"]["type"] == "string"

# Test array field
@test props["array_field"]["type"] == "array"
@test props["array_field"]["items"]["type"] == "number"

# Test tuple field
@test props["tuple_field"]["type"] == "array"
@test props["tuple_field"]["items"]["type"] == "string"

# Test required fields
@test "int_field" in schema["required"]
@test "float_field" in schema["required"]
@test "string_field" in schema["required"]
@test "bool_field" in schema["required"]
@test "nullable_int" schema["required"]

## Round-trip test with JSON3
str = JSON3.write(MockTypeTest4(
1, 2.0, "3", true, nothing, missing, "any", [4.0], (5, "6")))
@test_nowarn instance = JSON3.read(str, MockTypeTest4)
@test instance.int_field == 1
@test instance.float_field == 2.0
@test instance.string_field == "3"
@test instance.bool_field == true
@test instance.nullable_int == nothing
@test instance.missing_float === missing
@test instance.no_type_field == "any"
@test instance.array_field == [4.0]
@test instance.tuple_field == (5, "6")

struct TestTuple1
x::Tuple{Int64, String}
end
str = JSON3.write(TestTuple1((1, "2")))
instance = JSON3.read(str, TestTuple1)
@test instance.x == (1, "2")
end

@testset "to_json_schema-MaybeExtract" begin
Expand Down
2 changes: 1 addition & 1 deletion test/llm_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ end
end

# TODO: add aitools tracer tests
function calculator(x::Number, y::Number; operation::String = "add")
function calculator(x::Float64, y::Float64; operation::String = "add")
operation == "add" ?
x + y :
throw(ArgumentError("Unsupported operation"))
Expand Down

2 comments on commit b95f6c9

@svilupp
Copy link
Owner Author

@svilupp svilupp commented on b95f6c9 Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Fixed

  • Removed unnecessary printing to stdout during precompilation in precompile.jl.
  • Fixed a "bug-waiting-to-happen" in tool use. to_json_type now enforces users to provide concrete types, because abstract types can lead to errors during JSON3 deserialization.
  • Flowed through a bug fix in StreamCallback where the usage information was being included in the response even when usage=nothing. Lower bound of StreamCallbacks was bumped to 0.5.1.

Commits

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/120717

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.65.1 -m "<description of version>" b95f6c90d40fe6b07fe6eb93970b897aeeb20ec2
git push origin v0.65.1

Please sign in to comment.