Skip to content

llama-server : implement universal assisted decoding #12635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,10 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
for (auto & seq_breaker : params.sampling.dry_sequence_breakers) {
string_process_escapes(seq_breaker);
}
for (auto & pair : params.speculative.replacements) {
string_process_escapes(pair.first);
string_process_escapes(pair.second);
}
}

if (!params.kv_overrides.empty()) {
Expand Down Expand Up @@ -3217,6 +3221,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.model.path = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
add_opt(common_arg(
{"--spec-replace"}, "TARGET", "DRAFT",
"translate the string in TARGET into DRAFT if the draft model and main model are not compatible",
[](common_params & params, const std::string & tgt, const std::string & dft) {
params.speculative.replacements.push_back({ tgt, dft });
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
string_format(
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ struct common_params_speculative {
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements

ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
Expand Down
187 changes: 133 additions & 54 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
@@ -1,30 +1,38 @@
#include "speculative.h"

#include "llama.h"
#include "log.h"
#include "common.h"
#include "sampling.h"

#include <cstring>
#include <algorithm>
#include <map>

#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5

struct common_speculative {
struct llama_context * ctx;
struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
struct llama_context * ctx_dft;
struct common_sampler * smpl;

llama_batch batch;
llama_tokens prompt;
llama_tokens prompt_dft;
bool vocab_dft_compatible = true; // whether retokenization is needed
std::map<std::string, std::string> tgt_dft_replacements = {};
};

struct common_speculative * common_speculative_init(
struct llama_context * ctx_tgt,
struct llama_context * ctx_dft) {
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .prompt = */ {},
/* .ctx_main = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .prompt_dft = */ {},
/* .vocab_dft_compatible = */ false,
};

// TODO: optimize or pass from outside?
Expand Down Expand Up @@ -59,6 +67,9 @@ struct common_speculative * common_speculative_init(
}
#endif

result->vocab_dft_compatible = common_speculative_are_compatible(ctx_tgt, ctx_dft);
LOG_DBG("vocab_dft_compatible = %d\n", result->vocab_dft_compatible);

return result;
}

Expand All @@ -75,8 +86,8 @@ void common_speculative_free(struct common_speculative * spec) {
}

bool common_speculative_are_compatible(
const struct llama_context * ctx_tgt,
const struct llama_context * ctx_dft) {
const struct llama_context * ctx_tgt,
const struct llama_context * ctx_dft) {
const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
const struct llama_model * model_dft = llama_get_model(ctx_dft);

Expand All @@ -90,40 +101,41 @@ bool common_speculative_are_compatible(
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);

if (vocab_type_tgt != vocab_type_dft) {
LOG_ERR("%s: draft model vocab type must match target model to use speculation but "
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
LOG_DBG("%s: draft model vocab type must match target model to use speculation but ", __func__);
LOG_DBG("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
return false;
}

if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
if (
llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) {
LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__);
LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt));
LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft));
llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
) {
LOG_DBG("%s: draft model special tokens must match target model to use speculation\n", __func__);
return false;
}

{
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);

const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt;

if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
__func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return false;
}

for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but "
"token %d content differs - target '%s', draft '%s'\n", __func__, i,
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
common_token_to_piece(ctx_tgt, i).c_str(),
common_token_to_piece(ctx_dft, i).c_str());
return false;
Expand All @@ -134,32 +146,95 @@ bool common_speculative_are_compatible(
return true;
}

void common_speculative_add_replacement_tgt_dft(
struct common_speculative * spec,
const char *source, const char *dest) {
spec->tgt_dft_replacements[source] = dest;
}

static std::string replace_to_dft(
struct common_speculative * spec,
const std::string& input) {
std::string result = input;
for (const auto& pair : spec->tgt_dft_replacements) {
size_t pos = result.find(pair.first);
while (pos != std::string::npos) {
result.replace(pos, pair.first.length(), pair.second);
pos = result.find(pair.first, pos + pair.second.length());
}
}
return result;
}

static std::string replace_to_tgt(
struct common_speculative * spec,
const std::string& input) {
std::string result = input;
for (const auto& pair : spec->tgt_dft_replacements) {
size_t pos = result.find(pair.second);
while (pos != std::string::npos) {
result.replace(pos, pair.second.length(), pair.first);
pos = result.find(pair.second, pos + pair.first.length());
}
}
return result;
}


llama_tokens common_speculative_gen_draft(
struct common_speculative * spec,
struct common_speculative_params params,
const llama_tokens & prompt_tgt,
const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
llama_token id_last) {
auto & batch = spec->batch;
auto & ctx = spec->ctx;
auto & ctx_tgt = spec->ctx_tgt;
auto & ctx_dft = spec->ctx_dft;
auto & smpl = spec->smpl;
auto & prompt = spec->prompt;
auto & prompt_dft = spec->prompt_dft;

auto * mem = llama_get_memory(ctx);
auto * mem_dft = llama_get_memory(ctx_dft);

int reuse_i = 0;
int reuse_n = 0;

const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_draft;

llama_tokens prompt_tgt_draft_model;
if (!spec->vocab_dft_compatible) {
const llama_model * model_tgt = llama_get_model(ctx_tgt);

std::string text;
text = common_detokenize(ctx_tgt, prompt_tgt_main_model, true);
text = replace_to_dft(spec, text);
LOG_DBG("main->draft detokenized string: '%s'\n", text.c_str());
prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, true);
text.clear();

const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
int32_t n_chars;
n_chars = llama_detokenize(vocab_tgt, &id_last, 1, &text[0], text.size(), false, true);
if (n_chars < 0) {
text.resize(-n_chars);
n_chars = llama_detokenize(vocab_tgt, &id_last, 1, &text[0], text.size(), false, true);
}
text.resize(n_chars);
text = replace_to_dft(spec, text);
LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str());
id_last = common_tokenize(ctx_dft, text, false, true)[0];
}
// prompt_tgt's tokens will always be compatible with ctx_dft
const llama_tokens &prompt_tgt =
spec->vocab_dft_compatible ? prompt_tgt_main_model : prompt_tgt_draft_model;

