Skip to content

Commit

Permalink
Add a new models, clean up tools (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp authored Oct 20, 2024
1 parent f945ab9 commit 5280ee8
Show file tree
Hide file tree
Showing 10 changed files with 408 additions and 22 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Extends support for structured extraction with multiple "tools" definitions (see `?aiextract`).
- Added new primitives `Tool` (to re-use tool definitions) and a function `aitools` to support mixed structured and non-structured workflows, eg, agentic workflows (see `?aitools`).
- Added a field `name` to `AbstractChatMessage` and `AIToolRequest` messages to enable role-based workflows.
- Added a support for partial argument execution with `execute_tool` function (provide your own context to override the arg values).
- Added support for [SambaNova](https://sambanova.ai/) hosted models (set your ENV `SAMBANOVA_API_KEY`).
- Added many new models from Mistral, Groq, Sambanova, OpenAI.

### Updated
- Renamed `function_call_signature` to `tool_call_signature` to better reflect that it's used for tools, but kept a link to the old name for back-compatibility.
Expand Down
102 changes: 93 additions & 9 deletions src/extraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,41 @@ function set_properties_strict!(parameters::AbstractDict)
return parameters
end

"""
remove_field!(parameters::AbstractDict, field::AbstractString)
Utility to remove a specific top-level field from the parameters (and the `required` list if present) of the JSON schema.
"""
function remove_field!(parameters::AbstractDict, field::AbstractString)
if haskey(parameters, "properties") && haskey(parameters["properties"], field)
delete!(parameters["properties"], field)
end
if haskey(parameters, "required") && field in parameters["required"]
filter!(x -> x != field, parameters["required"])
end
return parameters
end

function remove_field!(parameters::AbstractDict, pattern::Regex)
if haskey(parameters, "properties")
for (key, value) in parameters["properties"]
if occursin(pattern, key)
delete!(parameters["properties"], key)
end
end
end
if haskey(parameters, "required")
filter!(x -> !occursin(pattern, x), parameters["required"])
end
return parameters
end

"""
tool_call_signature(
type_or_method::Union{Type, Method}; strict::Union{Nothing, Bool} = nothing,
max_description_length::Int = 200, name::Union{Nothing, String} = nothing,
docs::Union{Nothing, String} = nothing)
docs::Union{Nothing, String} = nothing, hidden_fields::AbstractVector{<:Union{
AbstractString, Regex}} = String[])
Extract the argument names, types and docstrings from a struct to create the function call signature in JSON schema.
Expand All @@ -356,9 +386,10 @@ Note: Fairly experimental, but works for combination of structs, arrays, strings
- `max_description_length::Int`: Maximum length for descriptions. Defaults to 200.
- `name::Union{Nothing, String}`: The name of the tool. Defaults to the name of the struct.
- `docs::Union{Nothing, String}`: The description of the tool. Defaults to the docstring of the struct/overall function.
- `hidden_fields::AbstractVector{<:Union{AbstractString, Regex}}`: A list of fields to hide from the LLM (eg, `["ctx_user_id"]` or `r"ctx"`).
# Returns
- `Dict{String, Any}`: A dictionary representing the function call signature schema.
- `Dict{String, Tool}`: A dictionary representing the function call signature schema.
# Tips
- You can improve the quality of the extraction by writing a helpful docstring for your struct (or any nested struct). It will be provided as a description.
Expand Down Expand Up @@ -421,11 +452,18 @@ msg = aiextract("Extract measurements from the text: I am giraffe", type)
# :error => true
```
That way, you can handle the error gracefully and get a reason why extraction failed.
You can also hide certain fields in your function call signature with Strings or Regex patterns (eg, `r"ctx"`).
```
tool_map = tool_call_signature(MyMeasurement; hidden_fields = ["ctx_user_id"])
```
"""
function tool_call_signature(
type_or_method::Union{Type, Method}; strict::Union{Nothing, Bool} = nothing,
max_description_length::Int = 200, name::Union{Nothing, String} = nothing,
docs::Union{Nothing, String} = nothing)
docs::Union{Nothing, String} = nothing, hidden_fields::AbstractVector{<:Union{
AbstractString, Regex}} = String[])
## Asserts
if type_or_method isa Type && !isstructtype(type_or_method)
error("Only Structs are supported (provided type: $type_or_method)")
Expand Down Expand Up @@ -460,6 +498,12 @@ function tool_call_signature(
end
end
call_type = type_or_method isa Type ? type_or_method : get_function(type_or_method)
## Remove hidden fields
if !isempty(hidden_fields)
for field in hidden_fields
remove_field!(schema["parameters"], field)
end
end
tool = Tool(; name = schema["name"], parameters = schema["parameters"],
description = haskey(schema, "description") ? schema["description"] : nothing,
strict = haskey(schema, "strict") ? schema["strict"] : nothing,
Expand Down Expand Up @@ -646,19 +690,59 @@ function parse_tool(tool::AbstractTool, input::Union{AbstractString, AbstractDic
end

"""
execute_tool(f::Function, args::AbstractDict)
execute_tool(f::Function, args::AbstractDict{Symbol, <:Any},
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}())
Executes a function with the provided arguments.
Dictionary is un-ordered, so we need to sort the arguments first and then pass them to the function.
# Arguments
- `f::Function`: The function to execute.
- `args::AbstractDict{Symbol, <:Any}`: The arguments to pass to the function.
- `context::AbstractDict{Symbol, <:Any}`: Optional context to pass to the function, it will prioritized to get the argument values from.
# Example
```julia
my_function(x, y) = x + y
execute_tool(my_function, Dict(:x => 1, :y => 2))
```
```julia
get_weather(date, location) = "The weather in \$location on \$date is 70 degrees."
tool_map = PT.tool_call_signature(get_weather)
msg = aitools("What's the weather in Tokyo on May 3rd, 2023?";
tools = collect(values(tool_map)))
PT.execute_tool(tool_map, PT.tool_calls(msg)[1])
# "The weather in Tokyo on 2023-05-03 is 70 degrees."
```
"""
function execute_tool(f::Function, args::AbstractDict)
args_sorted = [args[arg]
for arg in get_arg_names(f) if haskey(args, arg)]
function execute_tool(f::Function, args::AbstractDict{Symbol, <:Any},
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}())
args_sorted = []
for arg in get_arg_names(f)
if haskey(context, arg)
push!(args_sorted, context[arg])
elseif haskey(args, arg)
push!(args_sorted, args[arg])
end
end
return f(args_sorted...)
end
function execute_tool(tool::AbstractTool, args::AbstractDict)
return execute_tool(tool.callable, args)
function execute_tool(tool::AbstractTool, args::AbstractDict{Symbol, <:Any},
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}())
return execute_tool(tool.callable, args, context)
end
function execute_tool(tool::AbstractTool, msg::ToolMessage,
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}())
return execute_tool(tool.callable, msg.args, context)
end
function execute_tool(tool_map::AbstractDict{String, <:AbstractTool}, msg::ToolMessage,
context::AbstractDict{Symbol, <:Any} = Dict{Symbol, Any}())
tool = tool_map[msg.name]
return execute_tool(tool, msg, context)
end

"""
Expand Down
14 changes: 14 additions & 0 deletions src/llm_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,20 @@ Requires one environment variable to be set:
"""
struct CerebrasOpenAISchema <: AbstractOpenAISchema end

"""
SambaNovaOpenAISchema
Schema to call the [SambaNova](https://sambanova.ai/) API.
Links:
- [Get your API key](https://cloud.sambanova.ai/apis)
- [API Reference](https://community.sambanova.ai/c/docs)
Requires one environment variable to be set:
- `SAMBANOVA_API_KEY`: Your API key
"""
struct SambaNovaOpenAISchema <: AbstractOpenAISchema end

abstract type AbstractOllamaSchema <: AbstractPromptSchema end

"""
Expand Down
21 changes: 14 additions & 7 deletions src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -829,8 +829,10 @@ function response_to_message(schema::AbstractOpenAISchema,
# "Safe" parsing of the response - it still fails if JSON is invalid
tools_array = if json_mode == true
name, tool = only(tool_map)
content_blob = choice[:message][:content]
content_obj = content_blob isa String ? JSON3.read(content_blob) : content_blob
[parse_tool(
tool.callable, choice[:message][:content])]
tool.callable, content_obj)]
else
## If name does not match, we use the callable from the tool_map
## Can happen only in testing with auto-generated struct
Expand Down Expand Up @@ -1472,10 +1474,13 @@ function response_to_message(schema::AbstractOpenAISchema,
!isempty(choice[:message][:tool_calls])
tools_array = if json_mode == true
tool_name, tool = only(tool_map)
## Note, JSON mode doesn't have tool_call_id so we mock it
content_blob = choice[:message][:content]
[ToolMessage(;
content = nothing, req_id = run_id, tool_call_id = choice[:id],
raw = JSON3.write(choice[:message][:content]),
args = choice[:message][:content], name = tool_name)]
content = nothing, req_id = run_id, tool_call_id = string("call_", run_id),
raw = content_blob isa String ? content_blob : JSON3.write(content_blob),
args = content_blob isa String ? JSON3.read(content_blob) : content_blob,
name = tool_name)]
elseif has_tools
[ToolMessage(; raw = tool_call[:function][:arguments],
args = JSON3.read(tool_call[:function][:arguments]),
Expand All @@ -1488,7 +1493,9 @@ function response_to_message(schema::AbstractOpenAISchema,
else
ToolMessage[]
end
content = json_mode != true ? choice[:message][:content] : nothing
## Check if content key was provided (not required for tool calls)
content = json_mode != true && haskey(choice[:message], :content) ?
choice[:message][:content] : nothing
## Remember the tools
extras = Dict{Symbol, Any}()
if has_tools
Expand Down Expand Up @@ -1631,7 +1638,7 @@ function aitools(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYP
##
global MODEL_ALIASES
## Function calling specifics // get the tool map (signatures)
## Set strict mode on for JSON mode
## Set strict mode on for JSON mode as Structured outputs
strict_ = json_mode == true ? true : strict
tool_map = tool_call_signature(tools; strict = strict_)
tools = render(prompt_schema, tool_map; json_mode)
Expand All @@ -1653,7 +1660,7 @@ function aitools(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYP
@assert length(tools)==1 "Only 1 tool definition is allowed in JSON mode."
(; [k => v for (k, v) in pairs(api_kwargs) if k != :tool_choice]...,
response_format = (;
type = "json_schema", json_schema = only(tools)))
type = "json_schema", json_schema = only(tools)[:function]))
elseif isempty(tools)
api_kwargs
else
Expand Down
9 changes: 9 additions & 0 deletions src/llm_openai_schema_defs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,15 @@ function OpenAI.create_chat(schema::CerebrasOpenAISchema,
api_key = isempty(CEREBRAS_API_KEY) ? api_key : CEREBRAS_API_KEY
OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...)
end
function OpenAI.create_chat(schema::SambaNovaOpenAISchema,
api_key::AbstractString,
model::AbstractString,
conversation;
url::String = "https://api.sambanova.ai/v1",
kwargs...)
api_key = isempty(SAMBANOVA_API_KEY) ? api_key : SAMBANOVA_API_KEY
OpenAI.create_chat(CustomOpenAISchema(), api_key, model, conversation; url, kwargs...)
end
function OpenAI.create_chat(schema::DatabricksOpenAISchema,
api_key::AbstractString,
model::AbstractString,
Expand Down
44 changes: 44 additions & 0 deletions src/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,21 @@ Base.@kwdef struct DataMessage{T <: Any} <: AbstractDataMessage
_type::Symbol = :datamessage
end

"""
ToolMessage
A message type for tool calls.
It represents both the request (fields `args`, `name`) and the response (field `content`).
# Fields
- `content::Any`: The content of the message.
- `req_id::Union{Nothing, Int}`: The unique ID of the request.
- `tool_call_id::String`: The unique ID of the tool call.
- `raw::AbstractString`: The raw JSON string of the tool call request.
- `args::Union{Nothing, Dict{Symbol, Any}}`: The arguments of the tool call request.
- `name::Union{Nothing, String}`: The name of the tool call request.
"""
Base.@kwdef mutable struct ToolMessage <: AbstractDataMessage
content::Any = nothing
req_id::Union{Nothing, Int} = nothing
Expand All @@ -170,6 +185,30 @@ Base.@kwdef mutable struct ToolMessage <: AbstractDataMessage
_type::Symbol = :toolmessage
end

"""
AIToolRequest
A message type for AI-generated tool requests.
Returned by `aitools` functions.
# Fields
- `content::Union{AbstractString, Nothing}`: The content of the message.
- `tool_calls::Vector{ToolMessage}`: The vector of tool call requests.
- `name::Union{Nothing, String}`: The name of the `role` in the conversation.
- `status::Union{Int, Nothing}`: The status of the message from the API.
- `tokens::Tuple{Int, Int}`: The number of tokens used (prompt,completion).
- `elapsed::Float64`: The time taken to generate the response in seconds.
- `cost::Union{Nothing, Float64}`: The cost of the API call (calculated with information from `MODEL_REGISTRY`).
- `log_prob::Union{Nothing, Float64}`: The log probability of the response.
- `extras::Union{Nothing, Dict{Symbol, Any}}`: A dictionary for additional metadata that is not part of the key message fields. Try to limit to a small number of items and singletons to be serializable.
- `finish_reason::Union{Nothing, String}`: The reason the response was finished.
- `run_id::Union{Nothing, Int}`: The unique ID of the run.
- `sample_id::Union{Nothing, Int}`: The unique ID of the sample (if multiple samples are generated, they will all have the same `run_id`).
See `ToolMessage` for the fields of the tool call requests.
See also: [`tool_calls`](@ref), [`execute_tool`](@ref), [`parse_tool`](@ref)
"""
Base.@kwdef struct AIToolRequest{T <: Union{AbstractString, Nothing}} <: AbstractDataMessage
content::T = nothing
tool_calls::Vector{ToolMessage} = ToolMessage[]
Expand All @@ -185,6 +224,11 @@ Base.@kwdef struct AIToolRequest{T <: Union{AbstractString, Nothing}} <: Abstrac
sample_id::Union{Nothing, Int} = nothing
_type::Symbol = :aitoolrequest
end
"Get the vector of tool call requests from an AIToolRequest/message."
tool_calls(msg::AIToolRequest) = msg.tool_calls
tool_calls(msg::AbstractMessage) = ToolMessage[]
tool_calls(msg::ToolMessage) = [msg]
tool_calls(msg::AbstractTracerMessage) = tool_calls(msg.object)

### Other Message methods
# content-only constructor
Expand Down
Loading

0 comments on commit 5280ee8

Please sign in to comment.