Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix to-json-type #248

Merged
merged 4 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.65.1]

### Fixed
- Removed unnecessary printing to `stdout` during precompilation in `precompile.jl`.
- Fixed a "bug-waiting-to-happen" in tool use. `to_json_type` now enforces users to provide concrete types, because abstract types can lead to errors during JSON3 deserialization.
- Flowed through a bug fix in `StreamCallback` where the usage information was being included in the response even when `usage=nothing`. Lower bound of `StreamCallbacks` was bumped to `0.5.1`.

## [0.65.0]

### Breaking
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.65.0"
version = "0.65.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -55,7 +55,7 @@ REPL = "<0.0.1, 1"
Random = "<0.0.1, 1"
SparseArrays = "<0.0.1, 1"
Statistics = "<0.0.1, 1"
StreamCallbacks = "0.4, 0.5"
StreamCallbacks = "0.5.1"
Test = "<0.0.1, 1"
julia = "1.9, 1.10"

Expand Down
21 changes: 17 additions & 4 deletions src/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,24 @@ end
# 1) OpenAI / JSON format
######################

to_json_type(s::Type{<:AbstractString}) = "string"
to_json_type(n::Type{<:Real}) = "number"
to_json_type(n::Type{<:Integer}) = "integer"
"Check if a type is concrete."
function is_concrete_type(s::Type)
isconcretetype(s) ||
throw(ArgumentError("Cannot convert abstract type $s to JSON type. You must provide concrete types!"))
end

function is_not_union_type(s::Type)
!isa(s, Union) ||
throw(ArgumentError("Cannot convert $s to JSON type. The only supported union types are Union{..., Nothing}. Please pick a concrete type (`::String` is generic if you cannot pick)!"))
end

to_json_type(s::Type{<:AbstractString}) = (is_concrete_type(s); "string")
to_json_type(n::Type{<:Real}) = (is_concrete_type(n); "number")
to_json_type(n::Type{<:Integer}) = (is_concrete_type(n); "integer")
to_json_type(b::Type{Bool}) = "boolean"
to_json_type(t::Type{<:Union{Missing, Nothing}}) = "null"
to_json_type(t::Type{<:Any}) = "string" # object?
to_json_type(t::Type{<:Any}) = (is_not_union_type(t); is_concrete_type(t); "string") # object?
to_json_type(t::Type{Any}) = "string" # Allow explicit Any as it can be deserialized by JSON3

has_null_type(T::Type{Missing}) = true
has_null_type(T::Type{Nothing}) = true
Expand Down Expand Up @@ -457,6 +469,7 @@ end
Extract the argument names, types and docstrings from a struct to create the function call signature in JSON schema.

You must provide a Struct type (not an instance of it) with some fields.
The types must be CONCRETE, it helps with correct conversion to JSON schema and then conversion back to the struct.

Note: Fairly experimental, but works for combination of structs, arrays, strings and singletons.

