From 7c7fab301281842d05729b4273501e94955421d6 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 31 Oct 2025 17:34:21 -0400 Subject: [PATCH 1/5] Pass kernel block sizes to metadata builders Signed-off-by: Thomas Parnell --- vllm/v1/attention/backends/flash_attn.py | 10 ++++-- vllm/v1/attention/backends/flashinfer.py | 8 ++--- vllm/v1/attention/backends/utils.py | 2 ++ vllm/v1/worker/gpu_model_runner.py | 30 ++++++++++++---- vllm/v1/worker/utils.py | 44 ++++++++++++++---------- 5 files changed, 62 insertions(+), 32 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1eac94940e78..833060740955 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -62,7 +62,11 @@ def get_supported_head_sizes(cls) -> list[int]: @staticmethod def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] + # NOTE(tdoublep): while in principle, FA supports + # MultipleOf(16), these are the block sizes that do not + # suffer from the NaN propagation problem described here: + # https://github.com/Dao-AILab/flash-attention/issues/1974 + return [16, 32, 64] @classmethod def validate_head_size(cls, head_size: int) -> None: @@ -198,6 +202,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH ) + requires_kernel_block_size = True def __init__( self, @@ -205,6 +210,7 @@ def __init__( layer_names: list[str], vllm_config: VllmConfig, device: torch.device, + kernel_block_size: int, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config @@ -218,7 +224,7 @@ def __init__( self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) self.kv_cache_dtype = kv_cache_spec.dtype self.headdim = self.model_config.get_head_size() - self.block_size = kv_cache_spec.block_size + self.block_size = kernel_block_size self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = get_flash_attn_version() == 3 diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index e71d4ca4629d..c42fb61824e2 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -277,6 +277,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) reorder_batch_threshold: int = 1 + requires_kernel_block_size = True def __init__( self, @@ -284,6 +285,7 @@ def __init__( layer_names: list[str], vllm_config: VllmConfig, device: torch.device, + kernel_block_size: int, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.cache_config = vllm_config.cache_config @@ -302,9 +304,7 @@ def __init__( self.disable_split_kv = False self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv( - self.model_config.max_model_len, self.kv_cache_spec.block_size - ) + max_num_pages_per_req = cdiv(self.model_config.max_model_len, kernel_block_size) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req speculative_config = vllm_config.speculative_config @@ -333,7 +333,7 @@ def __init__( self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size FlashInferBackend.validate_head_size(self.head_dim) - self.page_size = self.kv_cache_spec.block_size + self.page_size = kernel_block_size self.cache_dtype = self.cache_config.cache_dtype if self.cache_dtype.startswith("fp8"): diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 07d62e9849e0..e806ff0a3cd1 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -249,6 +249,8 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. reorder_batch_threshold: int | None = None + # Does the builder need the kernel block size defined + requires_kernel_block_size = False @abstractmethod def __init__( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ba852bb89f33..9c822aa17059 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3982,18 +3982,12 @@ def create_attn_groups( ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): - attn_group = AttentionGroup.create_with_metadata_builders( + attn_group = AttentionGroup( attn_backend, layer_names, kv_cache_spec, - self.vllm_config, - self.device, kv_cache_group_id, - num_metadata_builders=1 - if not self.parallel_config.enable_dbo - else 2, ) - attn_groups.append(attn_group) return attn_groups @@ -4010,7 +4004,25 @@ def create_attn_groups( for i, attn_backend_map in enumerate(attention_backend_maps): self.attn_groups.append(create_attn_groups(attn_backend_map, i)) + def initialize_metadata_builders( + self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] + ) -> None: + """ + Create the metadata builders for all KV cache groups and attn groups. + """ + for kv_cache_group_id, _ in enumerate(kv_cache_config.kv_cache_groups): + for attn_group in self.attn_groups[kv_cache_group_id]: + attn_group.create_metadata_builders( + self.vllm_config, + self.device, + kernel_block_sizes[kv_cache_group_id], + num_metadata_builders=1 + if not self.parallel_config.enable_dbo + else 2, + ) # Calculate reorder batch threshold (if needed) + # Note (tdoublep): do this *after* constructing builders, + # because some of them change the threshold at init time. self.calculate_reorder_batch_threshold() def _check_and_update_cudagraph_mode( @@ -4576,6 +4588,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # kernel_block_size 64 and split the 256-token-block to 4 blocks with 64 # tokens each. kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) + + # create metadata builders + self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes) + # Reinitialize need to after initialize_attn_backend self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes) kv_caches = self.initialize_kv_cache_tensors( diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 396adbcfb289..e29c3499b5b6 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING import torch @@ -134,31 +134,37 @@ def reset_cache(self) -> None: @dataclass class AttentionGroup: backend: type[AttentionBackend] - # When ubatching is enabled we will have a metadata builder for each ubatch - # so that if they use internal persistant buffers for cudagraphs, and they - # won't have to worry about conflicting with the other ubatches. - metadata_builders: list[AttentionMetadataBuilder] layer_names: list[str] kv_cache_spec: KVCacheSpec kv_cache_group_id: int + # When ubatching is enabled we will have a metadata builder for each ubatch + # so that if they use internal persistant buffers for cudagraphs, and they + # won't have to worry about conflicting with the other ubatches. + metadata_builders: list[AttentionMetadataBuilder] = field( + default_factory=lambda: [] + ) - @staticmethod - def create_with_metadata_builders( - backend: type[AttentionBackend], - layer_names: list[str], - kv_cache_spec: KVCacheSpec, - vllm_config: VllmConfig, - device: torch.device, - kv_cache_group_id: int, + def create_metadata_builders( + self, + vllm_config, + device, + kernel_block_size: int, num_metadata_builders: int = 1, - ) -> "AttentionGroup": - metadata_builders = [ - backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device) + ): + builder_cls = self.backend.get_builder_cls() + builder_extra_args = {} + if builder_cls.requires_kernel_block_size: + builder_extra_args["kernel_block_size"] = kernel_block_size + self.metadata_builders = [ + builder_cls( + self.kv_cache_spec, + self.layer_names, + vllm_config, + device, + **builder_extra_args, + ) for _ in range(num_metadata_builders) ] - return AttentionGroup( - backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id - ) def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder: assert len(self.metadata_builders) > ubatch_id From 3d0061d150f8f9234905d0de3e5d8f8fa0663724 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 31 Oct 2025 17:35:37 -0400 Subject: [PATCH 2/5] Reduce diff Signed-off-by: Thomas Parnell --- vllm/v1/worker/gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9c822aa17059..02877eefa085 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3988,6 +3988,7 @@ def create_attn_groups( kv_cache_spec, kv_cache_group_id, ) + attn_groups.append(attn_group) return attn_groups From 61813aa9f96e7b91b9d8c62a1988bdb642e62bbc Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sat, 1 Nov 2025 10:11:34 -0400 Subject: [PATCH 3/5] Replace block size in KVCacheSpec Signed-off-by: Thomas Parnell --- vllm/v1/attention/backends/flash_attn.py | 4 +--- vllm/v1/attention/backends/flashinfer.py | 8 ++++---- vllm/v1/attention/backends/utils.py | 2 -- vllm/v1/kv_cache_interface.py | 8 +++++++- vllm/v1/worker/utils.py | 12 +++++------- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 833060740955..07f9ef173b4e 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -202,7 +202,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH ) - requires_kernel_block_size = True def __init__( self, @@ -210,7 +209,6 @@ def __init__( layer_names: list[str], vllm_config: VllmConfig, device: torch.device, - kernel_block_size: int, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.model_config = vllm_config.model_config @@ -224,7 +222,7 @@ def __init__( self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config) self.kv_cache_dtype = kv_cache_spec.dtype self.headdim = self.model_config.get_head_size() - self.block_size = kernel_block_size + self.block_size = kv_cache_spec.block_size self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = get_flash_attn_version() == 3 diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index c42fb61824e2..e71d4ca4629d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -277,7 +277,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) reorder_batch_threshold: int = 1 - requires_kernel_block_size = True def __init__( self, @@ -285,7 +284,6 @@ def __init__( layer_names: list[str], vllm_config: VllmConfig, device: torch.device, - kernel_block_size: int, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.cache_config = vllm_config.cache_config @@ -304,7 +302,9 @@ def __init__( self.disable_split_kv = False self.compilation_config = vllm_config.compilation_config - max_num_pages_per_req = cdiv(self.model_config.max_model_len, kernel_block_size) + max_num_pages_per_req = cdiv( + self.model_config.max_model_len, self.kv_cache_spec.block_size + ) max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_pages = max_num_reqs * max_num_pages_per_req speculative_config = vllm_config.speculative_config @@ -333,7 +333,7 @@ def __init__( self.num_kv_heads = self.kv_cache_spec.num_kv_heads self.head_dim = self.kv_cache_spec.head_size FlashInferBackend.validate_head_size(self.head_dim) - self.page_size = kernel_block_size + self.page_size = self.kv_cache_spec.block_size self.cache_dtype = self.cache_config.cache_dtype if self.cache_dtype.startswith("fp8"): diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index e806ff0a3cd1..07d62e9849e0 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -249,8 +249,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. reorder_batch_threshold: int | None = None - # Does the builder need the kernel block size defined - requires_kernel_block_size = False @abstractmethod def __init__( diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 0f564fdb3b08..7f33eb7e699c 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -from dataclasses import dataclass, fields +from dataclasses import dataclass, fields, replace from math import prod import torch @@ -44,6 +44,12 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: """ raise NotImplementedError + def copy_with_new_block_size(self, block_size: int) -> Self: + """ + Create a new KVCacheSpec from self but replacing the block size. + """ + return replace(self, block_size=block_size) + @classmethod def merge(cls, specs: list[Self]) -> Self: """ diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index e29c3499b5b6..5d71ce19abe5 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -151,17 +151,15 @@ def create_metadata_builders( kernel_block_size: int, num_metadata_builders: int = 1, ): - builder_cls = self.backend.get_builder_cls() - builder_extra_args = {} - if builder_cls.requires_kernel_block_size: - builder_extra_args["kernel_block_size"] = kernel_block_size + kv_cache_spec_kernel = self.kv_cache_spec.copy_with_new_block_size( + kernel_block_size + ) self.metadata_builders = [ - builder_cls( - self.kv_cache_spec, + self.backend.get_builder_cls()( + kv_cache_spec_kernel, self.layer_names, vllm_config, device, - **builder_extra_args, ) for _ in range(num_metadata_builders) ] From 083afeb09027e674afc0b5a1add4d43591d28e6f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sun, 2 Nov 2025 01:54:57 -0500 Subject: [PATCH 4/5] Fix codex suggestion Signed-off-by: Thomas Parnell --- vllm/v1/worker/gpu_model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 02877eefa085..206449d8cc12 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4011,7 +4011,10 @@ def initialize_metadata_builders( """ Create the metadata builders for all KV cache groups and attn groups. """ - for kv_cache_group_id, _ in enumerate(kv_cache_config.kv_cache_groups): + for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)): + if kv_cache_group_id == len(kernel_block_sizes): + # There may be a last group for layers without kv cache. + continue for attn_group in self.attn_groups[kv_cache_group_id]: attn_group.create_metadata_builders( self.vllm_config, From 2e5efe5fc9bda27e7c72651f59345100144eec3f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Sun, 2 Nov 2025 02:41:17 -0500 Subject: [PATCH 5/5] Fix issue for encoders Signed-off-by: Thomas Parnell --- vllm/v1/worker/gpu_model_runner.py | 7 +++---- vllm/v1/worker/utils.py | 10 ++++++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 206449d8cc12..f6c447937fc0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4012,14 +4012,13 @@ def initialize_metadata_builders( Create the metadata builders for all KV cache groups and attn groups. """ for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)): - if kv_cache_group_id == len(kernel_block_sizes): - # There may be a last group for layers without kv cache. - continue for attn_group in self.attn_groups[kv_cache_group_id]: attn_group.create_metadata_builders( self.vllm_config, self.device, - kernel_block_sizes[kv_cache_group_id], + kernel_block_sizes[kv_cache_group_id] + if kv_cache_group_id < len(kernel_block_sizes) + else None, num_metadata_builders=1 if not self.parallel_config.enable_dbo else 2, diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 5d71ce19abe5..0ca7e81a5c7b 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -148,15 +148,17 @@ def create_metadata_builders( self, vllm_config, device, - kernel_block_size: int, + kernel_block_size: int | None, num_metadata_builders: int = 1, ): - kv_cache_spec_kernel = self.kv_cache_spec.copy_with_new_block_size( - kernel_block_size + kv_cache_spec_builder = ( + self.kv_cache_spec.copy_with_new_block_size(kernel_block_size) + if kernel_block_size is not None + else self.kv_cache_spec ) self.metadata_builders = [ self.backend.get_builder_cls()( - kv_cache_spec_kernel, + kv_cache_spec_builder, self.layer_names, vllm_config, device,