From 1c8fb7df26f16fc4406d0a041d96973fa54335a1 Mon Sep 17 00:00:00 2001 From: J S <49557684+svilupp@users.noreply.github.com> Date: Tue, 3 Sep 2024 09:28:27 +0100 Subject: [PATCH] Update Structured Extraction --- CHANGELOG.md | 6 + Project.toml | 2 +- src/Experimental/RAGTools/evaluation.jl | 4 +- src/extraction.jl | 222 +++++++++++++++++++---- src/llm_anthropic.jl | 31 +++- src/llm_interface.jl | 7 + src/llm_openai.jl | 45 +++-- test/Experimental/RAGTools/evaluation.jl | 5 +- test/extraction.jl | 136 +++++++++++++- test/llm_openai.jl | 11 +- 10 files changed, 408 insertions(+), 61 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8566ddb05..374b56cb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +## [0.51.0] + +### Added +- Added more flexible structured extraction with `aiextract` -> now you can simply provide the field names and, optionally, their types without specifying the struct itself (in `aiextract`, provide the fields like `return_type = [:field_name => field_type]`). +- Added a way to attach field-level descriptions to the generated JSON schemas to better structured extraction (see `?update_schema_descriptions!` to see the syntax), which was not possible with struct-only extraction. + ## [0.50.0] ### Breaking Changes diff --git a/Project.toml b/Project.toml index 372b15dbb..eec2298ba 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PromptingTools" uuid = "670122d1-24a8-4d70-bfce-740807c42192" authors = ["J S @svilupp and contributors"] -version = "0.50.0" +version = "0.51.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/Experimental/RAGTools/evaluation.jl b/src/Experimental/RAGTools/evaluation.jl index fb69c87a5..47531f5ed 100644 --- a/src/Experimental/RAGTools/evaluation.jl +++ b/src/Experimental/RAGTools/evaluation.jl @@ -17,9 +17,9 @@ end context::AbstractString question::AbstractString answer::AbstractString - retrieval_score::Union{Number, Nothing} = nothing + retrieval_score::Union{Float64, Nothing} = nothing retrieval_rank::Union{Int, Nothing} = nothing - answer_score::Union{Number, Nothing} = nothing + answer_score::Union{Float64, Nothing} = nothing parameters::Dict{Symbol, Any} = Dict{Symbol, Any}() end diff --git a/src/extraction.jl b/src/extraction.jl index 0f591fdeb..2903919be 100644 --- a/src/extraction.jl +++ b/src/extraction.jl @@ -93,10 +93,159 @@ function to_json_schema(type::Type{<:AbstractDict}; max_description_length::Int throw(ArgumentError("Dicts are not supported yet as we cannot analyze their keys/values on a type-level. Use a nested Struct instead!")) end +### Type conversion / Schema generation +""" + generate_struct(fields::Vector) + +Generate a struct with the given name and fields. Fields can be specified simply as symbols (with default type `String`) or pairs of symbol and type. +Field descriptions can be provided by adding a pair with the field name suffixed with "__description" (eg, `:myfield__description => "My field description"`). + +Returns: A tuple of (struct type, descriptions) + +# Examples +```julia +Weather, descriptions = generate_struct( + [:location, + :temperature=>Float64, + :temperature__description=>"Temperature in degrees Fahrenheit", + :condition=>String, + :condition__description=>"Current weather condition (e.g., sunny, rainy, cloudy)" + ]) +``` +""" +function generate_struct(fields::Vector) + name = gensym("ExtractedData") + struct_fields = [] + descriptions = Dict{Symbol, String}() + + for field in fields + if field isa Symbol + push!(struct_fields, :($field::String)) + elseif field isa Pair + field_name, field_value = field + if endswith(string(field_name), "__description") + base_field = Symbol(replace(string(field_name), "__description" => "")) + descriptions[base_field] = field_value + elseif field_name isa Symbol && + (field_value isa Type || field_value isa AbstractString) + push!(struct_fields, :($field_name::$field_value)) + else + error("Invalid field specification: $(field). It must be a Symbol or a Pair{Symbol, Type} or Pair{Symbol, Pair{Type, String}}.") + end + else + error("Invalid field specification: $(field). It must be a Symbol or a Pair{Symbol, Type} or Pair{Symbol, Pair{Type, String}}.") + end + end + + struct_def = quote + @kwdef struct $name <: AbstractExtractedData + $(struct_fields...) + end + end + + # Evaluate the struct definition + eval(struct_def) + + return eval(name), descriptions +end + +""" + update_schema_descriptions!( + schema::Dict{String, <:Any}, descriptions::Dict{Symbol, <:AbstractString}; + max_description_length::Int = 200) + +Update the given JSON schema with descriptions from the `descriptions` dictionary. +This function modifies the schema in-place, adding a "description" field to each property +that has a corresponding entry in the `descriptions` dictionary. + +Note: It modifies the schema in place. Only the top-level "properties" are updated! + +Returns: The modified schema dictionary. + +# Arguments +- `schema`: A dictionary representing the JSON schema to be updated. +- `descriptions`: A dictionary mapping field names (as symbols) to their descriptions. +- `max_description_length::Int`: Maximum length for descriptions. Defaults to 200. + +# Examples +```julia + schema = Dict{String, Any}( + "name" => "varExtractedData235_extractor", + "parameters" => Dict{String, Any}( + "properties" => Dict{String, Any}( + "location" => Dict{String, Any}("type" => "string"), + "condition" => Dict{String, Any}("type" => "string"), + "temperature" => Dict{String, Any}("type" => "number") + ), + "required" => ["location", "temperature", "condition"], + "type" => "object" + ) + ) + descriptions = Dict{Symbol, String}( + :temperature => "Temperature in degrees Fahrenheit", + :condition => "Current weather condition (e.g., sunny, rainy, cloudy)" + ) + update_schema_descriptions!(schema, descriptions) +``` +""" +function update_schema_descriptions!( + schema::Dict{String, <:Any}, descriptions::Dict{Symbol, <:AbstractString}; + max_description_length::Int = 200) + properties = get(get(schema, "parameters", Dict()), "properties", Dict()) + + for (field, field_schema) in properties + field_sym = Symbol(field) + if haskey(descriptions, field_sym) + field_schema["description"] = first( + descriptions[field_sym], max_description_length) + end + end + + return schema +end + +""" + set_properties_strict!(properties::AbstractDict) + +Sets strict mode for the properties of a JSON schema. + +Changes: +- Sets `additionalProperties` to `false`. +- All keys must be included in `required`. +- All optional keys will have `null` added to their type. + +Reference: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas +""" +function set_properties_strict!(parameters::AbstractDict) + parameters["additionalProperties"] = false + required_fields = get(parameters, "required", String[]) + optional_fields = String[] + + for (key, value) in parameters["properties"] + if key ∉ required_fields + push!(optional_fields, key) + if haskey(value, "type") + value["type"] = [value["type"], "null"] + end + end + + # Recursively apply to nested properties + if haskey(value, "properties") + set_properties_strict!(value) + elseif haskey(value, "items") && haskey(value["items"], "properties") + ## if it's an array, we need to skip inside "items" + set_properties_strict!(value["items"]) + end + end + + parameters["required"] = vcat(required_fields, optional_fields) + return parameters +end + """ function_call_signature( datastructtype::Type; strict::Union{Nothing, Bool} = nothing, - max_description_length::Int = 100) + max_description_length::Int = 200) Extract the argument names, types and docstrings from a struct to create the function call signature in JSON schema. @@ -120,7 +269,7 @@ struct MyMeasurement height::Union{Int,Nothing} weight::Union{Nothing,Float64} end -signature = function_call_signature(MyMeasurement) +signature, t = function_call_signature(MyMeasurement) # # Dict{String, Any} with 3 entries: # "name" => "MyMeasurement_extractor" @@ -166,7 +315,7 @@ That way, you can handle the error gracefully and get a reason why extraction fa """ function function_call_signature( datastructtype::Type; strict::Union{Nothing, Bool} = nothing, - max_description_length::Int = 100) + max_description_length::Int = 200) !isstructtype(datastructtype) && error("Only Structs are supported (provided type: $datastructtype") ## Standardize the name @@ -191,45 +340,54 @@ function function_call_signature( set_properties_strict!(schema["parameters"]) end end - return schema + return schema, datastructtype end """ - set_properties_strict!(properties::AbstractDict) + function_call_signature(fields::Vector; strict::Union{Nothing, Bool} = nothing, max_description_length::Int = 200) -Sets strict mode for the properties of a JSON schema. +Generate a function call signature schema for a dynamically generated struct based on the provided fields. -Changes: -- Sets `additionalProperties` to `false`. -- All keys must be included in `required`. -- All optional keys will have `null` added to their type. +# Arguments +- `fields::Vector{Union{Symbol, Pair{Symbol, Type}, Pair{Symbol, String}}}`: A vector of field names or pairs of field name and type or string description, eg, `[:field1, :field2, :field3]` or `[:field1 => String, :field2 => Int, :field3 => Float64]` or `[:field1 => String, :field1__description => "Field 1 has the name"]`. +- `strict::Union{Nothing, Bool}`: Whether to enforce strict mode for the schema. Defaults to `nothing`. +- `max_description_length::Int`: Maximum length for descriptions. Defaults to 200. -Reference: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas +# Returns a tuple of (schema, struct type) +- `Dict{String, Any}`: A dictionary representing the function call signature schema. +- `Type`: The struct type to create instance of the result. + +See also `generate_struct`, `aiextract`, `update_schema_descriptions!`. + +# Examples +```julia +schema, return_type = function_call_signature([:field1, :field2, :field3]) +``` + +With the field types: +```julia +schema, return_type = function_call_signature([:field1 => String, :field2 => Int, :field3 => Float64]) +``` + +And with the field descriptions: +```julia +schema, return_type = function_call_signature([:field1 => String, :field1__description => "Field 1 has the name"]) +``` """ -function set_properties_strict!(parameters::AbstractDict) - parameters["additionalProperties"] = false - required_fields = get(parameters, "required", String[]) - optional_fields = String[] +function function_call_signature(fields::Vector; + strict::Union{Nothing, Bool} = nothing, max_description_length::Int = 200) + @assert all(x -> x isa Symbol || x isa Pair, fields) "Invalid return types provided. All fields must be either Symbols or Pairs of Symbol and Type or String" + # Generate the struct and descriptions + datastructtype, descriptions = generate_struct(fields) - for (key, value) in parameters["properties"] - if key ∉ required_fields - push!(optional_fields, key) - if haskey(value, "type") - value["type"] = [value["type"], "null"] - end - end + # Create the schema + schema, _ = function_call_signature( + datastructtype; strict, max_description_length) - # Recursively apply to nested properties - if haskey(value, "properties") - set_properties_strict!(value) - elseif haskey(value, "items") && haskey(value["items"], "properties") - ## if it's an array, we need to skip inside "items" - set_properties_strict!(value["items"]) - end - end + # Update the schema with descriptions + update_schema_descriptions!(schema, descriptions; max_description_length) - parameters["required"] = vcat(required_fields, optional_fields) - return parameters + return schema, datastructtype end ###################### diff --git a/src/llm_anthropic.jl b/src/llm_anthropic.jl index ac893d651..4629fdaa4 100644 --- a/src/llm_anthropic.jl +++ b/src/llm_anthropic.jl @@ -312,7 +312,7 @@ end """ aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_TYPE; - return_type::Type, + return_type::Union{Type, Vector}, verbose::Bool = true, api_key::String = ANTHROPIC_API_KEY, model::String = MODEL_CHAT, @@ -338,6 +338,7 @@ It's effectively a light wrapper around `aigenerate` call, which requires additi - `prompt`: Can be a string representing the prompt for the AI conversation, a `UserMessage`, a vector of `AbstractMessage` or an `AITemplate` - `return_type`: A **struct** TYPE representing the the information we want to extract. Do not provide a struct instance, only the type. If the struct has a docstring, it will be provided to the model as well. It's used to enforce structured model outputs or provide more information. + Alternatively, you can provide a vector of field names and their types (see `?generate_struct` function for the syntax). - `verbose`: A boolean indicating whether to print additional information. - `api_key`: A string representing the API key for accessing the OpenAI API. - `model`: A string representing the model to use for generating the response. Can be an alias corresponding to a model ID defined in `MODEL_ALIASES`. @@ -443,10 +444,32 @@ Note that when using a prompt template, we provide `data` for the extraction as Note that the error message refers to a giraffe not being a human, because in our `MyMeasurement` docstring, we said that it's for people! + +Example of using a vector of field names with `aiextract` +```julia +fields = [:location, :temperature => Float64, :condition => String] +msg = aiextract("Extract the following information from the text: location, temperature, condition. Text: The weather in New York is sunny and 72.5 degrees Fahrenheit."; +return_type = fields, model="claudeh") +``` + +Or simply call `aiextract("some text"; return_type = [:reasoning,:answer], model="claudeh")` to get a Chain of Thought reasoning for extraction task. + +It will be returned it a new generated type, which you can check with `PromptingTools.isextracted(msg.content) == true` to confirm the data has been extracted correctly. + +This new syntax also allows you to provide field-level descriptions, which will be passed to the model. +```julia +fields_with_descriptions = [ + :location, + :temperature => Float64, + :temperature__description => "Temperature in degrees Fahrenheit", + :condition => String, + :condition__description => "Current weather condition (e.g., sunny, rainy, cloudy)" +] +msg = aiextract("The weather in New York is sunny and 72.5 degrees Fahrenheit."; return_type = fields_with_descriptions, model="claudeh") ``` """ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMPT_TYPE; - return_type::Type, + return_type::Union{Type, Vector}, verbose::Bool = true, api_key::String = ANTHROPIC_API_KEY, model::String = MODEL_CHAT, @@ -465,7 +488,7 @@ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMP model_id = get(MODEL_ALIASES, model, model) ## Tools definition - sig = function_call_signature(return_type; max_description_length = 100) + sig, datastructtype = function_call_signature(return_type; max_description_length = 100) tools = [Dict("name" => sig["name"], "description" => get(sig, "description", ""), "input_schema" => sig["parameters"])] ## update tools to use caching @@ -493,7 +516,7 @@ function aiextract(prompt_schema::AbstractAnthropicSchema, prompt::ALLOWED_PROMP ## parse it into object arguments = JSON3.write(contents[1][:input]) try - JSON3.read(arguments, return_type) + Base.invokelatest(JSON3.read, arguments, datastructtype) catch e @warn "There was an error parsing the response: $e. Using the raw response instead." JSON3.read(arguments) |> copy diff --git a/src/llm_interface.jl b/src/llm_interface.jl index 926487342..a56e84e10 100644 --- a/src/llm_interface.jl +++ b/src/llm_interface.jl @@ -439,3 +439,10 @@ function response_to_message(schema::AbstractPromptSchema, sample_id::Union{Nothing, Integer} = nothing) where {T} throw(ArgumentError("Response unwrapping not implemented for $(typeof(schema)) and $MSG")) end + +### For structured extraction +# We can generate fields, they will all share this parent type +abstract type AbstractExtractedData end +Base.show(io::IO, x::AbstractExtractedData) = dump(io, x; maxdepth = 1) +"Check if the object is an instance of `AbstractExtractedData`" +isextracted(x) = x isa AbstractExtractedData \ No newline at end of file diff --git a/src/llm_openai.jl b/src/llm_openai.jl index 2f3f51f31..bc7ca3c15 100644 --- a/src/llm_openai.jl +++ b/src/llm_openai.jl @@ -980,13 +980,13 @@ function response_to_message(schema::AbstractOpenAISchema, tokens_completion = get(resp.response, :usage, Dict(:completion_tokens => 0))[:completion_tokens] cost = call_cost(tokens_prompt, tokens_completion, model_id) # "Safe" parsing of the response - it still fails if JSON is invalid + args = choice[:message][:tool_calls][1][:function][:arguments] content = try - choice[:message][:tool_calls][1][:function][:arguments] |> - x -> JSON3.read(x, return_type) + ## Must invoke latest because we might have generated the struct + Base.invokelatest(JSON3.read, args, return_type)::return_type catch e @warn "There was an error parsing the response: $e. Using the raw response instead." - choice[:message][:tool_calls][1][:function][:arguments] |> - JSON3.read |> copy + JSON3.read(args) |> copy end ## build DataMessage object msg = MSG(; @@ -1004,7 +1004,7 @@ end """ aiextract(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYPE; - return_type::Type, + return_type::Union{Type, Vector}, verbose::Bool = true, api_key::String = OPENAI_API_KEY, model::String = MODEL_CHAT, @@ -1027,7 +1027,7 @@ It's effectively a light wrapper around `aigenerate` call, which requires additi # Arguments - `prompt_schema`: An optional object to specify which prompt template should be applied (Default to `PROMPT_SCHEMA = OpenAISchema`) - `prompt`: Can be a string representing the prompt for the AI conversation, a `UserMessage`, a vector of `AbstractMessage` or an `AITemplate` -- `return_type`: A **struct** TYPE representing the the information we want to extract. Do not provide a struct instance, only the type. +- `return_type`: A **struct** TYPE representing the the information we want to extract. Do not provide a struct instance, only the type. Alternatively, you can provide a vector of field names and their types (see `?generate_struct` function for the syntax). If the struct has a docstring, it will be provided to the model as well. It's used to enforce structured model outputs or provide more information. - `verbose`: A boolean indicating whether to print additional information. - `api_key`: A string representing the API key for accessing the OpenAI API. @@ -1052,7 +1052,7 @@ If `return_all=true`: - `conversation`: A vector of `AbstractMessage` objects representing the full conversation history, including the response from the AI model (`DataMessage`). -See also: `function_call_signature`, `MaybeExtract`, `ItemsExtract`, `aigenerate` +See also: `function_call_signature`, `MaybeExtract`, `ItemsExtract`, `aigenerate`, `generate_struct` # Example @@ -1135,9 +1135,31 @@ end aiextract("I ate an apple",return_type=Fruit,api_kwargs=(;tool_choice="any"),model="mistrall") # Notice two differences: 1) struct MUST have a docstring, 2) tool_choice is set explicitly set to "any" ``` + +Example of using a vector of field names with `aiextract` +```julia +fields = [:location, :temperature => Float64, :condition => String] +msg = aiextract("Extract the following information from the text: location, temperature, condition. Text: The weather in New York is sunny and 72.5 degrees Fahrenheit."; return_type = fields) +``` + +Or simply call `aiextract("some text"; return_type = [:reasoning,:answer])` to get a Chain of Thought reasoning for extraction task. + +It will be returned it a new generated type, which you can check with `PromptingTools.isextracted(msg.content) == true` to confirm the data has been extracted correctly. + +This new syntax also allows you to provide field-level descriptions, which will be passed to the model. +```julia +fields_with_descriptions = [ + :location, + :temperature => Float64, + :temperature__description => "Temperature in degrees Fahrenheit", + :condition => String, + :condition__description => "Current weather condition (e.g., sunny, rainy, cloudy)" +] +msg = aiextract("The weather in New York is sunny and 72.5 degrees Fahrenheit."; return_type = fields_with_descriptions) +``` """ function aiextract(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYPE; - return_type::Type, + return_type::Union{Type, Vector}, verbose::Bool = true, api_key::String = OPENAI_API_KEY, model::String = MODEL_CHAT, @@ -1152,8 +1174,9 @@ function aiextract(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_T ## global MODEL_ALIASES ## Function calling specifics + schema, datastructtype = function_call_signature(return_type; strict) tools = [Dict( - :type => "function", :function => function_call_signature(return_type; strict))] + :type => "function", :function => schema)] ## force our function to be used tool_choice_ = get(api_kwargs, :tool_choice, "exact") tool_choice = if tool_choice_ == "exact" @@ -1182,7 +1205,7 @@ function aiextract(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_T run_id = Int(rand(Int32)) # remember one run ID ## extract all message msgs = [response_to_message(prompt_schema, DataMessage, choice, r; - return_type, time, model_id, run_id, sample_id = i) + return_type = datastructtype, time, model_id, run_id, sample_id = i) for (i, choice) in enumerate(r.response[:choices])] ## Order by log probability if available ## bigger is better, keep it last @@ -1195,7 +1218,7 @@ function aiextract(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_T ## only 1 sample / 1 completion choice = r.response[:choices][begin] response_to_message(prompt_schema, DataMessage, choice, r; - return_type, time, model_id) + return_type = datastructtype, time, model_id) end ## Reporting verbose && @info _report_stats(msg, model_id) diff --git a/test/Experimental/RAGTools/evaluation.jl b/test/Experimental/RAGTools/evaluation.jl index eb97c4329..7bda831c2 100644 --- a/test/Experimental/RAGTools/evaluation.jl +++ b/test/Experimental/RAGTools/evaluation.jl @@ -1,7 +1,8 @@ using PromptingTools.Experimental.RAGTools: QAItem, QAEvalItem, QAEvalResult using PromptingTools.Experimental.RAGTools: score_retrieval_hit, score_retrieval_rank using PromptingTools.Experimental.RAGTools: build_qa_evals, run_qa_evals, chunks, sources -using PromptingTools.Experimental.RAGTools: JudgeAllScores, Tag, MaybeTags +using PromptingTools.Experimental.RAGTools: JudgeAllScores, Tag, MaybeTags, ChunkIndex, + RAGConfig, airag @testset "QAEvalItem" begin empty_qa = QAEvalItem() @@ -75,7 +76,7 @@ end @testset "build_qa_evals" begin # test with a mock server - PORT = rand(10005:40001) + PORT = rand(10005:40010) PT.register_model!(; name = "mock-emb", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-meta", schema = PT.CustomOpenAISchema()) PT.register_model!(; name = "mock-gen", schema = PT.CustomOpenAISchema()) diff --git a/test/extraction.jl b/test/extraction.jl index e24cb9fa1..4a91bd320 100644 --- a/test/extraction.jl +++ b/test/extraction.jl @@ -1,6 +1,7 @@ using PromptingTools: MaybeExtract, extract_docstring, ItemsExtract using PromptingTools: has_null_type, is_required_field, remove_null_types, to_json_schema -using PromptingTools: function_call_signature, set_properties_strict! +using PromptingTools: function_call_signature, set_properties_strict!, + update_schema_descriptions!, generate_struct # TODO: check more edge cases like empty structs @@ -310,6 +311,77 @@ end @test !haskey(params["properties"]["email"], "null") end +@testset "generate_struct" begin + # Test with only field names + fields = [:field1, :field2, :field3] + struct_type, descriptions = generate_struct(fields) + @test fieldnames(struct_type) == (:field1, :field2, :field3) + @test descriptions == Dict{Symbol, String}() + + # Test with field names and types + fields = [:field1 => Int, :field2 => String, :field3 => Float64] + struct_type, descriptions = generate_struct(fields) + @test fieldnames(struct_type) == (:field1, :field2, :field3) + @test fieldtypes(struct_type) == (Int, String, Float64) + @test descriptions == Dict{Symbol, String}() + + # Test with field names, types, and descriptions + fields = [:field1 => Int, :field2 => String, :field3 => Float64, + :field1__description => "Field 1 description", + :field2__description => "Field 2 description"] + struct_type, descriptions = generate_struct(fields) + @test fieldnames(struct_type) == (:field1, :field2, :field3) + @test fieldtypes(struct_type) == (Int, String, Float64) + @test descriptions == + Dict(:field1 => "Field 1 description", :field2 => "Field 2 description") + + # Test with invalid field specification + fields = [:field1 => Int, :field2 => :InvalidType] + @test_throws ErrorException generate_struct(fields) +end + +@testset "update_schema_descriptions!" begin + # Test with empty descriptions + schema = Dict("parameters" => Dict("properties" => Dict("field1" => Dict("type" => "string")))) + descriptions = Dict{Symbol, String}() + updated_schema = update_schema_descriptions!(schema, descriptions) + @test !haskey(updated_schema["parameters"]["properties"]["field1"], "description") + + # Test with descriptions provided + schema = Dict("parameters" => Dict("properties" => Dict("field1" => Dict("type" => "string")))) + descriptions = Dict(:field1 => "Field 1 description") + updated_schema = update_schema_descriptions!(schema, descriptions) + @test updated_schema["parameters"]["properties"]["field1"]["description"] == + "Field 1 description" + + # Test with max_description_length + schema = Dict("parameters" => Dict("properties" => Dict("field1" => Dict("type" => "string")))) + descriptions = Dict(:field1 => "Field 1 description is very long and should be truncated") + updated_schema = update_schema_descriptions!( + schema, descriptions; max_description_length = 10) + @test updated_schema["parameters"]["properties"]["field1"]["description"] == + "Field 1 de" + + # Test with multiple fields + schema = Dict("parameters" => Dict("properties" => Dict( + "field1" => Dict("type" => "string"), "field2" => Dict("type" => "integer")))) + descriptions = Dict(:field1 => "Field 1 description", :field2 => "Field 2 description") + updated_schema = update_schema_descriptions!(schema, descriptions) + @test updated_schema["parameters"]["properties"]["field1"]["description"] == + "Field 1 description" + @test updated_schema["parameters"]["properties"]["field2"]["description"] == + "Field 2 description" + + # Test with missing field in descriptions + schema = Dict("parameters" => Dict("properties" => Dict( + "field1" => Dict("type" => "string"), "field2" => Dict("type" => "integer")))) + descriptions = Dict(:field1 => "Field 1 description") + updated_schema = update_schema_descriptions!(schema, descriptions) + @test updated_schema["parameters"]["properties"]["field1"]["description"] == + "Field 1 description" + @test !haskey(updated_schema["parameters"]["properties"]["field2"], "description") +end + @testset "function_call_signature" begin "Some docstring" struct MyMeasurement2 @@ -317,7 +389,7 @@ end height::Union{Int, Nothing} weight::Union{Nothing, Float64} end - output = function_call_signature(MyMeasurement2)#|> JSON3.pretty + output, t = function_call_signature(MyMeasurement2)#|> JSON3.pretty expected_output = Dict{String, Any}("name" => "MyMeasurement2_extractor", "parameters" => Dict{String, Any}( "properties" => Dict{String, Any}( @@ -333,7 +405,7 @@ end @test output == expected_output ## MaybeWraper name cleanup - schema = function_call_signature(MaybeExtract{MyMeasurement2}) + schema, t = function_call_signature(MaybeExtract{MyMeasurement2}) @test schema["name"] == "MaybeExtractMyMeasurement2_extractor" ## Test with strict = true @@ -346,7 +418,7 @@ end end # Test with strict = nothing (default behavior) - output_default = function_call_signature(MyMeasurement3) + output_default, t = function_call_signature(MyMeasurement3) @test !haskey(output_default, "strict") @test output_default["name"] == "MyMeasurement3_extractor" @test output_default["parameters"]["type"] == "object" @@ -354,7 +426,7 @@ end @test !haskey(output_default["parameters"], "additionalProperties") # Test with strict =false - output_not_strict = function_call_signature(MyMeasurement3; strict = false) + output_not_strict, t = function_call_signature(MyMeasurement3; strict = false) @test haskey(output_not_strict, "strict") @test output_not_strict["strict"] == false @test output_not_strict["name"] == "MyMeasurement3_extractor" @@ -363,7 +435,7 @@ end @test !haskey(output_default["parameters"], "additionalProperties") # Test with strict = true - output_strict = function_call_signature(MyMeasurement3; strict = true) + output_strict, t = function_call_signature(MyMeasurement3; strict = true) @test output_strict["strict"] == true @test output_strict["name"] == "MyMeasurement3_extractor" @test output_strict["parameters"]["type"] == "object" @@ -374,10 +446,58 @@ end @test output_strict["parameters"]["properties"]["age"]["type"] == "integer" # Test with MaybeExtract wrapper - output_maybe = function_call_signature(MaybeExtract{MyMeasurement3}; strict = true) + output_maybe, t = function_call_signature(MaybeExtract{MyMeasurement3}; strict = true) @test output_maybe["name"] == "MaybeExtractMyMeasurement3_extractor" @test output_maybe["parameters"]["properties"]["result"]["type"] == ["object", "null"] @test output_maybe["parameters"]["properties"]["error"]["type"] == "boolean" @test output_maybe["parameters"]["properties"]["message"]["type"] == ["string", "null"] @test Set(output_maybe["parameters"]["required"]) == Set(["result", "error", "message"]) -end \ No newline at end of file + + #### Test with generated structs and with descriptions + # Test with simple fields + fields = [:field1 => Int, :field2 => String] + schema, datastructtype = function_call_signature(fields) + @test haskey(schema, "name") + @test haskey(schema, "parameters") + @test haskey(schema["parameters"], "properties") + @test haskey(schema["parameters"]["properties"], "field1") + @test haskey(schema["parameters"]["properties"], "field2") + @test schema["parameters"]["properties"]["field1"]["type"] == "integer" + @test schema["parameters"]["properties"]["field2"]["type"] == "string" + + # Test with strict mode + fields = [:field1 => Int, :field2 => String] + schema, datastructtype = function_call_signature(fields; strict = true) + @test schema["strict"] == true + + # Test with descriptions and max_description_length + fields = [ + :field1 => Int, :field2 => String, :field1__description => "Field 1 description", + :field2__description => "Field 2 description"] + schema, datastructtype = function_call_signature(fields; max_description_length = 7) + @test haskey(schema, "name") + @test haskey(schema, "parameters") + @test haskey(schema["parameters"], "properties") + @test haskey(schema["parameters"]["properties"], "field1") + @test haskey(schema["parameters"]["properties"], "field2") + @test schema["parameters"]["properties"]["field1"]["type"] == "integer" + @test schema["parameters"]["properties"]["field2"]["type"] == "string" + @test schema["parameters"]["properties"]["field1"]["description"] == "Field 1" + @test schema["parameters"]["properties"]["field2"]["description"] == "Field 2" + + # Test with empty fields + fields = [] + schema, datastructtype = function_call_signature(fields) + @test haskey(schema, "name") + @test haskey(schema, "parameters") + @test haskey(schema["parameters"], "properties") + @test isempty(schema["parameters"]["properties"]) + + # Test with invalid field specification + fields = [:field1 => Int, :field2 => :InvalidType] + @test_throws ErrorException function_call_signature(fields) + fields = ["field1" => Int] + @test_throws ErrorException function_call_signature(fields) + fields = ["field1", "field2"] # caught earlier as an error so assertion error + @test_throws AssertionError function_call_signature(fields) +end diff --git a/test/llm_openai.jl b/test/llm_openai.jl index 56fd8dd8f..055f15ba4 100644 --- a/test/llm_openai.jl +++ b/test/llm_openai.jl @@ -4,7 +4,8 @@ using PromptingTools: UserMessage, UserMessageWithImages, DataMessage using PromptingTools: CustomProvider, CustomOpenAISchema, MistralOpenAISchema, MODEL_EMBEDDING, MODEL_IMAGE_GENERATION -using PromptingTools: encode_choices, decode_choices, response_to_message, call_cost +using PromptingTools: encode_choices, decode_choices, response_to_message, call_cost, + isextracted @testset "render-OpenAI" begin schema = OpenAISchema() @@ -629,6 +630,14 @@ end @test msg.content == RandomType1235(1) @test msg.log_prob ≈ -0.9 + ## Test with field descriptions + fields = [:x => Int, :x__description => "Field 1 description"] + msg = aiextract(schema1, "Extract number 1"; return_type = fields, + model = "gpt4", + api_kwargs = (; temperature = 0, n = 2)) + @test isextracted(msg.content) + @test msg.content.x == 1 + ## Test multiple samples -- mock_choice is less probable mock_choice2 = Dict( :message => Dict(:content => "Hello!",