diff --git a/.gitignore b/.gitignore index a71aa4928..c02992ac1 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,7 @@ docs/package-lock.json # Ignore Cursor rules -.cursorrules \ No newline at end of file +.cursorrules + +# Ignore any local preferences +**/LocalPreferences.toml \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index b303f1007..90a27cca7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added a new Gemini Experimental model from November 2024 (`gemini-exp-1121` with alias `gemexp`). +- Added a new `AnnotationMessage` type for keeping human-only information in the message changes. See `?annotate!` on how to use it. +- Added a new `ConversationMemory` type to enable long multi-turn conversations with a truncated memory of the conversation history. Truncation works in "batches" to not prevent caching. See `?ConversationMemory` and `get_last` for more information. + ### Updated - Changed the ENV variable for MistralAI API from `MISTRALAI_API_KEY` to `MISTRAL_API_KEY` to be compatible with the Mistral docs. diff --git a/src/PromptingTools.jl b/src/PromptingTools.jl index cc7cce720..214cc91e3 100644 --- a/src/PromptingTools.jl +++ b/src/PromptingTools.jl @@ -67,9 +67,13 @@ include("user_preferences.jl") ## Conversation history / Prompt elements export AIMessage -# export UserMessage, UserMessageWithImages, SystemMessage, DataMessage # for debugging only include("messages.jl") +# export ConversationMemory +include("memory.jl") +# export annotate! +include("annotation.jl") + export aitemplates, AITemplate include("templates.jl") diff --git a/src/annotation.jl b/src/annotation.jl new file mode 100644 index 000000000..f1597381e --- /dev/null +++ b/src/annotation.jl @@ -0,0 +1,40 @@ +""" + annotate!(messages::AbstractVector{<:AbstractMessage}, content; kwargs...) + annotate!(message::AbstractMessage, content; kwargs...) + +Add an annotation message to a vector of messages or wrap a single message in a vector with an annotation. +The annotation is always inserted after any existing annotation messages. + +# Arguments +- `messages`: Vector of messages or single message to annotate +- `content`: Content of the annotation +- `kwargs...`: Additional fields for the AnnotationMessage (extras, tags, comment) + +# Returns +Vector{AbstractMessage} with the annotation message inserted + +# Example +```julia +messages = [SystemMessage("Assistant"), UserMessage("Hello")] +annotate!(messages, "This is important"; tags=[:important], comment="For review") +``` +""" +function annotate!(messages::AbstractVector{T}, content::AbstractString; + kwargs...) where {T <: AbstractMessage} + # Convert to Vector{AbstractMessage} if needed + messages_abstract = T == AbstractMessage ? messages : + convert(Vector{AbstractMessage}, messages) + + # Find last annotation message index + last_anno_idx = findlast(isabstractannotationmessage, messages_abstract) + insert_idx = isnothing(last_anno_idx) ? 1 : last_anno_idx + 1 + + # Create and insert annotation message + anno = AnnotationMessage(; content = content, kwargs...) + insert!(messages_abstract, insert_idx, anno) + return messages_abstract +end + +function annotate!(message::AbstractMessage, content::AbstractString; kwargs...) + return annotate!(AbstractMessage[message], content; kwargs...) +end \ No newline at end of file diff --git a/src/llm_anthropic.jl b/src/llm_anthropic.jl index 7b020b21b..513950d29 100644 --- a/src/llm_anthropic.jl +++ b/src/llm_anthropic.jl @@ -28,10 +28,13 @@ function render(schema::AbstractAnthropicSchema, no_system_message::Bool = false, cache::Union{Nothing, Symbol} = nothing, kwargs...) - ## + ## @assert count(issystemmessage, messages)<=1 "AbstractAnthropicSchema only supports at most 1 System message" @assert (isnothing(cache)||cache in [:system, :tools, :last, :all]) "Currently only `:system`, `:tools`, `:last`, `:all` are supported for Anthropic Prompt Caching" + # Filter out annotation messages before any processing + messages = filter(!isabstractannotationmessage, messages) + system = nothing ## First pass: keep the message types but make the replacements provided in `kwargs` @@ -44,6 +47,8 @@ function render(schema::AbstractAnthropicSchema, for msg in messages_replaced if issystemmessage(msg) system = msg.content + elseif isabstractannotationmessage(msg) + continue elseif isusermessage(msg) || isaimessage(msg) content = msg.content push!(conversation, diff --git a/src/llm_google.jl b/src/llm_google.jl index 64e3568ea..c27fd9213 100644 --- a/src/llm_google.jl +++ b/src/llm_google.jl @@ -25,6 +25,9 @@ function render(schema::AbstractGoogleSchema, no_system_message::Bool = false, kwargs...) ## + # Filter out annotation messages before any processing + messages = filter(!isabstractannotationmessage, messages) + ## First pass: keep the message types but make the replacements provided in `kwargs` messages_replaced = render( NoSchema(), messages; conversation, no_system_message, kwargs...) @@ -34,6 +37,9 @@ function render(schema::AbstractGoogleSchema, # replace any handlebar variables in the messages for msg in messages_replaced + if isabstractannotationmessage(msg) + continue + end push!(conversation, Dict( :role => role4render(schema, msg), :parts => [Dict("text" => msg.content)])) diff --git a/src/llm_interface.jl b/src/llm_interface.jl index 245c20c8e..3f2dddb6a 100644 --- a/src/llm_interface.jl +++ b/src/llm_interface.jl @@ -41,14 +41,21 @@ struct OpenAISchema <: AbstractOpenAISchema end "Echoes the user's input back to them. Used for testing the implementation" @kwdef mutable struct TestEchoOpenAISchema <: AbstractOpenAISchema - response::AbstractDict - status::Integer + response::AbstractDict = Dict( + "choices" => [Dict("message" => Dict("content" => "Test response", "role" => "assistant"), "index" => 0, "finish_reason" => "stop")], + "usage" => Dict("prompt_tokens" => 10, "completion_tokens" => 20, "total_tokens" => 30), + "model" => "gpt-3.5-turbo", + "id" => "test-id", + "object" => "chat.completion", + "created" => 1234567890 + ) + status::Integer = 200 model_id::String = "" inputs::Any = nothing end """ - CustomOpenAISchema + CustomOpenAISchema CustomOpenAISchema() allows user to call any OpenAI-compatible API. diff --git a/src/llm_ollama.jl b/src/llm_ollama.jl index 07166944b..c95a3b6f5 100644 --- a/src/llm_ollama.jl +++ b/src/llm_ollama.jl @@ -27,6 +27,9 @@ function render(schema::AbstractOllamaSchema, no_system_message::Bool = false, kwargs...) ## + # Filter out annotation messages before any processing + messages = filter(!isabstractannotationmessage, messages) + ## First pass: keep the message types but make the replacements provided in `kwargs` messages_replaced = render( NoSchema(), messages; conversation, no_system_message, kwargs...) @@ -36,6 +39,9 @@ function render(schema::AbstractOllamaSchema, # replace any handlebar variables in the messages for msg in messages_replaced + if isabstractannotationmessage(msg) + continue + end new_message = Dict{String, Any}( "role" => role4render(schema, msg), "content" => msg.content) ## Special case for images @@ -376,4 +382,4 @@ end function aitools(prompt_schema::AbstractOllamaSchema, prompt::ALLOWED_PROMPT_TYPE; kwargs...) error("Managed schema does not support aitools. Please use OpenAISchema instead.") -end \ No newline at end of file +end diff --git a/src/llm_ollama_managed.jl b/src/llm_ollama_managed.jl index 669a58296..d26e9eec1 100644 --- a/src/llm_ollama_managed.jl +++ b/src/llm_ollama_managed.jl @@ -40,6 +40,8 @@ function render(schema::AbstractOllamaManagedSchema, system = msg.content elseif msg isa UserMessage prompt = msg.content + elseif isabstractannotationmessage(msg) + continue elseif msg isa UserMessageWithImages error("Managed schema does not support UserMessageWithImages. Please use OpenAISchema instead.") elseif msg isa AIMessage diff --git a/src/llm_openai.jl b/src/llm_openai.jl index 429d86325..3e565c2aa 100644 --- a/src/llm_openai.jl +++ b/src/llm_openai.jl @@ -33,6 +33,10 @@ function render(schema::AbstractOpenAISchema, kwargs...) ## @assert image_detail in ["auto", "high", "low"] "Image detail must be one of: auto, high, low" + + # Filter out annotation messages before any processing + messages = filter(!isabstractannotationmessage, messages) + ## First pass: keep the message types but make the replacements provided in `kwargs` messages_replaced = render( NoSchema(), messages; conversation, no_system_message, kwargs...) @@ -71,6 +75,8 @@ function render(schema::AbstractOpenAISchema, content = msg.content isa AbstractString ? msg.content : string(msg.content) Dict("role" => role4render(schema, msg), "content" => content, "tool_call_id" => msg.tool_call_id) + elseif isabstractannotationmessage(msg) + continue else ## Vanilla assistant message Dict("role" => role4render(schema, msg), @@ -1733,4 +1739,4 @@ function aitools(prompt_schema::AbstractOpenAISchema, prompt::ALLOWED_PROMPT_TYP kwargs...) return output -end \ No newline at end of file +end diff --git a/src/llm_shared.jl b/src/llm_shared.jl index fcb6e5702..831281214 100644 --- a/src/llm_shared.jl +++ b/src/llm_shared.jl @@ -8,6 +8,7 @@ role4render(schema::AbstractPromptSchema, msg::UserMessageWithImages) = "user" role4render(schema::AbstractPromptSchema, msg::AIMessage) = "assistant" role4render(schema::AbstractPromptSchema, msg::AIToolRequest) = "assistant" role4render(schema::AbstractPromptSchema, msg::ToolMessage) = "tool" +role4render(schema::AbstractPromptSchema, msg::AbstractAnnotationMessage) = "annotation" """ render(schema::NoSchema, messages::Vector{<:AbstractMessage}; @@ -39,6 +40,9 @@ function render(schema::NoSchema, count_system_msg = count(issystemmessage, conversation) # TODO: concat multiple system messages together (2nd pass) + # Filter out annotation messages from input messages + messages = filter(!isabstractannotationmessage, messages) + # replace any handlebar variables in the messages for msg in messages if issystemmessage(msg) || isusermessage(msg) || isusermessagewithimages(msg) @@ -73,6 +77,9 @@ function render(schema::NoSchema, count_system_msg += 1 # move to the front pushfirst!(conversation, msg) + elseif isabstractannotationmessage(msg) + # Ignore annotation messages + continue else # Note: Ignores any DataMessage or other types for the prompt/conversation history @warn "Unexpected message type: $(typeof(msg)). Skipping." diff --git a/src/llm_sharegpt.jl b/src/llm_sharegpt.jl index 24ed65733..01ad9ad22 100644 --- a/src/llm_sharegpt.jl +++ b/src/llm_sharegpt.jl @@ -9,7 +9,7 @@ end function render(schema::AbstractShareGPTSchema, conv::AbstractVector{<:AbstractMessage}) Dict("conversations" => [Dict("from" => role4render(schema, msg), "value" => msg.content) - for msg in conv]) + for msg in conv if !isabstractannotationmessage(msg)]) end ### AI Functions diff --git a/src/llm_tracer.jl b/src/llm_tracer.jl index 352dd7de0..33b920d8e 100644 --- a/src/llm_tracer.jl +++ b/src/llm_tracer.jl @@ -16,6 +16,9 @@ end function role4render(schema::AbstractTracerSchema, msg::AIMessage) role4render(schema.schema, msg) end +function role4render(schema::AbstractTracerSchema, msg::AbstractAnnotationMessage) + role4render(schema.schema, msg) +end """ render(tracer_schema::AbstractTracerSchema, conv::AbstractVector{<:AbstractMessage}; kwargs...) diff --git a/src/memory.jl b/src/memory.jl new file mode 100644 index 000000000..f84ebc166 --- /dev/null +++ b/src/memory.jl @@ -0,0 +1,308 @@ +""" + ConversationMemory + +A structured container for managing conversation history. It has only one field `:conversation` +which is a vector of `AbstractMessage`s. It's built to support intelligent truncation and caching +behavior (`get_last`). + +You can also use it as a functor to have extended conversations (easier than constantly passing `conversation` kwarg) + +# Examples + +Basic usage +```julia +mem = ConversationMemory() +push!(mem, SystemMessage("You are a helpful assistant")) +push!(mem, UserMessage("Hello!")) +push!(mem, AIMessage("Hi there!")) + +# or simply +mem = ConversationMemory(conv) +``` + +Check memory stats +```julia +println(mem) # ConversationMemory(2 messages) - doesn't count system message +@show length(mem) # 3 - counts all messages +@show last_message(mem) # gets last message +@show last_output(mem) # gets last content +``` + +Get recent messages with different options (System message, User message, ... + the most recent) +```julia +recent = get_last(mem, 5) # get last 5 messages (including system) +recent = get_last(mem, 20, batch_size=10) # align to batches of 10 for caching +recent = get_last(mem, 5, explain=true) # adds truncation explanation +recent = get_last(mem, 5, verbose=true) # prints truncation info +``` + +Append multiple messages at once (with deduplication to keep the memory complete) +```julia +msgs = [ + UserMessage("How are you?"), + AIMessage("I'm good!"; run_id=1), + UserMessage("Great!"), + AIMessage("Indeed!"; run_id=2) +] +append!(mem, msgs) # Will only append new messages based on run_ids etc. +``` + +Use for AI conversations (easier to manage conversations) +```julia +response = mem("Tell me a joke"; model="gpt4o") # Automatically manages context +response = mem("Another one"; last=3, model="gpt4o") # Use only last 3 messages (uses `get_last`) + +# Direct generation from the memory +result = aigenerate(mem) # Generate using full context +``` +""" +Base.@kwdef mutable struct ConversationMemory + conversation::Vector{AbstractMessage} = AbstractMessage[] +end + +""" + show(io::IO, mem::ConversationMemory) + +Display the number of non-system/non-annotation messages in the conversation memory. +""" +function Base.show(io::IO, mem::ConversationMemory) + n_msgs = count( + x -> !issystemmessage(x) && !isabstractannotationmessage(x), mem.conversation) + print(io, "ConversationMemory($(n_msgs) messages)") +end + +""" + length(mem::ConversationMemory) + +Return the number of messages. All of them. +""" +function Base.length(mem::ConversationMemory) + length(mem.conversation) +end + +""" + last_message(mem::ConversationMemory) + +Get the last message in the conversation. +""" +function last_message(mem::ConversationMemory) + last_message(mem.conversation) +end + +""" + last_output(mem::ConversationMemory) + +Get the last AI message in the conversation. +""" +function last_output(mem::ConversationMemory) + last_output(mem.conversation) +end + +function pprint( + io::IO, mem::ConversationMemory; + text_width::Int = displaysize(io)[2]) + pprint(io, mem.conversation; text_width) +end + +""" + get_last(mem::ConversationMemory, n::Integer=20; + batch_size::Union{Nothing,Integer}=nothing, + verbose::Bool=false, + explain::Bool=false) + +Get the last n messages (but including system message) with intelligent batching to preserve caching. + +Arguments: +- n::Integer: Maximum number of messages to return (default: 20) +- batch_size::Union{Nothing,Integer}: If provided, ensures messages are truncated in fixed batches +- verbose::Bool: Print detailed information about truncation +- explain::Bool: Add explanation about truncation in the response + +Returns: +Vector{AbstractMessage} with the selected messages, always including: +1. The system message (if present) +2. First user message +3. Messages up to n, respecting batch_size boundaries + +Once you get your full conversation back, you can use `append!(mem, conversation)` to merge the new messages into the memory. + +# Examples: +```julia +# Basic usage - get last 3 messages +mem = ConversationMemory() +push!(mem, SystemMessage("You are helpful")) +push!(mem, UserMessage("Hello")) +push!(mem, AIMessage("Hi!")) +push!(mem, UserMessage("How are you?")) +push!(mem, AIMessage("I'm good!")) +messages = get_last(mem, 3) + +# Using batch_size for caching efficiency +messages = get_last(mem, 10; batch_size=5) # Aligns to 5-message batches for caching + +# Add explanation about truncation +messages = get_last(mem, 3; explain=true) # Adds truncation note to first AI message so the model knows it's truncated + +# Get verbose output about truncation +messages = get_last(mem, 3; verbose=true) # Prints info about truncation +``` +""" +function get_last(mem::ConversationMemory, n::Integer = 20; + batch_size::Union{Nothing, Integer} = nothing, + verbose::Bool = false, + explain::Bool = false) + messages = mem.conversation + isempty(messages) && return AbstractMessage[] + + # Always include system message and first user message + system_idx = findfirst(issystemmessage, messages) + first_user_idx = findfirst(isusermessage, messages) + + # Initialize result with required messages + result = AbstractMessage[] + if !isnothing(system_idx) + push!(result, messages[system_idx]) + end + if !isnothing(first_user_idx) + push!(result, messages[first_user_idx]) + end + + # Calculate remaining message budget + remaining_budget = n - length(result) + visible_messages = findall( + x -> !issystemmessage(x) && !isabstractannotationmessage(x), messages) + + if remaining_budget > 0 + default_start_idx = max(1, length(visible_messages) - remaining_budget + 1) + start_idx = !isnothing(batch_size) ? + batch_start_index( + length(visible_messages), remaining_budget, batch_size) : default_start_idx + ## find first AIMessage after that (it must be aligned to follow after UserMessage) + valid_idxs = @view(visible_messages[start_idx:end]) + ai_msg_idx = findfirst(isaimessage, @view(messages[valid_idxs])) + !isnothing(ai_msg_idx) && + append!(result, messages[@view(valid_idxs[ai_msg_idx:end])]) + end + + verbose && + @info "ConversationMemory truncated to $(length(result))/$(length(messages)) messages" + + # Add explanation if requested and we truncated messages + if explain && (length(visible_messages) + 1) > length(result) + # Find first AI message in result after required messages + ai_msg_idx = findfirst(x -> isaimessage(x) || isaitoolrequest(x), result) + trunc_count = length(visible_messages) + 1 - length(result) + if !isnothing(ai_msg_idx) + ai_msg = result[ai_msg_idx] + orig_content = ai_msg.content + explanation = "[This is an automatically added explanation to inform you that for efficiency reasons, the user has truncated the preceding $(trunc_count) messages.]\n\n$orig_content" + ai_msg_type = typeof(ai_msg) + result[ai_msg_idx] = ai_msg_type(; + [f => getfield(ai_msg, f) + for f in fieldnames(ai_msg_type) if f != :content]..., + content = explanation) + end + end + + return result +end + +""" + batch_start_index(array_length::Integer, n::Integer, batch_size::Integer) -> Integer + +Compute the starting index for retrieving the most recent data, adjusting in blocks of `batch_size`. +The function accumulates messages until hitting a batch boundary, then jumps to the next batch. + +For example, with n=20 and batch_size=10: +- At length 90-99: returns 80 (allowing accumulation of 11-20 messages) +- At length 100-109: returns 90 (allowing accumulation of 11-20 messages) +- At length 110: returns 100 (resetting to 11 messages) +""" +function batch_start_index(array_length::Integer, n::Integer, batch_size::Integer)::Integer + @assert n>=batch_size "n must be >= batch_size" + # Calculate which batch we're in + batch_number = (array_length - (n - batch_size)) ÷ batch_size + # Calculate the start of the current batch + batch_start = batch_number * batch_size + + # Ensure we don't go before the first element + return max(1, batch_start) +end + +""" + append!(mem::ConversationMemory, msgs::Vector{<:AbstractMessage}) + +Smart append that handles duplicate messages based on run IDs. +Only appends messages that are newer than the latest matching message in memory. +""" +function Base.append!(mem::ConversationMemory, msgs::Vector{<:AbstractMessage}) + isempty(msgs) && return mem + isempty(mem.conversation) && return append!(mem.conversation, msgs) + + # get all messages in mem.conversation with run_id + run_id_indices = findall(x -> hasproperty(x, :run_id), mem.conversation) + + # Search backwards through messages to find matching point + for idx in reverse(eachindex(msgs)) + msg = msgs[idx] + + # Find matching message in memory based on run_id if present + match_idx = if hasproperty(msg, :run_id) + findlast( + m -> hasproperty(m, :run_id) && m.run_id == msg.run_id, @view(mem.conversation[run_id_indices])) + else + findlast(m -> m == msg, mem.conversation) + end + + if !isnothing(match_idx) + # Found match - append everything after this message + (idx + 1 <= length(msgs)) && append!(mem.conversation, msgs[(idx + 1):end]) + return mem + end + end + + @warn "No matching messages found in memory, appending all" + return append!(mem.conversation, msgs) +end + +""" + push!(mem::ConversationMemory, msg::AbstractMessage) + +Add a single message to the conversation memory. +""" +function Base.push!(mem::ConversationMemory, msg::AbstractMessage) + push!(mem.conversation, msg) + return mem +end + +""" + (mem::ConversationMemory)(prompt::AbstractString; last::Union{Nothing,Integer}=nothing, kwargs...) + +Functor interface for direct generation using the conversation memory. +Optionally, specify the number of last messages to include in the context (uses `get_last`). +""" +function (mem::ConversationMemory)( + prompt::AbstractString; last::Union{Nothing, Integer} = nothing, kwargs...) + # Get conversation context + context = isnothing(last) ? mem.conversation : get_last(mem, last) + + # Add user message to memory first + user_msg = UserMessage(prompt) + push!(mem, user_msg) + + # Generate response with context + response = aigenerate(context; return_all = true, kwargs...) + append!(mem, response) + return last_message(response) +end + +""" + aigenerate(schema::AbstractPromptSchema, + mem::ConversationMemory; kwargs...) + +Generate a response using the conversation memory context. +""" +function aigenerate(schema::AbstractPromptSchema, + mem::ConversationMemory; kwargs...) + aigenerate(schema, mem.conversation; kwargs...) +end diff --git a/src/messages.jl b/src/messages.jl index acc6e2a39..f1a8e6b91 100644 --- a/src/messages.jl +++ b/src/messages.jl @@ -4,6 +4,17 @@ abstract type AbstractMessage end abstract type AbstractChatMessage <: AbstractMessage end # with text-based content abstract type AbstractDataMessage <: AbstractMessage end # with data-based content, eg, embeddings +""" + AbstractAnnotationMessage + +Messages that provide extra information without being sent to LLMs. + +Required fields: `content`, `tags`, `comment`, `run_id`. + +Note: `comment` is intended for human readers only and should never be used. +`run_id` should be a unique identifier for the annotation, typically a random number. +""" +abstract type AbstractAnnotationMessage <: AbstractMessage end # messages that provide extra information without being sent to LLMs abstract type AbstractTracerMessage{T <: AbstractMessage} <: AbstractMessage end # message with annotation that exposes the underlying message # Complementary type for tracing, follows the same API as TracerMessage abstract type AbstractTracer{T <: Any} end @@ -22,6 +33,7 @@ Base.@kwdef struct MetadataMessage{T <: AbstractString} <: AbstractChatMessage description::String = "" version::String = "1" source::String = "" + run_id::Union{Nothing, Int} = Int(rand(Int16)) _type::Symbol = :metadatamessage end Base.@kwdef struct SystemMessage{T <: AbstractString} <: AbstractChatMessage @@ -230,10 +242,47 @@ tool_calls(msg::AbstractMessage) = ToolMessage[] tool_calls(msg::ToolMessage) = [msg] tool_calls(msg::AbstractTracerMessage) = tool_calls(msg.object) +""" + AnnotationMessage + +A message type for providing extra information in the conversation history without being sent to LLMs. +These messages are filtered out during rendering to ensure they don't affect the LLM's context. + +Used to bundle key information and documentation for colleagues and future reference together with the data. + +# Fields +- `content::T`: The content of the annotation (can be used for inputs to airag etc.) +- `extras::Dict{Symbol,Any}`: Additional metadata with symbol keys and any values +- `tags::Vector{Symbol}`: Vector of tags for categorization (default: empty) +- `comment::String`: Human-readable comment, never used for automatic operations (default: empty) +- `run_id::Union{Nothing,Int}`: The unique ID of the annotation + +Note: The comment field is intended for human readers only and should never be used +for automatic operations. +""" +Base.@kwdef struct AnnotationMessage{T <: AbstractString} <: AbstractAnnotationMessage + content::T + extras::Union{Nothing, Dict{Symbol, Any}} = nothing + tags::Vector{Symbol} = Symbol[] + comment::String = "" + run_id::Union{Nothing, Int} = Int(rand(Int32)) + _type::Symbol = :annotationmessage +end + ### Other Message methods # content-only constructor -function (MSG::Type{<:AbstractChatMessage})(prompt::AbstractString) - MSG(; content = prompt) +function (MSG::Type{<:AbstractChatMessage})(prompt::AbstractString; kwargs...) + MSG(; content = prompt, kwargs...) +end +function (MSG::Type{<:AbstractAnnotationMessage})(content::AbstractString; kwargs...) + ## Re-type extras to be generic Dict{Symbol, Any} + new_kwargs = if haskey(kwargs, :extras) + [f == :extras ? f => convert(Dict{Symbol, Any}, kwargs[f]) : f => kwargs[f] + for f in keys(kwargs)] + else + kwargs + end + MSG(; content, new_kwargs...) end function (MSG::Type{<:AbstractChatMessage})(msg::AbstractChatMessage) MSG(; msg.content) @@ -250,6 +299,7 @@ isdatamessage(m::Any) = m isa DataMessage isaimessage(m::Any) = m isa AIMessage istoolmessage(m::Any) = m isa ToolMessage isaitoolrequest(m::Any) = m isa AIToolRequest +isabstractannotationmessage(msg::Any) = msg isa AbstractAnnotationMessage istracermessage(m::Any) = m isa AbstractTracerMessage isusermessage(m::AbstractTracerMessage) = isusermessage(m.object) isusermessagewithimages(m::AbstractTracerMessage) = isusermessagewithimages(m.object) @@ -258,6 +308,9 @@ isdatamessage(m::AbstractTracerMessage) = isdatamessage(m.object) isaimessage(m::AbstractTracerMessage) = isaimessage(m.object) istoolmessage(m::AbstractTracerMessage) = istoolmessage(m.object) isaitoolrequest(m::AbstractTracerMessage) = isaitoolrequest(m.object) +function isabstractannotationmessage(m::AbstractTracerMessage) + isabstractannotationmessage(m.object) +end # equality check for testing, only equal if all fields are equal and type is the same Base.var"=="(m1::AbstractMessage, m2::AbstractMessage) = false @@ -481,7 +534,7 @@ end "Helpful accessor for the last generated output (`msg.content`) in `conversation`. Returns the last output in the conversation (eg, the string/data in the last message)." function last_output(conversation::AbstractVector{<:AbstractMessage}) msg = last_message(conversation) - return isnothing(msg) ? nothing : msg.content + return isnothing(msg) ? nothing : last_output(msg) end last_message(msg::AbstractMessage) = msg last_output(msg::AbstractMessage) = msg.content @@ -521,6 +574,11 @@ function Base.show(io::IO, ::MIME"text/plain", m::AbstractDataMessage) print(io, "(", typeof(m.content), ")") end end +function Base.show(io::IO, ::MIME"text/plain", m::AbstractAnnotationMessage) + type_ = string(typeof(m)) |> x -> split(x, "{")[begin] + printstyled(io, type_; color = :light_blue) + print(io, "(\"", m.content, "\")") +end function Base.show(io::IO, ::MIME"text/plain", t::AbstractTracerMessage) dump(IOContext(io, :limit => true), t, maxdepth = 1) end @@ -557,7 +615,8 @@ function StructTypes.subtypes(::Type{AbstractMessage}) systemmessage = SystemMessage, metadatamessage = MetadataMessage, datamessage = DataMessage, - tracermessage = TracerMessage) + tracermessage = TracerMessage, + annotationmessage = AnnotationMessage) end StructTypes.StructType(::Type{AbstractChatMessage}) = StructTypes.AbstractType() @@ -567,7 +626,14 @@ function StructTypes.subtypes(::Type{AbstractChatMessage}) usermessagewithimages = UserMessageWithImages, aimessage = AIMessage, systemmessage = SystemMessage, - metadatamessage = MetadataMessage) + metadatamessage = MetadataMessage, + annotationmessage = AnnotationMessage) +end + +StructTypes.StructType(::Type{AbstractAnnotationMessage}) = StructTypes.AbstractType() +StructTypes.subtypekey(::Type{AbstractAnnotationMessage}) = :_type +function StructTypes.subtypes(::Type{AbstractAnnotationMessage}) + (annotationmessage = AnnotationMessage,) end StructTypes.StructType(::Type{AbstractTracerMessage}) = StructTypes.AbstractType() @@ -590,6 +656,7 @@ StructTypes.StructType(::Type{ToolMessage}) = StructTypes.Struct() StructTypes.StructType(::Type{AIToolRequest}) = StructTypes.Struct() StructTypes.StructType(::Type{AIMessage}) = StructTypes.Struct() StructTypes.StructType(::Type{DataMessage}) = StructTypes.Struct() +StructTypes.StructType(::Type{AnnotationMessage}) = StructTypes.Struct() StructTypes.StructType(::Type{TracerMessage}) = StructTypes.Struct() # Ignore mutability once we serialize StructTypes.StructType(::Type{TracerMessageLike}) = StructTypes.Struct() # Ignore mutability once we serialize @@ -615,6 +682,8 @@ function pprint(io::IO, msg::AbstractMessage; text_width::Int = displaysize(io)[ "AI Tool Request" elseif msg isa ToolMessage "Tool Message" + elseif msg isa AnnotationMessage + "Annotation Message" else "Unknown Message" end @@ -633,6 +702,10 @@ function pprint(io::IO, msg::AbstractMessage; text_width::Int = displaysize(io)[ elseif istoolmessage(msg) isnothing(msg.content) ? string("Name: ", msg.name, ", Args: ", msg.raw) : string(msg.content) + elseif isabstractannotationmessage(msg) + tags_str = isempty(msg.tags) ? "" : "\n [$(join(msg.tags, ", "))]" + comment_str = isempty(msg.comment) ? "" : "\n ($(msg.comment))" + "$(msg.content)$tags_str$comment_str" else wrap_string(msg.content, text_width) end diff --git a/src/precompilation.jl b/src/precompilation.jl index eb91c6afc..3dfe85339 100644 --- a/src/precompilation.jl +++ b/src/precompilation.jl @@ -1,6 +1,40 @@ +# Basic Message Types precompilation - moved to top +sys_msg = SystemMessage("You are a helpful assistant") +user_msg = UserMessage("Hello!") +ai_msg = AIMessage(content = "Test response") + +# Annotation Message precompilation - after basic types +annotation_msg = AnnotationMessage("Test metadata"; + extras = Dict{Symbol, Any}(:key => "value"), + tags = Symbol[:test], + comment = "Test comment") +_ = isabstractannotationmessage(annotation_msg) + +# ConversationMemory precompilation +memory = ConversationMemory() +push!(memory, sys_msg) +push!(memory, user_msg) +_ = get_last(memory, 2) +_ = length(memory) +_ = last_message(memory) + +# Test message rendering with all types - moved before API calls +messages = [ + sys_msg, + annotation_msg, + user_msg, + ai_msg +] +_ = render(OpenAISchema(), messages) + +## Utilities +pprint(messages) +last_output(messages) +last_message(messages) + # Load templates load_template(joinpath(@__DIR__, "..", "templates", "general", "BlankSystemUser.json")) -load_templates!(); +load_templates!() # Preferences @load_preference("MODEL_CHAT", default="x") @@ -50,4 +84,4 @@ msg = aiscan(schema, ## Streaming configuration cb = StreamCallback() -configure_callback!(cb, OpenAISchema()) \ No newline at end of file +configure_callback!(cb, OpenAISchema()) diff --git a/src/user_preferences.jl b/src/user_preferences.jl index 166754ee5..95aa2c2a4 100644 --- a/src/user_preferences.jl +++ b/src/user_preferences.jl @@ -194,7 +194,7 @@ function load_api_keys!() default=get(ENV, "MISTRAL_API_KEY", get(ENV, "MISTRALAI_API_KEY", ""))) if !isempty(get(ENV, "MISTRALAI_API_KEY", "")) - @warn "The MISTRALAI_API_KEY environment variable is deprecated. Use MISTRAL_API_KEY instead." + @debug "The MISTRALAI_API_KEY environment variable is deprecated. Use MISTRAL_API_KEY instead." end global COHERE_API_KEY COHERE_API_KEY = @load_preference("COHERE_API_KEY", diff --git a/test/annotation.jl b/test/annotation.jl new file mode 100644 index 000000000..878cf4294 --- /dev/null +++ b/test/annotation.jl @@ -0,0 +1,113 @@ +using PromptingTools: isabstractannotationmessage, annotate!, pprint +using PromptingTools: OpenAISchema, AnthropicSchema, OllamaSchema, GoogleSchema, + TestEchoOpenAISchema, render, NoSchema +using PromptingTools: AnnotationMessage, SystemMessage, TracerMessage, UserMessage, + AIMessage + +@testset "Annotation Message Rendering" begin + # Create a mix of messages including annotation messages + messages = [ + SystemMessage("Be helpful"), + AnnotationMessage("This is metadata", extras = Dict{Symbol, Any}(:key => "value")), + UserMessage("Hello"), + AnnotationMessage("More metadata"), + AIMessage("Hi there!") # No status needed for basic message + ] + + @testset "Basic Message Filtering" begin + # Test OpenAI Schema with TestEcho + schema = TestEchoOpenAISchema(; + response = Dict( + "choices" => [Dict( + "message" => Dict("content" => "Test response", "role" => "assistant"), + "index" => 0, "finish_reason" => "stop")], + "usage" => Dict( + "prompt_tokens" => 10, "completion_tokens" => 20, "total_tokens" => 30), + "model" => "gpt-3.5-turbo", + "id" => "test-id", + "object" => "chat.completion", + "created" => 1234567890 + ), + status = 200 + ) + rendered = render(schema, messages) + @test length(rendered) == 3 # Should only have system, user, and AI messages + @test all(msg["role"] in ["system", "user", "assistant"] for msg in rendered) + @test !any(msg -> contains(msg["content"], "metadata"), rendered) + + # Test Anthropic Schema + rendered = render(AnthropicSchema(), messages) + @test length(rendered.conversation) == 2 # Should have user and AI messages + @test !isnothing(rendered.system) # System message should be preserved separately + @test all(msg["role"] in ["user", "assistant"] for msg in rendered.conversation) + @test !contains(rendered.system, "metadata") # Check system message + @test !any( + msg -> any(content -> contains(content["text"], "metadata"), msg["content"]), + rendered.conversation) + + # Test Ollama Schema + rendered = render(OllamaSchema(), messages) + @test length(rendered) == 3 # Should only have system, user, and AI messages + @test all(msg["role"] in ["system", "user", "assistant"] for msg in rendered) + @test !any(msg -> contains(msg["content"], "metadata"), rendered) + + # Test Google Schema + rendered = render(GoogleSchema(), messages) + @test length(rendered) == 2 # Google schema combines system message with first user message + @test all(msg[:role] in ["user", "model"] for msg in rendered) # Google uses "model" instead of "assistant" + @test !any( + msg -> any(part -> contains(part["text"], "metadata"), msg[:parts]), rendered) + + # Create a basic NoSchema + schema = NoSchema() + rendered = render(schema, messages) + @test length(rendered) == 3 + @test all(!isabstractannotationmessage, rendered) + end +end + +@testset "annotate!" begin + # Test basic annotation with single message + msg = UserMessage("Hello") + annotated = annotate!(msg, "metadata"; tags = [:test]) + @test length(annotated) == 2 + @test isabstractannotationmessage(annotated[1]) + @test annotated[1].content == "metadata" + @test annotated[1].tags == [:test] + @test annotated[2].content == "Hello" + + # Test annotation with vector of messages + messages = [ + SystemMessage("System"), + UserMessage("User"), + AIMessage("AI") + ] + annotated = annotate!(messages, "metadata"; comment = "test comment") + @test length(annotated) == 4 + @test isabstractannotationmessage(annotated[1]) + @test annotated[1].content == "metadata" + @test annotated[1].comment == "test comment" + @test annotated[2:end] == messages + + # Test annotation with existing annotations + messages = [ + AnnotationMessage("First annotation"), + SystemMessage("System"), + UserMessage("User"), + AnnotationMessage("Second annotation"), + AIMessage("AI") + ] + annotated = annotate!(messages, "new metadata") + @test length(annotated) == 6 + @test isabstractannotationmessage(annotated[1]) + @test isabstractannotationmessage(annotated[4]) + @test annotated[5].content == "new metadata" + @test annotated[6].content == "AI" + + # Test annotation with extras + extras = Dict{Symbol, Any}(:key => "value") + annotated = annotate!(UserMessage("Hello"), "metadata"; extras = extras) + @test length(annotated) == 2 + @test annotated[1].content == "metadata" + @test annotated[1].extras == extras +end diff --git a/test/memory.jl b/test/memory.jl new file mode 100644 index 000000000..c7fcc5026 --- /dev/null +++ b/test/memory.jl @@ -0,0 +1,174 @@ +using PromptingTools: SystemMessage, UserMessage, AIMessage, AbstractMessage +using PromptingTools: TestEchoOpenAISchema, ConversationMemory +using PromptingTools: issystemmessage, isusermessage, isaimessage, last_message, + last_output, register_model!, batch_start_index, + get_last, pprint + +@testset "batch_start_index" begin + # Test basic batch calculation + @test batch_start_index(30, 10, 10) == 30 # Last batch of size 10 + @test batch_start_index(31, 10, 10) == 30 # Last batch of size 10 + @test batch_start_index(32, 10, 10) == 30 # Last batch of size 10 + @test batch_start_index(30, 20, 10) == 20 # Middle batch + @test batch_start_index(31, 20, 10) == 20 # Middle batch + @test batch_start_index(32, 20, 10) == 20 # Middle batch + @test batch_start_index(30, 30, 10) == 10 + @test batch_start_index(31, 30, 10) == 10 + @test batch_start_index(32, 30, 10) == 10 + + # Test edge cases + @test batch_start_index(10, 10, 5) == 5 # Last batch with exact fit + @test batch_start_index(11, 10, 5) == 5 # Last batch with exact fit + @test batch_start_index(12, 10, 5) == 5 # Last batch with exact fit + @test batch_start_index(13, 10, 5) == 5 # Last batch with exact fit + @test batch_start_index(14, 10, 5) == 5 # Last batch with exact fit + @test batch_start_index(15, 10, 5) == 10 + + # Test minimum bound + @test batch_start_index(5, 10, 10) == 1 # Should not go below 1 + + @test_throws AssertionError batch_start_index(3, 5, 10) +end + +@testset "ConversationMemory-type" begin + # Test constructor and empty initialization + mem = ConversationMemory() + @test length(mem) == 0 + @test isempty(mem.conversation) + + # Test show method + io = IOBuffer() + show(io, mem) + @test String(take!(io)) == "ConversationMemory(0 messages)" + pprint(io, mem) + @test String(take!(io)) == "" + + # Test push! and length + push!(mem, SystemMessage("System prompt")) + show(io, mem) + @test String(take!(io)) == "ConversationMemory(0 messages)" # don't count system messages + @test length(mem) == 1 + push!(mem, UserMessage("Hello")) + @test length(mem) == 2 + push!(mem, AIMessage("Hi there")) + @test length(mem) == 3 + + # Test last_message and last_output + @test last_message(mem).content == "Hi there" + @test last_output(mem) == "Hi there" + + # Test with non-AI last message + push!(mem, UserMessage("How are you?")) + @test last_message(mem).content == "How are you?" + @test last_output(mem) == "How are you?" + + pprint(io, mem) + output = String(take!(io)) + @test occursin("How are you?", output) +end + +@testset "get_last" begin + mem = ConversationMemory() + + # Add test messages + push!(mem, SystemMessage("System prompt")) + push!(mem, UserMessage("First user")) + for i in 1:15 + push!(mem, AIMessage("AI message $i")) + push!(mem, UserMessage("User message $i")) + end + + # Test get_last without batch_size + recent = get_last(mem, 5) + @test length(recent) == 4 # 5 + system + first user + @test recent[1].content == "System prompt" + @test recent[2].content == "First user" + + # Test get_last with batch_size=10 + recent = get_last(mem, 20; batch_size = 10) + # @test 11 <= length(recent) <= 20 # Should be between 11-20 messages + @test length(recent) == 14 + @test recent[1].content == "System prompt" + @test recent[2].content == "First user" + recent = get_last(mem, 14; batch_size = 10) + @test length(recent) == 14 + # @test 11 <= length(recent) <= 14 # Should be between 11-20 messages + @test recent[1].content == "System prompt" + @test recent[2].content == "First user" + + # Test get_last with explanation + recent = get_last(mem, 5; explain = true) + @test startswith(recent[3].content, "[This is an automatically added explanation") + + # Test get_last with verbose + @test_logs (:info, r"truncated to 4/32") get_last(mem, 5; verbose = true) +end + +@testset "ConversationMemory-append!" begin + mem = ConversationMemory() + + # Test append! with empty memory + msgs = [ + SystemMessage("System prompt"), + UserMessage("User 1"), + AIMessage("AI 1"; run_id = 1) + ] + append!(mem, msgs) + @test length(mem) == 3 + + # Run again, changes nothing + append!(mem, msgs) + @test length(mem) == 3 + + # Test append! with run_id based deduplication + msgs = [ + SystemMessage("System prompt"), + UserMessage("User 1"), + AIMessage("AI 1"; run_id = 1), + UserMessage("User 2"), + AIMessage("AI 2"; run_id = 2) + ] + append!(mem, msgs) + @test length(mem) == 5 + + # Test append! with overlapping messages + msgs_overlap = [ + SystemMessage("System prompt 2"), + UserMessage("User 3"), + AIMessage("AI 3"; run_id = 3), + UserMessage("User 4"), + AIMessage("AI 4"; run_id = 4) + ] + append!(mem, msgs_overlap) + @test length(mem) == 10 +end + +@testset "ConversationMemory-aigenerate" begin + # Setup mock response + response = Dict( + :choices => [ + Dict(:message => Dict(:content => "Hello World!"), + :finish_reason => "stop") + ], + :usage => Dict(:total_tokens => 3, :prompt_tokens => 2, :completion_tokens => 1)) + + schema = TestEchoOpenAISchema(; response = response, status = 200) + register_model!(; name = "memory-echo", schema) + + mem = ConversationMemory() + push!(mem, SystemMessage("You are a helpful assistant")) + result = mem("Hello!"; model = "memory-echo") + @test result.content == "Hello World!" + @test length(mem) == 3 + + # Test functor interface with history truncation + for i in 1:5 + result = mem("Message $i"; model = "memory-echo") + end + result = mem("Final message"; last = 3, model = "memory-echo") + @test length(mem) == 15 # 5x2 + final x2 + 3 + + # Test aigenerate method integration + result = aigenerate(mem; model = "memory-echo") + @test result.content == "Hello World!" +end diff --git a/test/messages.jl b/test/messages.jl index 2ece68c37..7efb90414 100644 --- a/test/messages.jl +++ b/test/messages.jl @@ -1,12 +1,13 @@ using PromptingTools: AIMessage, SystemMessage, MetadataMessage, AbstractMessage using PromptingTools: UserMessage, UserMessageWithImages, DataMessage, AIToolRequest, - ToolMessage + ToolMessage, AnnotationMessage using PromptingTools: _encode_local_image, attach_images_to_user_message, last_message, last_output, tool_calls using PromptingTools: isusermessage, issystemmessage, isdatamessage, isaimessage, - istracermessage, isaitoolrequest, istoolmessage + istracermessage, isaitoolrequest, istoolmessage, + isabstractannotationmessage using PromptingTools: TracerMessageLike, TracerMessage, align_tracer!, unwrap, - AbstractTracerMessage, AbstractTracer, pprint + AbstractTracerMessage, AbstractTracer, pprint, annotate! using PromptingTools: TracerSchema, SaverSchema @testset "Message constructors" begin @@ -47,6 +48,56 @@ using PromptingTools: TracerSchema, SaverSchema @test isaimessage(missing) == false @test istracermessage(1) == false end + +@testset "AnnotationMessage" begin + # Test creation and basic properties + annotation = AnnotationMessage( + content = "Test annotation", + extras = Dict{Symbol, Any}(:key => "value"), + tags = [:debug, :test], + comment = "Test comment" + ) + @test annotation.content == "Test annotation" + @test annotation.extras[:key] == "value" + @test :debug in annotation.tags + @test annotation.comment == "Test comment" + @test isabstractannotationmessage(annotation) + @test !isabstractannotationmessage(UserMessage("test")) + + # Test that annotations are filtered out during rendering + messages = [ + SystemMessage("System prompt"), + UserMessage("User message"), + AnnotationMessage(content = "Debug info", comment = "Debug note"), + AIMessage("AI response") + ] + + # Test annotate! utility + msgs = [UserMessage("Hello"), AIMessage("Hi")] + msgs = annotate!(msgs, "Debug info", tags = [:debug]) + @test length(msgs) == 3 + @test isabstractannotationmessage(msgs[1]) + @test msgs[1].tags == [:debug] + + # Test pretty printing + io = IOBuffer() + pprint(io, annotation) + output = String(take!(io)) + @test occursin("Test annotation", output) + @test occursin("debug", output) + @test occursin("Test comment", output) + + # Test show method + io = IOBuffer() + show(io, MIME("text/plain"), annotation) + output = String(take!(io)) + @test occursin("AnnotationMessage", output) + @test occursin("Test annotation", output) + @test !occursin("extras", output) # Should only show type and content + @test !occursin("tags", output) + @test !occursin("comment", output) +end + @testset "UserMessageWithImages" begin content = "Hello, world!" image_path = joinpath(@__DIR__, "data", "julia.png") diff --git a/test/runtests.jl b/test/runtests.jl index 930a83861..31c1e69c5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,6 +19,8 @@ end @testset "PromptingTools.jl" begin include("utils.jl") include("messages.jl") + include("annotation.jl") + include("memory.jl") include("extraction.jl") include("user_preferences.jl") include("llm_interface.jl") diff --git a/test/serialization.jl b/test/serialization.jl index dd023b28d..b20bb0223 100644 --- a/test/serialization.jl +++ b/test/serialization.jl @@ -1,12 +1,17 @@ using PromptingTools: AIMessage, SystemMessage, UserMessage, UserMessageWithImages, AbstractMessage, - DataMessage, ShareGPTSchema, Tool, ToolMessage, AIToolRequest + DataMessage, ShareGPTSchema, Tool, ToolMessage, AIToolRequest, + AnnotationMessage, AbstractAnnotationMessage using PromptingTools: save_conversation, load_conversation, save_conversations using PromptingTools: save_template, load_template @testset "Serialization - Messages" begin # Test save_conversation - messages = AbstractMessage[SystemMessage("System message 1"), + messages = AbstractMessage[AnnotationMessage(; + content = "Annotation message"), + AnnotationMessage(; + content = "Annotation message 2", extras = Dict{Symbol, Any}(:a => 1, :b => 2)), + SystemMessage("System message 1"), UserMessage("User message"), AIMessage("AI message"), UserMessageWithImages(; content = "a", image_url = String["b", "c"]), @@ -22,6 +27,12 @@ using PromptingTools: save_template, load_template # Test load_conversation loaded_messages = load_conversation(tmp) @test loaded_messages == messages + + # save and load AbstractAnnotationMessage + msg = AnnotationMessage("Annotation message"; extras = Dict(:a => 1)) + JSON3.write(tmp, msg) + loaded_msg = JSON3.read(tmp, AbstractAnnotationMessage) + @test loaded_msg == msg end @testset "Serialization - Templates" begin