Skip to content

Commit

Permalink
[Serving] Apply tree structure in draft token verification (#2563)
Browse files Browse the repository at this point in the history
This adds the interface to draft token state and sampler to allow tree
structure being recorded and used for verification
  • Loading branch information
vinx13 committed Jun 12, 2024
1 parent 873827c commit dcece51
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 39 deletions.
3 changes: 2 additions & 1 deletion cpp/serve/engine_actions/batch_draft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ class BatchDraftActionObj : public EngineActionObj {
models_[model_id]->ScatterDraftProbs(probs_on_device, draft_token_slots_,
&model_workspaces_[0].draft_probs_storage);
for (int i = 0; i < num_rsentries; ++i) {
mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]);
int64_t parent_idx = static_cast<int64_t>(mstates[i]->draft_output_tokens.size()) - 1;
mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i], parent_idx);
}

auto tdraft_end = std::chrono::high_resolution_clock::now();
Expand Down
16 changes: 5 additions & 11 deletions cpp/serve/engine_actions/batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class BatchVerifyActionObj : public EngineActionObj {
Array<GenerationConfig> generation_cfg;
std::vector<RandomGenerator*> rngs;
std::vector<std::vector<SampleResult>> draft_output_tokens;
std::vector<int64_t> token_tree_parent_ptr;
token_tree_parent_ptr.reserve(total_verify_length);
request_internal_ids.reserve(num_rsentries);
all_tokens_to_verify.reserve(total_verify_length);
verify_request_mstates.reserve(num_rsentries);
Expand All @@ -83,9 +85,11 @@ class BatchVerifyActionObj : public EngineActionObj {
// the last committed token + all the draft tokens.
draft_token_slots_.push_back(0); // placeholder for the last committed token
all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().GetTokenId());
token_tree_parent_ptr.push_back(-1);
for (int j = 0; j < static_cast<int>(draft_mstate->draft_output_tokens.size()); ++j) {
all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].GetTokenId());
draft_token_slots_.push_back(draft_mstate->draft_token_slots[j]);
token_tree_parent_ptr.push_back(draft_mstate->draft_token_parent_idx[j] + 1);
}
verify_request_mstates.push_back(verify_mstate);
generation_cfg.push_back(rsentries[i]->request->generation_cfg);
Expand All @@ -101,16 +105,6 @@ class BatchVerifyActionObj : public EngineActionObj {
{IntTuple{all_tokens_to_verify.begin(), all_tokens_to_verify.end()}});
RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding");

// Construct the token tree. Right now only chains are supported.
std::vector<int64_t> token_tree_parent_ptr;
token_tree_parent_ptr.reserve(total_verify_length);
for (int i = 0; i < num_rsentries; ++i) {
for (int pos = 0; pos < verify_lengths[i]; ++pos) {
token_tree_parent_ptr.push_back(pos - 1);
}
}
ICHECK_EQ(token_tree_parent_ptr.size(), total_verify_length);

RECORD_EVENT(trace_recorder_, request_ids, "start verify");
NDArray logits = models_[verify_model_id_]->BatchVerify(embeddings, request_internal_ids,
verify_lengths, token_tree_parent_ptr);
Expand Down Expand Up @@ -140,7 +134,7 @@ class BatchVerifyActionObj : public EngineActionObj {
std::vector<std::vector<SampleResult>> sample_results_arr =
sampler_->BatchVerifyDraftTokensWithProbAfterTopP(
renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs,
draft_output_tokens, draft_probs_on_device);
draft_output_tokens, token_tree_parent_ptr, draft_probs_on_device);
ICHECK_EQ(sample_results_arr.size(), num_rsentries);

// We collect the requests whose drafts are fully accepted.
Expand Down
3 changes: 2 additions & 1 deletion cpp/serve/engine_actions/eagle_batch_draft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ class EagleBatchDraftActionObj : public EngineActionObj {
&model_workspaces_[0].draft_probs_storage);
// No need to save hidden states as they are not used by subsequent engine actions
for (int i = 0; i < num_rsentries; ++i) {
mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]);
int64_t parent_idx = static_cast<int64_t>(mstates[i]->draft_output_tokens.size()) - 1;
mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i], parent_idx);
}

