diff --git a/CHANGELOG.md b/CHANGELOG.md index 43f156a8a..00fc0a244 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.48.0] + +### Added +- Implements the new OpenAI structured output mode for `aiextract` (just provide kwarg `strict=true`). Reference [blog post](https://openai.com/index/introducing-structured-outputs-in-the-api/). + ## [0.47.0] ### Added diff --git a/Project.toml b/Project.toml index 7c3cd6b13..c2ca2ab49 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.47.0" +version = "0.48.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/extraction.jl b/src/extraction.jl index bd77aa1f4..0f591fdeb 100644 --- a/src/extraction.jl +++ b/src/extraction.jl @@ -75,18 +75,18 @@ function to_json_schema(orig_type; max_description_length::Int = 100) return schema end function to_json_schema(type::Type{<:AbstractString}; max_description_length::Int = 100) - Dict("type" => to_json_type(type)) + Dict{String, Any}("type" => to_json_type(type)) end function to_json_schema(type::Type{T}; max_description_length::Int = 100) where {T <: Union{AbstractSet, Tuple, AbstractArray}} element_type = eltype(type) - return Dict("type" => "array", + return Dict{String, Any}("type" => "array", "items" => to_json_schema(remove_null_types(element_type))) end function to_json_schema(type::Type{<:Enum}; max_description_length::Int = 100) enum_options = Base.Enums.namemap(type) |> values .|> string - return Dict("type" => "string", + return Dict{String, Any}("type" => "string", "enum" => enum_options) end function to_json_schema(type::Type{<:AbstractDict}; max_description_length::Int = 100) @@ -94,7 +94,9 @@ function to_json_schema(type::Type{<:AbstractDict}; max_description_length::Int end """ - function_call_signature(datastructtype::Struct; max_description_length::Int = 100) + function_call_signature( + datastructtype::Type; strict::Union{Nothing, Bool} = nothing, + max_description_length::Int = 100) Extract the argument names, types and docstrings from a struct to create the function call signature in JSON schema. @@ -123,7 +125,8 @@ signature = function_call_signature(MyMeasurement) # Dict{String, Any} with 3 entries: # "name" => "MyMeasurement_extractor" # "parameters" => Dict{String, Any}("properties"=>Dict{String, Any}("height"=>Dict{String, Any}("type"=>"integer"), "weight"=>Dic… -# "description" => "Represents person's age, height, and weight\n" +# "description" => "Represents person's age, height, and weight +" ``` You can see that only the field `age` does not allow null values, hence, it's "required". @@ -161,7 +164,9 @@ msg = aiextract("Extract measurements from the text: I am giraffe", type) ``` That way, you can handle the error gracefully and get a reason why extraction failed. """ -function function_call_signature(datastructtype::Type; max_description_length::Int = 100) +function function_call_signature( + datastructtype::Type; strict::Union{Nothing, Bool} = nothing, + max_description_length::Int = 100) !isstructtype(datastructtype) && error("Only Structs are supported (provided type: $datastructtype") ## Standardize the name @@ -177,9 +182,56 @@ function function_call_signature(datastructtype::Type; max_description_length::I schema["parameters"]["description"] == docs delete!(schema["parameters"], "description") end + ## strict mode // see https://platform.openai.com/docs/guides/structured-outputs/supported-schemas + if strict == false + schema["strict"] = false + elseif strict == true + schema["strict"] = true + if haskey(schema["parameters"], "properties") + set_properties_strict!(schema["parameters"]) + end + end return schema end +""" + set_properties_strict!(properties::AbstractDict) + +Sets strict mode for the properties of a JSON schema. + +Changes: +- Sets `additionalProperties` to `false`. +- All keys must be included in `required`. +- All optional keys will have `null` added to their type. + +Reference: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas +""" +function set_properties_strict!(parameters::AbstractDict) + parameters["additionalProperties"] = false + required_fields = get(parameters, "required", String[]) + optional_fields = String[] + + for (key, value) in parameters["properties"] + if key ∉ required_fields + push!(optional_fields, key) + if haskey(value, "type") + value["type"] = [value["type"], "null"] + end + end + + # Recursively apply to nested properties + if haskey(value, "properties") + set_properties_strict!(value) + elseif haskey(value, "items") && haskey(value["items"], "properties") + ## if it's an array, we need to skip inside "items" + set_properties_strict!(value["items"]) + end + end + + parameters["required"] = vcat(required_fields, optional_fields) + return parameters +end + ###################### # 2) Anthropic / XML format ###################### diff --git a/src/llm_openai.jl b/src/llm_openai.jl index 4bee1ced2..fe866f022 100644 --- a/src/llm_openai.jl +++ b/src/llm_openai.jl @@ -969,6 +969,7 @@ end retries = 5, readtimeout = 120), api_kwargs::NamedTuple = (; tool_choice = "exact"), + strict::Union{Nothing, Bool} = nothing, kwargs...) Extract required information (defined by a struct **`return_type`**) from the provided prompt by leveraging OpenAI function calling mode. @@ -994,6 +995,7 @@ It's effectively a light wrapper around `aigenerate` call, which requires additi - `tool_choice`: A string representing the tool choice to use for the API call. Usually, one of "auto","any","exact". Defaults to `"exact"`, which is a made-up value to enforce the OpenAI requirements if we want one exact function. Providers like Mistral, Together, etc. use `"any"` instead. +- `strict::Union{Nothing, Bool} = nothing`: A boolean indicating whether to enforce strict generation of the response (supported only for OpenAI models). It has additional latency for the first request. If `nothing`, standard function calling is used. - `kwargs`: Prompt variables to be used to fill the prompt/template # Returns @@ -1100,11 +1102,13 @@ function aiextract(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_T retries = 5, readtimeout = 120), api_kwargs::NamedTuple = (; tool_choice = "exact"), + strict::Union{Nothing, Bool} = nothing, kwargs...) ## global MODEL_ALIASES ## Function calling specifics - tools = [Dict(:type => "function", :function => function_call_signature(return_type))] + tools = [Dict( + :type => "function", :function => function_call_signature(return_type; strict))] ## force our function to be used tool_choice_ = get(api_kwargs, :tool_choice, "exact") tool_choice = if tool_choice_ == "exact" diff --git a/test/extraction.jl b/test/extraction.jl index 476487321..e24cb9fa1 100644 --- a/test/extraction.jl +++ b/test/extraction.jl @@ -1,6 +1,6 @@ using PromptingTools: MaybeExtract, extract_docstring, ItemsExtract using PromptingTools: has_null_type, is_required_field, remove_null_types, to_json_schema -using PromptingTools: function_call_signature +using PromptingTools: function_call_signature, set_properties_strict! # TODO: check more edge cases like empty structs @@ -216,6 +216,100 @@ end @test schema_measurement["description"] == "Represents person's age, height, and weight\n" end +@testset "set_properties_strict!" begin + # Test 1: Basic functionality + params = Dict( + "properties" => Dict{String, Any}( + "name" => Dict{String, Any}("type" => "string"), + "age" => Dict{String, Any}("type" => "integer") + ), + "required" => ["name"] + ) + set_properties_strict!(params) + @test params["additionalProperties"] == false + @test Set(params["required"]) == Set(["name", "age"]) + @test params["properties"]["age"]["type"] == ["integer", "null"] + + # Test 2: Nested properties + params = Dict{String, Any}( + "properties" => Dict{String, Any}( + "person" => Dict{String, Any}( + "type" => "object", + "properties" => Dict{String, Any}( + "name" => Dict{String, Any}("type" => "string"), + "age" => Dict{String, Any}("type" => "integer") + ) + ) + ) + ) + set_properties_strict!(params) + @test params["properties"]["person"]["additionalProperties"] == false + @test Set(params["properties"]["person"]["required"]) == + Set(["name", "age"]) + + # Test 3: Array of objects + params = Dict{String, Any}( + "properties" => Dict{String, Any}( + "people" => Dict{String, Any}( + "type" => "array", + "items" => Dict{String, Any}( + "type" => "object", + "properties" => Dict{String, Any}( + "name" => Dict{String, Any}("type" => "string"), + "age" => Dict{String, Any}("type" => "integer") + ) + ) + ) + ) + ) + set_properties_strict!(params) + @test params["properties"]["people"]["items"]["additionalProperties"] == false + @test Set(params["properties"]["people"]["items"]["required"]) == Set(["name", "age"]) + + # Test 4: Multiple levels of nesting + params = Dict{String, Any}( + "properties" => Dict{String, Any}( + "company" => Dict{String, Any}( + "type" => "object", + "properties" => Dict{String, Any}( + "name" => Dict{String, Any}("type" => "string"), + "employees" => Dict{String, Any}( + "type" => "array", + "items" => Dict{String, Any}( + "type" => "object", + "properties" => Dict{String, Any}( + "name" => Dict{String, Any}("type" => "string"), + "position" => Dict{String, Any}("type" => "string") + ) + ) + ) + ) + ) + ) + ) + set_properties_strict!(params) + @test params["properties"]["company"]["additionalProperties"] == false + @test params["properties"]["company"]["properties"]["employees"]["items"]["additionalProperties"] == + false + @test Set(params["properties"]["company"]["properties"]["employees"]["items"]["required"]) == + Set(["name", "position"]) + + # Test 5: Handling of existing required fields + params = Dict{String, Any}( + "properties" => Dict{String, Any}( + "name" => Dict{String, Any}("type" => "string"), + "age" => Dict{String, Any}("type" => "integer"), + "email" => Dict{String, Any}("type" => "string") + ), + "required" => ["name", "email"] + ) + set_properties_strict!(params) + @test Set(params["required"]) == Set(["name", "email", "age"]) + @test params["properties"]["age"]["type"] == ["integer", "null"] + @test !haskey(params["properties"]["name"], "null") + @test !haskey(params["properties"]["email"], "null") +end + @testset "function_call_signature" begin "Some docstring" struct MyMeasurement2 @@ -241,4 +335,49 @@ end ## MaybeWraper name cleanup schema = function_call_signature(MaybeExtract{MyMeasurement2}) @test schema["name"] == "MaybeExtractMyMeasurement2_extractor" -end + + ## Test with strict = true + + "Person's age, height, and weight." + struct MyMeasurement3 + age::Int + height::Union{Int, Nothing} + weight::Union{Nothing, Float64} + end + + # Test with strict = nothing (default behavior) + output_default = function_call_signature(MyMeasurement3) + @test !haskey(output_default, "strict") + @test output_default["name"] == "MyMeasurement3_extractor" + @test output_default["parameters"]["type"] == "object" + @test output_default["parameters"]["required"] == ["age"] + @test !haskey(output_default["parameters"], "additionalProperties") + + # Test with strict =false + output_not_strict = function_call_signature(MyMeasurement3; strict = false) + @test haskey(output_not_strict, "strict") + @test output_not_strict["strict"] == false + @test output_not_strict["name"] == "MyMeasurement3_extractor" + @test output_not_strict["parameters"]["type"] == "object" + @test output_not_strict["parameters"]["required"] == ["age"] + @test !haskey(output_default["parameters"], "additionalProperties") + + # Test with strict = true + output_strict = function_call_signature(MyMeasurement3; strict = true) + @test output_strict["strict"] == true + @test output_strict["name"] == "MyMeasurement3_extractor" + @test output_strict["parameters"]["type"] == "object" + @test Set(output_strict["parameters"]["required"]) == Set(["age", "height", "weight"]) + @test output_strict["parameters"]["additionalProperties"] == false + @test output_strict["parameters"]["properties"]["height"]["type"] == ["integer", "null"] + @test output_strict["parameters"]["properties"]["weight"]["type"] == ["number", "null"] + @test output_strict["parameters"]["properties"]["age"]["type"] == "integer" + + # Test with MaybeExtract wrapper + output_maybe = function_call_signature(MaybeExtract{MyMeasurement3}; strict = true) + @test output_maybe["name"] == "MaybeExtractMyMeasurement3_extractor" + @test output_maybe["parameters"]["properties"]["result"]["type"] == ["object", "null"] + @test output_maybe["parameters"]["properties"]["error"]["type"] == "boolean" + @test output_maybe["parameters"]["properties"]["message"]["type"] == ["string", "null"] + @test Set(output_maybe["parameters"]["required"]) == Set(["result", "error", "message"]) +end \ No newline at end of file