Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modelopt/torch/export/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,7 @@ def module_match_name_list(module, name_list):
"Qwen3MoeSparseMoeBlock",
"Qwen3NextSparseMoeBlock",
"Qwen3_5MoeSparseMoeBlock",
"Qwen3VLMoeTextSparseMoeBlock",
"DeepseekMoE",
],
):
Expand Down
72 changes: 47 additions & 25 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,9 +687,27 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
return self.w2_linear[expert_idx](x1)


class _Qwen3VLMoeExpertModule(nn.Module):
"""Container for a single Qwen3VL MoE expert's linear layers.

Produces the naming pattern: experts.{id}.gate_proj.weight
(consistent with standard Qwen3 MoE per-expert module structure).
"""

def __init__(self, hidden_size: int, expert_dim: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, expert_dim, bias=False)
self.up_proj = nn.Linear(hidden_size, expert_dim, bias=False)
self.down_proj = nn.Linear(expert_dim, hidden_size, bias=False)


class _QuantQwen3VLMoeTextExperts(QuantModule):
def _setup(self):
"""Modify the Qwen3VLMoeTextExperts by using nn.Linear layers."""
"""Modify the Qwen3VLMoeTextExperts by using per-expert nn.Module containers.

This produces the naming pattern: experts.{id}.gate_proj.weight
(consistent with standard Qwen3 MoE per-expert module structure).
"""
from accelerate import init_empty_weights

dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device
Expand All @@ -709,35 +727,37 @@ def _copy_weight(module, weight):
raise AttributeError("Could not find intermediate dimension size in model")

with init_empty_weights():
gate_proj = nn.ModuleList(
[
nn.Linear(self.hidden_size, expert_dim, bias=False)
for _ in range(self.num_experts)
]
)
up_proj = nn.ModuleList(
[
nn.Linear(self.hidden_size, expert_dim, bias=False)
for _ in range(self.num_experts)
]
)
down_proj = nn.ModuleList(
expert_modules = nn.ModuleList(
[
nn.Linear(expert_dim, self.hidden_size, bias=False)
_Qwen3VLMoeExpertModule(self.hidden_size, expert_dim)
for _ in range(self.num_experts)
]
)

for idx in range(self.num_experts):
_copy_weight(gate_proj[idx], self.gate_up_proj[idx, :, :expert_dim].T)
_copy_weight(up_proj[idx], self.gate_up_proj[idx, :, expert_dim:].T)
_copy_weight(down_proj[idx], self.down_proj[idx, :].T)
_copy_weight(expert_modules[idx].gate_proj, self.gate_up_proj[idx, :, :expert_dim].T)
_copy_weight(expert_modules[idx].up_proj, self.gate_up_proj[idx, :, expert_dim:].T)
_copy_weight(expert_modules[idx].down_proj, self.down_proj[idx, :].T)

delattr(self, "gate_up_proj")
delattr(self, "down_proj")
self.gate_proj = gate_proj
self.up_proj = up_proj
self.down_proj = down_proj
# Register expert modules directly as numbered children
# so the naming pattern is: experts.{id}.gate_proj.weight (no extra nesting)
for idx in range(self.num_experts):
self.add_module(str(idx), expert_modules[idx])

def __len__(self):
"""Support len() so the module is iterable like standard MoE experts."""
return self.num_experts

def __iter__(self):
"""Support iteration over expert modules."""
for idx in range(self.num_experts):
yield getattr(self, str(idx))

def __getitem__(self, idx):
"""Support indexing to get individual expert modules."""
return getattr(self, str(int(idx)))

def forward(
self,
Expand All @@ -753,13 +773,15 @@ def forward(
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx[0]])
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate = self.gate_proj[expert_idx](current_state)
up = self.up_proj[expert_idx](current_state)
expert = self[expert_idx]
gate = expert.gate_proj(current_state)
up = expert.up_proj(current_state)
gated_output = up * self.act_fn(gate)
out = self.down_proj[expert_idx](gated_output)
out = expert.down_proj(gated_output)
weighted_output = out * routing_weights[token_idx, expert_idx, None]
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
next_states = next_states.view(batch_size, -1, self.hidden_size)
Expand Down
Loading