auto tdraft_end = std::chrono::high_resolution_clock::now();
Expand Down
20 changes: 8 additions & 12 deletions cpp/serve/engine_actions/eagle_batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
Array<GenerationConfig> generation_cfg;
std::vector<RandomGenerator*> rngs;
std::vector<std::vector<SampleResult>> draft_output_tokens;
std::vector<int64_t> token_tree_parent_ptr;
request_internal_ids.reserve(num_rsentries);
all_tokens_to_verify.reserve(total_draft_length);
token_tree_parent_ptr.reserve(total_draft_length);
verify_request_mstates.reserve(num_rsentries);
rngs.reserve(num_rsentries);
generation_cfg.reserve(num_rsentries);
Expand All @@ -83,9 +85,12 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
// the last committed token + all the draft tokens but the last one.
all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().GetTokenId());
draft_token_slots_.push_back(0); // placeholder for the last committed token
token_tree_parent_ptr.push_back(-1);

for (int j = 0; j < static_cast<int>(draft_mstate->draft_output_tokens.size()); ++j) {
all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].GetTokenId());
draft_token_slots_.push_back(draft_mstate->draft_token_slots[j]);
token_tree_parent_ptr.push_back(draft_mstate->draft_token_parent_idx[j] + 1);
}
verify_request_mstates.push_back(verify_mstate);
generation_cfg.push_back(rsentries[i]->request->generation_cfg);
Expand All @@ -111,16 +116,6 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
{IntTuple{all_tokens_to_verify.begin(), all_tokens_to_verify.end()}});
RECORD_EVENT(trace_recorder_, request_ids, "finish verify embedding");

// Construct the token tree. Right now only chains are supported.
std::vector<int64_t> token_tree_parent_ptr;
token_tree_parent_ptr.reserve(cum_verify_lengths.back());
for (int i = 0; i < num_rsentries; ++i) {
for (int pos = 0; pos < verify_lengths[i]; ++pos) {
token_tree_parent_ptr.push_back(pos - 1);
}
}
ICHECK_EQ(token_tree_parent_ptr.size(), cum_verify_lengths.back());

RECORD_EVENT(trace_recorder_, request_ids, "start verify");
ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden(
embeddings, request_internal_ids, verify_lengths, token_tree_parent_ptr);
Expand All @@ -143,7 +138,7 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
std::vector<std::vector<SampleResult>> sample_results_arr =
sampler_->BatchVerifyDraftTokensWithProbAfterTopP(
renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs,
draft_output_tokens, draft_probs_on_device);
draft_output_tokens, token_tree_parent_ptr, draft_probs_on_device);
ICHECK_EQ(sample_results_arr.size(), num_rsentries);

// We collect the requests whose drafts are fully accepted.
Expand Down Expand Up @@ -398,7 +393,8 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
&model_workspaces_[0].draft_hidden_states_storage);
}
for (int i = 0; i < static_cast<int>(mstates.size()); ++i) {
mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i]);
int64_t parent_idx = static_cast<int64_t>(mstates[i]->draft_output_tokens.size()) - 1;
mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i], parent_idx);
}
}
/*!
Expand Down
6 changes: 5 additions & 1 deletion cpp/serve/engine_actions/eagle_new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,12 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj {
&model_workspaces_[0].draft_hidden_states_storage);
}
for (int i = 0; i < static_cast<int>(rsentries_for_sample.size()); ++i) {
int parent_idx =
rsentries_for_sample[i]->mstates[model_id]->draft_output_tokens.empty()
? -1
: rsentries_for_sample[i]->mstates[model_id]->draft_output_tokens.size() - 1;
rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(
sample_results[i], draft_token_slots_[sample_indices[i]]);
sample_results[i], draft_token_slots_[sample_indices[i]], parent_idx);
}
}

Expand Down
7 changes: 5 additions & 2 deletions cpp/serve/logit_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,9 @@ class LogitProcessorImpl : public LogitProcessorObj {
p_penalties[num_token_for_penalty * 3 + 2] = generation_cfg[i]->repetition_penalty;
++num_token_for_penalty;
if (j > 0) {
mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], /*draft_token_slot=*/-1);
// Assume chain-style token tree.
mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], /*draft_token_slot=*/-1,
j - 1 - 1);
}
}
if (num_token_to_process != 1) {
Expand Down Expand Up @@ -379,7 +381,8 @@ class LogitProcessorImpl : public LogitProcessorObj {
p_seq_ids[token_start_offset + j] = 1;
}
if (j > 0) {
mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], /*draft_token_slot=*/-1);
// Assume chain-style token tree.
mstates[i]->AddDraftToken(draft_tokens->at(i)[j - 1], /*draft_token_slot=*/-1, j - 1 - 1);
}
}
if (token_number != 1) {
Expand Down
5 changes: 4 additions & 1 deletion cpp/serve/request_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,19 @@ void RequestModelStateNode::RollbackTokens(int count) {
}
}