Expand Down
2 changes: 2 additions & 0 deletions src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,8 @@ This is a perfect solution for extracting structured information from text (eg,
It's effectively a light wrapper around `aigenerate` call, which requires additional keyword argument `return_type` to be provided
and will enforce the model outputs to adhere to it.

!!! Note: The types must be CONCRETE, it helps with correct conversion to JSON schema and then conversion back to the struct.

# Arguments
- `prompt_schema`: An optional object to specify which prompt template should be applied (Default to `PROMPT_SCHEMA = OpenAISchema`)
- `prompt`: Can be a string representing the prompt for the AI conversation, a `UserMessage`, a vector of `AbstractMessage` or an `AITemplate`
Expand Down
2 changes: 1 addition & 1 deletion src/precompilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ messages = [
_ = render(OpenAISchema(), messages)

## Utilities
pprint(messages)
pprint(devnull, messages)
last_output(messages)
last_message(messages)

Expand Down
132 changes: 130 additions & 2 deletions test/extraction.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using PromptingTools: MaybeExtract, extract_docstring, ItemsExtract, ToolMessage
using PromptingTools: has_null_type, is_required_field, remove_null_types, to_json_schema
using PromptingTools: tool_call_signature, set_properties_strict!,
using PromptingTools: tool_call_signature, set_properties_strict!, is_concrete_type,
to_json_type, is_not_union_type,
update_field_descriptions!, generate_struct
using PromptingTools: Tool, isabstracttool, execute_tool, parse_tool, get_arg_names,
get_arg_types, get_method, get_function, remove_field!,
Expand Down Expand Up @@ -99,6 +100,46 @@ end
@test occursin("computer", output)
end

@testset "is_concrete_type" begin
@test is_concrete_type(Int) == true
@test_throws ArgumentError is_concrete_type(AbstractString)
end

@testset "is_not_union_type" begin
@test_throws ArgumentError is_not_union_type(Union{Int, String})
@test is_not_union_type(Int) == true
end

@testset "to_json_type" begin
# Test string types
@test to_json_type(String) == "string"
@test to_json_type(SubString{String}) == "string"

# Test number types
@test to_json_type(Float64) == "number"
@test to_json_type(Float32) == "number"
@test to_json_type(Int64) == "integer"
@test to_json_type(Int32) == "integer"
@test to_json_type(UInt8) == "integer"

# Test boolean type
@test to_json_type(Bool) == "boolean"

# Test null types
@test to_json_type(Nothing) == "null"
@test to_json_type(Missing) == "null"

# Test concrete Any types
struct CustomType end
@test to_json_type(CustomType) == "string"

# Test error cases for abstract types
@test_throws ArgumentError to_json_type(AbstractString)
@test_throws ArgumentError to_json_type(Real)
@test_throws ArgumentError to_json_type(Integer)
@test_throws ArgumentError to_json_type(AbstractArray)
end

@testset "has_null_type" begin
@test has_null_type(Number) == false
@test has_null_type(Nothing) == true
Expand Down Expand Up @@ -236,6 +277,10 @@ end
@test schema["type"] == "array"
@test schema["items"]["type"] == "number"

schema = to_json_schema(Tuple{Int64, String})
@test schema["type"] == "array"
@test schema["items"]["type"] == "string"

## Special types
@enum TemperatureUnits celsius fahrenheit
schema = to_json_schema(TemperatureUnits)
Expand Down Expand Up @@ -273,7 +318,8 @@ end

## Fallback to string (for tough unions)
@test to_json_schema(Any) == Dict("type" => "string")
@test to_json_schema(Union{Int, String, Real}) == Dict("type" => "string")
## We force user to be explicit about the type, so it fails with a clear error
@test_throws ArgumentError to_json_schema(Union{Int, String, Float64})

## Disallowed types
@test_throws ArgumentError to_json_schema(Dict{String, Int})
Expand Down Expand Up @@ -309,6 +355,88 @@ end
@test schema["required"] == ["x", "y"]
@test haskey(schema, "description")
@test schema["description"] == "This is a test function.\n"

## Round trip on complicated types
struct MockTypeTest4
# Test concrete types
int_field::Int64
float_field::Float64
string_field::String
bool_field::Bool

# Test unions with Nothing/Missing
nullable_int::Union{Int64, Nothing}
missing_float::Union{Float64, Missing}

# No type
no_type_field::Any

# Test nested types
array_field::Vector{Float64}
tuple_field::Tuple{Int64, String}

## Not supported
# dict_field::Dict{String, Int64}
# union_field::Union{Int64, String}
end

# Test basic schema structure for MockTypeTest4
schema = to_json_schema(MockTypeTest4)
@test schema isa Dict{String, Any}
@test schema["type"] == "object"
@test haskey(schema, "properties")
@test haskey(schema, "required")

# Test concrete type fields
props = schema["properties"]
@test props["int_field"]["type"] == "integer"
@test props["float_field"]["type"] == "number"
@test props["string_field"]["type"] == "string"
@test props["bool_field"]["type"] == "boolean"

# Test nullable/missing fields
@test props["nullable_int"]["type"] == "integer"

@test props["missing_float"]["type"] == "number"

# Test Any field
@test props["no_type_field"]["type"] == "string"

# Test array field
@test props["array_field"]["type"] == "array"
@test props["array_field"]["items"]["type"] == "number"

# Test tuple field
@test props["tuple_field"]["type"] == "array"
@test props["tuple_field"]["items"]["type"] == "string"

# Test required fields
@test "int_field" in schema["required"]
@test "float_field" in schema["required"]
@test "string_field" in schema["required"]
@test "bool_field" in schema["required"]
@test "nullable_int" ∉ schema["required"]

## Round-trip test with JSON3
str = JSON3.write(MockTypeTest4(
1, 2.0, "3", true, nothing, missing, "any", [4.0], (5, "6")))
@test_nowarn instance = JSON3.read(str, MockTypeTest4)
@test instance.int_field == 1
@test instance.float_field == 2.0
@test instance.string_field == "3"
@test instance.bool_field == true
@test instance.nullable_int == nothing
@test instance.missing_float === missing
@test instance.no_type_field == "any"
@test instance.array_field == [4.0]
@test instance.tuple_field == (5, "6")

struct TestTuple1
x::Tuple{Int64, String}
end
str = JSON3.write(TestTuple1((1, "2")))
instance = JSON3.read(str, TestTuple1)
@test instance.x == (1, "2")
end

@testset "to_json_schema-MaybeExtract" begin
Expand Down
2 changes: 1 addition & 1 deletion test/llm_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ end
end

# TODO: add aitools tracer tests
function calculator(x::Number, y::Number; operation::String = "add")
function calculator(x::Float64, y::Float64; operation::String = "add")
operation == "add" ?
x + y :
throw(ArgumentError("Unsupported operation"))
Expand Down
Loading