Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ transforms:
backend: trtllm
fuse_moe:
stage: post_load_fusion
enabled: true
enabled: false
backend: trtllm
fuse_fp8_moe:
stage: post_load_fusion
Expand Down
128 changes: 91 additions & 37 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,49 @@
from typing import Callable, List
from typing import Callable, List, Optional

import torch
import torch.nn.functional as F

from tensorrt_llm._torch.utils import ActivationType


def _resolve_torch_fn(act_fn: ActivationType) -> Callable[[torch.Tensor], torch.Tensor]:
def _resolve_act_fn(act_fn: ActivationType, is_gated: bool = False):
"""
Returns an elementwise activation callable matching the given activation function.
Supported: ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2
Returns an activation callable matching the given activation function.

Args:
act_fn: Activation type (Silu, Swiglu, Relu2, SwigluBias)
is_gated: If True, returns (gate, up) -> output for gated MLP.
If False, returns (x) -> output for simple MLP.

Supported activations:
- Silu/Swiglu: silu(x) or silu(gate) * up
- Relu2: relu(x)^2 or relu(gate)^2
- SwigluBias: (up + 1) * (gate * sigmoid(1.702 * gate)) with clamping [GPT-OSS, gated only]
"""
assert act_fn in [ActivationType.Silu, ActivationType.Swiglu, ActivationType.Relu2], (
f"Unsupported activation '{ActivationType(act_fn).name}'. Use 'silu', 'swiglu' or 'relu2'."
)
torch_fn = None
if act_fn == ActivationType.Silu or act_fn == ActivationType.Swiglu:
torch_fn = F.silu
return (lambda gate, up: F.silu(gate) * up) if is_gated else F.silu

elif act_fn == ActivationType.Relu2:
return (
(lambda gate, up: torch.square(F.relu(gate)))
if is_gated
else (lambda x: torch.square(F.relu(x)))
)

elif act_fn == ActivationType.SwigluBias:
if not is_gated:
return F.silu # Fallback for ungated case
# GPT-OSS style with fixed parameters
alpha, limit = 1.702, 7.0