void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, int draft_token_slot) {
void RequestModelStateNode::AddDraftToken(SampleResult sampled_token, int draft_token_slot,
int64_t parent_idx) {
draft_output_tokens.push_back(std::move(sampled_token));
draft_token_slots.push_back(draft_token_slot);
draft_token_parent_idx.push_back(parent_idx);
appeared_token_ids[sampled_token.GetTokenId()] += 1;
}

void RequestModelStateNode::RemoveLastDraftToken() {
ICHECK(!draft_output_tokens.empty());
auto it = appeared_token_ids.find(draft_output_tokens.back().GetTokenId());
draft_output_tokens.pop_back();
draft_token_parent_idx.pop_back();
CHECK(it != appeared_token_ids.end());
if (--it->second == 0) {
appeared_token_ids.erase(it);
Expand Down
4 changes: 3 additions & 1 deletion cpp/serve/request_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class RequestModelStateNode : public Object {
std::vector<SampleResult> draft_output_tokens;
/*! \brief The storage slots for the associated states of draft tokens. */
std::vector<int> draft_token_slots;
/*! \brief The parent indices of the draft tokens. */
std::vector<int64_t> draft_token_parent_idx;
/*! \brief The appeared committed and draft tokens and their occurrence times. */
std::unordered_map<int32_t, int32_t> appeared_token_ids;

Expand Down Expand Up @@ -106,7 +108,7 @@ class RequestModelStateNode : public Object {
void RollbackTokens(int count);

/*! \brief Add a draft token into draft_output_tokens. Update appeared_token_ids. */
void AddDraftToken(SampleResult sampled_token, int draft_token_slot);
void AddDraftToken(SampleResult sampled_token, int draft_token_slot, int64_t parent_idx);
/*! \brief Remove all draft tokens from draft_output_tokens. Update appeared_token_ids. */
void RemoveAllDraftTokens(std::vector<int>* removed_draft_token_slots = nullptr);

Expand Down
8 changes: 7 additions & 1 deletion cpp/serve/sampler/cpu_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ class CPUSampler : public SamplerObj {
const std::vector<int>& cum_verify_lengths, const Array<GenerationConfig>& generation_cfg,
const std::vector<RandomGenerator*>& rngs,
const std::vector<std::vector<SampleResult>>& draft_output_tokens,
NDArray draft_probs_on_device) final {
const std::vector<int64_t>& token_tree_parent_ptr, NDArray draft_probs_on_device) final {
// probs_on_host: (n, v)
RECORD_EVENT(trace_recorder_, request_ids, "start draft verification");
CHECK_EQ(probs_on_host->ndim, 2);
Expand All @@ -435,6 +435,12 @@ class CPUSampler : public SamplerObj {
int verify_start = cum_verify_lengths[i];
int verify_end = cum_verify_lengths[i + 1];

CHECK_EQ(token_tree_parent_ptr[verify_start], -1);
for (int j = verify_start + 1; j < verify_end; ++j) {
CHECK_EQ(token_tree_parent_ptr[j], j - verify_start)
<< "CPU sampler only supports chain-style draft tokens.";
}

int cur_token_idx = 0;
// Sub 1 to ignore the last prediction.
for (; cur_token_idx < verify_end - verify_start - 1; ++cur_token_idx) {
Expand Down
22 changes: 15 additions & 7 deletions cpp/serve/sampler/gpu_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class GPUSampler : public SamplerObj {
const std::vector<int>& cum_verify_lengths, const Array<GenerationConfig>& generation_cfg,
const std::vector<RandomGenerator*>& rngs,
const std::vector<std::vector<SampleResult>>& draft_output_tokens,
NDArray draft_probs_on_device) final {
const std::vector<int64_t>& token_tree_parent_ptr, NDArray draft_probs_on_device) final {
NVTXScopedRange nvtx_scope("BatchVerifyDraftTokensWithProbAfterTopP");
std::vector<std::vector<SampleResult>> sample_results;
// probs_on_device: (n, v)
Expand Down Expand Up @@ -252,21 +252,29 @@ class GPUSampler : public SamplerObj {
token_tree_parent_ptr_device_.CreateView({num_sequence}, dtype_i32_);
std::vector<int> token_tree_child_to_parent(/*n=*/num_nodes);

int* token_tree_first_child_ptr_host = static_cast<int*>(token_tree_first_child_host->data);
int* token_tree_next_sibling_ptr_host = static_cast<int*>(token_tree_next_sibling_host->data);
// Build the tree structure on CPU
for (int i = 0; i < num_sequence; i++) {
// Assuming no tree structure for now
int start = cum_verify_lengths[i];
int end = cum_verify_lengths[i + 1];
ICHECK_GE(end - start, 2);
token_tree_child_to_parent[start] = -1; // root has no parent
for (int j = 0; j < end - start; j++) {
int cur_node = j + start;
int child_node = j + 1 >= end - start ? -1 : cur_node + 1;
static_cast<int*>(token_tree_first_child_host->data)[cur_node] = child_node;
if (child_node != -1) {
token_tree_child_to_parent[child_node] = cur_node;
int parent_node =
token_tree_parent_ptr[cur_node] != -1 ? token_tree_parent_ptr[cur_node] + start : -1;
token_tree_first_child_ptr_host[cur_node] = -1;
if (parent_node != -1 && token_tree_first_child_ptr_host[parent_node] == -1) {
token_tree_first_child_ptr_host[parent_node] = cur_node;
}
token_tree_child_to_parent[cur_node] = parent_node;
if (cur_node + 1 < end && token_tree_parent_ptr[cur_node - start + 1] ==
token_tree_parent_ptr[cur_node - start]) {
token_tree_next_sibling_ptr_host[cur_node] = cur_node + 1;
} else {
token_tree_next_sibling_ptr_host[cur_node] = -1;
}
static_cast<int*>(token_tree_next_sibling_host->data)[cur_node] = -1;
}
static_cast<int*>(token_tree_parent_ptr_host->data)[i] = start; // point to the root
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/serve/sampler/sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class SamplerObj : public Object {
* \param rngs The random number generator of each sequence.
* \param draft_output_tokens The draft tokens generated by the small model for
* each sequence.
* \param token_tree_parent_ptr The parent pointer of the token tree.
* \param draft_probs_on_device The probability distribution computed from the
* small model for each sequence. Concatenated tensor of shape (total_verify_length, vocab_size).
* It includes the slot for the last committed token that has undefined probablity value.
Expand All @@ -115,7 +116,7 @@ class SamplerObj : public Object {
NDArray probs, const Array<String>& request_ids, const std::vector<int>& cum_verify_lengths,
const Array<GenerationConfig>& generation_cfg, const std::vector<RandomGenerator*>& rngs,
const std::vector<std::vector<SampleResult>>& draft_output_tokens,
NDArray draft_probs_on_device) = 0;
const std::vector<int64_t>& token_tree_parent_ptr, NDArray draft_probs_on_device) = 0;

static constexpr const char* _type_key = "mlc.serve.Sampler";
static constexpr const bool _type_has_method_sequal_reduce = false;
Expand Down

0 comments on commit dcece51

Please sign in to comment.