Skip to content

Commit

Permalink
Add OpenAI structured outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Aug 9, 2024
1 parent 7d6a8d8 commit 8f697f6
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 10 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.48.0]

### Added
- Implements the new OpenAI structured output mode for `aiextract` (just provide kwarg `strict=true`). Reference [blog post](https://openai.com/index/introducing-structured-outputs-in-the-api/).

## [0.47.0]

### Added
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.47.0"
version = "0.48.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
64 changes: 58 additions & 6 deletions src/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,26 +75,28 @@ function to_json_schema(orig_type; max_description_length::Int = 100)
return schema
end
function to_json_schema(type::Type{<:AbstractString}; max_description_length::Int = 100)
Dict("type" => to_json_type(type))
Dict{String, Any}("type" => to_json_type(type))
end
function to_json_schema(type::Type{T};
max_description_length::Int = 100) where {T <:
Union{AbstractSet, Tuple, AbstractArray}}
element_type = eltype(type)
return Dict("type" => "array",
return Dict{String, Any}("type" => "array",
"items" => to_json_schema(remove_null_types(element_type)))
end
function to_json_schema(type::Type{<:Enum}; max_description_length::Int = 100)
enum_options = Base.Enums.namemap(type) |> values .|> string
return Dict("type" => "string",
return Dict{String, Any}("type" => "string",
"enum" => enum_options)
end
function to_json_schema(type::Type{<:AbstractDict}; max_description_length::Int = 100)
throw(ArgumentError("Dicts are not supported yet as we cannot analyze their keys/values on a type-level. Use a nested Struct instead!"))
end

"""
function_call_signature(datastructtype::Struct; max_description_length::Int = 100)
function_call_signature(
datastructtype::Type; strict::Union{Nothing, Bool} = nothing,
max_description_length::Int = 100)
Extract the argument names, types and docstrings from a struct to create the function call signature in JSON schema.
Expand Down Expand Up @@ -123,7 +125,8 @@ signature = function_call_signature(MyMeasurement)
# Dict{String, Any} with 3 entries:
# "name" => "MyMeasurement_extractor"
# "parameters" => Dict{String, Any}("properties"=>Dict{String, Any}("height"=>Dict{String, Any}("type"=>"integer"), "weight"=>Dic…
# "description" => "Represents person's age, height, and weight\n"
# "description" => "Represents person's age, height, and weight
"
```
You can see that only the field `age` does not allow null values, hence, it's "required".
Expand Down Expand Up @@ -161,7 +164,9 @@ msg = aiextract("Extract measurements from the text: I am giraffe", type)
```
That way, you can handle the error gracefully and get a reason why extraction failed.
"""
function function_call_signature(datastructtype::Type; max_description_length::Int = 100)
function function_call_signature(
datastructtype::Type; strict::Union{Nothing, Bool} = nothing,
max_description_length::Int = 100)
!isstructtype(datastructtype) &&
error("Only Structs are supported (provided type: $datastructtype")
## Standardize the name
Expand All @@ -177,9 +182,56 @@ function function_call_signature(datastructtype::Type; max_description_length::I
schema["parameters"]["description"] == docs
delete!(schema["parameters"], "description")
end
## strict mode // see https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
if strict == false
schema["strict"] = false
elseif strict == true
schema["strict"] = true
if haskey(schema["parameters"], "properties")
set_properties_strict!(schema["parameters"])
end
end
return schema
end

"""
set_properties_strict!(properties::AbstractDict)
Sets strict mode for the properties of a JSON schema.
Changes:
- Sets `additionalProperties` to `false`.
- All keys must be included in `required`.
- All optional keys will have `null` added to their type.
Reference: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
"""
function set_properties_strict!(parameters::AbstractDict)
parameters["additionalProperties"] = false
required_fields = get(parameters, "required", String[])
optional_fields = String[]

for (key, value) in parameters["properties"]
if key required_fields
push!(optional_fields, key)
if haskey(value, "type")
value["type"] = [value["type"], "null"]
end
end

# Recursively apply to nested properties
if haskey(value, "properties")
set_properties_strict!(value)
elseif haskey(value, "items") && haskey(value["items"], "properties")
## if it's an array, we need to skip inside "items"
set_properties_strict!(value["items"])
end
end

parameters["required"] = vcat(required_fields, optional_fields)
return parameters
end

######################
# 2) Anthropic / XML format
######################
Expand Down
6 changes: 5 additions & 1 deletion src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,7 @@ end
retries = 5,
readtimeout = 120), api_kwargs::NamedTuple = (;
tool_choice = "exact"),
strict::Union{Nothing, Bool} = nothing,
kwargs...)
Extract required information (defined by a struct **`return_type`**) from the provided prompt by leveraging OpenAI function calling mode.
Expand All @@ -994,6 +995,7 @@ It's effectively a light wrapper around `aigenerate` call, which requires additi
- `tool_choice`: A string representing the tool choice to use for the API call. Usually, one of "auto","any","exact".
Defaults to `"exact"`, which is a made-up value to enforce the OpenAI requirements if we want one exact function.
Providers like Mistral, Together, etc. use `"any"` instead.
- `strict::Union{Nothing, Bool} = nothing`: A boolean indicating whether to enforce strict generation of the response (supported only for OpenAI models). It has additional latency for the first request. If `nothing`, standard function calling is used.
- `kwargs`: Prompt variables to be used to fill the prompt/template
# Returns
Expand Down Expand Up @@ -1100,11 +1102,13 @@ function aiextract(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_T
retries = 5,
readtimeout = 120), api_kwargs::NamedTuple = (;
tool_choice = "exact"),
strict::Union{Nothing, Bool} = nothing,
kwargs...)
##
global MODEL_ALIASES
## Function calling specifics
tools = [Dict(:type => "function", :function => function_call_signature(return_type))]
tools = [Dict(
:type => "function", :function => function_call_signature(return_type; strict))]
## force our function to be used
tool_choice_ = get(api_kwargs, :tool_choice, "exact")
tool_choice = if tool_choice_ == "exact"
Expand Down
143 changes: 141 additions & 2 deletions test/extraction.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using PromptingTools: MaybeExtract, extract_docstring, ItemsExtract
using PromptingTools: has_null_type, is_required_field, remove_null_types, to_json_schema
using PromptingTools: function_call_signature
using PromptingTools: function_call_signature, set_properties_strict!

# TODO: check more edge cases like empty structs

Expand Down Expand Up @@ -216,6 +216,100 @@ end
@test schema_measurement["description"] ==
"Represents person's age, height, and weight\n"
end
@testset "set_properties_strict!" begin
# Test 1: Basic functionality
params = Dict(
"properties" => Dict{String, Any}(
"name" => Dict{String, Any}("type" => "string"),
"age" => Dict{String, Any}("type" => "integer")
),
"required" => ["name"]
)
set_properties_strict!(params)
@test params["additionalProperties"] == false
@test Set(params["required"]) == Set(["name", "age"])
@test params["properties"]["age"]["type"] == ["integer", "null"]

# Test 2: Nested properties
params = Dict{String, Any}(
"properties" => Dict{String, Any}(
"person" => Dict{String, Any}(
"type" => "object",
"properties" => Dict{String, Any}(
"name" => Dict{String, Any}("type" => "string"),
"age" => Dict{String, Any}("type" => "integer")
)
)
)
)
set_properties_strict!(params)
@test params["properties"]["person"]["additionalProperties"] == false
@test Set(params["properties"]["person"]["required"]) ==
Set(["name", "age"])

# Test 3: Array of objects
params = Dict{String, Any}(
"properties" => Dict{String, Any}(
"people" => Dict{String, Any}(
"type" => "array",
"items" => Dict{String, Any}(
"type" => "object",
"properties" => Dict{String, Any}(
"name" => Dict{String, Any}("type" => "string"),
"age" => Dict{String, Any}("type" => "integer")
)
)
)
)
)
set_properties_strict!(params)
@test params["properties"]["people"]["items"]["additionalProperties"] == false
@test Set(params["properties"]["people"]["items"]["required"]) == Set(["name", "age"])

# Test 4: Multiple levels of nesting
params = Dict{String, Any}(
"properties" => Dict{String, Any}(
"company" => Dict{String, Any}(
"type" => "object",
"properties" => Dict{String, Any}(
"name" => Dict{String, Any}("type" => "string"),
"employees" => Dict{String, Any}(
"type" => "array",
"items" => Dict{String, Any}(
"type" => "object",
"properties" => Dict{String, Any}(
"name" => Dict{String, Any}("type" => "string"),
"position" => Dict{String, Any}("type" => "string")
)
)
)
)
)
)
)
set_properties_strict!(params)
@test params["properties"]["company"]["additionalProperties"] == false
@test params["properties"]["company"]["properties"]["employees"]["items"]["additionalProperties"] ==
false
@test Set(params["properties"]["company"]["properties"]["employees"]["items"]["required"]) ==
Set(["name", "position"])

# Test 5: Handling of existing required fields
params = Dict{String, Any}(
"properties" => Dict{String, Any}(
"name" => Dict{String, Any}("type" => "string"),
"age" => Dict{String, Any}("type" => "integer"),
"email" => Dict{String, Any}("type" => "string")
),
"required" => ["name", "email"]
)
set_properties_strict!(params)
@test Set(params["required"]) == Set(["name", "email", "age"])
@test params["properties"]["age"]["type"] == ["integer", "null"]
@test !haskey(params["properties"]["name"], "null")
@test !haskey(params["properties"]["email"], "null")
end

@testset "function_call_signature" begin
"Some docstring"
struct MyMeasurement2
Expand All @@ -241,4 +335,49 @@ end
## MaybeWraper name cleanup
schema = function_call_signature(MaybeExtract{MyMeasurement2})
@test schema["name"] == "MaybeExtractMyMeasurement2_extractor"
end

## Test with strict = true

"Person's age, height, and weight."
struct MyMeasurement3
age::Int
height::Union{Int, Nothing}
weight::Union{Nothing, Float64}
end

# Test with strict = nothing (default behavior)
output_default = function_call_signature(MyMeasurement3)
@test !haskey(output_default, "strict")
@test output_default["name"] == "MyMeasurement3_extractor"
@test output_default["parameters"]["type"] == "object"
@test output_default["parameters"]["required"] == ["age"]
@test !haskey(output_default["parameters"], "additionalProperties")

# Test with strict =false
output_not_strict = function_call_signature(MyMeasurement3; strict = false)
@test haskey(output_not_strict, "strict")
@test output_not_strict["strict"] == false
@test output_not_strict["name"] == "MyMeasurement3_extractor"
@test output_not_strict["parameters"]["type"] == "object"
@test output_not_strict["parameters"]["required"] == ["age"]
@test !haskey(output_default["parameters"], "additionalProperties")

# Test with strict = true
output_strict = function_call_signature(MyMeasurement3; strict = true)
@test output_strict["strict"] == true
@test output_strict["name"] == "MyMeasurement3_extractor"
@test output_strict["parameters"]["type"] == "object"
@test Set(output_strict["parameters"]["required"]) == Set(["age", "height", "weight"])
@test output_strict["parameters"]["additionalProperties"] == false
@test output_strict["parameters"]["properties"]["height"]["type"] == ["integer", "null"]
@test output_strict["parameters"]["properties"]["weight"]["type"] == ["number", "null"]
@test output_strict["parameters"]["properties"]["age"]["type"] == "integer"

# Test with MaybeExtract wrapper
output_maybe = function_call_signature(MaybeExtract{MyMeasurement3}; strict = true)
@test output_maybe["name"] == "MaybeExtractMyMeasurement3_extractor"
@test output_maybe["parameters"]["properties"]["result"]["type"] == ["object", "null"]
@test output_maybe["parameters"]["properties"]["error"]["type"] == "boolean"
@test output_maybe["parameters"]["properties"]["message"]["type"] == ["string", "null"]
@test Set(output_maybe["parameters"]["required"]) == Set(["result", "error", "message"])
end

0 comments on commit 8f697f6

Please sign in to comment.