Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[eagle2] fix end check when target model verify #2723

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions python/sglang/srt/speculative/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,8 +550,44 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten
triton.next_power_of_2(max_draft_len),
)

accept_index = accept_index[accept_index != -1]
draft_input = EAGLEDraftInput()
new_accept_index = []
unfinished_index = []
finished_extend_len = {} # {rid:accept_length + 1}
accept_index_cpu = accept_index.tolist()
predict_cpu = predict.tolist()
# iterate every accepted token and check if req has finished after append the token
# should be checked BEFORE free kv cache slots
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
# found_finished = False
new_accept_index_ = []
for j, idx in enumerate(accept_index_row):
if idx == -1:
break
id = predict_cpu[idx]
# if not found_finished:
req.output_ids.append(id)
finished_extend_len[req.rid] = j + 1
req.check_finished()
if req.finished():
draft_input.has_finished = True
jjjjohnson marked this conversation as resolved.
Show resolved Hide resolved
# set all tokens after finished token to -1 and break
accept_index[i, j + 1 :] = -1
break
# if not found_finished:
# finished_extend_len[req.rid] = j + 1
# found_finished = True
# else:
# # set all tokens after finished token to -1
# accept_index[i,j] = -1
else:
new_accept_index_.append(idx)
if not req.finished():
new_accept_index.extend(new_accept_index_)
unfinished_index.append(i)
accept_length = (accept_index != -1).sum(dim=1) - 1

accept_index = accept_index[accept_index != -1]
accept_length_cpu = accept_length.tolist()
verified_id = predict[accept_index]
verified_id_cpu = verified_id.tolist()
Expand All @@ -570,26 +606,9 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten
triton.next_power_of_2(bs),
)
batch.seq_lens.add_(accept_length + 1)
new_accept_index = []
unfinished_index = []
finished_extend_len = {} # {rid:accept_length + 1}
# retracted_reqs, new_token_ratio = batch.retract_decode()

low = 0
draft_input = EAGLEDraftInput()
for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)):
req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1])
req.check_finished()
if req.finished():
draft_input.has_finished = True
else:
new_accept_index.append(accept_index[low : low + verified_len + 1])
unfinished_index.append(i)
low += verified_len + 1
finished_extend_len[req.rid] = verified_len + 1

if len(new_accept_index) > 0:
new_accept_index = torch.cat(new_accept_index, dim=0)
new_accept_index = torch.tensor(new_accept_index, device="cuda")
draft_input.verified_id = predict[new_accept_index]
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
draft_input.accept_length = accept_length[unfinished_index]
Expand Down
25 changes: 24 additions & 1 deletion test/srt/test_eagle_infer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

import sglang as sgl

from transformers import AutoConfig, AutoTokenizer

class TestEAGLEEngine(unittest.TestCase):

Expand Down Expand Up @@ -34,6 +34,29 @@ def test_eagle_accuracy(self):
print(out2)
self.assertEqual(out1, out2)

def test_eagle_eos_token(self):
prompt = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]"
target_model_path = "meta-llama/Llama-2-7b-chat"
tokenizer = AutoTokenizer.from_pretrained(target_model_path)
speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B"

sampling_params = {"temperature": 0, "max_new_tokens": 1024}

engine = sgl.Engine(
model_path=target_model_path,
speculative_draft_model_path=speculative_draft_model_path,
speculative_algorithm="EAGLE",
speculative_num_steps=3,
speculative_eagle_topk=4,
speculative_num_draft_tokens=16,
)
out1 = engine.generate(prompt, sampling_params)["text"]
engine.shutdown()
print("==== Answer 1 ====")
print(repr(out1))
tokens = tokenizer.encode(out1, truncation=False)
assert tokenizer.eos_token_id not in tokens


if __name__ == "__main__":
unittest.main()
Loading