Skip to content

Commit 5c3b94d

Browse files
authored
spec decode: move ops.advane_step to flash attention backend (aphrodite-engine#1005)
1 parent 135dfd6 commit 5c3b94d

File tree

3 files changed

+21
-33
lines changed

3 files changed

+21
-33
lines changed

aphrodite/attention/backends/flash_attn.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
1919

2020
if TYPE_CHECKING:
21-
from aphrodite.worker.model_runner import ModelInputForGPUBuilder
21+
from aphrodite.worker.model_runner import (ModelInputForGPUBuilder,
22+
ModelInputForGPUWithSamplingMetadata)
2223

2324
from aphrodite_flash_attn import (
2425
flash_attn_varlen_func as _flash_attn_varlen_func)
@@ -305,13 +306,12 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
305306
)
306307
return self._cached_decode_metadata
307308

308-
def advance_step(self, num_seqs: int, num_queries: int):
309+
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
310+
sampled_token_ids: Optional[torch.Tensor],
311+
block_size: int, num_seqs: int, num_queries: int):
309312
"""
310313
Update metadata in-place to advance one decode step.
311314
"""
312-
# GPU in-place update is currently called separately through
313-
# custom_ops.advance_step(). See draft_model_runner.
314-
# TODO: Move this logic to the backend.
315315

316316
# When using cudagraph, the num_seqs is padded to the next captured
317317
# batch sized, but num_queries tracks the actual number of requests in
@@ -350,6 +350,16 @@ def advance_step(self, num_seqs: int, num_queries: int):
350350
self.seq_lens[i] += 1
351351
self.max_decode_seq_len = max(self.seq_lens)
352352

353+
ops.advance_step(num_seqs=num_seqs,
354+
num_queries=num_queries,
355+
block_size=block_size,
356+
input_tokens=model_input.input_tokens,
357+
sampled_token_ids=sampled_token_ids,
358+
input_positions=model_input.input_positions,
359+
seq_lens=self.seq_lens_tensor,
360+
slot_mapping=self.slot_mapping,
361+
block_tables=self.block_tables)
362+
353363

354364
class FlashAttentionMetadataBuilder(
355365
AttentionMetadataBuilder[FlashAttentionMetadata]):

aphrodite/spec_decode/draft_model_runner.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import torch
44
from loguru import logger
55

6-
from aphrodite import _custom_ops as ops
7-
86
try:
97
from aphrodite.attention.backends.flash_attn import FlashAttentionMetadata
108
except ModuleNotFoundError:
@@ -114,18 +112,8 @@ def _gpu_advance_step(
114112
# Update attn_metadata
115113
attn_metadata = model_input.attn_metadata
116114
assert isinstance(attn_metadata, FlashAttentionMetadata)
117-
attn_metadata.advance_step(num_seqs, num_queries)
118-
119-
# Update GPU tensors
120-
ops.advance_step(num_seqs=num_seqs,
121-
num_queries=num_queries,
122-
block_size=self.block_size,
123-
input_tokens=model_input.input_tokens,
124-
sampled_token_ids=sampled_token_ids,
125-
input_positions=model_input.input_positions,
126-
seq_lens=attn_metadata.seq_lens_tensor,
127-
slot_mapping=attn_metadata.slot_mapping,
128-
block_tables=attn_metadata.block_tables)
115+
attn_metadata.advance_step(model_input, sampled_token_ids,
116+
self.block_size, num_seqs, num_queries)
129117

130118
# Update sampling_metadata
131119
sampling_metadata = model_input.sampling_metadata

aphrodite/worker/multi_step_model_runner.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import torch
1616

17-
from aphrodite import _custom_ops as ops
1817
from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
1918
IntermediateTensors, Logprob,
2019
SequenceGroupMetadata, SequenceOutput)
@@ -490,19 +489,10 @@ def _advance_step(
490489
assert num_seqs >= num_queries
491490
attn_metadata = frozen_model_input.attn_metadata
492491
assert isinstance(attn_metadata, FlashAttentionMetadata)
493-
attn_metadata.advance_step(num_seqs, num_queries)
494-
# Update GPU tensors
495-
ops.advance_step(
496-
num_seqs=num_seqs,
497-
num_queries=num_queries,
498-
block_size=self.block_size,
499-
input_tokens=frozen_model_input.input_tokens,
500-
sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids,
501-
input_positions=frozen_model_input.input_positions,
502-
seq_lens=attn_metadata.seq_lens_tensor,
503-
slot_mapping=attn_metadata.slot_mapping,
504-
block_tables=attn_metadata.block_tables,
505-
)
492+
attn_metadata.advance_step(
493+
frozen_model_input,
494+
model_input.cached_outputs[-1].sampled_token_ids, self.block_size,
495+
num_seqs, num_queries)
506496
if frozen_model_input.seq_lens is not None:
507497
for i in range(num_queries):
508498
frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i]

0 commit comments

Comments
 (0)