Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ def get_supported_head_sizes(cls) -> list[int]:

@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)]
# NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return [16, 32, 64]

@classmethod
def validate_head_size(cls, head_size: int) -> None:
Expand Down
8 changes: 7 additions & 1 deletion vllm/v1/kv_cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import copy
from dataclasses import dataclass, fields
from dataclasses import dataclass, fields, replace
from math import prod

import torch
Expand Down Expand Up @@ -44,6 +44,12 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
"""
raise NotImplementedError

def copy_with_new_block_size(self, block_size: int) -> Self:
"""
Create a new KVCacheSpec from self but replacing the block size.
"""
return replace(self, block_size=block_size)

@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Expand Down
31 changes: 25 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3982,16 +3982,11 @@ def create_attn_groups(
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
attn_group = AttentionGroup.create_with_metadata_builders(
attn_group = AttentionGroup(
attn_backend,
layer_names,
kv_cache_spec,
self.vllm_config,
self.device,
kv_cache_group_id,
num_metadata_builders=1
if not self.parallel_config.enable_dbo
else 2,
)

attn_groups.append(attn_group)
Expand All @@ -4010,7 +4005,27 @@ def create_attn_groups(
for i, attn_backend_map in enumerate(attention_backend_maps):
self.attn_groups.append(create_attn_groups(attn_backend_map, i))

def initialize_metadata_builders(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> None:
"""
Create the metadata builders for all KV cache groups and attn groups.
"""
for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)):
for attn_group in self.attn_groups[kv_cache_group_id]:
attn_group.create_metadata_builders(
self.vllm_config,
self.device,
kernel_block_sizes[kv_cache_group_id]
if kv_cache_group_id < len(kernel_block_sizes)
else None,
num_metadata_builders=1
if not self.parallel_config.enable_dbo
else 2,
)
# Calculate reorder batch threshold (if needed)
# Note (tdoublep): do this *after* constructing builders,
# because some of them change the threshold at init time.
self.calculate_reorder_batch_threshold()

def _check_and_update_cudagraph_mode(
Expand Down Expand Up @@ -4576,6 +4591,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
# tokens each.
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)

# create metadata builders
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)

# Reinitialize need to after initialize_attn_backend
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
kv_caches = self.initialize_kv_cache_tensors(
Expand Down
44 changes: 25 additions & 19 deletions vllm/v1/worker/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import torch
Expand Down Expand Up @@ -134,31 +134,37 @@ def reset_cache(self) -> None:
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder]
layer_names: list[str]
kv_cache_spec: KVCacheSpec
kv_cache_group_id: int
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder] = field(
default_factory=lambda: []
)

@staticmethod
def create_with_metadata_builders(
backend: type[AttentionBackend],
layer_names: list[str],
kv_cache_spec: KVCacheSpec,
vllm_config: VllmConfig,
device: torch.device,
kv_cache_group_id: int,
def create_metadata_builders(
self,
vllm_config,
device,
kernel_block_size: int | None,
num_metadata_builders: int = 1,
) -> "AttentionGroup":
metadata_builders = [
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device)
):
kv_cache_spec_builder = (
self.kv_cache_spec.copy_with_new_block_size(kernel_block_size)
if kernel_block_size is not None
else self.kv_cache_spec
)
self.metadata_builders = [
self.backend.get_builder_cls()(
kv_cache_spec_builder,
self.layer_names,
vllm_config,
device,
)
for _ in range(num_metadata_builders)
]
return AttentionGroup(
backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id
)

def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
assert len(self.metadata_builders) > ubatch_id
Expand Down