Skip to content

Commit

Permalink
[Bugfix] Fix Mamba multistep (#11071)
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
tlrmchlsmth authored Dec 11, 2024
1 parent 134810b commit 9a93973
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
64 changes: 63 additions & 1 deletion vllm/attention/backends/placeholder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from vllm.multimodal import MultiModalPlaceholderMap

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)

# Placeholder attention backend for models like Mamba and embedding models that
# lack attention.
Expand Down Expand Up @@ -186,6 +187,67 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
)
return self._cached_decode_metadata

def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph

assert not turn_prefills_into_decodes, \
("Multi-Step + Chunked-Prefill is not supported for attention-free"
"models. turn_prefills_into_decodes is a "
"Multi-Step + Chunked-Prefill specific parameter.")

assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)

assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs

assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0

assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )

assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )

assert self.block_tables is not None

# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)

# Update sequences, masking off entries greater than num_queries
device = self.seq_lens_tensor.device
mask = torch.arange(self.seq_lens_tensor.size(0),
device=device) < num_queries
self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype)
if sampled_token_ids is not None:
model_input.input_tokens.masked_scatter_(
mask, sampled_token_ids[:num_queries])


class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
Expand Down
4 changes: 3 additions & 1 deletion vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@

logger = init_logger(__name__)

MULTI_STEP_ATTENTION_BACKENDS = ["FLASH_ATTN", "ROCM_FLASH", "FLASHINFER"]
MULTI_STEP_ATTENTION_BACKENDS = [
"FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION"
]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"]

def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
Expand Down

0 comments on commit 9a93973

Please sign in to comment.