-
Notifications
You must be signed in to change notification settings - Fork 367
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
373 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,373 @@ | ||
"""Inference-only MiniCPM model compatible with HuggingFace weights.""" | ||
|
||
import math | ||
from typing import Any, Dict, Iterable, Optional, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from vllm.config import CacheConfig | ||
from vllm.distributed import get_tensor_model_parallel_world_size | ||
|
||
from vllm.model_executor.layers.activation import SiluAndMul | ||
|
||
from vllm.model_executor.layers.layernorm import RMSNorm | ||
from vllm.model_executor.layers.linear import ( | ||
MergedColumnParallelLinear, | ||
QKVParallelLinear, | ||
RowParallelLinear, | ||
) | ||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig | ||
from vllm.model_executor.layers.rotary_embedding import get_rope | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
ParallelLMHead, | ||
VocabParallelEmbedding, | ||
) | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
|
||
from sglang.srt.layers.logits_processor import LogitsProcessor | ||
from sglang.srt.layers.radix_attention import RadixAttention | ||
from sglang.srt.managers.controller.model_runner import InputMetadata | ||
|
||
|
||
class MiniCPMMLP(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int, | ||
intermediate_size: int, | ||
hidden_act: str, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
) -> None: | ||
super().__init__() | ||
self.gate_up_proj = MergedColumnParallelLinear( | ||
hidden_size, | ||
[intermediate_size] * 2, | ||
bias=False, | ||
quant_config=quant_config, | ||
) | ||
self.down_proj = RowParallelLinear( | ||
intermediate_size, | ||
hidden_size, | ||
bias=False, | ||
quant_config=quant_config, | ||
) | ||
if hidden_act != "silu": | ||
raise ValueError( | ||
f"Unsupported activation: {hidden_act}. " | ||
"Only silu is supported for now." | ||
) | ||
self.act_fn = SiluAndMul() | ||
|
||
def forward(self, x): | ||
gate_up, _ = self.gate_up_proj(x) | ||
x = self.act_fn(gate_up) | ||
x, _ = self.down_proj(x) | ||
return x | ||
|
||
|
||
class MiniCPMAttention(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int, | ||
num_heads: int, | ||
num_kv_heads: int, | ||
layer_id: int = 0, | ||
rope_theta: float = 10000, | ||
rope_scaling: Optional[Dict[str, Any]] = None, | ||
max_position_embeddings: int = 8192, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
) -> None: | ||
super().__init__() | ||
self.hidden_size = hidden_size | ||
tp_size = get_tensor_model_parallel_world_size() | ||
self.total_num_heads = num_heads | ||
assert self.total_num_heads % tp_size == 0 | ||
self.num_heads = self.total_num_heads // tp_size | ||
self.total_num_kv_heads = num_kv_heads | ||
if self.total_num_kv_heads >= tp_size: | ||
# Number of KV heads is greater than TP size, so we partition | ||
# the KV heads across multiple tensor parallel GPUs. | ||
assert self.total_num_kv_heads % tp_size == 0 | ||
else: | ||
# Number of KV heads is less than TP size, so we replicate | ||
# the KV heads across multiple tensor parallel GPUs. | ||
assert tp_size % self.total_num_kv_heads == 0 | ||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) | ||
self.head_dim = hidden_size // self.total_num_heads | ||
self.q_size = self.num_heads * self.head_dim | ||
self.kv_size = self.num_kv_heads * self.head_dim | ||
self.scaling = self.head_dim**-0.5 | ||
self.rope_theta = rope_theta | ||
self.max_position_embeddings = max_position_embeddings | ||
|
||
self.qkv_proj = QKVParallelLinear( | ||
hidden_size, | ||
self.head_dim, | ||
self.total_num_heads, | ||
self.total_num_kv_heads, | ||
bias=False, | ||
quant_config=quant_config, | ||
) | ||
self.o_proj = RowParallelLinear( | ||
self.total_num_heads * self.head_dim, | ||
hidden_size, | ||
bias=False, | ||
quant_config=quant_config, | ||
) | ||
|
||
self.rotary_emb = get_rope( | ||
self.head_dim, | ||
rotary_dim=self.head_dim, | ||
max_position=max_position_embeddings, | ||
base=rope_theta, | ||
rope_scaling=rope_scaling, | ||
) | ||
# set rope as fp32 instead of bf16 | ||
self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache() | ||
self.attn = RadixAttention( | ||
self.num_heads, | ||
self.head_dim, | ||
self.scaling, | ||
num_kv_heads=self.num_kv_heads, | ||
layer_id=layer_id, | ||
) | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
input_metadata: InputMetadata, | ||
) -> torch.Tensor: | ||
qkv, _ = self.qkv_proj(hidden_states) | ||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | ||
orig_dtype = q.dtype | ||
q, k = q.float(), k.float() | ||
q, k = self.rotary_emb(positions, q, k) | ||
q, k = q.to(orig_dtype), k.to(orig_dtype) | ||
attn_output = self.attn(q, k, v, input_metadata) | ||
output, _ = self.o_proj(attn_output) | ||
return output | ||
|
||
|
||
class MiniCPMDecoderLayer(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
config, | ||
layer_id: int = 0, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
) -> None: | ||
super().__init__() | ||
self.config = config | ||
self.hidden_size = config.hidden_size | ||
rope_theta = getattr(config, "rope_theta", 10000) | ||
rope_scaling = getattr(config, "rope_scaling", None) | ||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) | ||
self.self_attn = MiniCPMAttention( | ||
hidden_size=self.hidden_size, | ||
num_heads=config.num_attention_heads, | ||
num_kv_heads=config.num_key_value_heads, | ||
layer_id=layer_id, | ||
rope_theta=rope_theta, | ||
rope_scaling=rope_scaling, | ||
max_position_embeddings=max_position_embeddings, | ||
quant_config=quant_config, | ||
) | ||
self.mlp = MiniCPMMLP( | ||
hidden_size=self.hidden_size, | ||
intermediate_size=config.intermediate_size, | ||
hidden_act=config.hidden_act, | ||
quant_config=quant_config, | ||
) | ||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
self.post_attention_layernorm = RMSNorm( | ||
config.hidden_size, eps=config.rms_norm_eps | ||
) | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
input_metadata: InputMetadata, | ||
residual: Optional[torch.Tensor], | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
# Self Attention | ||
residual = hidden_states | ||
hidden_states = self.input_layernorm(hidden_states) | ||
hidden_states = self.self_attn( | ||
positions=positions, | ||
hidden_states=hidden_states, | ||
input_metadata=input_metadata, | ||
) | ||
hidden_states = residual + hidden_states * ( | ||
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) | ||
) | ||
|
||
# Fully Connected | ||
residual = hidden_states | ||
hidden_states = self.post_attention_layernorm(hidden_states) | ||
hidden_states = self.mlp(hidden_states) | ||
hidden_states = residual + hidden_states * ( | ||
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) | ||
) | ||
|
||
return hidden_states, None | ||
|
||
|
||
class MiniCPMModel(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
config, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
) -> None: | ||
super().__init__() | ||
self.config = config | ||
self.padding_idx = config.pad_token_id | ||
self.vocab_size = config.vocab_size | ||
self.embed_tokens = VocabParallelEmbedding( | ||
self.vocab_size, | ||
config.hidden_size, | ||
org_num_embeddings=config.vocab_size, | ||
) | ||
self.layers = nn.ModuleList( | ||
[ | ||
MiniCPMDecoderLayer(config, i, quant_config=quant_config) | ||
for i in range(config.num_hidden_layers) | ||
] | ||
) | ||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
input_metadata: InputMetadata, | ||
input_embeds: torch.Tensor = None, | ||
) -> torch.Tensor: | ||
if input_embeds is None: | ||
hidden_states = self.embed_tokens(input_ids) * self.config.scale_emb | ||
else: | ||
hidden_states = input_embeds | ||
residual = None | ||
|
||
for i in range(len(self.layers)): | ||
layer = self.layers[i] | ||
hidden_states, residual = layer( | ||
positions, | ||
hidden_states, | ||
input_metadata, | ||
residual, | ||
) | ||
hidden_states = self.norm(hidden_states) | ||
return hidden_states | ||
|
||
|
||
class MiniCPMForCausalLM(nn.Module): | ||
def __init__( | ||
self, | ||
config, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
cache_config: Optional[CacheConfig] = None, | ||
) -> None: | ||
super().__init__() | ||
self.config = config | ||
|
||
self.num_experts = getattr(self.config, "num_experts", 0) | ||
self.quant_config = quant_config | ||
self.model = MiniCPMModel(config, quant_config=quant_config) | ||
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) | ||
if not self.config.tie_word_embeddings: | ||
self.lm_head = ParallelLMHead( | ||
config.vocab_size, | ||
config.hidden_size, | ||
org_num_embeddings=config.vocab_size, | ||
) | ||
|
||
self.scale_width = self.config.hidden_size / self.config.dim_model_base | ||
|
||
self.logits_processor = LogitsProcessor(config) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
input_metadata: InputMetadata, | ||
input_embeds: torch.Tensor = None, | ||
) -> torch.Tensor: | ||
if input_embeds is not None: | ||
input_embeds = input_embeds * self.config.scale_emb | ||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) | ||
hidden_states = hidden_states / self.scale_width | ||
if self.config.tie_word_embeddings: | ||
lm_head_weight = self.model.embed_tokens.weight | ||
else: | ||
lm_head_weight = self.lm_head.weight | ||
return self.logits_processor( | ||
input_ids, hidden_states, lm_head_weight, input_metadata | ||
) | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
stacked_params_mapping = [ | ||
# (param_name, shard_name, shard_id) | ||
("qkv_proj", "q_proj", "q"), | ||
("qkv_proj", "k_proj", "k"), | ||
("qkv_proj", "v_proj", "v"), | ||
("gate_up_proj", "gate_proj", 0), | ||
("gate_up_proj", "up_proj", 1), | ||
] | ||
expert_params_mapping = [ | ||
# (param_name, weight_name, expert_id) | ||
( | ||
"ws" if weight_name in ["w1", "w3"] else "w2s", | ||
f"experts.{expert_id}.{weight_name}.weight", | ||
expert_id, | ||
) | ||
for expert_id in range(self.num_experts) | ||
for weight_name in ["w1", "w2", "w3"] | ||
] | ||
params_dict = dict(self.named_parameters()) | ||
for name, loaded_weight in weights: | ||
if "rotary_emb.inv_freq" in name: | ||
continue | ||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: | ||
# Models trained using ColossalAI may include these tensors in | ||
# the checkpoint. Skip them. | ||
continue | ||
|
||
for param_name, weight_name, shard_id in stacked_params_mapping: | ||
if weight_name not in name: | ||
continue | ||
name = name.replace(weight_name, param_name) | ||
# Skip loading extra bias for GPTQ models. | ||
if name.endswith(".bias") and name not in params_dict: | ||
continue | ||
param = params_dict[name] | ||
weight_loader = param.weight_loader | ||
weight_loader(param, loaded_weight, shard_id) | ||
break | ||
else: | ||
for param_name, weight_name, expert_id in expert_params_mapping: | ||
if weight_name not in name: | ||
continue | ||
name = name.replace(weight_name, param_name) | ||
param = params_dict[name] | ||
weight_loader = param.weight_loader | ||
weight_loader( | ||
param, loaded_weight, weight_name, expert_id=expert_id | ||
) | ||
break | ||
else: | ||
# Skip loading extra bias for GPTQ models. | ||
if name.endswith(".bias") and name not in params_dict: | ||
continue | ||
param = params_dict[name] | ||
weight_loader = getattr( | ||
param, "weight_loader", default_weight_loader | ||
) | ||
weight_loader(param, loaded_weight) | ||
|
||
|
||
EntryClass = MiniCPMForCausalLM |