diff --git a/src/PromptingTools.jl b/src/PromptingTools.jl index 4f911ab35..214cc91e3 100644 --- a/src/PromptingTools.jl +++ b/src/PromptingTools.jl @@ -66,21 +66,14 @@ include("llm_interface.jl") include("user_preferences.jl") ## Conversation history / Prompt elements +export AIMessage include("messages.jl") -include("memory.jl") - -# Export message types and predicates -export SystemMessage, UserMessage, AIMessage, AnnotationMessage, issystemmessage, isusermessage, isaimessage, isabstractannotationmessage, annotate! -# Export memory-related functionality -export ConversationMemory, get_last, last_message, last_output -# export UserMessage, UserMessageWithImages, SystemMessage, DataMessage # for debugging only # export ConversationMemory include("memory.jl") # export annotate! include("annotation.jl") - export aitemplates, AITemplate include("templates.jl") diff --git a/src/memory.jl b/src/memory.jl index 121c267d8..df99f9fe0 100644 --- a/src/memory.jl +++ b/src/memory.jl @@ -6,6 +6,55 @@ which is a vector of `AbstractMessage`s. It's built to support intelligent trunc behavior (`get_last`). You can also use it as a functor to have extended conversations (easier than constantly passing `conversation` kwarg) + +# Examples + +Basic usage +```julia +mem = ConversationMemory() +push!(mem, SystemMessage("You are a helpful assistant")) +push!(mem, UserMessage("Hello!")) +push!(mem, AIMessage("Hi there!")) + +# or simply +mem = ConversationMemory(conv) +``` + +Check memory stats +```julia +println(mem) # ConversationMemory(2 messages) - doesn't count system message +@show length(mem) # 3 - counts all messages +@show last_message(mem) # gets last message +@show last_output(mem) # gets last content +``` + +Get recent messages with different options (System message, User message, ... + the most recent) +```julia +recent = get_last(mem, 5) # get last 5 messages (including system) +recent = get_last(mem, 20, batch_size=10) # align to batches of 10 for caching +recent = get_last(mem, 5, explain=true) # adds truncation explanation +recent = get_last(mem, 5, verbose=true) # prints truncation info +``` + +Append multiple messages at once (with deduplication to keep the memory complete) +```julia +msgs = [ + UserMessage("How are you?"), + AIMessage("I'm good!"; run_id=1), + UserMessage("Great!"), + AIMessage("Indeed!"; run_id=2) +] +append!(mem, msgs) # Will only append new messages based on run_ids etc. +``` + +Use for AI conversations (easier to manage conversations) +```julia +response = mem("Tell me a joke"; model="gpt4o") # Automatically manages context +response = mem("Another one"; last=3, model="gpt4o") # Use only last 3 messages (uses `get_last`) + +# Direct generation from the memory +result = aigenerate(mem) # Generate using full context +``` """ Base.@kwdef mutable struct ConversationMemory conversation::Vector{AbstractMessage} = AbstractMessage[] diff --git a/src/messages.jl b/src/messages.jl index a3c446e66..50c0ec8d0 100644 --- a/src/messages.jl +++ b/src/messages.jl @@ -527,7 +527,7 @@ end "Helpful accessor for the last generated output (`msg.content`) in `conversation`. Returns the last output in the conversation (eg, the string/data in the last message)." function last_output(conversation::AbstractVector{<:AbstractMessage}) msg = last_message(conversation) - return isnothing(msg) ? nothing : msg.content + return isnothing(msg) ? nothing : last_output(msg) end last_message(msg::AbstractMessage) = msg last_output(msg::AbstractMessage) = msg.content @@ -653,30 +653,6 @@ StructTypes.StructType(::Type{AnnotationMessage}) = StructTypes.Struct() StructTypes.StructType(::Type{TracerMessage}) = StructTypes.Struct() # Ignore mutability once we serialize StructTypes.StructType(::Type{TracerMessageLike}) = StructTypes.Struct() # Ignore mutability once we serialize -### Message Access Utilities - -""" - last_message(messages::Vector{<:AbstractMessage}) - -Get the last message in a conversation, regardless of type. -""" -function last_message(messages::Vector{<:AbstractMessage}) - isempty(messages) && return nothing - return last(messages) -end - -""" - last_output(messages::Vector{<:AbstractMessage}) - -Get the last AI-generated message (AIMessage) in a conversation. -""" -function last_output(messages::Vector{<:AbstractMessage}) - isempty(messages) && return nothing - last_ai_idx = findlast(isaimessage, messages) - isnothing(last_ai_idx) && return nothing - return messages[last_ai_idx] -end - ### Utilities for Pretty Printing """ pprint(io::IO, msg::AbstractMessage; text_width::Int = displaysize(io)[2]) diff --git a/test/memory.jl b/test/memory.jl index 43c24bcfc..9e1d46729 100644 --- a/test/memory.jl +++ b/test/memory.jl @@ -1,8 +1,7 @@ using PromptingTools: SystemMessage, UserMessage, AIMessage, AbstractMessage using PromptingTools: TestEchoOpenAISchema, ConversationMemory using PromptingTools: issystemmessage, isusermessage, isaimessage, last_message, - last_output, register_model!, batch_start_index -using HTTP, JSON3 + last_output, register_model!, batch_start_index, get_last @testset "batch_start_index" begin # Test basic batch calculation @@ -30,8 +29,6 @@ using HTTP, JSON3 @test_throws AssertionError batch_start_index(3, 5, 10) end -# @testset "ConversationMemory" begin - @testset "ConversationMemory-type" begin # Test constructor and empty initialization mem = ConversationMemory() diff --git a/test/test_annotation_messages.jl b/test/test_annotation_messages.jl deleted file mode 100644 index 4db9c5c2c..000000000 --- a/test/test_annotation_messages.jl +++ /dev/null @@ -1,162 +0,0 @@ -using Test -using PromptingTools -using PromptingTools: TestEchoOpenAISchema, render, SystemMessage, UserMessage, AIMessage, AnnotationMessage -using PromptingTools: OpenAISchema, AnthropicSchema, OllamaSchema, GoogleSchema - -@testset "AnnotationMessage" begin - # Test creation and basic properties - @testset "Basic Construction" begin - msg = AnnotationMessage(content="Test content") - @test msg.content == "Test content" - @test isempty(msg.extras) - @test !isnothing(msg.run_id) - end - - # Test with all fields - @testset "Full Construction" begin - msg = AnnotationMessage( - content="Full test", - extras=Dict{Symbol,Any}(:key => "value"), - tags=[:test, :example], - comment="Test comment" - ) - @test msg.content == "Full test" - @test msg.extras[:key] == "value" - @test msg.tags == [:test, :example] - @test msg.comment == "Test comment" - end - - # Test annotate! utility - @testset "annotate! utility" begin - # Test with vector of messages - messages = [SystemMessage("System"), UserMessage("User")] - annotated = annotate!(messages, "Annotation") - @test length(annotated) == 3 - @test annotated[1] isa AnnotationMessage - @test annotated[1].content == "Annotation" - - # Test with single message - message = UserMessage("Single") - annotated = annotate!(message, "Single annotation") - @test length(annotated) == 2 - @test annotated[1] isa AnnotationMessage - @test annotated[1].content == "Single annotation" - - # Test annotation placement with existing annotations - messages = [ - AnnotationMessage("First"), - SystemMessage("System"), - UserMessage("User") - ] - annotated = annotate!(messages, "Second") - @test length(annotated) == 4 - @test annotated[2] isa AnnotationMessage - @test annotated[2].content == "Second" - end - - # Test serialization - @testset "Serialization" begin - original = AnnotationMessage( - content="Test", - extras=Dict{Symbol,Any}(:key => "value"), - tags=[:test], - comment="Comment" - ) - - # Convert to Dict and back - dict = Dict(original) - reconstructed = convert(AnnotationMessage, dict) - - @test reconstructed.content == original.content - @test reconstructed.extras == original.extras - @test reconstructed.tags == original.tags - @test reconstructed.comment == original.comment - end - - # Test rendering skipping across all providers - @testset "Render Skipping" begin - # Create a mix of messages including annotation messages - messages = [ - SystemMessage("Be helpful"), - AnnotationMessage("This is metadata", extras=Dict{Symbol,Any}(:key => "value")), - UserMessage("Hello"), - AnnotationMessage("More metadata"), - AIMessage("Hi there!") - ] - - # Additional edge cases - messages_complex = [ - AnnotationMessage("Metadata 1", extras=Dict{Symbol,Any}(:key => "value")), - AnnotationMessage("Metadata 2", extras=Dict{Symbol,Any}(:key2 => "value2")), - SystemMessage("Be helpful"), - AnnotationMessage("Metadata 3", tags=[:important]), - UserMessage("Hello"), - AnnotationMessage("Metadata 4", comment="For debugging"), - AIMessage("Hi there!"), - AnnotationMessage("Metadata 5", extras=Dict{Symbol,Any}(:key3 => "value3")) - ] - - # Test OpenAI Schema with TestEcho - schema = TestEchoOpenAISchema( - response=Dict( - "choices" => [Dict("message" => Dict("content" => "Test response", "role" => "assistant"), "index" => 0, "finish_reason" => "stop")], - "usage" => Dict("prompt_tokens" => 10, "completion_tokens" => 20, "total_tokens" => 30), - "model" => "gpt-3.5-turbo", - "id" => "test-id", - "object" => "chat.completion", - "created" => 1234567890 - ), - status=200 - ) - rendered = render(schema, messages) - @test length(rendered) == 3 # Should only have system, user, and AI messages - @test all(msg["role"] in ["system", "user", "assistant"] for msg in rendered) - @test !any(msg -> contains(msg["content"], "metadata"), rendered) - - # Test Anthropic Schema - rendered = render(AnthropicSchema(), messages) - @test length(rendered.conversation) == 2 # Should have user and AI messages - @test !isnothing(rendered.system) # System message should be preserved separately - @test all(msg["role"] in ["user", "assistant"] for msg in rendered.conversation) - @test !contains(rendered.system, "metadata") # Check system message - @test !any(msg -> any(content -> contains(content["text"], "metadata"), msg["content"]), rendered.conversation) - - # Test Ollama Schema - rendered = render(OllamaSchema(), messages) - @test length(rendered) == 3 # Should only have system, user, and AI messages - @test all(msg["role"] in ["system", "user", "assistant"] for msg in rendered) - @test !any(msg -> contains(msg["content"], "metadata"), rendered) - - # Test Google Schema - rendered = render(GoogleSchema(), messages) - @test length(rendered) == 2 # Google schema combines system message with first user message - @test all(msg[:role] in ["user", "model"] for msg in rendered) # Google uses "model" instead of "assistant" - @test !any(msg -> any(part -> contains(part["text"], "metadata"), msg[:parts]), rendered) - - # Test complex edge cases - @testset "Complex Edge Cases" begin - for schema in [TestEchoOpenAISchema(), AnthropicSchema(), OllamaSchema(), GoogleSchema()] - rendered = render(schema, messages_complex) - - if schema isa AnthropicSchema - @test length(rendered.conversation) == 2 # user and AI only - @test !isnothing(rendered.system) # system preserved - else - @test length(rendered) == (schema isa GoogleSchema ? 2 : 3) # Google schema combines system with user message - end - - # Test no metadata leaks through - for i in 1:5 - if schema isa GoogleSchema - @test !any(msg -> any(part -> contains(part["text"], "Metadata $i"), msg[:parts]), rendered) - elseif schema isa AnthropicSchema - @test !any(msg -> any(content -> contains(content["text"], "Metadata $i"), msg["content"]), rendered.conversation) - @test !contains(rendered.system, "Metadata $i") - else - @test !any(msg -> contains(msg["content"], "Metadata $i"), rendered) - end - end - end - end - end -end