diff --git a/llamafile/server/atom.cpp b/llamafile/server/atom.cpp index e627345738..7dc6fa4b3c 100644 --- a/llamafile/server/atom.cpp +++ b/llamafile/server/atom.cpp @@ -56,6 +56,22 @@ Atom::Atom(const Atom& other) word_ = 2ull << 56 | (uintptr_t)image; } +Atom& +Atom::operator=(const Atom& other) +{ + if (this != &other) { + if (is_image()) + delete (Image*)(word_ & 0x00ffffffffffffff); + if (!other.is_image()) { + word_ = other.word_; + } else { + Image* image = new Image(other.image()); + word_ = 2ull << 56 | (uintptr_t)image; + } + } + return *this; +} + Atom::~Atom() { if (is_image()) diff --git a/llamafile/server/atom.h b/llamafile/server/atom.h index a09f40ed99..bb517b9397 100644 --- a/llamafile/server/atom.h +++ b/llamafile/server/atom.h @@ -31,6 +31,7 @@ class Atom Atom(const Atom&); Atom(Atom&&); ~Atom(); + Atom& operator=(const Atom&); int token() const; bool empty() const; int ctx_used() const; diff --git a/llamafile/server/atomize.cpp b/llamafile/server/atomize.cpp index a80391142a..6c8cfdf08a 100644 --- a/llamafile/server/atomize.cpp +++ b/llamafile/server/atomize.cpp @@ -92,5 +92,14 @@ remove_old_image_atoms(const std::vector& atoms) return result; } +int +count_tokens(const std::vector& atoms) +{ + int n = 0; + for (const Atom& atom : atoms) + n += atom.ctx_used(); + return n; +} + } // namespace server } // namespace lf diff --git a/llamafile/server/slot.cpp b/llamafile/server/slot.cpp index 78a9d9d23e..d858f0fa0c 100644 --- a/llamafile/server/slot.cpp +++ b/llamafile/server/slot.cpp @@ -25,6 +25,7 @@ #include "llamafile/server/atom.h" #include "llamafile/server/image.h" #include "llamafile/server/log.h" +#include "llamafile/server/utils.h" #include "llamafile/vector.h" #include "llamafile/version.h" #include @@ -290,43 +291,135 @@ Slot::prefill(const std::vector& atoms, const ProgressCallback& progress) { if (!ctx_) return uninitialized; - int used_tokens = ctx_used(); - int reuse_atoms = 0; - int reuse_tokens = 0; - int erase_tokens = 0; + + // handle special case of empty prefill + if (atoms.empty()) { + llama_kv_cache_clear(ctx_); + history_.clear(); + return 0; + } + + // when a prefill request comes in, chances are the system prompt + // will already be loaded and the unique user request in atoms is + // going to have something different that follows. in such a case + // we'll rapidly delete the latter portion from the KV cache, and + // then we won't need the cost of prefilling the earlier portion. + // + // "hello world i love you!" <-- atoms + // "hello world how are you" <-- history + // "hello world " <-- keep + // "how are you" <-- evaluated + // + // when context runs out the completions interface or user client + // might delete content in the middle, in which case we can shift + // content backwards based on the matching suffix. + // + // "sysprompt msg2 msg3 msg4" <-- atoms + // └──┬────┘ └──────┬┘ + // │ │ + // ┌──┴────┐ ┌─┴─────┐ + // "sysprompt msg1 msg2 msg3" <-- history + // "sysprompt " <-- keep + // "msg1 " <-- discard + // "msg2 msg3" <-- relocate + // "sysprompt msg2 msg3" <-- llama_kv_cache_seq_rm + // "sysprompt msg2 msg3" <-- llama_kv_cache_seq_add + // "msg4" <-- evaluated + // + int keep = 0; int n = std::min(atoms.size(), history_.size()); - for (int i = 0; i < n && atoms[i] == history_[i]; ++i) { - reuse_tokens += history_[i].ctx_used(); - reuse_atoms += 1; + for (int i = 0; i < n && atoms[i] == history_[i]; ++i) + ++keep; + int relocate_p0 = -1; + int relocate_p1 = -1; + int skipped = keep; + for (int i = keep + 1; i < history_.size(); ++i) { + if (history_.size() - i > atoms.size() - keep) + continue; + if (std::equal(history_.begin() + i, // + history_.end(), + atoms.begin() + keep)) { + relocate_p0 = i; + relocate_p1 = history_.size(); + skipped += history_.size() - i; + break; + } } - // xxx: ensure we prefill at least one token (prevents badness) - if (reuse_tokens >= 1) { - reuse_atoms -= 1; - reuse_tokens -= history_[reuse_atoms].ctx_used(); + + // xxx: ensure we eval at least one token + // this prevents an observed badness + if (skipped == atoms.size()) { + if (relocate_p0 != -1) { + --relocate_p1; + } else { + --keep; + } + --skipped; + } + + // now count tokens + int keep_tokens = 0; + int history_tokens = ctx_used(); + for (int i = 0; i < keep; ++i) + keep_tokens += history_[i].ctx_used(); + int relocate_p0_tokens = -1; + int relocate_p1_tokens = -1; + if (relocate_p0 != -1) { + relocate_p0_tokens = 0; + for (int i = 0; i < relocate_p0; ++i) + relocate_p0_tokens += history_[i].ctx_used(); + relocate_p1_tokens = 0; + for (int i = 0; i < relocate_p1; ++i) + relocate_p1_tokens += history_[i].ctx_used(); } - if (used_tokens > reuse_tokens) { - erase_tokens = used_tokens - reuse_tokens; - if (llama_kv_cache_seq_rm(ctx_, 0, reuse_tokens, -1)) { - history_.resize(reuse_atoms); + int skipped_tokens = 0; + for (int i = 0; i < skipped; ++i) + skipped_tokens += atoms[i].ctx_used(); + + // discard tokens from kv cache + int discarded_tokens; + int relocated_tokens = 0; + if (llama_kv_cache_seq_rm(ctx_, 0, keep_tokens, relocate_p0_tokens)) { + if (relocate_p0 == -1) { + discarded_tokens = history_.size() - keep_tokens; + history_.resize(keep); } else { - SLOG("failed to remove tokens from KV cache"); - reuse_atoms = 0; - reuse_tokens = 0; - erase_tokens = used_tokens; - llama_kv_cache_clear(ctx_); - history_.clear(); + discarded_tokens = (history_.size() - relocate_p1) + + (relocate_p0_tokens - keep_tokens); + relocated_tokens = relocate_p1_tokens - relocate_p0_tokens; + history_.resize(relocate_p1); + history_.erase(history_.begin() + keep, + history_.begin() + relocate_p0); + // memmove relocated tokens in kv cache + llama_kv_cache_seq_add(ctx_, + 0, + relocate_p0_tokens, + relocate_p1_tokens, + -(relocate_p0_tokens - keep_tokens)); } + } else { + // models like Mamba can't be partially erased + SLOG("failed to remove tokens from KV cache"); + discarded_tokens = history_.size(); + llama_kv_cache_clear(ctx_); + history_.clear(); + skipped = 0; } - std::vector new_atoms(atoms.begin() + reuse_atoms, atoms.end()); + + // evaluate tokens + std::vector new_atoms(atoms.begin() + skipped, atoms.end()); int rc; if ((rc = eval_atoms(new_atoms, progress)) < 0) return rc; - int token_count = reuse_tokens + rc; - SLOG("prefilled %zu tokens (after removing %zu and reusing %zu)", - token_count, - erase_tokens, - reuse_tokens); - return token_count; + int total_tokens = keep_tokens + relocated_tokens + rc; + SLOG("prefilled %d tokens (after keeping %d, discarding %d, " + "relocating %d, and evaluating %d)", + total_tokens, + keep_tokens, + discarded_tokens, + relocated_tokens, + count_tokens(new_atoms)); + return total_tokens; } void diff --git a/llamafile/server/utils.h b/llamafile/server/utils.h index 4ee1cea229..5343758bf6 100644 --- a/llamafile/server/utils.h +++ b/llamafile/server/utils.h @@ -47,5 +47,8 @@ atomize(const llama_model* model, std::vector remove_old_image_atoms(const std::vector&); +int +count_tokens(const std::vector&); + } // namespace server } // namespace lf diff --git a/llamafile/server/v1_chat_completions.cpp b/llamafile/server/v1_chat_completions.cpp index ab8aedf662..aa7f65ee8e 100644 --- a/llamafile/server/v1_chat_completions.cpp +++ b/llamafile/server/v1_chat_completions.cpp @@ -173,15 +173,6 @@ has_images(const std::vector& atoms) return false; } -static int -count_tokens(const std::vector& atoms) -{ - int n = 0; - for (const Atom& atom : atoms) - n += atom.ctx_used(); - return n; -} - static int count_bytes(const std::vector& messages) { @@ -548,8 +539,10 @@ Client::v1_chat_completions() ++last; } while (bytes_deleted < bytes_to_delete && forgotten_msgs < max_forget_msgs); + SLOG("forgot %d / %zu old messages", + forgotten_msgs, + params->messages.size()); params->messages.erase(first, last); - SLOG("forgot %d old messages", forgotten_msgs); } // init sampling diff --git a/llamafile/server/www/chatbot.js b/llamafile/server/www/chatbot.js index ad68669b6f..544fbfa986 100644 --- a/llamafile/server/www/chatbot.js +++ b/llamafile/server/www/chatbot.js @@ -106,6 +106,7 @@ async function handleChatStream(response) { let buffer = ""; let currentMessageElement = null; let messageAppended = false; + let finishReason = null; let hdom = null; let high = null; streamingMessageContent = []; @@ -132,7 +133,7 @@ async function handleChatStream(response) { try { const parsed = JSON.parse(data); const content = parsed.choices[0]?.delta?.content || ""; - const finishReason = parsed.choices[0]?.finish_reason; + finishReason = parsed.choices[0]?.finish_reason; // handle prefill progress if (parsed.x_prefill_progress !== undefined) { @@ -155,18 +156,6 @@ async function handleChatStream(response) { high.feed(content); scrollToBottom(); } - - // we don't supply max_tokens, so "length" can - // only mean that we ran out of context window - if (finishReason === "length" && hdom) { - let img = hdom.push("IMG", "ooc"); - img.src = "ooc.svg"; - img.alt = "🚫"; - img.title = "Message truncated due to running out of context window. Consider tuning --ctx-size and/or --reserve-tokens"; - img.width = 16; - img.height = 16; - hdom.pop(); - } } catch (e) { console.error("Error parsing JSON:", e); } @@ -183,6 +172,18 @@ async function handleChatStream(response) { } finally { if (messageAppended) { high.flush(); + // we don't supply max_tokens, so "length" can + // only mean that we ran out of context window + if (finishReason === "length") { + let img = document.createElement("IMG"); + img.className = "ooc"; + img.src = "ooc.svg"; + img.alt = "🚫"; + img.title = "Message truncated due to running out of context window. Consider tuning --ctx-size and/or --reserve-tokens"; + img.width = 16; + img.height = 16; + hdom.lastElement.appendChild(img); + } } prefillStatus.style.display = "none"; cleanupAfterMessage(); @@ -248,7 +249,8 @@ async function sendMessage() { if (response.ok) { await handleChatStream(response); const lastMessage = streamingMessageContent.join(""); - chatHistory.push({ role: "assistant", content: lastMessage }); + if (lastMessage) + chatHistory.push({ role: "assistant", content: lastMessage }); } else { console.error("sendMessage() failed due to server error", response); chatMessages.appendChild(createMessageElement( diff --git a/llamafile/server/www/highlight.js b/llamafile/server/www/highlight.js index 92f5214968..35d87691df 100644 --- a/llamafile/server/www/highlight.js +++ b/llamafile/server/www/highlight.js @@ -40,6 +40,7 @@ class HighlightDom extends Highlight { super(); this.currentElement = containerElement; this.containerElement = containerElement; + this.lastElement = containerElement; this.text = ''; } @@ -59,6 +60,7 @@ class HighlightDom extends Highlight { elem.className = className; this.currentElement.appendChild(elem); this.currentElement = elem; + this.lastElement = elem; return elem; } @@ -80,6 +82,7 @@ class HighlightDom extends Highlight { flushText() { if (this.text) { this.currentElement.appendChild(document.createTextNode(this.text)); + this.lastElement = this.currentElement; this.text = ''; } }