Skip to content

Commit 3b68217

Browse files
authored
[Core] Add AttentionState abstraction (vllm-project#7663)
1 parent c6af027 commit 3b68217

16 files changed

+372
-247
lines changed

tests/worker/test_model_input.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
77
from vllm.attention.backends.abstract import AttentionBackend
8+
from vllm.attention.backends.utils import CommonAttentionState
89
from vllm.model_executor import SamplingMetadata
910
from vllm.model_executor.pooling_metadata import PoolingMetadata
1011
from vllm.worker.embedding_model_runner import (
@@ -29,7 +30,11 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
2930

3031
@staticmethod
3132
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
32-
raise AttentionMetadataBuilder
33+
return AttentionMetadataBuilder
34+
35+
@staticmethod
36+
def get_state_cls() -> Type["CommonAttentionState"]:
37+
return CommonAttentionState
3338

3439
@staticmethod
3540
def get_kv_cache_shape(

vllm/attention/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from vllm.attention.backends.abstract import (AttentionBackend,
22
AttentionMetadata,
33
AttentionMetadataBuilder,
4-
AttentionType)
4+
AttentionState, AttentionType)
55
from vllm.attention.layer import Attention
66
from vllm.attention.selector import get_attn_backend
77

@@ -12,5 +12,6 @@
1212
"AttentionType",
1313
"AttentionMetadataBuilder",
1414
"Attention",
15+
"AttentionState",
1516
"get_attn_backend",
1617
]

vllm/attention/backends/abstract.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from contextlib import contextmanager
23
from dataclasses import dataclass, fields
34
from enum import Enum, auto
45
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
@@ -7,7 +8,9 @@
78
import torch
89

910
if TYPE_CHECKING:
10-
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase
11+
from vllm.worker.model_runner_base import (ModelRunnerBase,
12+
ModelRunnerInputBase,
13+
ModelRunnerInputBuilderBase)
1114

1215

1316
class AttentionType(Enum):
@@ -34,6 +37,11 @@ def get_impl_cls() -> Type["AttentionImpl"]:
3437
def get_metadata_cls() -> Type["AttentionMetadata"]:
3538
raise NotImplementedError
3639

40+
@staticmethod
41+
@abstractmethod
42+
def get_state_cls() -> Type["AttentionState"]:
43+
raise NotImplementedError
44+
3745
@classmethod
3846
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
3947
return cls.get_metadata_cls()(*args, **kwargs)
@@ -126,6 +134,47 @@ def asdict_zerocopy(self,
126134
T = TypeVar("T", bound=AttentionMetadata)
127135

128136

137+
class AttentionState(ABC, Generic[T]):
138+
"""Holds attention backend-specific objects reused during the
139+
lifetime of the model runner."""
140+
141+
@abstractmethod
142+
def __init__(self, runner: "ModelRunnerBase"):
143+
...
144+
145+
@abstractmethod
146+
@contextmanager
147+
def graph_capture(self, max_batch_size: int):
148+
"""Context manager used when capturing CUDA graphs."""
149+
yield
150+
151+
@abstractmethod
152+
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
153+
"""Clone attention state to save in CUDA graph metadata."""
154+
...
155+
156+
@abstractmethod
157+
def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T:
158+
"""Get attention metadata for CUDA graph capture of batch_size."""
159+
...
160+
161+
@abstractmethod
162+
def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]:
163+
"""Get attention-specific input buffers for CUDA graph capture."""
164+
...
165+
166+
@abstractmethod
167+
def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any],
168+
attn_metadata: T) -> None:
169+
"""In-place modify input buffers dict for CUDA graph replay."""
170+
...
171+
172+
@abstractmethod
173+
def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
174+
"""Prepare state for forward pass."""
175+
...
176+
177+
129178
class AttentionMetadataBuilder(ABC, Generic[T]):
130179
"""Abstract class for attention metadata builders."""
131180

vllm/attention/backends/blocksparse_attn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
77
AttentionMetadata, AttentionType)
8-
from vllm.attention.backends.utils import CommonMetadataBuilder
8+
from vllm.attention.backends.utils import (CommonAttentionState,
9+
CommonMetadataBuilder)
910
from vllm.attention.ops.blocksparse_attention.interface import (
1011
LocalStridedBlockSparseAttn, get_head_sliding_step)
1112
from vllm.attention.ops.paged_attn import PagedAttention
@@ -98,6 +99,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
9899
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
99100
return BlocksparseFlashAttentionMetadataBuilder
100101

