diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index a6fcf2e570..88c88c0724 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -550,8 +550,37 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten triton.next_power_of_2(max_draft_len), ) - accept_index = accept_index[accept_index != -1] + draft_input = EAGLEDraftInput() + new_accept_index = [] + unfinished_index = [] + finished_extend_len = {} # {rid:accept_length + 1} + accept_index_cpu = accept_index.tolist() + predict_cpu = predict.tolist() + # iterate every accepted token and check if req has finished after append the token + # should be checked BEFORE free kv cache slots + for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): + new_accept_index_ = [] + for j, idx in enumerate(accept_index_row): + if idx == -1: + break + id = predict_cpu[idx] + # if not found_finished: + req.output_ids.append(id) + finished_extend_len[req.rid] = j + 1 + req.check_finished() + if req.finished(): + draft_input.has_finished = True + # set all tokens after finished token to -1 and break + accept_index[i, j + 1 :] = -1 + break + else: + new_accept_index_.append(idx) + if not req.finished(): + new_accept_index.extend(new_accept_index_) + unfinished_index.append(i) + accept_length = (accept_index != -1).sum(dim=1) - 1 + accept_index = accept_index[accept_index != -1] accept_length_cpu = accept_length.tolist() verified_id = predict[accept_index] verified_id_cpu = verified_id.tolist() @@ -570,26 +599,9 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten triton.next_power_of_2(bs), ) batch.seq_lens.add_(accept_length + 1) - new_accept_index = [] - unfinished_index = [] - finished_extend_len = {} # {rid:accept_length + 1} - # retracted_reqs, new_token_ratio = batch.retract_decode() - - low = 0 - draft_input = EAGLEDraftInput() - for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)): - req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1]) - req.check_finished() - if req.finished(): - draft_input.has_finished = True - else: - new_accept_index.append(accept_index[low : low + verified_len + 1]) - unfinished_index.append(i) - low += verified_len + 1 - finished_extend_len[req.rid] = verified_len + 1 if len(new_accept_index) > 0: - new_accept_index = torch.cat(new_accept_index, dim=0) + new_accept_index = torch.tensor(new_accept_index, device="cuda") draft_input.verified_id = predict[new_accept_index] draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index] draft_input.accept_length = accept_length[unfinished_index] diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 609d4411d7..94ebc79ca7 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -1,5 +1,7 @@ import unittest +from transformers import AutoConfig, AutoTokenizer + import sglang as sgl @@ -34,6 +36,33 @@ def test_eagle_accuracy(self): print(out2) self.assertEqual(out1, out2) + def test_eagle_end_check(self): + prompt = "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like [/INST]" + target_model_path = "meta-llama/Llama-2-7b-chat-hf" + tokenizer = AutoTokenizer.from_pretrained(target_model_path) + speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B" + + sampling_params = { + "temperature": 0, + "max_new_tokens": 1024, + "skip_special_tokens": False, + } + + engine = sgl.Engine( + model_path=target_model_path, + speculative_draft_model_path=speculative_draft_model_path, + speculative_algorithm="EAGLE", + speculative_num_steps=3, + speculative_eagle_topk=4, + speculative_num_draft_tokens=16, + ) + out1 = engine.generate(prompt, sampling_params)["text"] + engine.shutdown() + print("==== Answer 1 ====") + print(repr(out1)) + tokens = tokenizer.encode(out1, truncation=False) + assert tokenizer.eos_token_id not in tokens + if __name__ == "__main__": unittest.main()