Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
svilupp committed Nov 26, 2024
1 parent cbb769a commit 309a6ce
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 199 deletions.
9 changes: 1 addition & 8 deletions src/PromptingTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,14 @@ include("llm_interface.jl")
include("user_preferences.jl")

## Conversation history / Prompt elements
export AIMessage
include("messages.jl")
include("memory.jl")

# Export message types and predicates
export SystemMessage, UserMessage, AIMessage, AnnotationMessage, issystemmessage, isusermessage, isaimessage, isabstractannotationmessage, annotate!
# Export memory-related functionality
export ConversationMemory, get_last, last_message, last_output
# export UserMessage, UserMessageWithImages, SystemMessage, DataMessage # for debugging only

# export ConversationMemory
include("memory.jl")
# export annotate!
include("annotation.jl")


export aitemplates, AITemplate
include("templates.jl")

Expand Down
49 changes: 49 additions & 0 deletions src/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,55 @@ which is a vector of `AbstractMessage`s. It's built to support intelligent trunc
behavior (`get_last`).
You can also use it as a functor to have extended conversations (easier than constantly passing `conversation` kwarg)
# Examples
Basic usage
```julia
mem = ConversationMemory()
push!(mem, SystemMessage("You are a helpful assistant"))
push!(mem, UserMessage("Hello!"))
push!(mem, AIMessage("Hi there!"))
# or simply
mem = ConversationMemory(conv)
```
Check memory stats
```julia
println(mem) # ConversationMemory(2 messages) - doesn't count system message
@show length(mem) # 3 - counts all messages
@show last_message(mem) # gets last message
@show last_output(mem) # gets last content
```
Get recent messages with different options (System message, User message, ... + the most recent)
```julia
recent = get_last(mem, 5) # get last 5 messages (including system)
recent = get_last(mem, 20, batch_size=10) # align to batches of 10 for caching
recent = get_last(mem, 5, explain=true) # adds truncation explanation
recent = get_last(mem, 5, verbose=true) # prints truncation info
```
Append multiple messages at once (with deduplication to keep the memory complete)
```julia
msgs = [
UserMessage("How are you?"),
AIMessage("I'm good!"; run_id=1),
UserMessage("Great!"),
AIMessage("Indeed!"; run_id=2)
]
append!(mem, msgs) # Will only append new messages based on run_ids etc.
```
Use for AI conversations (easier to manage conversations)
```julia
response = mem("Tell me a joke"; model="gpt4o") # Automatically manages context
response = mem("Another one"; last=3, model="gpt4o") # Use only last 3 messages (uses `get_last`)
# Direct generation from the memory
result = aigenerate(mem) # Generate using full context
```
"""
Base.@kwdef mutable struct ConversationMemory
conversation::Vector{AbstractMessage} = AbstractMessage[]
Expand Down
26 changes: 1 addition & 25 deletions src/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ end
"Helpful accessor for the last generated output (`msg.content`) in `conversation`. Returns the last output in the conversation (eg, the string/data in the last message)."
function last_output(conversation::AbstractVector{<:AbstractMessage})
msg = last_message(conversation)
return isnothing(msg) ? nothing : msg.content
return isnothing(msg) ? nothing : last_output(msg)
end
last_message(msg::AbstractMessage) = msg
last_output(msg::AbstractMessage) = msg.content
Expand Down Expand Up @@ -653,30 +653,6 @@ StructTypes.StructType(::Type{AnnotationMessage}) = StructTypes.Struct()
StructTypes.StructType(::Type{TracerMessage}) = StructTypes.Struct() # Ignore mutability once we serialize
StructTypes.StructType(::Type{TracerMessageLike}) = StructTypes.Struct() # Ignore mutability once we serialize

### Message Access Utilities

"""
last_message(messages::Vector{<:AbstractMessage})
Get the last message in a conversation, regardless of type.
"""
function last_message(messages::Vector{<:AbstractMessage})
isempty(messages) && return nothing
return last(messages)
end

"""
last_output(messages::Vector{<:AbstractMessage})
Get the last AI-generated message (AIMessage) in a conversation.
"""
function last_output(messages::Vector{<:AbstractMessage})
isempty(messages) && return nothing
last_ai_idx = findlast(isaimessage, messages)
isnothing(last_ai_idx) && return nothing
return messages[last_ai_idx]
end

### Utilities for Pretty Printing
"""
pprint(io::IO, msg::AbstractMessage; text_width::Int = displaysize(io)[2])
Expand Down
5 changes: 1 addition & 4 deletions test/memory.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
using PromptingTools: SystemMessage, UserMessage, AIMessage, AbstractMessage
using PromptingTools: TestEchoOpenAISchema, ConversationMemory
using PromptingTools: issystemmessage, isusermessage, isaimessage, last_message,
last_output, register_model!, batch_start_index
using HTTP, JSON3
last_output, register_model!, batch_start_index, get_last

@testset "batch_start_index" begin
# Test basic batch calculation
Expand Down Expand Up @@ -30,8 +29,6 @@ using HTTP, JSON3
@test_throws AssertionError batch_start_index(3, 5, 10)
end

# @testset "ConversationMemory" begin

@testset "ConversationMemory-type" begin
# Test constructor and empty initialization
mem = ConversationMemory()
Expand Down
162 changes: 0 additions & 162 deletions test/test_annotation_messages.jl

This file was deleted.

0 comments on commit 309a6ce

Please sign in to comment.