Skip to content

Commit

Permalink
Update Structured Extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Sep 3, 2024
1 parent 5fe770a commit 1c8fb7d
Show file tree
Hide file tree
Showing 10 changed files with 408 additions and 61 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

## [0.51.0]

### Added
- Added more flexible structured extraction with `aiextract` -> now you can simply provide the field names and, optionally, their types without specifying the struct itself (in `aiextract`, provide the fields like `return_type = [:field_name => field_type]`).
- Added a way to attach field-level descriptions to the generated JSON schemas to better structured extraction (see `?update_schema_descriptions!` to see the syntax), which was not possible with struct-only extraction.

## [0.50.0]

### Breaking Changes
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PromptingTools"
uuid = "670122d1-24a8-4d70-bfce-740807c42192"
authors = ["J S @svilupp and contributors"]
version = "0.50.0"
version = "0.51.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
4 changes: 2 additions & 2 deletions src/Experimental/RAGTools/evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ end
context::AbstractString
question::AbstractString
answer::AbstractString
retrieval_score::Union{Number, Nothing} = nothing
retrieval_score::Union{Float64, Nothing} = nothing
retrieval_rank::Union{Int, Nothing} = nothing
answer_score::Union{Number, Nothing} = nothing
answer_score::Union{Float64, Nothing} = nothing
parameters::Dict{Symbol, Any} = Dict{Symbol, Any}()
end