102+
@staticmethod
103+
def get_state_cls() -> Type["CommonAttentionState"]:
104+
return CommonAttentionState
105+
101106
@staticmethod
102107
def get_kv_cache_shape(
103108
num_blocks: int,

vllm/attention/backends/flash_attn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
AttentionMetadata,
1010
AttentionMetadataBuilder,
1111
AttentionType)
12-
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
12+
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
13+
compute_slot_mapping,
1314
compute_slot_mapping_start_idx,
1415
is_block_tables_empty)
1516
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
@@ -142,6 +143,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
142143
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
143144
return FlashAttentionMetadataBuilder
144145

146+
@staticmethod
147+
def get_state_cls() -> Type["CommonAttentionState"]:
148+
return CommonAttentionState
149+
145150
@staticmethod
146151
def get_kv_cache_shape(
147152
num_blocks: int,

vllm/attention/backends/flashinfer.py

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,27 @@
1+
from contextlib import contextmanager
12
from dataclasses import dataclass
23
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
34

45
try:
56
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
7+
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
68
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
79

810
import vllm.attention.backends.flash_attn # noqa
11+
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
912
except ImportError:
1013
BatchDecodeWithPagedKVCacheWrapper = None
14+
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
1115
BatchPrefillWithPagedKVCacheWrapper = None
16+
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
1217

1318
import torch
1419

1520
from vllm import _custom_ops as ops
1621
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1722
AttentionMetadata,
1823
AttentionMetadataBuilder,
19-
AttentionType)
24+
AttentionState, AttentionType)
2025
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
2126
compute_slot_mapping_start_idx,
2227
is_block_tables_empty)
@@ -46,6 +51,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
4651
def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
4752
return FlashInferMetadataBuilder
4853

