From ffa48c9146fda1e8810d1cfa159e1d70aadae6c6 Mon Sep 17 00:00:00 2001 From: Mor Zusman Date: Wed, 11 Dec 2024 04:53:37 +0200 Subject: [PATCH] [Model] PP support for Mamba-like models (#10992) Signed-off-by: mzusman --- docs/source/models/supported_models.rst | 6 +- tests/distributed/test_pipeline_parallel.py | 6 +- vllm/config.py | 58 +++++++++---- vllm/model_executor/models/interfaces.py | 37 ++++++++ vllm/model_executor/models/jamba.py | 93 ++++++++++++++------- vllm/model_executor/models/mamba.py | 68 ++++++++++----- vllm/model_executor/models/registry.py | 11 ++- vllm/utils.py | 5 ++ vllm/v1/worker/gpu_model_runner.py | 8 +- vllm/v1/worker/gpu_worker.py | 6 +- vllm/worker/cache_engine.py | 12 +-- 11 files changed, 229 insertions(+), 81 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 4e5b10967e3bb..6540e023c1ab0 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -128,7 +128,7 @@ Text Generation - FalconMamba - :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc. - ✅︎ - - + - ✅︎ * - :code:`GemmaForCausalLM` - Gemma - :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc. @@ -193,7 +193,7 @@ Text Generation - Jamba - :code:`ai21labs/AI21-Jamba-1.5-Large`, :code:`ai21labs/AI21-Jamba-1.5-Mini`, :code:`ai21labs/Jamba-v0.1`, etc. - ✅︎ - - + - ✅︎ * - :code:`LlamaForCausalLM` - Llama 3.1, Llama 3, Llama 2, LLaMA, Yi - :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc. @@ -203,7 +203,7 @@ Text Generation - Mamba - :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc. - - - + - ✅︎ * - :code:`MiniCPMForCausalLM` - MiniCPM - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc. diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index b818ca921fcb0..85d408efafe96 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -156,13 +156,13 @@ def iter_params(self, model_name: str): # "internlm/internlm-chat-7b": PPTestSettings.fast(), "internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True), "inceptionai/jais-13b-chat": PPTestSettings.fast(), - # TODO: Implement PP - # "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(), + "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(), "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True), "openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True), # Uses Llama # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(), + "state-spaces/mamba-130m-hf": PPTestSettings.fast(), "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4), "mosaicml/mpt-7b": PPTestSettings.fast(), "nvidia/Minitron-8B-Base": PPTestSettings.fast(), @@ -234,6 +234,8 @@ def iter_params(self, model_name: str): "OpenGVLab/InternVL2-1B", "microsoft/Phi-3-vision-128k-instruct", "fixie-ai/ultravox-v0_3", + # [LANGUAGE GENERATION - HYBRID ARCH] + "ai21labs/Jamba-tiny-dev", ] diff --git a/vllm/config.py b/vllm/config.py index c66ddbb47f22e..2a9f0ebae997d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -27,8 +27,8 @@ ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) -from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - print_warning_once, random_uuid, +from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, + get_cpu_memory, print_warning_once, random_uuid, resolve_obj_by_qualname) if TYPE_CHECKING: @@ -284,6 +284,7 @@ def __init__( self._verify_tokenizer_mode() self.is_attention_free = self._init_attention_free() + self.is_hybrid = self._init_is_hybrid() self.has_inner_state = self._init_has_inner_state() if current_platform.is_neuron(): @@ -340,6 +341,10 @@ def _init_attention_free(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) return ModelRegistry.is_attention_free_model(architectures) + def _init_is_hybrid(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_hybrid_model(architectures) + def _init_has_inner_state(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) return ModelRegistry.model_has_inner_state(architectures) @@ -669,26 +674,51 @@ def get_num_attention_heads(self, num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) return num_heads // parallel_config.tensor_parallel_size - def get_num_layers(self, parallel_config: "ParallelConfig") -> int: + def get_layers_start_end_indices( + self, parallel_config: "ParallelConfig") -> Tuple[int, int]: from vllm.distributed.utils import get_pp_indices total_num_hidden_layers = getattr(self.hf_text_config, "num_hidden_layers", 0) pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size pp_size = parallel_config.pipeline_parallel_size start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) - return end - start - - def get_num_attention_layers(self, - parallel_config: "ParallelConfig") -> int: - if self.is_attention_free: - return 0 + return start, end - num_layers = self.get_num_layers(parallel_config) + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: + start, end = self.get_layers_start_end_indices(parallel_config) + return end - start - # Transformers supports layers_block_type @property - layers = getattr(self.hf_config, "layers_block_type", - ["attention"] * num_layers) - return len([t for t in layers if t == "attention"]) + def get_num_layers_by_block_type( + self, + parallel_config: "ParallelConfig", + block_type: LayerBlockType = LayerBlockType.attention, + ) -> int: + # This function relies on 'layers_block_type' in hf_config, + # for w/o this attribute, we will need to have workarounds like so + attn_block_type = block_type == LayerBlockType.attention + is_transformer = not self.is_hybrid and not self.is_attention_free + start, end = self.get_layers_start_end_indices(parallel_config) + + if is_transformer: + # Handle the basic case first + return end - start if attn_block_type else 0 + elif self.is_attention_free: + # Attention free + # Note that this code assumes there + # is only one type of attention-free block type. + return 0 if attn_block_type else end - start + else: + # Hybrid model + layers_block_type_value = getattr(self.hf_config, + "layers_block_type", None) + if layers_block_type_value is None: + raise ValueError("The model is an hybrid without a" + "layers_block_type in the hf_config," + "cannot determine the num of " + f"{block_type.value} layers") + + return sum(t == block_type.value + for t in layers_block_type_value[start:end]) def get_multimodal_config(self) -> "MultiModalConfig": """ diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index c3979eab905db..70b78fe64f2d8 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -363,6 +363,43 @@ def is_attention_free( return isinstance(model, IsAttentionFree) +@runtime_checkable +class IsHybrid(Protocol): + """The interface required for all models like Jamba that have both + attention and mamba blocks, indicates that + hf_config has 'layers_block_type'""" + + is_hybrid: ClassVar[Literal[True]] = True + """ + A flag that indicates this model has both mamba and attention blocks + , also indicates that the model's hf_config has + 'layers_block_type' """ + + +@runtime_checkable +class _IsHybridType(Protocol): + is_hybrid: ClassVar[Literal[True]] + + +@overload +def is_hybrid(model: object) -> TypeIs[IsHybrid]: + ... + + +@overload +def is_hybrid(model: Type[object]) -> TypeIs[Type[IsHybrid]]: + ... + + +def is_hybrid( + model: Union[Type[object], object] +) -> Union[TypeIs[Type[IsHybrid]], TypeIs[IsHybrid]]: + if isinstance(model, type): + return isinstance(model, _IsHybridType) + + return isinstance(model, IsHybrid) + + @runtime_checkable class SupportsCrossEncoding(Protocol): """The interface required for all models that support cross encoding.""" diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 5d5e8ae1ee532..6bb4c13ab35df 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -9,6 +9,7 @@ from vllm.attention.layer import Attention from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -25,9 +26,12 @@ MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType -from .interfaces import HasInnerState, SupportsLoRA -from .utils import maybe_prefix +from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -281,16 +285,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, ) - decoder_layers = [] - for i in range(config.num_hidden_layers): - layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]] - decoder_layers.append( - layer_class(config, - layer_idx=i, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.layers.{i}")) - self.layers = nn.ModuleList(decoder_layers) + def get_layer(prefix: str): + layer_idx = int(prefix.rsplit(".", 1)[1]) + layer_class = ALL_DECODER_LAYER_TYPES[ + config.layers_block_type[layer_idx]] + return layer_class( + config, + layer_idx, + cache_config, + quant_config=quant_config, + prefix=prefix, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers") + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -304,26 +316,34 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - for i in range(len(self.layers)): + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + kv_cache_index = 0 + mamba_cache_index = 0 + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] kv_cache = None layer_mamba_cache_params = None if isinstance(layer, JambaAttentionDecoderLayer): - kv_cache = kv_caches[(i - self.config.attn_layer_offset) // - self.config.attn_layer_period] + kv_cache = kv_caches[kv_cache_index] + kv_cache_index += 1 if isinstance(layer, JambaMambaDecoderLayer): - current_state_layer = i - (1 + - (i - self.config.attn_layer_offset) - // self.config.attn_layer_period) + current_state_layer = mamba_cache_index layer_mamba_cache_params = mamba_cache_params.at_layer_idx( current_state_layer) + mamba_cache_index += 1 hidden_states, residual = layer( positions=positions, @@ -332,11 +352,17 @@ def forward( attn_metadata=attn_metadata, residual=residual, mamba_cache_params=layer_mamba_cache_params) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.final_layernorm(hidden_states, residual) return hidden_states -class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): +class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, + IsHybrid): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -368,6 +394,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = config + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config self.scheduler_config = scheduler_config self.model = JambaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) @@ -390,6 +418,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size) self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -406,10 +437,8 @@ def forward(self, self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) - layers_type = self.config.layers_block_type - num_mamba_layers = sum( - [layer_type == "mamba" for layer_type in layers_type]) - + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, *self._get_mamba_cache_shape()) @@ -423,7 +452,7 @@ def forward(self, state_indices_tensor) hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, mamba_cache_params, - inputs_embeds) + intermediate_tensors, inputs_embeds) return hidden_states def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): @@ -504,8 +533,12 @@ def load_weights(self, weights: Iterable[Tuple[str, continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -520,6 +553,8 @@ def load_weights(self, weights: Iterable[Tuple[str, if weight_name not in name: continue + if is_pp_missing_parameter(name, self): + continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader @@ -533,6 +568,8 @@ def load_weights(self, weights: Iterable[Tuple[str, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 8bdcd2c5aad1f..1f5cd02711899 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -8,6 +8,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer @@ -18,13 +19,16 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (HasInnerState, - IsAttentionFree) + IsAttentionFree, SupportsPP) from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from vllm.utils import LayerBlockType -from .utils import maybe_prefix +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -95,15 +99,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=config.vocab_size, ) - decoder_layers = [] - for i in range(config.num_hidden_layers): - decoder_layers.append( - MambaDecoderLayer(config, - cache_config=cache_config, - quant_config=quant_config)) - self.layers = nn.ModuleList(decoder_layers) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MambaDecoderLayer( + config, cache_config=cache_config, quant_config=quant_config), + prefix=f"{prefix}.layers") + self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embeddings(input_ids) @@ -114,29 +120,40 @@ def forward( positions: torch.Tensor, attn_metadata: AttentionMetadata, mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - if inputs_embeds is not None: - hidden_states = inputs_embeds + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] - for i in range(len(self.layers)): + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions=positions, hidden_states=hidden_states, attn_metadata=attn_metadata, residual=residual, - mamba_cache_params=mamba_cache_params.at_layer_idx(i)) + mamba_cache_params=mamba_cache_params.at_layer_idx( + i - self.start_layer)) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm_f(hidden_states, residual) return hidden_states -class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): +class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config @@ -148,7 +165,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = config + self.vllm_config = vllm_config self.scheduler_config = scheduler_config + self.model_config = vllm_config.model_config self.backbone = MambaModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone")) self.unpadded_vocab_size = config.vocab_size @@ -174,6 +193,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vocab_size) self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.backbone.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.backbone.get_input_embeddings(input_ids) @@ -189,9 +211,12 @@ def forward(self, max_batch_size = (VllmConfig.get_graph_batch_size( self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) + + num_mamba_layers = self.model_config.get_num_layers_by_block_type( + self.vllm_config.parallel_config, LayerBlockType.mamba) self.mamba_cache = MambaCacheManager( - self.lm_head.weight.dtype, self.config.num_hidden_layers, - max_batch_size, *self._get_mamba_cache_shape()) + self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, + *self._get_mamba_cache_shape()) ( mamba_cache_tensors, @@ -204,7 +229,8 @@ def forward(self, state_indices_tensor) hidden_states = self.backbone(input_ids, positions, attn_metadata, - mamba_cache_params, inputs_embeds) + mamba_cache_params, intermediate_tensors, + inputs_embeds) return hidden_states @@ -252,6 +278,8 @@ def load_weights(self, weights: Iterable[Tuple[str, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index e69596aa915b5..4beea4641f5ab 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -21,7 +21,7 @@ from vllm.platforms import current_platform from .adapters import as_embedding_model -from .interfaces import (has_inner_state, is_attention_free, +from .interfaces import (has_inner_state, is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, supports_pp) from .interfaces_base import is_pooling_model, is_text_generation_model @@ -218,6 +218,7 @@ class _ModelInfo: supports_pp: bool has_inner_state: bool is_attention_free: bool + is_hybrid: bool @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": @@ -239,6 +240,7 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), + is_hybrid=is_hybrid(model), ) @@ -484,6 +486,13 @@ def is_attention_free_model( model_cls, _ = self.inspect_model_cls(architectures) return model_cls.is_attention_free + def is_hybrid_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_hybrid + ModelRegistry = _ModelRegistry({ model_arch: _LazyRegisteredModel( diff --git a/vllm/utils.py b/vllm/utils.py index 7cdb2cb320b05..1882264c19775 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -170,6 +170,11 @@ class Device(enum.Enum): CPU = enum.auto() +class LayerBlockType(enum.Enum): + attention = "attention" + mamba = "mamba" + + class Counter: def __init__(self, start: int = 0) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a3335fa838352..8d9976ded7c5e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -15,8 +15,8 @@ from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingType -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, - is_pin_memory_available) +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, + LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput @@ -68,8 +68,8 @@ def __init__( self.max_num_tokens = scheduler_config.max_num_batched_tokens # Model-related. - self.num_attn_layers = model_config.get_num_attention_layers( - parallel_config) + self.num_attn_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d32848c3775ae..49e415ab72e0b 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -14,7 +14,7 @@ from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -260,8 +260,8 @@ def _get_cache_block_size( ) -> int: head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) - num_attention_layers = model_config.get_num_attention_layers( - parallel_config) + num_attention_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index ac3270d1c9909..7ccd4571b19df 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -6,8 +6,8 @@ from vllm.attention import get_attn_backend from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, - is_pin_memory_available) +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, + get_dtype_size, is_pin_memory_available) logger = init_logger(__name__) @@ -34,8 +34,8 @@ def __init__( self.head_size = model_config.get_head_size() # Models like Jamba, have mixed typed layers, E.g Mamba - self.num_attention_layers = model_config.get_num_attention_layers( - parallel_config) + self.num_attention_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.block_size = cache_config.block_size @@ -105,8 +105,8 @@ def get_cache_block_size( ) -> int: head_size = model_config.get_head_size() num_heads = model_config.get_num_kv_heads(parallel_config) - num_attention_layers = model_config.get_num_attention_layers( - parallel_config) + num_attention_layers = model_config.get_num_layers_by_block_type( + parallel_config, LayerBlockType.attention) key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block