Skip to content

Commit

Permalink
[Model] PP support for Mamba-like models (#10992)
Browse files Browse the repository at this point in the history
Signed-off-by: mzusman <[email protected]>
  • Loading branch information
mzusman authored Dec 11, 2024
1 parent d5c5154 commit ffa48c9
Show file tree
Hide file tree
Showing 11 changed files with 229 additions and 81 deletions.
6 changes: 3 additions & 3 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Text Generation
- FalconMamba
- :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc.
- ✅︎
-
- ✅︎
* - :code:`GemmaForCausalLM`
- Gemma
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
Expand Down Expand Up @@ -193,7 +193,7 @@ Text Generation
- Jamba
- :code:`ai21labs/AI21-Jamba-1.5-Large`, :code:`ai21labs/AI21-Jamba-1.5-Mini`, :code:`ai21labs/Jamba-v0.1`, etc.
- ✅︎
-
- ✅︎
* - :code:`LlamaForCausalLM`
- Llama 3.1, Llama 3, Llama 2, LLaMA, Yi
- :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc.
Expand All @@ -203,7 +203,7 @@ Text Generation
- Mamba
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
-
-
- ✅︎
* - :code:`MiniCPMForCausalLM`
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc.
Expand Down
6 changes: 4 additions & 2 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,13 @@ def iter_params(self, model_name: str):
# "internlm/internlm-chat-7b": PPTestSettings.fast(),
"internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
# TODO: Implement PP
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True),
# Uses Llama
# "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
"state-spaces/mamba-130m-hf": PPTestSettings.fast(),
"mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4),
"mosaicml/mpt-7b": PPTestSettings.fast(),
"nvidia/Minitron-8B-Base": PPTestSettings.fast(),
Expand Down Expand Up @@ -234,6 +234,8 @@ def iter_params(self, model_name: str):
"OpenGVLab/InternVL2-1B",
"microsoft/Phi-3-vision-128k-instruct",
"fixie-ai/ultravox-v0_3",
# [LANGUAGE GENERATION - HYBRID ARCH]
"ai21labs/Jamba-tiny-dev",
]


Expand Down
58 changes: 44 additions & 14 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
print_warning_once, random_uuid,
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, print_warning_once, random_uuid,
resolve_obj_by_qualname)

