Skip to content

Commit 18961c5

Browse files
authored
[Hybrid] Pass kernel block size to builders (#27753)
Signed-off-by: Thomas Parnell <[email protected]>
1 parent 470ad11 commit 18961c5

File tree

4 files changed

+62
-27
lines changed

4 files changed

+62
-27
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ def get_supported_head_sizes(cls) -> list[int]:
6262

6363
@staticmethod
6464
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
65-
return [MultipleOf(16)]
65+
# NOTE(tdoublep): while in principle, FA supports
66+
# MultipleOf(16), these are the block sizes that do not
67+
# suffer from the NaN propagation problem described here:
68+
# https://github.com/Dao-AILab/flash-attention/issues/1974
69+
return [16, 32, 64]
6670

6771
@classmethod
6872
def validate_head_size(cls, head_size: int) -> None:

vllm/v1/kv_cache_interface.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import copy
5-
from dataclasses import dataclass, fields
5+
from dataclasses import dataclass, fields, replace
66
from math import prod
77

88
import torch
@@ -44,6 +44,12 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
4444
"""
4545
raise NotImplementedError
4646

47+
def copy_with_new_block_size(self, block_size: int) -> Self:
48+
"""
49+
Create a new KVCacheSpec from self but replacing the block size.
50+
"""
51+
return replace(self, block_size=block_size)
52+
4753
@classmethod
4854
def merge(cls, specs: list[Self]) -> Self:
4955
"""

vllm/v1/worker/gpu_model_runner.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4039,16 +4039,11 @@ def create_attn_groups(
40394039
) -> list[AttentionGroup]:
40404040
attn_groups: list[AttentionGroup] = []
40414041
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
4042-
attn_group = AttentionGroup.create_with_metadata_builders(
4042+
attn_group = AttentionGroup(
40434043
attn_backend,
40444044
layer_names,
40454045
kv_cache_spec,
4046-
self.vllm_config,
4047-
self.device,
40484046
kv_cache_group_id,
4049-
num_metadata_builders=1
4050-
if not self.parallel_config.enable_dbo
4051-
else 2,
40524047
)
40534048

40544049
attn_groups.append(attn_group)
@@ -4067,7 +4062,27 @@ def create_attn_groups(
40674062
for i, attn_backend_map in enumerate(attention_backend_maps):
40684063
self.attn_groups.append(create_attn_groups(attn_backend_map, i))
40694064

4065+
def initialize_metadata_builders(
4066+
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
4067+
) -> None:
4068+
"""
4069+
Create the metadata builders for all KV cache groups and attn groups.
4070+
"""
4071+
for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)):
4072+
for attn_group in self.attn_groups[kv_cache_group_id]:
4073+
attn_group.create_metadata_builders(
4074+
self.vllm_config,
4075+
self.device,
4076+
kernel_block_sizes[kv_cache_group_id]
4077+
if kv_cache_group_id < len(kernel_block_sizes)
4078+
else None,
4079+
num_metadata_builders=1
4080+
if not self.parallel_config.enable_dbo
4081+
else 2,
4082+
)
40704083
# Calculate reorder batch threshold (if needed)
4084+
# Note (tdoublep): do this *after* constructing builders,
4085+
# because some of them change the threshold at init time.
40714086
self.calculate_reorder_batch_threshold()
40724087

40734088
def _check_and_update_cudagraph_mode(
@@ -4633,6 +4648,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
46334648
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
46344649
# tokens each.
46354650
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
4651+
4652+
# create metadata builders
4653+
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)
4654+
46364655
# Reinitialize need to after initialize_attn_backend
46374656
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
46384657
kv_caches = self.initialize_kv_cache_tensors(

vllm/v1/worker/utils.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from collections import defaultdict
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from typing import TYPE_CHECKING
66

77
import torch
@@ -134,31 +134,37 @@ def reset_cache(self) -> None:
134134
@dataclass
135135
class AttentionGroup:
136136
backend: type[AttentionBackend]
137-
# When ubatching is enabled we will have a metadata builder for each ubatch
138-
# so that if they use internal persistant buffers for cudagraphs, and they
139-
# won't have to worry about conflicting with the other ubatches.
140-
metadata_builders: list[AttentionMetadataBuilder]
141137
layer_names: list[str]
142138
kv_cache_spec: KVCacheSpec
143139
kv_cache_group_id: int
140+
# When ubatching is enabled we will have a metadata builder for each ubatch
141+
# so that if they use internal persistant buffers for cudagraphs, and they
142+
# won't have to worry about conflicting with the other ubatches.
143+
metadata_builders: list[AttentionMetadataBuilder] = field(
144+
default_factory=lambda: []
145+
)
144146

145-
@staticmethod
146-
def create_with_metadata_builders(
147-
backend: type[AttentionBackend],
148-
layer_names: list[str],
149-
kv_cache_spec: KVCacheSpec,
150-
vllm_config: VllmConfig,
151-
device: torch.device,
152-
kv_cache_group_id: int,
147+
def create_metadata_builders(
148+
self,
149+
vllm_config,
150+
device,
151+
kernel_block_size: int | None,
153152
num_metadata_builders: int = 1,
154-
) -> "AttentionGroup":
155-
metadata_builders = [
156-
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device)
153+
):
154+
kv_cache_spec_builder = (
155+
self.kv_cache_spec.copy_with_new_block_size(kernel_block_size)
156+
if kernel_block_size is not None
157+
else self.kv_cache_spec
158+
)
159+
self.metadata_builders = [
160+
self.backend.get_builder_cls()(
161+
kv_cache_spec_builder,
162+
self.layer_names,
163+
vllm_config,
164+
device,
165+
)
157166
for _ in range(num_metadata_builders)
158167
]
159-
return AttentionGroup(
160-
backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id
161-
)
162168

163169
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
164170
assert len(self.metadata_builders) > ubatch_id

0 commit comments

Comments
 (0)