Skip to content

Commit 1f3ce65

Browse files
authored
[Feature] support mtp distribution equivalence verification (#4699)
1 parent 28de91b commit 1f3ce65

File tree

6 files changed

+257
-88
lines changed

6 files changed

+257
-88
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,8 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor& accept_tokens,
748748
const paddle::Tensor& stop_seqs_len,
749749
const paddle::Tensor& end_ids);
750750

751-
void SpeculateVerify(const paddle::Tensor& accept_tokens,
751+
void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
752+
const paddle::Tensor& accept_tokens,
752753
const paddle::Tensor& accept_num,
753754
const paddle::Tensor& step_idx,
754755
const paddle::Tensor& stop_flags,

0 commit comments

Comments
 (0)