Skip to content

Commit

Permalink
Support relocating matching suffixes in KV cache
Browse files Browse the repository at this point in the history
The chat completions endpoint is now able to relocate matching suffixes.
It'll happen automatically when your conversation history gets too long.
You can also remove messages from the middle of your chat history as the
client, and the same optimization will still take place. This grants the
new server a higher degree of user control and flexibility than upstream
in llama.cpp whose server allows only a single hard-coded system prompt.

This change fixes an issue created by the previous change where upon the
context window being filled up, each new chat message typed into the web
ui would result in the prefill progress bar appearing because the latter
portion of the conversation would need to be reprocessed, after old chat
history was deleted. Now all the unpleasant user visible latency is gone
  • Loading branch information
jart committed Dec 14, 2024
1 parent 956e62c commit 83676d1
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 52 deletions.
16 changes: 16 additions & 0 deletions llamafile/server/atom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ Atom::Atom(const Atom& other)
word_ = 2ull << 56 | (uintptr_t)image;
}

Atom&
Atom::operator=(const Atom& other)
{
if (this != &other) {
if (is_image())
delete (Image*)(word_ & 0x00ffffffffffffff);
if (!other.is_image()) {
word_ = other.word_;
} else {
Image* image = new Image(other.image());
word_ = 2ull << 56 | (uintptr_t)image;
}
}
return *this;
}

Atom::~Atom()
{
if (is_image())
Expand Down
1 change: 1 addition & 0 deletions llamafile/server/atom.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Atom
Atom(const Atom&);
Atom(Atom&&);
~Atom();
Atom& operator=(const Atom&);
int token() const;
bool empty() const;
int ctx_used() const;
Expand Down
9 changes: 9 additions & 0 deletions llamafile/server/atomize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,14 @@ remove_old_image_atoms(const std::vector<Atom>& atoms)
return result;
}

