|
| 1 | +from contextlib import contextmanager |
1 | 2 | from dataclasses import dataclass
|
2 | 3 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
|
3 | 4 |
|
4 | 5 | try:
|
5 | 6 | from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
| 7 | + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper |
6 | 8 | from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
7 | 9 |
|
8 | 10 | import vllm.attention.backends.flash_attn # noqa
|
| 11 | + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 |
9 | 12 | except ImportError:
|
10 | 13 | BatchDecodeWithPagedKVCacheWrapper = None
|
| 14 | + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None |
11 | 15 | BatchPrefillWithPagedKVCacheWrapper = None
|
| 16 | + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 |
12 | 17 |
|
13 | 18 | import torch
|
14 | 19 |
|
15 | 20 | from vllm import _custom_ops as ops
|
16 | 21 | from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
17 | 22 | AttentionMetadata,
|
18 | 23 | AttentionMetadataBuilder,
|
19 |
| - AttentionType) |
| 24 | + AttentionState, AttentionType) |
20 | 25 | from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
21 | 26 | compute_slot_mapping_start_idx,
|
22 | 27 | is_block_tables_empty)
|
@@ -46,6 +51,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
|
46 | 51 | def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
|
47 | 52 | return FlashInferMetadataBuilder
|
48 | 53 |
|
| 54 | + @staticmethod |
| 55 | + def get_state_cls() -> Type["FlashInferState"]: |
| 56 | + return FlashInferState |
| 57 | + |
49 | 58 | @staticmethod
|
50 | 59 | def get_kv_cache_shape(
|
51 | 60 | num_blocks: int,
|
@@ -75,6 +84,160 @@ def get_supported_head_sizes() -> List[int]:
|
75 | 84 | return [64, 128, 256]
|
76 | 85 |
|
77 | 86 |
|
| 87 | +class FlashInferState(AttentionState): |
| 88 | + |
| 89 | + def __init__(self, runner): |
| 90 | + self.runner = runner |
| 91 | + self._is_graph_capturing = False |
| 92 | + self._workspace_buffer = None |
| 93 | + self._decode_wrapper = None |
| 94 | + self._prefill_wrapper = None |
| 95 | + |
| 96 | + def _get_workspace_buffer(self): |
| 97 | + if self._workspace_buffer is None: |
| 98 | + self._workspace_buffer = torch.empty( |
| 99 | + FLASHINFER_WORKSPACE_BUFFER_SIZE, |
| 100 | + dtype=torch.uint8, |
| 101 | + device=self.runner.device) |
| 102 | + return self._workspace_buffer |
| 103 | + |
| 104 | + def _get_prefill_wrapper(self): |
| 105 | + if self._prefill_wrapper is None: |
| 106 | + self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( |
| 107 | + self._get_workspace_buffer(), "NHD") |
| 108 | + return self._prefill_wrapper |
| 109 | + |
| 110 | + def _get_decode_wrapper(self): |
| 111 | + if self._decode_wrapper is None: |
| 112 | + num_qo_heads = (self.runner.model_config.get_num_attention_heads( |
| 113 | + self.runner.parallel_config)) |
| 114 | + num_kv_heads = self.runner.model_config.get_num_kv_heads( |
| 115 | + self.runner.parallel_config) |
| 116 | + use_tensor_cores = num_qo_heads // num_kv_heads >= 4 |
| 117 | + self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( |
| 118 | + self._get_workspace_buffer(), |
| 119 | + "NHD", |
| 120 | + use_tensor_cores=use_tensor_cores) |
| 121 | + return self._decode_wrapper |
| 122 | + |
| 123 | + @contextmanager |
| 124 | + def graph_capture(self, max_batch_size: int): |
| 125 | + self._is_graph_capturing = True |
| 126 | + self._graph_decode_wrapper = None |
| 127 | + self._graph_slot_mapping = torch.full((max_batch_size, ), |
| 128 | + PAD_SLOT_ID, |
| 129 | + dtype=torch.long, |
| 130 | + device=self.runner.device) |
| 131 | + self._graph_seq_lens = torch.ones(max_batch_size, |
| 132 | + dtype=torch.int32, |
| 133 | + device=self.runner.device) |
| 134 | + self._graph_block_tables = torch.from_numpy( |
| 135 | + self.runner.graph_block_tables).to(device=self.runner.device) |
| 136 | + self._graph_decode_workspace_buffer = self._get_workspace_buffer() |
| 137 | + self._graph_indices_buffer = torch.empty( |
| 138 | + max_batch_size * self.runner.cache_config.num_gpu_blocks, |
| 139 | + dtype=torch.int32, |
| 140 | + device=self.runner.device) |
| 141 | + self._graph_indptr_buffer = torch.empty(max_batch_size + 1, |
| 142 | + dtype=torch.int32, |
| 143 | + device=self.runner.device) |
| 144 | + self._graph_last_page_len_buffer = torch.empty( |
| 145 | + max_batch_size, dtype=torch.int32, device=self.runner.device) |
| 146 | + yield |
| 147 | + self._is_graph_capturing = False |
| 148 | + del self._graph_slot_mapping |
| 149 | + del self._graph_seq_lens |
| 150 | + del self._graph_block_tables |
| 151 | + del self._graph_decode_workspace_buffer |
| 152 | + del self._graph_indices_buffer |
| 153 | + del self._graph_indptr_buffer |
| 154 | + del self._graph_last_page_len_buffer |
| 155 | + del self._graph_decode_wrapper |
| 156 | + |
| 157 | + def graph_clone(self, batch_size: int): |
| 158 | + assert self._is_graph_capturing |
| 159 | + state = self.__class__(self.runner) |
| 160 | + state._workspace_buffer = self._graph_decode_workspace_buffer |
| 161 | + state._decode_wrapper = self._graph_decode_wrapper |
| 162 | + state._prefill_wrapper = self._get_prefill_wrapper() |
| 163 | + return state |
| 164 | + |
| 165 | + def graph_capture_get_metadata_for_batch(self, batch_size: int): |
| 166 | + assert self._is_graph_capturing |
| 167 | + _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] |
| 168 | + _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] |
| 169 | + |
| 170 | + num_qo_heads = (self.runner.model_config.get_num_attention_heads( |
| 171 | + self.runner.parallel_config)) |
| 172 | + num_kv_heads = self.runner.model_config.get_num_kv_heads( |
| 173 | + self.runner.parallel_config) |
| 174 | + use_tensor_cores = num_qo_heads // num_kv_heads >= 4 |
| 175 | + self._graph_decode_wrapper = \ |
| 176 | + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( |
| 177 | + self._graph_decode_workspace_buffer, _indptr_buffer, |
| 178 | + self._graph_indices_buffer, _last_page_len_buffer, "NHD", |
| 179 | + use_tensor_cores) |
| 180 | + kv_cache_dtype = get_kv_cache_torch_dtype( |
| 181 | + self.runner.kv_cache_dtype, self.runner.model_config.dtype) |
| 182 | + |
| 183 | + paged_kv_indptr_tensor_host = torch.arange(0, |
| 184 | + batch_size + 1, |
| 185 | + dtype=torch.int32) |
| 186 | + paged_kv_indices_tensor_host = torch.arange(0, |
| 187 | + batch_size, |
| 188 | + dtype=torch.int32) |
| 189 | + paged_kv_last_page_len_tensor_host = torch.full((batch_size, ), |
| 190 | + self.runner.block_size, |
| 191 | + dtype=torch.int32) |
| 192 | + query_start_loc_host = torch.arange(0, |
| 193 | + batch_size + 1, |
| 194 | + dtype=torch.int32) |
| 195 | + |
| 196 | + attn_metadata = self.runner.attn_backend.make_metadata( |
| 197 | + num_prefills=0, |
| 198 | + slot_mapping=self._graph_slot_mapping[:batch_size], |
| 199 | + num_prefill_tokens=0, |
| 200 | + num_decode_tokens=batch_size, |
| 201 | + max_prefill_seq_len=0, |
| 202 | + block_tables=self._graph_block_tables, |
| 203 | + paged_kv_indptr=paged_kv_indptr_tensor_host, |
| 204 | + paged_kv_indices=paged_kv_indices_tensor_host, |
| 205 | + paged_kv_last_page_len=paged_kv_last_page_len_tensor_host, |
| 206 | + num_qo_heads=num_qo_heads, |
| 207 | + num_kv_heads=num_kv_heads, |
| 208 | + head_dim=self.runner.model_config.get_head_size(), |
| 209 | + page_size=self.runner.block_size, |
| 210 | + seq_start_loc=None, |
| 211 | + query_start_loc=query_start_loc_host, |
| 212 | + device=self.runner.device, |
| 213 | + data_type=kv_cache_dtype, |
| 214 | + use_cuda_graph=True, |
| 215 | + decode_wrapper=self._graph_decode_wrapper, |
| 216 | + prefill_wrapper=None) |
| 217 | + attn_metadata.begin_forward() |
| 218 | + return attn_metadata |
| 219 | + |
| 220 | + def get_graph_input_buffers(self, attn_metadata): |
| 221 | + return { |
| 222 | + "slot_mapping": attn_metadata.slot_mapping, |
| 223 | + } |
| 224 | + |
| 225 | + def prepare_graph_input_buffers(self, input_buffers, attn_metadata): |
| 226 | + return |
| 227 | + |
| 228 | + def begin_forward(self, model_input): |
| 229 | + assert not self._is_graph_capturing |
| 230 | + state = self |
| 231 | + if model_input.attn_metadata.use_cuda_graph: |
| 232 | + batch_size = model_input.input_tokens.shape[0] |
| 233 | + state = (self.runner.graph_runners[model_input.virtual_engine] |
| 234 | + [batch_size].attn_state) |
| 235 | + model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( |
| 236 | + ) |
| 237 | + model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() |
| 238 | + model_input.attn_metadata.begin_forward() |
| 239 | + |
| 240 | + |
78 | 241 | @dataclass
|
79 | 242 | class FlashInferMetadata(AttentionMetadata):
|
80 | 243 | # Maximum sequence length among prefill batch. 0 if there are decoding
|
|
0 commit comments