diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1eac94940e78..07f9ef173b4e 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -62,7 +62,11 @@ def get_supported_head_sizes(cls) -> list[int]: @staticmethod def get_supported_kernel_block_size() -> list[int | MultipleOf]: - return [MultipleOf(16)] + # NOTE(tdoublep): while in principle, FA supports + # MultipleOf(16), these are the block sizes that do not + # suffer from the NaN propagation problem described here: + # https://github.com/Dao-AILab/flash-attention/issues/1974 + return [16, 32, 64] @classmethod def validate_head_size(cls, head_size: int) -> None: diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 0f564fdb3b08..7f33eb7e699c 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -from dataclasses import dataclass, fields +from dataclasses import dataclass, fields, replace from math import prod import torch @@ -44,6 +44,12 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: """ raise NotImplementedError + def copy_with_new_block_size(self, block_size: int) -> Self: + """ + Create a new KVCacheSpec from self but replacing the block size. + """ + return replace(self, block_size=block_size) + @classmethod def merge(cls, specs: list[Self]) -> Self: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ba852bb89f33..f6c447937fc0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3982,16 +3982,11 @@ def create_attn_groups( ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): - attn_group = AttentionGroup.create_with_metadata_builders( + attn_group = AttentionGroup( attn_backend, layer_names, kv_cache_spec, - self.vllm_config, - self.device, kv_cache_group_id, - num_metadata_builders=1 - if not self.parallel_config.enable_dbo - else 2, ) attn_groups.append(attn_group) @@ -4010,7 +4005,27 @@ def create_attn_groups( for i, attn_backend_map in enumerate(attention_backend_maps): self.attn_groups.append(create_attn_groups(attn_backend_map, i)) + def initialize_metadata_builders( + self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] + ) -> None: + """ + Create the metadata builders for all KV cache groups and attn groups. + """ + for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)): + for attn_group in self.attn_groups[kv_cache_group_id]: + attn_group.create_metadata_builders( + self.vllm_config, + self.device, + kernel_block_sizes[kv_cache_group_id] + if kv_cache_group_id < len(kernel_block_sizes) + else None, + num_metadata_builders=1 + if not self.parallel_config.enable_dbo + else 2, + ) # Calculate reorder batch threshold (if needed) + # Note (tdoublep): do this *after* constructing builders, + # because some of them change the threshold at init time. self.calculate_reorder_batch_threshold() def _check_and_update_cudagraph_mode( @@ -4576,6 +4591,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # kernel_block_size 64 and split the 256-token-block to 4 blocks with 64 # tokens each. kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) + + # create metadata builders + self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes) + # Reinitialize need to after initialize_attn_backend self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes) kv_caches = self.initialize_kv_cache_tensors( diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 396adbcfb289..0ca7e81a5c7b 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING import torch @@ -134,31 +134,37 @@ def reset_cache(self) -> None: @dataclass class AttentionGroup: backend: type[AttentionBackend] - # When ubatching is enabled we will have a metadata builder for each ubatch - # so that if they use internal persistant buffers for cudagraphs, and they - # won't have to worry about conflicting with the other ubatches. - metadata_builders: list[AttentionMetadataBuilder] layer_names: list[str] kv_cache_spec: KVCacheSpec kv_cache_group_id: int + # When ubatching is enabled we will have a metadata builder for each ubatch + # so that if they use internal persistant buffers for cudagraphs, and they + # won't have to worry about conflicting with the other ubatches. + metadata_builders: list[AttentionMetadataBuilder] = field( + default_factory=lambda: [] + ) - @staticmethod - def create_with_metadata_builders( - backend: type[AttentionBackend], - layer_names: list[str], - kv_cache_spec: KVCacheSpec, - vllm_config: VllmConfig, - device: torch.device, - kv_cache_group_id: int, + def create_metadata_builders( + self, + vllm_config, + device, + kernel_block_size: int | None, num_metadata_builders: int = 1, - ) -> "AttentionGroup": - metadata_builders = [ - backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device) + ): + kv_cache_spec_builder = ( + self.kv_cache_spec.copy_with_new_block_size(kernel_block_size) + if kernel_block_size is not None + else self.kv_cache_spec + ) + self.metadata_builders = [ + self.backend.get_builder_cls()( + kv_cache_spec_builder, + self.layer_names, + vllm_config, + device, + ) for _ in range(num_metadata_builders) ] - return AttentionGroup( - backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id - ) def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder: assert len(self.metadata_builders) > ubatch_id