int
count_tokens(const std::vector<Atom>& atoms)
{
int n = 0;
for (const Atom& atom : atoms)
n += atom.ctx_used();
return n;
}

} // namespace server
} // namespace lf
149 changes: 121 additions & 28 deletions llamafile/server/slot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "llamafile/server/atom.h"
#include "llamafile/server/image.h"
#include "llamafile/server/log.h"
#include "llamafile/server/utils.h"
#include "llamafile/vector.h"
#include "llamafile/version.h"
#include <algorithm>
Expand Down Expand Up @@ -290,43 +291,135 @@ Slot::prefill(const std::vector<Atom>& atoms, const ProgressCallback& progress)
{
if (!ctx_)
return uninitialized;
int used_tokens = ctx_used();
int reuse_atoms = 0;
int reuse_tokens = 0;
int erase_tokens = 0;

// handle special case of empty prefill
if (atoms.empty()) {
llama_kv_cache_clear(ctx_);
history_.clear();
return 0;
}

// when a prefill request comes in, chances are the system prompt
// will already be loaded and the unique user request in atoms is
// going to have something different that follows. in such a case
// we'll rapidly delete the latter portion from the KV cache, and
// then we won't need the cost of prefilling the earlier portion.
//
// "hello world i love you!" <-- atoms
// "hello world how are you" <-- history
// "hello world " <-- keep
// "how are you" <-- evaluated
//
// when context runs out the completions interface or user client
// might delete content in the middle, in which case we can shift
// content backwards based on the matching suffix.
//
// "sysprompt msg2 msg3 msg4" <-- atoms
// └──┬────┘ └──────┬┘
// │ │
// ┌──┴────┐ ┌─┴─────┐
// "sysprompt msg1 msg2 msg3" <-- history
// "sysprompt " <-- keep
// "msg1 " <-- discard
// "msg2 msg3" <-- relocate
// "sysprompt msg2 msg3" <-- llama_kv_cache_seq_rm
// "sysprompt msg2 msg3" <-- llama_kv_cache_seq_add
// "msg4" <-- evaluated
//
int keep = 0;
int n = std::min(atoms.size(), history_.size());
for (int i = 0; i < n && atoms[i] == history_[i]; ++i) {
reuse_tokens += history_[i].ctx_used();
reuse_atoms += 1;
for (int i = 0; i < n && atoms[i] == history_[i]; ++i)
++keep;
int relocate_p0 = -1;
int relocate_p1 = -1;
int skipped = keep;
for (int i = keep + 1; i < history_.size(); ++i) {
if (history_.size() - i > atoms.size() - keep)
continue;
if (std::equal(history_.begin() + i, //
history_.end(),
atoms.begin() + keep)) {
relocate_p0 = i;
relocate_p1 = history_.size();
skipped += history_.size() - i;
break;
}
}
// xxx: ensure we prefill at least one token (prevents badness)
if (reuse_tokens >= 1) {
reuse_atoms -= 1;
reuse_tokens -= history_[reuse_atoms].ctx_used();

// xxx: ensure we eval at least one token
// this prevents an observed badness
if (skipped == atoms.size()) {
if (relocate_p0 != -1) {
--relocate_p1;
} else {
--keep;
}
--skipped;
}

// now count tokens
int keep_tokens = 0;
int history_tokens = ctx_used();
for (int i = 0; i < keep; ++i)
keep_tokens += history_[i].ctx_used();
int relocate_p0_tokens = -1;
int relocate_p1_tokens = -1;
if (relocate_p0 != -1) {
relocate_p0_tokens = 0;
for (int i = 0; i < relocate_p0; ++i)
relocate_p0_tokens += history_[i].ctx_used();
relocate_p1_tokens = 0;
for (int i = 0; i < relocate_p1; ++i)
relocate_p1_tokens += history_[i].ctx_used();
}
if (used_tokens > reuse_tokens) {
erase_tokens = used_tokens - reuse_tokens;
if (llama_kv_cache_seq_rm(ctx_, 0, reuse_tokens, -1)) {
history_.resize(reuse_atoms);
int skipped_tokens = 0;
for (int i = 0; i < skipped; ++i)
skipped_tokens += atoms[i].ctx_used();

// discard tokens from kv cache
int discarded_tokens;
int relocated_tokens = 0;
if (llama_kv_cache_seq_rm(ctx_, 0, keep_tokens, relocate_p0_tokens)) {
if (relocate_p0 == -1) {
discarded_tokens = history_.size() - keep_tokens;
history_.resize(keep);
} else {
SLOG("failed to remove tokens from KV cache");
reuse_atoms = 0;
reuse_tokens = 0;
erase_tokens = used_tokens;
llama_kv_cache_clear(ctx_);
history_.clear();
discarded_tokens = (history_.size() - relocate_p1) +
(relocate_p0_tokens - keep_tokens);
relocated_tokens = relocate_p1_tokens - relocate_p0_tokens;
history_.resize(relocate_p1);
history_.erase(history_.begin() + keep,
history_.begin() + relocate_p0);
// memmove relocated tokens in kv cache
llama_kv_cache_seq_add(ctx_,
0,
relocate_p0_tokens,
relocate_p1_tokens,
-(relocate_p0_tokens - keep_tokens));
}
} else {
// models like Mamba can't be partially erased
SLOG("failed to remove tokens from KV cache");
discarded_tokens = history_.size();
llama_kv_cache_clear(ctx_);
history_.clear();
skipped = 0;
}
std::vector<Atom> new_atoms(atoms.begin() + reuse_atoms, atoms.end());

// evaluate tokens
std::vector<Atom> new_atoms(atoms.begin() + skipped, atoms.end());
int rc;
if ((rc = eval_atoms(new_atoms, progress)) < 0)
return rc;
int token_count = reuse_tokens + rc;
SLOG("prefilled %zu tokens (after removing %zu and reusing %zu)",
token_count,
erase_tokens,
reuse_tokens);
return token_count;
int total_tokens = keep_tokens + relocated_tokens + rc;
SLOG("prefilled %d tokens (after keeping %d, discarding %d, "
"relocating %d, and evaluating %d)",
total_tokens,
keep_tokens,
discarded_tokens,
relocated_tokens,
count_tokens(new_atoms));
return total_tokens;
}

void
Expand Down
3 changes: 3 additions & 0 deletions llamafile/server/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,8 @@ atomize(const llama_model* model,
std::vector<Atom>
remove_old_image_atoms(const std::vector<Atom>&);

int
count_tokens(const std::vector<Atom>&);

} // namespace server
} // namespace lf
13 changes: 3 additions & 10 deletions llamafile/server/v1_chat_completions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,6 @@ has_images(const std::vector<Atom>& atoms)
return false;
}

static int
count_tokens(const std::vector<Atom>& atoms)
{
int n = 0;
for (const Atom& atom : atoms)
n += atom.ctx_used();
return n;
}

static int
count_bytes(const std::vector<llama_chat_msg>& messages)
{
Expand Down Expand Up @@ -548,8 +539,10 @@ Client::v1_chat_completions()
++last;
} while (bytes_deleted < bytes_to_delete &&
forgotten_msgs < max_forget_msgs);
SLOG("forgot %d / %zu old messages",
forgotten_msgs,
params->messages.size());
params->messages.erase(first, last);
SLOG("forgot %d old messages", forgotten_msgs);
}

// init sampling
Expand Down
30 changes: 16 additions & 14 deletions llamafile/server/www/chatbot.js
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ async function handleChatStream(response) {
let buffer = "";
let currentMessageElement = null;
let messageAppended = false;
let finishReason = null;
let hdom = null;
let high = null;
streamingMessageContent = [];
Expand All @@ -132,7 +133,7 @@ async function handleChatStream(response) {
try {
const parsed = JSON.parse(data);
const content = parsed.choices[0]?.delta?.content || "";
const finishReason = parsed.choices[0]?.finish_reason;
finishReason = parsed.choices[0]?.finish_reason;

// handle prefill progress
if (parsed.x_prefill_progress !== undefined) {
Expand All @@ -155,18 +156,6 @@ async function handleChatStream(response) {
high.feed(content);
scrollToBottom();
}

// we don't supply max_tokens, so "length" can
// only mean that we ran out of context window
if (finishReason === "length" && hdom) {
let img = hdom.push("IMG", "ooc");
img.src = "ooc.svg";
img.alt = "🚫";
img.title = "Message truncated due to running out of context window. Consider tuning --ctx-size and/or --reserve-tokens";
img.width = 16;
img.height = 16;
hdom.pop();
}
} catch (e) {
console.error("Error parsing JSON:", e);
}
Expand All @@ -183,6 +172,18 @@ async function handleChatStream(response) {
} finally {
if (messageAppended) {
high.flush();
// we don't supply max_tokens, so "length" can
// only mean that we ran out of context window
if (finishReason === "length") {
let img = document.createElement("IMG");
img.className = "ooc";
img.src = "ooc.svg";
img.alt = "🚫";
img.title = "Message truncated due to running out of context window. Consider tuning --ctx-size and/or --reserve-tokens";
img.width = 16;
img.height = 16;
hdom.lastElement.appendChild(img);
}
}
prefillStatus.style.display = "none";
cleanupAfterMessage();
Expand Down Expand Up @@ -248,7 +249,8 @@ async function sendMessage() {
if (response.ok) {
await handleChatStream(response);
const lastMessage = streamingMessageContent.join("");
chatHistory.push({ role: "assistant", content: lastMessage });
if (lastMessage)
chatHistory.push({ role: "assistant", content: lastMessage });
} else {
console.error("sendMessage() failed due to server error", response);
chatMessages.appendChild(createMessageElement(
Expand Down
3 changes: 3 additions & 0 deletions llamafile/server/www/highlight.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class HighlightDom extends Highlight {
super();
this.currentElement = containerElement;
this.containerElement = containerElement;
this.lastElement = containerElement;
this.text = '';
}

Expand All @@ -59,6 +60,7 @@ class HighlightDom extends Highlight {
elem.className = className;
this.currentElement.appendChild(elem);
this.currentElement = elem;
this.lastElement = elem;
return elem;
}

Expand All @@ -80,6 +82,7 @@ class HighlightDom extends Highlight {
flushText() {
if (this.text) {
this.currentElement.appendChild(document.createTextNode(this.text));
this.lastElement = this.currentElement;
this.text = '';
}
}
Expand Down

0 comments on commit 83676d1

Please sign in to comment.