Expand Down
222 changes: 190 additions & 32 deletions src/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,159 @@ function to_json_schema(type::Type{<:AbstractDict}; max_description_length::Int
throw(ArgumentError("Dicts are not supported yet as we cannot analyze their keys/values on a type-level. Use a nested Struct instead!"))
end

### Type conversion / Schema generation
"""
generate_struct(fields::Vector)
Generate a struct with the given name and fields. Fields can be specified simply as symbols (with default type `String`) or pairs of symbol and type.
Field descriptions can be provided by adding a pair with the field name suffixed with "__description" (eg, `:myfield__description => "My field description"`).
Returns: A tuple of (struct type, descriptions)
# Examples
```julia
Weather, descriptions = generate_struct(
[:location,
:temperature=>Float64,
:temperature__description=>"Temperature in degrees Fahrenheit",
:condition=>String,
:condition__description=>"Current weather condition (e.g., sunny, rainy, cloudy)"
])
```
"""
function generate_struct(fields::Vector)
name = gensym("ExtractedData")
struct_fields = []
descriptions = Dict{Symbol, String}()

for field in fields
if field isa Symbol
push!(struct_fields, :($field::String))
elseif field isa Pair
field_name, field_value = field
if endswith(string(field_name), "__description")
base_field = Symbol(replace(string(field_name), "__description" => ""))
descriptions[base_field] = field_value
elseif field_name isa Symbol &&
(field_value isa Type || field_value isa AbstractString)
push!(struct_fields, :($field_name::$field_value))
else
error("Invalid field specification: $(field). It must be a Symbol or a Pair{Symbol, Type} or Pair{Symbol, Pair{Type, String}}.")
end
else
error("Invalid field specification: $(field). It must be a Symbol or a Pair{Symbol, Type} or Pair{Symbol, Pair{Type, String}}.")
end
end

struct_def = quote
@kwdef struct $name <: AbstractExtractedData
$(struct_fields...)
end
end

# Evaluate the struct definition
eval(struct_def)

return eval(name), descriptions
end

"""
update_schema_descriptions!(
schema::Dict{String, <:Any}, descriptions::Dict{Symbol, <:AbstractString};
max_description_length::Int = 200)
Update the given JSON schema with descriptions from the `descriptions` dictionary.
This function modifies the schema in-place, adding a "description" field to each property
that has a corresponding entry in the `descriptions` dictionary.
Note: It modifies the schema in place. Only the top-level "properties" are updated!
Returns: The modified schema dictionary.
# Arguments
- `schema`: A dictionary representing the JSON schema to be updated.
- `descriptions`: A dictionary mapping field names (as symbols) to their descriptions.
- `max_description_length::Int`: Maximum length for descriptions. Defaults to 200.
# Examples
```julia
schema = Dict{String, Any}(
"name" => "varExtractedData235_extractor",
"parameters" => Dict{String, Any}(
"properties" => Dict{String, Any}(
"location" => Dict{String, Any}("type" => "string"),
"condition" => Dict{String, Any}("type" => "string"),
"temperature" => Dict{String, Any}("type" => "number")
),
"required" => ["location", "temperature", "condition"],
"type" => "object"
)
)
descriptions = Dict{Symbol, String}(
:temperature => "Temperature in degrees Fahrenheit",
:condition => "Current weather condition (e.g., sunny, rainy, cloudy)"
)
update_schema_descriptions!(schema, descriptions)
```
"""
function update_schema_descriptions!(
schema::Dict{String, <:Any}, descriptions::Dict{Symbol, <:AbstractString};
max_description_length::Int = 200)
properties = get(get(schema, "parameters", Dict()), "properties", Dict())

for (field, field_schema) in properties
field_sym = Symbol(field)
if haskey(descriptions, field_sym)
field_schema["description"] = first(
descriptions[field_sym], max_description_length)
end
end

return schema
end

"""
set_properties_strict!(properties::AbstractDict)
Sets strict mode for the properties of a JSON schema.
Changes:
- Sets `additionalProperties` to `false`.
- All keys must be included in `required`.
- All optional keys will have `null` added to their type.
Reference: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
"""
function set_properties_strict!(parameters::AbstractDict)
parameters["additionalProperties"] = false
required_fields = get(parameters, "required", String[])
optional_fields = String[]

for (key, value) in parameters["properties"]
if key required_fields
push!(optional_fields, key)
if haskey(value, "type")
value["type"] = [value["type"], "null"]
end
end

# Recursively apply to nested properties
if haskey(value, "properties")
set_properties_strict!(value)
elseif haskey(value, "items") && haskey(value["items"], "properties")
## if it's an array, we need to skip inside "items"
set_properties_strict!(value["items"])
end
end

parameters["required"] = vcat(required_fields, optional_fields)
return parameters
end

"""
function_call_signature(
datastructtype::Type; strict::Union{Nothing, Bool} = nothing,
max_description_length::Int = 100)
max_description_length::Int = 200)
Extract the argument names, types and docstrings from a struct to create the function call signature in JSON schema.
Expand All @@ -120,7 +269,7 @@ struct MyMeasurement
height::Union{Int,Nothing}
weight::Union{Nothing,Float64}
end
signature = function_call_signature(MyMeasurement)
signature, t = function_call_signature(MyMeasurement)
#
# Dict{String, Any} with 3 entries:
# "name" => "MyMeasurement_extractor"
Expand Down Expand Up @@ -166,7 +315,7 @@ That way, you can handle the error gracefully and get a reason why extraction fa
"""
function function_call_signature(
datastructtype::Type; strict::Union{Nothing, Bool} = nothing,
max_description_length::Int = 100)
max_description_length::Int = 200)
!isstructtype(datastructtype) &&
error("Only Structs are supported (provided type: $datastructtype")
## Standardize the name
Expand All @@ -191,45 +340,54 @@ function function_call_signature(
set_properties_strict!(schema["parameters"])
end
end
return schema
return schema, datastructtype
end

"""
set_properties_strict!(properties::AbstractDict)
function_call_signature(fields::Vector; strict::Union{Nothing, Bool} = nothing, max_description_length::Int = 200)
Sets strict mode for the properties of a JSON schema.
Generate a function call signature schema for a dynamically generated struct based on the provided fields.
Changes:
- Sets `additionalProperties` to `false`.
- All keys must be included in `required`.
- All optional keys will have `null` added to their type.
# Arguments
- `fields::Vector{Union{Symbol, Pair{Symbol, Type}, Pair{Symbol, String}}}`: A vector of field names or pairs of field name and type or string description, eg, `[:field1, :field2, :field3]` or `[:field1 => String, :field2 => Int, :field3 => Float64]` or `[:field1 => String, :field1__description => "Field 1 has the name"]`.
- `strict::Union{Nothing, Bool}`: Whether to enforce strict mode for the schema. Defaults to `nothing`.
- `max_description_length::Int`: Maximum length for descriptions. Defaults to 200.
Reference: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
# Returns a tuple of (schema, struct type)
- `Dict{String, Any}`: A dictionary representing the function call signature schema.
- `Type`: The struct type to create instance of the result.
See also `generate_struct`, `aiextract`, `update_schema_descriptions!`.
# Examples
```julia
schema, return_type = function_call_signature([:field1, :field2, :field3])
```
With the field types:
```julia
schema, return_type = function_call_signature([:field1 => String, :field2 => Int, :field3 => Float64])
```
And with the field descriptions:
```julia
schema, return_type = function_call_signature([:field1 => String, :field1__description => "Field 1 has the name"])
```
"""
function set_properties_strict!(parameters::AbstractDict)
parameters["additionalProperties"] = false
required_fields = get(parameters, "required", String[])
optional_fields = String[]
function function_call_signature(fields::Vector;
strict::Union{Nothing, Bool} = nothing, max_description_length::Int = 200)
@assert all(x -> x isa Symbol || x isa Pair, fields) "Invalid return types provided. All fields must be either Symbols or Pairs of Symbol and Type or String"
# Generate the struct and descriptions
datastructtype, descriptions = generate_struct(fields)

for (key, value) in parameters["properties"]
if key required_fields
push!(optional_fields, key)
if haskey(value, "type")
value["type"] = [value["type"], "null"]
end
end
# Create the schema
schema, _ = function_call_signature(
datastructtype; strict, max_description_length)

# Recursively apply to nested properties
if haskey(value, "properties")
set_properties_strict!(value)
elseif haskey(value, "items") && haskey(value["items"], "properties")
## if it's an array, we need to skip inside "items"
set_properties_strict!(value["items"])
end
end
# Update the schema with descriptions
update_schema_descriptions!(schema, descriptions; max_description_length)

parameters["required"] = vcat(required_fields, optional_fields)
return parameters
return schema, datastructtype
end

######################
Expand Down
31 changes: 27 additions & 4 deletions src/llm_anthropic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ end

"""
aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_TYPE;
return_type::Type,
return_type::Union{Type, Vector},
verbose::Bool = true,
api_key::String = ANTHROPIC_API_KEY,
model::String = MODEL_CHAT,
Expand All @@ -338,6 +338,7 @@ It's effectively a light wrapper around `aigenerate` call, which requires additi
- `prompt`: Can be a string representing the prompt for the AI conversation, a `UserMessage`, a vector of `AbstractMessage` or an `AITemplate`
- `return_type`: A **struct** TYPE representing the the information we want to extract. Do not provide a struct instance, only the type.
If the struct has a docstring, it will be provided to the model as well. It's used to enforce structured model outputs or provide more information.
Alternatively, you can provide a vector of field names and their types (see `?generate_struct` function for the syntax).
- `verbose`: A boolean indicating whether to print additional information.
- `api_key`: A string representing the API key for accessing the OpenAI API.
- `model`: A string representing the model to use for generating the response. Can be an alias corresponding to a model ID defined in `MODEL_ALIASES`.
Expand Down Expand Up @@ -443,10 +444,32 @@ Note that when using a prompt template, we provide `data` for the extraction as
Note that the error message refers to a giraffe not being a human,
because in our `MyMeasurement` docstring, we said that it's for people!
Example of using a vector of field names with `aiextract`
```julia
fields = [:location, :temperature => Float64, :condition => String]
msg = aiextract("Extract the following information from the text: location, temperature, condition. Text: The weather in New York is sunny and 72.5 degrees Fahrenheit.";
return_type = fields, model="claudeh")
```
Or simply call `aiextract("some text"; return_type = [:reasoning,:answer], model="claudeh")` to get a Chain of Thought reasoning for extraction task.
It will be returned it a new generated type, which you can check with `PromptingTools.isextracted(msg.content) == true` to confirm the data has been extracted correctly.
This new syntax also allows you to provide field-level descriptions, which will be passed to the model.
```julia
fields_with_descriptions = [
:location,
:temperature => Float64,
:temperature__description => "Temperature in degrees Fahrenheit",
:condition => String,
:condition__description => "Current weather condition (e.g., sunny, rainy, cloudy)"
]
msg = aiextract("The weather in New York is sunny and 72.5 degrees Fahrenheit."; return_type = fields_with_descriptions, model="claudeh")
```
"""
function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_TYPE;
return_type::Type,
return_type::Union{Type, Vector},
verbose::Bool = true,
api_key::String = ANTHROPIC_API_KEY,
model::String = MODEL_CHAT,
Expand All @@ -465,7 +488,7 @@ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMP
model_id = get(MODEL_ALIASES, model, model)

## Tools definition
sig = function_call_signature(return_type; max_description_length = 100)
sig, datastructtype = function_call_signature(return_type; max_description_length = 100)
tools = [Dict("name" => sig["name"], "description" => get(sig, "description", ""),
"input_schema" => sig["parameters"])]
## update tools to use caching
Expand Down Expand Up @@ -493,7 +516,7 @@ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMP
## parse it into object
arguments = JSON3.write(contents[1][:input])
try
JSON3.read(arguments, return_type)
Base.invokelatest(JSON3.read, arguments, datastructtype)
catch e
@warn "There was an error parsing the response: $e. Using the raw response instead."
JSON3.read(arguments) |> copy
Expand Down
7 changes: 7 additions & 0 deletions src/llm_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,10 @@ function response_to_message(schema::AbstractPromptSchema,
sample_id::Union{Nothing, Integer} = nothing) where {T}
throw(ArgumentError("Response unwrapping not implemented for $(typeof(schema)) and $MSG"))
end

### For structured extraction
# We can generate fields, they will all share this parent type
abstract type AbstractExtractedData end
Base.show(io::IO, x::AbstractExtractedData) = dump(io, x; maxdepth = 1)
"Check if the object is an instance of `AbstractExtractedData`"
isextracted(x) = x isa AbstractExtractedData
Loading

0 comments on commit 1c8fb7d

Please sign in to comment.