const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);

// reuse as much as possible from the old draft context
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
for (int i = 0; i < (int) prompt.size(); ++i) {
for (int i = 0; i < (int) prompt_dft.size(); ++i) {
int cur = 0;
while (i_start + cur < (int) prompt_tgt.size() &&
i + cur < (int) prompt.size() &&
prompt_tgt[i_start + cur] == prompt[i + cur]) {
i + cur < (int) prompt_dft.size() &&
prompt_tgt[i_start + cur] == prompt_dft[i + cur]) {
cur++;
}

Expand All @@ -169,21 +244,20 @@ llama_tokens common_speculative_gen_draft(
}
}

LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size());
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());

llama_tokens result;
result.reserve(params.n_draft);

if (reuse_n == 0) {
llama_memory_clear(mem, false);

prompt.clear();
llama_memory_clear(mem_dft, false);
prompt_dft.clear();
} else {
// this happens when a previous draft has been discarded (for example, due to being too small), but the
// target model agreed with it. in this case, we simply pass back the previous results to save compute
if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) {
result.push_back(prompt[i]);
if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
result.push_back(prompt_dft[i]);

if (params.n_draft <= (int) result.size()) {
break;
Expand All @@ -194,16 +268,15 @@ llama_tokens common_speculative_gen_draft(
}

if (reuse_i > 0) {
llama_memory_seq_rm (mem, 0, 0, reuse_i);
llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i);
llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);

prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
}

if (reuse_n < (int) prompt.size()) {
llama_memory_seq_rm (mem, 0, reuse_n, -1);

prompt.erase(prompt.begin() + reuse_n, prompt.end());
if (reuse_n < (int) prompt_dft.size()) {
llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
}
}

Expand All @@ -214,42 +287,42 @@ llama_tokens common_speculative_gen_draft(
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);

prompt.push_back(prompt_tgt[i]);
prompt_dft.push_back(prompt_tgt[i]);
}

// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());

llama_decode(ctx, batch);
llama_decode(ctx_dft, batch);
}

const llama_pos n_past = prompt.size();
const llama_pos n_past = prompt_dft.size();

LOG_DBG("%s: n_past = %d\n", __func__, n_past);

common_batch_clear(batch);
common_batch_add (batch, id_last, n_past, { 0 }, true);

prompt.push_back(id_last);
prompt_dft.push_back(id_last);

//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());

llama_decode(ctx, batch);
llama_decode(ctx_dft, batch);

common_sampler_reset(smpl);

// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch);

common_sampler_sample(smpl, ctx, 0, true);
common_sampler_sample(smpl, ctx_dft, 0, true);

const auto * cur_p = common_sampler_get_candidates(smpl);

for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}

// add drafted token for each sequence
Expand All @@ -271,10 +344,16 @@ llama_tokens common_speculative_gen_draft(
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);

// evaluate the drafted tokens on the draft model
llama_decode(ctx, batch);
llama_decode(ctx_dft, batch);

prompt.push_back(id);
prompt_dft.push_back(id);
}

if (!spec->vocab_dft_compatible) {
std::string detokenized = common_detokenize(ctx_dft, result, true);
detokenized = replace_to_tgt(spec, detokenized);
LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str());
result = common_tokenize(ctx_tgt, detokenized, false, true);
}
return result;
}
Loading