diff --git a/python/sglang/srt/layers/attention/minicpm_backend.py b/python/sglang/srt/layers/attention/minicpm_backend.py index 70683277c0ce..473dd6841010 100644 --- a/python/sglang/srt/layers/attention/minicpm_backend.py +++ b/python/sglang/srt/layers/attention/minicpm_backend.py @@ -1756,6 +1756,19 @@ def init_forward_metadata_replay_cuda_graph( kv_indptr_view = self.decode_cuda_graph_metadata[ "flashinfer_kv_indptr" ][: sparse_bs + 1] + kv_indptr_view[0] = 0 + if sparse_real_bs > 0: + actual_seqlens = metadata.sparse_cache_seqlens_int32[ + :sparse_real_bs + ].clone() + actual_seqlens = torch.clamp( + actual_seqlens, max=self.num_sparse_topk_tokens + ) + kv_indptr_view[1 : sparse_real_bs + 1] = torch.cumsum( + actual_seqlens, dim=0 + ) + kv_indptr_view[sparse_real_bs:].fill_(kv_indptr_view[sparse_real_bs]) + # kv_indices only needs num_sparse_topk_tokens per batch kv_indices_view = self.decode_cuda_graph_metadata[ "flashinfer_kv_indices" @@ -1763,7 +1776,6 @@ def init_forward_metadata_replay_cuda_graph( kv_last_page_len_view = self.decode_cuda_graph_metadata[ "flashinfer_kv_last_page_len" ][:sparse_bs] - kv_indptr_view[sparse_real_bs:].fill_(kv_indptr_view[-1]) kv_last_page_len_view[sparse_real_bs:].fill_(0) # Retrieve the wrapper stored during capture