def relu2(x: torch.Tensor) -> torch.Tensor:
return torch.square(F.relu(x))
def swiglu_bias(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
gate = gate.clamp(max=limit)
up = up.clamp(min=-limit, max=limit)
return (up + 1.0) * (gate * torch.sigmoid(alpha * gate))

torch_fn = relu2
return torch_fn
return swiglu_bias

raise ValueError(f"Unsupported activation: {act_fn}")


def _template_moe(
Expand Down Expand Up @@ -97,52 +117,83 @@ def torch_moe(
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
apply_routing_on_input: bool = False,
w1_bias_stacked: Optional[torch.Tensor] = None,
w2_bias_stacked: Optional[torch.Tensor] = None,
w3_bias_stacked: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Unified Mixture-of-Experts (MoE) operator that uses a Mixtral-style dispatch
(token routing + index_add_ accumulation) and a selectable per-expert MLP.

MLP Styles:
- is_gated_mlp=True (default, Mixtral/LLaMA style):
output = down_proj(act_fn(gate_proj(x), up_proj(x)))
= W2 @ act_fn(W1 @ x, W3 @ x)

- is_gated_mlp=False (simple 2-layer MLP):
output = down_proj(act_fn(up_proj(x)))
= W2 @ act_fn(W1 @ x)
Note: w3_weight is ignored in this mode.

Parameters:
x (torch.Tensor): Input tensor of shape (B, H) or (B, S, H), where B is the batch size,
S is the sequence length, and H is the hidden size.
selected_experts (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the indices
of the selected experts for each token. Only experts within range [0,num_experts) is processed
of the selected experts for each token. Only experts within range [0,num_experts) is processed.
routing_weights (torch.Tensor): A tensor of shape (B, TOP_K) or (B*S, TOP_K) containing the normalized
routing weights for the selected experts.
w1_weight: List of per-expert weight tensors of up projection.
w2_weight: List of per-expert weight tensors of down projection.
w3_weight: List of per-expert weight tensors of gate projection.
is_gated_mlp: If True, use a gated MLP. If False, use a simple MLP.
w1_weight: List of per-expert weight tensors.
- If is_gated_mlp=True: gate projection (I, H), i.e., gate_proj.weight
- If is_gated_mlp=False: up projection (I, H), i.e., up_proj.weight
w2_weight: List of per-expert weight tensors of down projection (H, I), i.e., down_proj.weight
w3_weight: List of per-expert weight tensors of up projection (I, H), i.e., up_proj.weight
Only used when is_gated_mlp=True. Can be empty list [] when is_gated_mlp=False.
is_gated_mlp: If True, use gated MLP: act_fn(gate, up). If False, use simple MLP: act_fn(x).
act_fn: Activation function applied inside the expert MLP.
Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square).
apply_routing_on_input: If True, multiply routing weights with INPUT before MLP
This means: silu(input * routing_weight)
If False, multiply routing weights with OUTPUT after MLP
This means: silu(input) * routing_weight
Supported: ActivationType.Silu (default), ActivationType.Swiglu, ActivationType.Relu2,
ActivationType.SwigluBias (GPT-OSS style with alpha=1.702, limit=7.0).
apply_routing_on_input: If True, multiply routing weights with INPUT before MLP.
If False, multiply routing weights with OUTPUT after MLP.
w1_bias_stacked: Optional stacked bias tensor (E, I).
- If is_gated_mlp=True: bias for gate projection
- If is_gated_mlp=False: bias for up projection
w2_bias_stacked: Optional stacked bias tensor for down projection (E, H).
w3_bias_stacked: Optional stacked bias tensor for up projection (E, I).
Only used when is_gated_mlp=True.

Returns:
torch.Tensor: Output tensor with the same shape as the input x.
"""
torch_act_fn = _resolve_torch_fn(act_fn)
act_fn_resolved = _resolve_act_fn(act_fn, is_gated=is_gated_mlp)

mlps = []
if is_gated_mlp:
# Standard per-expert list format with gated MLP

def make_mlp(i: int):
W1 = w1_weight[i] # (I, H)
W2 = w2_weight[i] # (H, I)
W3 = w3_weight[i] # (I, H)
W1 = w1_weight[i] # (I, H) - gate
W2 = w2_weight[i] # (H, I) - down
W3 = w3_weight[i] # (I, H) - up
b1 = w1_bias_stacked[i] if w1_bias_stacked is not None else None
b2 = w2_bias_stacked[i] if w2_bias_stacked is not None else None
b3 = w3_bias_stacked[i] if w3_bias_stacked is not None else None
return lambda inp: F.linear(
torch_act_fn(F.linear(inp.to(W1.dtype), W1)) * F.linear(inp.to(W3.dtype), W3), W2
act_fn_resolved(
F.linear(inp.to(W1.dtype), W1, b1), F.linear(inp.to(W3.dtype), W3, b3)
),
W2,
b2,
)

mlps = [make_mlp(i) for i in range(len(w1_weight))]

else:
# Standard per-expert list format with simple MLP

def make_mlp(i: int):
W_up = w1_weight[i] # (I, H)
W_down = w2_weight[i] # (H, I)
return lambda inp: F.linear(torch_act_fn(F.linear(inp, W_up)), W_down)
b_up = w1_bias_stacked[i] if w1_bias_stacked is not None else None
b_down = w2_bias_stacked[i] if w2_bias_stacked is not None else None
return lambda inp: F.linear(act_fn_resolved(F.linear(inp, W_up, b_up)), W_down, b_down)

mlps = [make_mlp(i) for i in range(len(w1_weight))]

Expand All @@ -160,6 +211,9 @@ def torch_moe_fake(
is_gated_mlp: bool = True,
act_fn: int = int(ActivationType.Silu),
apply_routing_on_input: bool = False,
w1_bias_stacked: Optional[torch.Tensor] = None,
w2_bias_stacked: Optional[torch.Tensor] = None,
w3_bias_stacked: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(x)

Expand Down Expand Up @@ -279,7 +333,7 @@ def torch_quant_fp8_moe(
Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square).
"""

torch_act_fn = _resolve_torch_fn(act_fn)
act_fn_resolved = _resolve_act_fn(act_fn, is_gated=is_gated_mlp)

if is_gated_mlp:

Expand All @@ -299,7 +353,7 @@ def mlp(inp):
input_scale=w3_input_scale[i],
weight_scale=w3_weight_scale[i],
)
prod = torch_act_fn(gate_out) * up_out
prod = act_fn_resolved(gate_out, up_out)
return torch.ops.auto_deploy.torch_quant_fp8_linear(
prod,
w2_weight[i],
Expand All @@ -324,7 +378,7 @@ def mlp(inp):
weight_scale=w1_weight_scale[i],
)
return torch.ops.auto_deploy.torch_quant_fp8_linear(
torch_act_fn(up_out),
act_fn_resolved(up_out),
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
Expand Down Expand Up @@ -414,7 +468,7 @@ def torch_quant_nvfp4_moe(
Supported: ActivationType.Silu (default), ActivationType.Relu2 (ReLU then square).
"""

torch_act_fn = _resolve_torch_fn(act_fn)
act_fn_resolved = _resolve_act_fn(act_fn, is_gated=is_gated_mlp)

if is_gated_mlp:

Expand All @@ -438,7 +492,7 @@ def mlp(inp):
weight_scale=w3_weight_scale[i],
alpha=w3_alpha[i],
)
prod = torch_act_fn(gate_out) * up_out
prod = act_fn_resolved(gate_out, up_out)
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
prod,
w2_weight[i],
Expand Down Expand Up @@ -467,7 +521,7 @@ def mlp(inp):
alpha=w1_alpha[i],
)
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
torch_act_fn(up_out),
act_fn_resolved(up_out),
w2_weight[i],
bias=None,
input_scale=w2_input_scale[i],
Expand Down
109 changes: 109 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/models/patches/gptoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A patch for GPT-OSS MoE to make it compatible with torch.export.

GPT-OSS uses a dense MoE pattern where all experts are computed for all tokens,
with soft routing weights. This patch replaces the BMM-based forward with torch_moe
using SwigluBias activation.
"""

import torch

from tensorrt_llm._torch.utils import ActivationType

from ...export.interface import BaseExportPatch, ExportPatchRegistry


def _forward_gptoss_mlp(self, hidden_states: torch.Tensor):
"""GPT-OSS MoE forward rewritten for torch.export compatibility.

Uses torch_moe with SwigluBias activation to match the original GPT-OSS
dense MoE computation: (up + 1) * (gate * sigmoid(alpha * gate)) with clamping.
"""
batch_size, seq_len, hidden_size = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_size)
num_tokens = hidden_states_flat.shape[0]

router_scores, _ = self.router(hidden_states)

experts = self.experts
num_experts = experts.num_experts

selected_experts = (
torch.arange(num_experts, device=hidden_states.device).unsqueeze(0).expand(num_tokens, -1)
)

# gate_up_proj: [E, H, 2*I] with interleaved gate/up
# Split into gate [E, H, I] and up [E, H, I], then transpose to [I, H] for F.linear
gate_proj = experts.gate_up_proj[..., ::2] # [E, H, I] - gate
up_proj = experts.gate_up_proj[..., 1::2] # [E, H, I] - up

# Create per-expert weight lists with shape [I, H] for F.linear
w1_weight = [gate_proj[i].T for i in range(num_experts)] # gate: [I, H]
w3_weight = [up_proj[i].T for i in range(num_experts)] # up: [I, H]
# down_proj: [E, I, H] -> transpose to [H, I] for F.linear
w2_weight = [experts.down_proj[i].T for i in range(num_experts)] # down: [H, I]

# Biases: gate_up_proj_bias [E, 2*I] -> split into gate [E, I] and up [E, I]
w1_bias_stacked = experts.gate_up_proj_bias[..., ::2] # [E, I] - gate bias
w3_bias_stacked = experts.gate_up_proj_bias[..., 1::2] # [E, I] - up bias
w2_bias_stacked = experts.down_proj_bias # [E, H] - down bias

final_hidden_states = torch.ops.auto_deploy.torch_moe(
hidden_states_flat,
selected_experts,
router_scores, # [num_tokens, num_experts] - soft routing weights
w1_weight=w1_weight,
w2_weight=w2_weight,
w3_weight=w3_weight,
is_gated_mlp=True,
act_fn=int(ActivationType.SwigluBias),
w1_bias_stacked=w1_bias_stacked,
w2_bias_stacked=w2_bias_stacked,
w3_bias_stacked=w3_bias_stacked,
)

final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_size)
return final_hidden_states, router_scores


@ExportPatchRegistry.register("hf_gptoss_moe")
class GptOssMoePatch(BaseExportPatch):
"""Patch for GPT-OSS MoE to make it compatible with torch.export.

GPT-OSS uses a dense MoE pattern with:
- Soft routing over ALL experts (not top-k sparse)
- SwigluBias activation: (up + 1) * (gate * sigmoid(alpha * gate)) with clamping
- Biases on all projections

The original BMM-based forward is replaced with torch_moe custom op
which handles the computation in an export-compatible way.
"""

def _apply_patch(self):
"""Apply the GPT-OSS MoE patch."""
try:
from transformers.models.gpt_oss import modeling_gpt_oss

self.modeling_module = modeling_gpt_oss
self.original_values["GptOssMLP.forward"] = modeling_gpt_oss.GptOssMLP.forward
modeling_gpt_oss.GptOssMLP.forward = _forward_gptoss_mlp
except (ImportError, AttributeError):
pass

def _revert_patch(self):
"""Revert the GPT-OSS MoE patch."""
if hasattr(self, "modeling_module") and "GptOssMLP.forward" in self.original_values:
self.modeling_module.GptOssMLP.forward = self.original_values["GptOssMLP.forward"]
Loading
Loading