From 83676d11bc43b6e1e8bdbb66823aef2b757ab7f0 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Fri, 13 Dec 2024 17:55:40 -0800 Subject: [PATCH] Support relocating matching suffixes in KV cache The chat completions endpoint is now able to relocate matching suffixes. It'll happen automatically when your conversation history gets too long. You can also remove messages from the middle of your chat history as the client, and the same optimization will still take place. This grants the new server a higher degree of user control and flexibility than upstream in llama.cpp whose server allows only a single hard-coded system prompt. This change fixes an issue created by the previous change where upon the context window being filled up, each new chat message typed into the web ui would result in the prefill progress bar appearing because the latter portion of the conversation would need to be reprocessed, after old chat history was deleted. Now all the unpleasant user visible latency is gone --- llamafile/server/atom.cpp | 16 +++ llamafile/server/atom.h | 1 + llamafile/server/atomize.cpp | 9 ++ llamafile/server/slot.cpp | 149 ++++++++++++++++++----- llamafile/server/utils.h | 3 + llamafile/server/v1_chat_completions.cpp | 13 +- llamafile/server/www/chatbot.js | 30 ++--- llamafile/server/www/highlight.js | 3 + 8 files changed, 172 insertions(+), 52 deletions(-) diff --git a/llamafile/server/atom.cpp b/llamafile/server/atom.cpp index e627345738..7dc6fa4b3c 100644 --- a/llamafile/server/atom.cpp +++ b/llamafile/server/atom.cpp @@ -56,6 +56,22 @@ Atom::Atom(const Atom& other) word_ = 2ull << 56 | (uintptr_t)image; } +Atom& +Atom::operator=(const Atom& other) +{ + if (this != &other) { + if (is_image()) + delete (Image*)(word_ & 0x00ffffffffffffff); + if (!other.is_image()) { + word_ = other.word_; + } else { + Image* image = new Image(other.image()); + word_ = 2ull << 56 | (uintptr_t)image; + } + } + return *this; +} + Atom::~Atom() { if (is_image()) diff --git a/llamafile/server/atom.h b/llamafile/server/atom.h index a09f40ed99..bb517b9397 100644 --- a/llamafile/server/atom.h +++ b/llamafile/server/atom.h @@ -31,6 +31,7 @@ class Atom Atom(const Atom&); Atom(Atom&&); ~Atom(); + Atom& operator=(const Atom&); int token() const; bool empty() const; int ctx_used() const; diff --git a/llamafile/server/atomize.cpp b/llamafile/server/atomize.cpp index a80391142a..6c8cfdf08a 100644 --- a/llamafile/server/atomize.cpp +++ b/llamafile/server/atomize.cpp @@ -92,5 +92,14 @@ remove_old_image_atoms(const std::vector& atoms) return result; } +int +count_tokens(const std::vector& atoms) +{ + int n = 0; + for (const Atom& atom : atoms) + n += atom.ctx_used(); + return n; +} + } // namespace server } // namespace lf diff --git a/llamafile/server/slot.cpp b/llamafile/server/slot.cpp index 78a9d9d23e..d858f0fa0c 100644 --- a/llamafile/server/slot.cpp +++ b/llamafile/server/slot.cpp @@ -25,6 +25,7 @@ #include "llamafile/server/atom.h" #include "llamafile/server/image.h" #include "llamafile/server/log.h" +#include "llamafile/server/utils.h" #include "llamafile/vector.h" #include "llamafile/version.h" #include @@ -290,43 +291,135 @@ Slot::prefill(const std::vector& atoms, const ProgressCallback& progress) { if (!ctx_) return uninitialized; - int used_tokens = ctx_used(); - int reuse_atoms = 0; - int reuse_tokens = 0; - int erase_tokens = 0; + + // handle special case of empty prefill + if (atoms.empty()) { + llama_kv_cache_clear(ctx_); + history_.clear(); + return 0; + } + + // when a prefill request comes in, chances are the system prompt + // will already be loaded and the unique user request in atoms is + // going to have something different that follows. in such a case + // we'll rapidly delete the latter portion from the KV cache, and + // then we won't need the cost of prefilling the earlier portion. + // + // "hello world i love you!" <-- atoms + // "hello world how are you" <-- history + // "hello world " <-- keep + // "how are you" <-- evaluated + // + // when context runs out the completions interface or user client + // might delete content in the middle, in which case we can shift + // content backwards based on the matching suffix. + // + // "sysprompt msg2 msg3 msg4" <-- atoms + // └──┬────┘ └──────┬┘ + // │ │ + // ┌──┴────┐ ┌─┴─────┐ + // "sysprompt msg1 msg2 msg3" <-- history + // "sysprompt " <-- keep + // "msg1 " <-- discard + // "msg2 msg3" <-- relocate + // "sysprompt msg2 msg3" <-- llama_kv_cache_seq_rm + // "sysprompt msg2 msg3" <-- llama_kv_cache_seq_add + // "msg4" <-- evaluated + // + int keep = 0; int n = std::min(atoms.size(), history_.size()); - for (int i = 0; i < n && atoms[i] == history_[i]; ++i) { - reuse_tokens += history_[i].ctx_used(); - reuse_atoms += 1; + for (int i = 0; i < n && atoms[i] == history_[i]; ++i) + ++keep; + int relocate_p0 = -1; + int relocate_p1 = -1; + int skipped = keep; + for (int i = keep + 1; i < history_.size(); ++i) { + if (history_.size() - i > atoms.size() - keep) + continue; + if (std::equal(history_.begin() + i, // + history_.end(), + atoms.begin() + keep)) { + relocate_p0 = i; + relocate_p1 = history_.size(); + skipped += history_.size() - i; + break; + } } - // xxx: ensure we prefill at least one token (prevents badness) - if (reuse_tokens >= 1) { - reuse_atoms -= 1; - reuse_tokens -= history_[reuse_atoms].ctx_used(); + + // xxx: ensure we eval at least one token + // this prevents an observed badness + if (skipped == atoms.size()) { + if (relocate_p0 != -1) { + --relocate_p1; + } else { + --keep; + } + --skipped; + } + + // now count tokens + int keep_tokens = 0; + int history_tokens = ctx_used(); + for (int i = 0; i < keep; ++i) + keep_tokens += history_[i].ctx_used(); + int relocate_p0_tokens = -1; + int relocate_p1_tokens = -1; + if (relocate_p0 != -1) { + relocate_p0_tokens = 0; + for (int i = 0; i < relocate_p0; ++i) + relocate_p0_tokens += history_[i].ctx_used(); + relocate_p1_tokens = 0; + for (int i = 0; i < relocate_p1; ++i) + relocate_p1_tokens += history_[i].ctx_used(); } - if (used_tokens > reuse_tokens) { - erase_tokens = used_tokens - reuse_tokens; - if (llama_kv_cache_seq_rm(ctx_, 0, reuse_tokens, -1)) { - history_.resize(reuse_atoms); + int skipped_tokens = 0; + for (int i = 0; i < skipped; ++i) + skipped_tokens += atoms[i].ctx_used(); + + // discard tokens from kv cache + int discarded_tokens; + int relocated_tokens = 0; + if (llama_kv_cache_seq_rm(ctx_, 0, keep_tokens, relocate_p0_tokens)) { + if (relocate_p0 == -1) { + discarded_tokens = history_.size() - keep_tokens; + history_.resize(keep); } else { - SLOG("failed to remove tokens from KV cache"); - reuse_atoms = 0; - reuse_tokens = 0; - erase_tokens = used_tokens; - llama_kv_cache_clear(ctx_); - history_.clear(); + discarded_tokens = (history_.size() - relocate_p1) + + (relocate_p0_tokens - keep_tokens); + relocated_tokens = relocate_p1_tokens - relocate_p0_tokens; + history_.resize(relocate_p1); + history_.erase(history_.begin() + keep, + history_.begin() + relocate_p0); + // memmove relocated tokens in kv cache + llama_kv_cache_seq_add(ctx_, + 0, + relocate_p0_tokens, + relocate_p1_tokens, + -(relocate_p0_tokens - keep_tokens)); } + } else { + // models like Mamba can't be partially erased + SLOG("failed to remove tokens from KV cache"); + discarded_tokens = history_.size(); + llama_kv_cache_clear(ctx_); + history_.clear(); + skipped = 0; } - std::vector new_atoms(atoms.begin() + reuse_atoms, atoms.end()); + + // evaluate tokens + std::vector new_atoms(atoms.begin() + skipped, atoms.end()); int rc; if ((rc = eval_atoms(new_atoms, progress)) < 0) return rc; - int token_count = reuse_tokens + rc; - SLOG("prefilled %zu tokens (after removing %zu and reusing %zu)", - token_count, - erase_tokens, - reuse_tokens); - return token_count; + int total_tokens = keep_tokens + relocated_tokens + rc; + SLOG("prefilled %d tokens (after keeping %d, discarding %d, " + "relocating %d, and evaluating %d)", + total_tokens, + keep_tokens, + discarded_tokens, + relocated_tokens, + count_tokens(new_atoms)); + return total_tokens; } void diff --git a/llamafile/server/utils.h b/llamafile/server/utils.h index 4ee1cea229..5343758bf6 100644 --- a/llamafile/server/utils.h +++ b/llamafile/server/utils.h @@ -47,5 +47,8 @@ atomize(const llama_model* model, std::vector remove_old_image_atoms(const std::vector&); +int +count_tokens(const std::vector&); + } // namespace server } // namespace lf diff --git a/llamafile/server/v1_chat_completions.cpp b/llamafile/server/v1_chat_completions.cpp index ab8aedf662..aa7f65ee8e 100644 --- a/llamafile/server/v1_chat_completions.cpp +++ b/llamafile/server/v1_chat_completions.cpp @@ -173,15 +173,6 @@ has_images(const std::vector& atoms) return false; } -static int -count_tokens(const std::vector& atoms) -{ - int n = 0; - for (const Atom& atom : atoms) - n += atom.ctx_used(); - return n; -} - static int count_bytes(const std::vector& messages) { @@ -548,8 +539,10 @@ Client::v1_chat_completions() ++last; } while (bytes_deleted < bytes_to_delete && forgotten_msgs < max_forget_msgs); + SLOG("forgot %d / %zu old messages", + forgotten_msgs, + params->messages.size()); params->messages.erase(first, last); - SLOG("forgot %d old messages", forgotten_msgs); } // init sampling diff --git a/llamafile/server/www/chatbot.js b/llamafile/server/www/chatbot.js index ad68669b6f..544fbfa986 100644 --- a/llamafile/server/www/chatbot.js +++ b/llamafile/server/www/chatbot.js @@ -106,6 +106,7 @@ async function handleChatStream(response) { let buffer = ""; let currentMessageElement = null; let messageAppended = false; + let finishReason = null; let hdom = null; let high = null; streamingMessageContent = []; @@ -132,7 +133,7 @@ async function handleChatStream(response) { try { const parsed = JSON.parse(data); const content = parsed.choices[0]?.delta?.content || ""; - const finishReason = parsed.choices[0]?.finish_reason; + finishReason = parsed.choices[0]?.finish_reason; // handle prefill progress if (parsed.x_prefill_progress !== undefined) { @@ -155,18 +156,6 @@ async function handleChatStream(response) { high.feed(content); scrollToBottom(); } - - // we don't supply max_tokens, so "length" can - // only mean that we ran out of context window - if (finishReason === "length" && hdom) { - let img = hdom.push("IMG", "ooc"); - img.src = "ooc.svg"; - img.alt = "🚫"; - img.title = "Message truncated due to running out of context window. Consider tuning --ctx-size and/or --reserve-tokens"; - img.width = 16; - img.height = 16; - hdom.pop(); - } } catch (e) { console.error("Error parsing JSON:", e); } @@ -183,6 +172,18 @@ async function handleChatStream(response) { } finally { if (messageAppended) { high.flush(); + // we don't supply max_tokens, so "length" can + // only mean that we ran out of context window + if (finishReason === "length") { + let img = document.createElement("IMG"); + img.className = "ooc"; + img.src = "ooc.svg"; + img.alt = "🚫"; + img.title = "Message truncated due to running out of context window. Consider tuning --ctx-size and/or --reserve-tokens"; + img.width = 16; + img.height = 16; + hdom.lastElement.appendChild(img); + } } prefillStatus.style.display = "none"; cleanupAfterMessage(); @@ -248,7 +249,8 @@ async function sendMessage() { if (response.ok) { await handleChatStream(response); const lastMessage = streamingMessageContent.join(""); - chatHistory.push({ role: "assistant", content: lastMessage }); + if (lastMessage) + chatHistory.push({ role: "assistant", content: lastMessage }); } else { console.error("sendMessage() failed due to server error", response); chatMessages.appendChild(createMessageElement( diff --git a/llamafile/server/www/highlight.js b/llamafile/server/www/highlight.js index 92f5214968..35d87691df 100644 --- a/llamafile/server/www/highlight.js +++ b/llamafile/server/www/highlight.js @@ -40,6 +40,7 @@ class HighlightDom extends Highlight { super(); this.currentElement = containerElement; this.containerElement = containerElement; + this.lastElement = containerElement; this.text = ''; } @@ -59,6 +60,7 @@ class HighlightDom extends Highlight { elem.className = className; this.currentElement.appendChild(elem); this.currentElement = elem; + this.lastElement = elem; return elem; } @@ -80,6 +82,7 @@ class HighlightDom extends Highlight { flushText() { if (this.text) { this.currentElement.appendChild(document.createTextNode(this.text)); + this.lastElement = this.currentElement; this.text = ''; } }