if TYPE_CHECKING:
Expand Down Expand Up @@ -284,6 +284,7 @@ def __init__(
self._verify_tokenizer_mode()

self.is_attention_free = self._init_attention_free()
self.is_hybrid = self._init_is_hybrid()
self.has_inner_state = self._init_has_inner_state()

if current_platform.is_neuron():
Expand Down Expand Up @@ -340,6 +341,10 @@ def _init_attention_free(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_attention_free_model(architectures)

def _init_is_hybrid(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_hybrid_model(architectures)

def _init_has_inner_state(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.model_has_inner_state(architectures)
Expand Down Expand Up @@ -669,26 +674,51 @@ def get_num_attention_heads(self,
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
from vllm.distributed.utils import get_pp_indices
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
return end - start

def get_num_attention_layers(self,
parallel_config: "ParallelConfig") -> int:
if self.is_attention_free:
return 0
return start, end

num_layers = self.get_num_layers(parallel_config)
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
start, end = self.get_layers_start_end_indices(parallel_config)
return end - start

# Transformers supports layers_block_type @property
layers = getattr(self.hf_config, "layers_block_type",
["attention"] * num_layers)
return len([t for t in layers if t == "attention"])
def get_num_layers_by_block_type(
self,
parallel_config: "ParallelConfig",
block_type: LayerBlockType = LayerBlockType.attention,
) -> int:
# This function relies on 'layers_block_type' in hf_config,
# for w/o this attribute, we will need to have workarounds like so
attn_block_type = block_type == LayerBlockType.attention
is_transformer = not self.is_hybrid and not self.is_attention_free
start, end = self.get_layers_start_end_indices(parallel_config)

if is_transformer:
# Handle the basic case first
return end - start if attn_block_type else 0
elif self.is_attention_free:
# Attention free
# Note that this code assumes there
# is only one type of attention-free block type.
return 0 if attn_block_type else end - start
else:
# Hybrid model
layers_block_type_value = getattr(self.hf_config,
"layers_block_type", None)
if layers_block_type_value is None:
raise ValueError("The model is an hybrid without a"
"layers_block_type in the hf_config,"
"cannot determine the num of "
f"{block_type.value} layers")

return sum(t == block_type.value
for t in layers_block_type_value[start:end])

def get_multimodal_config(self) -> "MultiModalConfig":
"""
Expand Down
37 changes: 37 additions & 0 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,43 @@ def is_attention_free(
return isinstance(model, IsAttentionFree)


@runtime_checkable
class IsHybrid(Protocol):
"""The interface required for all models like Jamba that have both
attention and mamba blocks, indicates that
hf_config has 'layers_block_type'"""

is_hybrid: ClassVar[Literal[True]] = True
"""
A flag that indicates this model has both mamba and attention blocks
, also indicates that the model's hf_config has
'layers_block_type' """


@runtime_checkable
class _IsHybridType(Protocol):
is_hybrid: ClassVar[Literal[True]]


@overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]:
...


@overload
def is_hybrid(model: Type[object]) -> TypeIs[Type[IsHybrid]]:
...


def is_hybrid(
model: Union[Type[object], object]
) -> Union[TypeIs[Type[IsHybrid]], TypeIs[IsHybrid]]:
if isinstance(model, type):
return isinstance(model, _IsHybridType)

return isinstance(model, IsHybrid)


@runtime_checkable
class SupportsCrossEncoding(Protocol):
"""The interface required for all models that support cross encoding."""
Expand Down
93 changes: 65 additions & 28 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.attention.layer import Attention
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Expand All @@ -25,9 +26,12 @@
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType

from .interfaces import HasInnerState, SupportsLoRA
from .utils import maybe_prefix
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

KVCache = Tuple[torch.Tensor, torch.Tensor]

Expand Down Expand Up @@ -281,16 +285,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
org_num_embeddings=config.vocab_size,
)

decoder_layers = []
for i in range(config.num_hidden_layers):
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
decoder_layers.append(
layer_class(config,
layer_idx=i,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{i}"))
self.layers = nn.ModuleList(decoder_layers)
def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1])
layer_class = ALL_DECODER_LAYER_TYPES[
config.layers_block_type[layer_idx]]
return layer_class(
config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=prefix,
)

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))

self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

Expand All @@ -304,26 +316,34 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

kv_cache_index = 0
mamba_cache_index = 0
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
kv_cache = None
layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
self.config.attn_layer_period]
kv_cache = kv_caches[kv_cache_index]
kv_cache_index += 1
if isinstance(layer, JambaMambaDecoderLayer):
current_state_layer = i - (1 +
(i - self.config.attn_layer_offset)
// self.config.attn_layer_period)
current_state_layer = mamba_cache_index
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
current_state_layer)
mamba_cache_index += 1

hidden_states, residual = layer(
positions=positions,
Expand All @@ -332,11 +352,17 @@ def forward(
attn_metadata=attn_metadata,
residual=residual,
mamba_cache_params=layer_mamba_cache_params)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states


class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -368,6 +394,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

super().__init__()
self.config = config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.scheduler_config = scheduler_config
self.model = JambaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
Expand All @@ -390,6 +418,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config.vocab_size)
self.sampler = get_sampler()

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

Expand All @@ -406,10 +437,8 @@ def forward(self,
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)

layers_type = self.config.layers_block_type
num_mamba_layers = sum(
[layer_type == "mamba" for layer_type in layers_type])

num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
*self._get_mamba_cache_shape())
Expand All @@ -423,7 +452,7 @@ def forward(self,
state_indices_tensor)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_params,
inputs_embeds)
intermediate_tensors, inputs_embeds)
return hidden_states

def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
Expand Down Expand Up @@ -504,8 +533,12 @@ def load_weights(self, weights: Iterable[Tuple[str,
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.

if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -520,6 +553,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
if weight_name not in name:
continue

if is_pp_missing_parameter(name, self):
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
Expand All @@ -533,6 +568,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
Expand Down
Loading

0 comments on commit ffa48c9

Please sign in to comment.