54+
@staticmethod
55+
def get_state_cls() -> Type["FlashInferState"]:
56+
return FlashInferState
57+
4958
@staticmethod
5059
def get_kv_cache_shape(
5160
num_blocks: int,
@@ -75,6 +84,160 @@ def get_supported_head_sizes() -> List[int]:
7584
return [64, 128, 256]
7685

7786

87+
class FlashInferState(AttentionState):
88+
89+
def __init__(self, runner):
90+
self.runner = runner
91+
self._is_graph_capturing = False
92+
self._workspace_buffer = None
93+
self._decode_wrapper = None
94+
self._prefill_wrapper = None
95+
96+
def _get_workspace_buffer(self):
97+
if self._workspace_buffer is None:
98+
self._workspace_buffer = torch.empty(
99+
FLASHINFER_WORKSPACE_BUFFER_SIZE,
100+
dtype=torch.uint8,
101+
device=self.runner.device)
102+
return self._workspace_buffer
103+
104+
def _get_prefill_wrapper(self):
105+
if self._prefill_wrapper is None:
106+
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
107+
self._get_workspace_buffer(), "NHD")
108+
return self._prefill_wrapper
109+
110+
def _get_decode_wrapper(self):
111+
if self._decode_wrapper is None:
112+
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
113+
self.runner.parallel_config))
114+
num_kv_heads = self.runner.model_config.get_num_kv_heads(
115+
self.runner.parallel_config)
116+
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
117+
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
118+
self._get_workspace_buffer(),
119+
"NHD",
120+
use_tensor_cores=use_tensor_cores)
121+
return self._decode_wrapper
122+
123+
@contextmanager
124+
def graph_capture(self, max_batch_size: int):
125+
self._is_graph_capturing = True
126+
self._graph_decode_wrapper = None
127+
self._graph_slot_mapping = torch.full((max_batch_size, ),
128+
PAD_SLOT_ID,
129+
dtype=torch.long,
130+
device=self.runner.device)
131+
self._graph_seq_lens = torch.ones(max_batch_size,
132+
dtype=torch.int32,
133+
device=self.runner.device)
134+
self._graph_block_tables = torch.from_numpy(
135+
self.runner.graph_block_tables).to(device=self.runner.device)
136+
self._graph_decode_workspace_buffer = self._get_workspace_buffer()
137+
self._graph_indices_buffer = torch.empty(
138+
max_batch_size * self.runner.cache_config.num_gpu_blocks,
139+
dtype=torch.int32,
140+
device=self.runner.device)
141+
self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
142+
dtype=torch.int32,
143+
device=self.runner.device)
144+
self._graph_last_page_len_buffer = torch.empty(
145+
max_batch_size, dtype=torch.int32, device=self.runner.device)
146+
yield
147+
self._is_graph_capturing = False
148+
del self._graph_slot_mapping
149+
del self._graph_seq_lens
150+
del self._graph_block_tables
151+
del self._graph_decode_workspace_buffer
152+
del self._graph_indices_buffer
153+
del self._graph_indptr_buffer
154+
del self._graph_last_page_len_buffer
155+
del self._graph_decode_wrapper
156+
157+
def graph_clone(self, batch_size: int):
158+
assert self._is_graph_capturing
159+
state = self.__class__(self.runner)
160+
state._workspace_buffer = self._graph_decode_workspace_buffer
161+
state._decode_wrapper = self._graph_decode_wrapper
162+
state._prefill_wrapper = self._get_prefill_wrapper()
163+
return state
164+
165+
def graph_capture_get_metadata_for_batch(self, batch_size: int):
166+
assert self._is_graph_capturing
167+
_indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
168+
_last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
169+
170+
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
171+
self.runner.parallel_config))
172+
num_kv_heads = self.runner.model_config.get_num_kv_heads(
173+
self.runner.parallel_config)
174+
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
175+
self._graph_decode_wrapper = \
176+
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
177+
self._graph_decode_workspace_buffer, _indptr_buffer,
178+
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
179+
use_tensor_cores)
180+
kv_cache_dtype = get_kv_cache_torch_dtype(
181+
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
182+
183+
paged_kv_indptr_tensor_host = torch.arange(0,
184+
batch_size + 1,
185+
dtype=torch.int32)
186+
paged_kv_indices_tensor_host = torch.arange(0,
187+
batch_size,
188+
dtype=torch.int32)
189+
paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
190+
self.runner.block_size,
191+
dtype=torch.int32)
192+
query_start_loc_host = torch.arange(0,
193+
batch_size + 1,
194+
dtype=torch.int32)
195+
196+
attn_metadata = self.runner.attn_backend.make_metadata(
197+
num_prefills=0,
198+
slot_mapping=self._graph_slot_mapping[:batch_size],
199+
num_prefill_tokens=0,
200+
num_decode_tokens=batch_size,
201+
max_prefill_seq_len=0,
202+
block_tables=self._graph_block_tables,
203+
paged_kv_indptr=paged_kv_indptr_tensor_host,
204+
paged_kv_indices=paged_kv_indices_tensor_host,
205+
paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
206+
num_qo_heads=num_qo_heads,
207+
num_kv_heads=num_kv_heads,
208+
head_dim=self.runner.model_config.get_head_size(),
209+
page_size=self.runner.block_size,
210+
seq_start_loc=None,
211+
query_start_loc=query_start_loc_host,
212+
device=self.runner.device,
213+
data_type=kv_cache_dtype,
214+
use_cuda_graph=True,
215+
decode_wrapper=self._graph_decode_wrapper,
216+
prefill_wrapper=None)
217+
attn_metadata.begin_forward()
218+
return attn_metadata
219+
220+
def get_graph_input_buffers(self, attn_metadata):
221+
return {
222+
"slot_mapping": attn_metadata.slot_mapping,
223+
}
224+
225+
def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
226+
return
227+
228+
def begin_forward(self, model_input):
229+
assert not self._is_graph_capturing
230+
state = self
231+
if model_input.attn_metadata.use_cuda_graph:
232+
batch_size = model_input.input_tokens.shape[0]
233+
state = (self.runner.graph_runners[model_input.virtual_engine]
234+
[batch_size].attn_state)
235+
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
236+
)
237+
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
238+
model_input.attn_metadata.begin_forward()
239+
240+
78241
@dataclass
79242
class FlashInferMetadata(AttentionMetadata):
80243
# Maximum sequence length among prefill batch. 0 if there are decoding

vllm/attention/backends/ipex_attn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from vllm._ipex_ops import ipex_ops
99
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1010
AttentionMetadata, AttentionType)
11+
from vllm.attention.backends.utils import CommonAttentionState
1112
from vllm.attention.ops.paged_attn import (PagedAttention,
1213
PagedAttentionMetadata)
1314

@@ -28,6 +29,10 @@ def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
2829
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
2930
return IpexAttnMetadata
3031

32+
@staticmethod
33+
def get_state_cls() -> Type["CommonAttentionState"]:
34+
return CommonAttentionState
35+
3136
@staticmethod
3237
def get_kv_cache_shape(
3338
num_blocks: int,

0 commit comments

Comments
 (0)