Skip to content

Commit

Permalink
[Model] Support Multi-GPU for Qwen-MoE model (#2573)
Browse files Browse the repository at this point in the history
This PR introduces the multi-GPU support for the Qwen-MoE model.
Validated on 4090x2.
  • Loading branch information
MasterJH5574 committed Jun 13, 2024
1 parent 07c92b0 commit 94a0295
Showing 1 changed file with 37 additions and 8 deletions.
45 changes: 37 additions & 8 deletions python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
from mlc_llm.nn import PagedKVCache, RopeMode
from mlc_llm.nn.expert import MixtralExperts
from mlc_llm.support import logging
from mlc_llm.support import tensor_parallel as tp

logger = logging.getLogger(__name__)

# TODO(mlc-team): Support Tensor Parallel.


@dataclasses.dataclass
class Qwen2MoeConfig(QWen2Config): # pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -68,10 +67,7 @@ def __init__(self, config: Qwen2MoeConfig):
)
self.moe_intermediate_size = config.moe_intermediate_size // config.tensor_parallel_shards
self.norm_topk_prob = config.norm_topk_prob
self.share_expert_intermediate_size = (
config.shared_expert_intermediate_size // config.tensor_parallel_shards
)
self.shared_expert = Qwen2MoeMLP(config, self.share_expert_intermediate_size)
self.shared_expert = Qwen2MoeMLP(config, config.shared_expert_intermediate_size)
self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias=False)

self.gate = nn.Linear(
Expand Down Expand Up @@ -154,7 +150,42 @@ def __init__(self, config: Qwen2MoeConfig):
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, -1, config.rms_norm_eps, bias=False
)

def _set_tp():
def _set(layer, hint):
layer.attrs["shard_strategy"] = hint

hd = config.head_dim
q = self.self_attn.num_attention_heads * hd
k = self.self_attn.num_key_value_heads * hd
v = self.self_attn.num_key_value_heads * hd
si = self.mlp.shared_expert.intermediate_size
mi = self.mlp.moe_intermediate_size
_set(
self.self_attn.c_attn.weight,
tp.ShardSingleDim("_shard_qkv_weight", dim=0, segs=[q, k, v]),
)
_set(
self.self_attn.c_attn.bias,
tp.ShardSingleDim("_shard_qkv_bias", dim=0, segs=[q, k, v]),
)
_set(self.self_attn.o_proj.weight, tp.ShardSingleDim("_shard_o", dim=1))
_set(
self.mlp.shared_expert.gate_up_proj.weight,
tp.ShardSingleDim("_shard_shared_mlp_up", segs=[si, si], dim=0),
)
_set(
self.mlp.shared_expert.down_proj.weight,
tp.ShardSingleDim("_shard_shared_mlp_down", dim=1),
)
_set(
self.mlp.moe_gate_up_proj.weight,
tp.ShardSingleDim("_shard_moe_mlp_up", segs=[mi, mi], dim=1),
)
_set(self.mlp.moe_down_proj.weight, tp.ShardSingleDim("_shard_moe_mlp_down", dim=2))

self.tensor_parallel_shards = config.tensor_parallel_shards
_set_tp()

def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
out = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -202,8 +233,6 @@ def __init__(self, config: Qwen2MoeConfig):
self.vocab_size = config.vocab_size
self.tensor_parallel_shards = config.tensor_parallel_shards
self.head_dim = config.head_dim
if self.tensor_parallel_shards != 1:
raise ValueError("Currently only support tensor_parallel_shards=1.")

def to(self, dtype: Optional[str] = None):
super().to(dtype=dtype)
Expand Down

0 comments on commit 94a0295

Please sign in to comment.