Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-message types for multi-modal content #259

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions src/llm_openai.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,95 @@ function render(schema::AbstractOpenAISchema,
throw(ArgumentError("Function `render` is not implemented for the provided schema ($(typeof(schema))) and $(typeof(tool))."))
end

"""
render(schema::AbstractOpenAISchema, msg::AIMultiMessage; kwargs...)

Render an AIMultiMessage for the OpenAI API schema. Combines multiple content blocks
into a single message with appropriate formatting for each content type.
"""
function render(schema::AbstractOpenAISchema,
msg::AIMultiMessage;
image_detail::AbstractString = "auto",
kwargs...)
@assert image_detail in ["auto", "high", "low"] "Image detail must be one of: auto, high, low"

# Initialize the content array for multi-modal content
content = []

# Process each content block
for block in msg.content
if block isa TextContent
push!(content, Dict("type" => "text",
"text" => block.text))
elseif block isa ImageContent
push!(content, Dict("type" => "image_url",
"image_url" => Dict("url" => block.url,
"detail" => image_detail)))
elseif block isa AudioContent
# Currently, OpenAI doesn't support audio in messages directly
# Convert to text description or URL reference
push!(content, Dict("type" => "text",
"text" => "Audio content: $(block.url)"))
elseif block isa DataContent
# Convert data content to string representation
push!(content, Dict("type" => "text",
"text" => string(block.data)))
end
end

Dict("role" => role4render(schema, msg),
"content" => content)
end

"""
render(schema::AbstractOpenAISchema, msg::UserMultiMessage; kwargs...)

Render a UserMultiMessage for the OpenAI API schema. Handles both content blocks
and tool messages in the appropriate format.
"""
function render(schema::AbstractOpenAISchema,
msg::UserMultiMessage;
image_detail::AbstractString = "auto",
kwargs...)
@assert image_detail in ["auto", "high", "low"] "Image detail must be one of: auto, high, low"

# Initialize the message dictionary
message = Dict{String, Any}("role" => role4render(schema, msg))

# Process content blocks similar to AIMultiMessage
content = []
for block in msg.content
if block isa TextContent
push!(content, Dict("type" => "text",
"text" => block.text))
elseif block isa ImageContent
push!(content, Dict("type" => "image_url",
"image_url" => Dict("url" => block.url,
"detail" => image_detail)))
elseif block isa AudioContent
push!(content, Dict("type" => "text",
"text" => "Audio content: $(block.url)"))
elseif block isa DataContent
push!(content, Dict("type" => "text",
"text" => string(block.data)))
end
end
message["content"] = content

# Add tools if present
if !isempty(msg.tools)
message["tool_calls"] = [
Dict("id" => tool.tool_call_id,
"type" => "function",
"function" => Dict("name" => tool.name,
"arguments" => tool.raw))
for tool in msg.tools
]
end

message
end

"""
response_to_message(schema::AbstractOpenAISchema,
MSG::Type{AIMessage},
Expand Down
105 changes: 105 additions & 0 deletions src/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,111 @@
abstract type AbstractMessage end
abstract type AbstractChatMessage <: AbstractMessage end # with text-based content
abstract type AbstractDataMessage <: AbstractMessage end # with data-based content, eg, embeddings

# Content block types for multi-modal messages
abstract type AbstractContentBlock end

"""
TextContent <: AbstractContentBlock

A content block type for text-based content.

# Fields
- `text::String`: The text content
"""
struct TextContent <: AbstractContentBlock
text::String
end

"""
ImageContent <: AbstractContentBlock

A content block type for image-based content.

# Fields
- `url::String`: The URL of the image
"""
struct ImageContent <: AbstractContentBlock
url::String
end

"""
AudioContent <: AbstractContentBlock

A content block type for audio-based content.

# Fields
- `url::String`: The URL of the audio file
"""
struct AudioContent <: AbstractContentBlock
url::String
end

"""
DataContent <: AbstractContentBlock

A content block type for data-based content (e.g., embeddings).

# Fields
- `data::Any`: The data content
"""
struct DataContent <: AbstractContentBlock
data::Any
end

"""
AIMultiMessage <: AbstractChatMessage

A message type for AI-generated responses that can contain multiple content types.
Extends AbstractChatMessage to support multi-modal responses.

# Fields
- `content::Vector{AbstractContentBlock}`: Vector of content blocks (text, images, audio, data)
- `status::Union{Int, Nothing}`: The status of the message from the API
- `tokens::Tuple{Int, Int}`: The number of tokens used (prompt, completion)
- `elapsed::Float64`: The time taken to generate the response in seconds
- `cost::Union{Nothing, Float64}`: The cost of the API call
- `log_prob::Union{Nothing, Float64}`: The log probability of the response
- `extras::Union{Nothing, Dict{Symbol, Any}}`: Additional metadata
- `finish_reason::Union{Nothing, String}`: The reason the response was finished
- `run_id::Union{Nothing, Int}`: The unique ID of the run
- `sample_id::Union{Nothing, Int}`: The unique ID of the sample
- `name::Union{Nothing, String}`: The name of the role in the conversation
"""
Base.@kwdef mutable struct AIMultiMessage <: AbstractChatMessage
content::Vector{AbstractContentBlock}
status::Union{Int, Nothing} = nothing
tokens::Tuple{Int, Int} = (0, 0)
elapsed::Float64 = 0.0
cost::Union{Nothing, Float64} = nothing
log_prob::Union{Nothing, Float64} = nothing
extras::Union{Nothing, Dict{Symbol, Any}} = nothing
finish_reason::Union{Nothing, String} = nothing
run_id::Union{Nothing, Int} = Int(rand(Int16))
sample_id::Union{Nothing, Int} = nothing
name::Union{Nothing, String} = nothing
_type::Symbol = :aimultimessage
end

"""
UserMultiMessage <: AbstractChatMessage

A message type for user-generated responses that can contain multiple content types.
Extends AbstractChatMessage to support multi-modal inputs and tool specifications.

# Fields
- `content::Vector{AbstractContentBlock}`: Vector of content blocks (text, images, audio, data)
- `tools::Vector{ToolMessage}`: Vector of tool messages that AI can use
- `run_id::Union{Nothing, Int}`: The unique ID of the run
- `name::Union{Nothing, String}`: The name of the role in the conversation
"""
Base.@kwdef mutable struct UserMultiMessage <: AbstractChatMessage
content::Vector{AbstractContentBlock}
tools::Vector{ToolMessage} = ToolMessage[]
run_id::Union{Nothing, Int} = Int(rand(Int16))
name::Union{Nothing, String} = nothing
_type::Symbol = :usermultimessage
end
"""
AbstractAnnotationMessage

Expand Down
86 changes: 86 additions & 0 deletions test/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,89 @@ end
@test occursin("TracerMessageLike with:", pprint_output)
@test occursin("Test Message", pprint_output)
end

@testset "MultiMessage Types" begin
# Test AIMultiMessage
text_content = TextContent("Hello, AI!")
image_content = ImageContent("https://example.com/image.jpg")
audio_content = AudioContent("https://example.com/audio.mp3")
data_content = DataContent([1, 2, 3])

# Test constructor and type inheritance
ai_msg = AIMultiMessage(content=[text_content, image_content])
@test ai_msg isa AbstractChatMessage
@test ai_msg isa AIMultiMessage
@test length(ai_msg.content) == 2
@test ai_msg.content[1] isa TextContent
@test ai_msg.content[2] isa ImageContent

# Test all content types
ai_msg = AIMultiMessage(content=[text_content, image_content, audio_content, data_content])
@test length(ai_msg.content) == 4
@test ai_msg.content[1].text == "Hello, AI!"
@test ai_msg.content[2].url == "https://example.com/image.jpg"
@test ai_msg.content[3].url == "https://example.com/audio.mp3"
@test ai_msg.content[4].data == [1, 2, 3]

# Test metadata fields
ai_msg = AIMultiMessage(
content=[text_content],
status=200,
tokens=(10, 20),
elapsed=1.5,
cost=0.001,
extras=Dict(:key => "value")
)
@test ai_msg.status == 200
@test ai_msg.tokens == (10, 20)
@test ai_msg.elapsed == 1.5
@test ai_msg.cost == 0.001
@test ai_msg.extras[:key] == "value"

# Test UserMultiMessage
user_msg = UserMultiMessage(content=[text_content])
@test user_msg isa AbstractChatMessage
@test user_msg isa UserMultiMessage
@test length(user_msg.content) == 1
@test isempty(user_msg.tools)

# Test with tools
tool_msg = ToolMessage(
tool_call_id="1",
name="test_tool",
raw="test args",
content="test output"
)
user_msg = UserMultiMessage(
content=[text_content, image_content],
tools=[tool_msg]
)
@test length(user_msg.content) == 2
@test length(user_msg.tools) == 1
@test user_msg.tools[1].tool_call_id == "1"
@test user_msg.tools[1].name == "test_tool"

# Test show methods
io = IOBuffer()
show(io, MIME("text/plain"), ai_msg)
output = String(take!(io))
@test occursin("AIMultiMessage", output)
@test occursin("Hello, AI!", output)

show(io, MIME("text/plain"), user_msg)
output = String(take!(io))
@test occursin("UserMultiMessage", output)
@test occursin("Hello, AI!", output)

# Test pprint methods
pprint(io, ai_msg)
output = String(take!(io))
@test occursin("AI Multi Message", output)
@test occursin("Hello, AI!", output)

pprint(io, user_msg)
output = String(take!(io))
@test occursin("User Multi Message", output)
@test occursin("Hello, AI!", output)
@test occursin("test_tool", output)
end
Loading