Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,17 @@ class MaxCalibConfig(QuantizeAlgorithmConfig):
description="If True, the amax will be synced across the distributed processes.",
)

sync_expert_weight_amax: bool = ModeloptField(
default=False,
title="Sync weight quantizer amax across MoE experts",
description=(
"If True, the weight quantizer amax values are synchronized (max) across "
"local experts in SequentialMLP layers during calibration. This matches "
"TEGroupedMLP behavior where all experts share a single weight quantizer. "
"Only affects MoE models with SequentialMLP experts."
),
)


class MseCalibConfig(QuantizeAlgorithmConfig):
"""Configuration for per-tensor MSE calibration.
Expand Down
6 changes: 4 additions & 2 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def max_calibrate(
model: nn.Module,
forward_loop: ForwardLoop | None = None,
distributed_sync=True,
sync_expert_weight_amax=False,
):
"""Calibrate the model using max.

Expand All @@ -117,6 +118,7 @@ def max_calibrate(
forward_loop: A callable which takes the model as argument and
forwards calibration data through the model.
distributed_sync: Whether to sync input_quantizer amax across distributed processes.
sync_expert_weight_amax: Whether to sync weight quantizer amax across MoE experts.

See :class:`MaxCalibConfig <modelopt.torch.quantization.config.MaxCalibConfig>` for
details on the remaining arguments.
Expand All @@ -128,10 +130,10 @@ def max_calibrate(
forward_loop(model)
finish_stats_collection(model)

# Sync input_quantizer amax across local experts within each rank (for SequentialMLP)
# Sync quantizer amax across local experts within each rank (for SequentialMLP)
for name, module in model.named_modules():
if hasattr(module, "layer_sync_moe_local_experts_amax"):
module.layer_sync_moe_local_experts_amax()
module.layer_sync_moe_local_experts_amax(sync_weight_amax=sync_expert_weight_amax)

if not distributed_sync:
return
Expand Down
8 changes: 6 additions & 2 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,19 +553,23 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
self._count_expert_tokens = False
return output

def layer_sync_moe_local_experts_amax(self):
def layer_sync_moe_local_experts_amax(self, sync_weight_amax=False):
"""Sync input_quantizer amax across experts so all share the same amax per quantizer.

Skipped when _moe_calib_experts_ratio is set, as each expert is calibrated independently.
Also skipped when experts is a fused module (e.g. Llama4TextExperts) with shared quantizers.

Args:
sync_weight_amax: If True, also sync weight quantizer amax across experts.

"""
if self._moe_calib_experts_ratio is not None:
return
try:
iter(self.experts)
except TypeError:
return
sync_moe_expert_amax(self.experts)
sync_moe_expert_amax(self.experts, sync_weight_amax=sync_weight_amax)


class _QuantLlama4TextExperts(QuantModule):
Expand Down
21 changes: 13 additions & 8 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,23 +575,28 @@ def _setup(self):
expert.linear_fc1.parallel_state = self.parallel_state
expert.linear_fc2.parallel_state = self.parallel_state

def layer_sync_moe_local_experts_amax(self):
"""Sync input quantizer amax across local experts in a SequentialMLP.
def layer_sync_moe_local_experts_amax(self, sync_weight_amax=False):
"""Sync quantizer amax across local experts in a SequentialMLP.

Ensures all experts have the same input quantizer amax. This function operates
on a single rank and does not require distributed sync.
Always syncs input quantizer amax across experts. Optionally syncs weight
quantizer amax as well, which matches TEGroupedMLP behavior where all
experts are fused into a single GEMM with one quantizer per linear layer.

Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate().
This function should be called before the distributed sync to ensure the amax values
are synchronized across the layer first.
Args:
sync_weight_amax: If True, also sync weight quantizer amax across experts.

This function operates on a single rank and does not require distributed sync.
Distributed amax sync across EP and ETP (for RowParallel) happens in
model_calib.max_calibrate(). This function should be called before the
distributed sync to ensure the amax values are synchronized across the layer first.

Note:
Because there are logic which calls collective communication based on whether amax is not None,
we need to guarantee that all experts must have amax. Otherwise, there will be deadlock
when synchronizing over EP since some ranks may have amax None and not calling the collective
communication.
"""
sync_moe_expert_amax(self.local_experts)
sync_moe_expert_amax(self.local_experts, sync_weight_amax=sync_weight_amax)

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
"""Override the default to enable singleton_local_shards.
Expand Down
17 changes: 9 additions & 8 deletions modelopt/torch/quantization/utils/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,12 +516,15 @@ def set_quantizer_state_dict(model: nn.Module, quantizer_state_dict: dict):
module.load_state_dict(quantizer_state_dict[key])


def sync_moe_expert_amax(experts):
"""Sync input_quantizer amax across MoE experts and fix missing weight amax.
def sync_moe_expert_amax(experts, sync_weight_amax=False):
"""Sync quantizer amax across MoE experts and fix missing weight amax.

1. Takes the element-wise max of each ``input_quantizer`` amax across all experts
and writes it back, so every expert shares the same input amax.
2. For any ``weight_quantizer`` that is enabled but has ``amax is None`` (expert
2. If ``sync_weight_amax`` is True, also syncs ``weight_quantizer`` amax across
experts (max across experts). This matches TEGroupedMLP behavior where all
experts share a single weight quantizer.
3. For any ``weight_quantizer`` that is enabled but has ``amax is None`` (expert
received no tokens during calibration), runs a weight-only ``max_calibrate``
to populate the missing amax.
"""
Expand All @@ -530,11 +533,9 @@ def sync_moe_expert_amax(experts):
amax_dict: dict[str, torch.Tensor] = {}
for expert in experts:
for name, module in expert.named_modules():
if (
isinstance(module, TensorQuantizer)
and module.amax is not None
and "input_quantizer" in name
):
if not isinstance(module, TensorQuantizer) or module.amax is None:
continue
if "input_quantizer" in name or (sync_weight_amax and "weight_quantizer" in name):
stored_amax = amax_dict.get(name)
amax_tensor = module.amax.detach().clone()
amax_dict[name] = (
Expand Down
81 changes: 44 additions & 37 deletions tests/gpu_megatron/torch/quantization/plugins/test_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def _gpt_model_provider(
meta_device=False,
ep_size=1,
etp_size=None,
use_te=False,
transformer_impl="local",
# Hybrid mamba MOE parameters
is_hybrid=False,
Expand Down Expand Up @@ -291,7 +290,6 @@ def _gpt_model_provider(
use_cpu_initialization=meta_device,
num_moe_experts=num_moe_experts,
moe_grouped_gemm=moe_grouped_gemm,
use_te=use_te,
)

if not meta_device:
Expand All @@ -315,7 +313,6 @@ def _test_sharded_state_dict(
etp_size = model_config.get("etp_size", None)
num_moe_experts = model_config.get("num_moe_experts", None)
moe_grouped_gemm = model_config.get("moe_grouped_gemm", False)
use_te = model_config.get("use_te", False)
transformer_impl = model_config.get("transformer_impl", "local")
# Hybrid mamba MOE parameters
is_hybrid = model_config.get("is_hybrid", False)
Expand All @@ -334,7 +331,6 @@ def _test_sharded_state_dict(
vocab_size=256,
num_moe_experts=num_moe_experts,
moe_grouped_gemm=moe_grouped_gemm,
use_te=use_te,
ep_size=ep_size,
etp_size=etp_size,
transformer_impl=transformer_impl,
Expand All @@ -347,7 +343,6 @@ def _test_sharded_state_dict(
vocab_size=256,
num_moe_experts=num_moe_experts,
moe_grouped_gemm=moe_grouped_gemm,
use_te=use_te,
meta_device=meta_device,
ep_size=ep_size,
etp_size=etp_size,
Expand Down Expand Up @@ -445,8 +440,6 @@ def test_homogeneous_sharded_state_dict(
)

model_config = {"transformer_impl": transformer_impl}
if transformer_impl == "modelopt":
model_config["use_te"] = True
dist_workers.run(
partial(
_test_sharded_state_dict,
Expand Down Expand Up @@ -581,25 +574,27 @@ def test_fp8_real_quantize(dist_workers):
dist_workers.run(_test_fp8_real_quantize_helper)


@pytest.mark.skip(reason="TODO: etp requires sequence parallelism now in Megatron due to a bug;")
# TODO: etp requires sequence parallelism now in Megatron due to a bug
@pytest.mark.parametrize(
"config",
[mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG],
)
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
def test_moe_sharded_state_dict(dist_workers, need_4_gpus, tmp_path, config, moe_grouped_gemm):
if moe_grouped_gemm:
pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently")
if config == mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG and moe_grouped_gemm:
pytest.skip("TEGroupedMLP only supports per-tensor quantization.")

# TODO: Add support for compress=True for TEGroupedMLP
moe_config = {
"tp_size": 2,
"ep_size": 2,
"etp_size": 2,
"etp_size": 1,
"num_moe_experts": 4,
"moe_grouped_gemm": moe_grouped_gemm,
"use_te": moe_grouped_gemm,
"transformer_impl": "modelopt",
"transformer_impl": "transformer_engine" if moe_grouped_gemm else "modelopt",
}
if not moe_grouped_gemm:
moe_config["tp_size"] = 1 # TODO: TP+EP is not supported by QuantSequentialMLP
dist_workers.run(
partial(
_test_sharded_state_dict,
Expand All @@ -614,41 +609,38 @@ def test_moe_sharded_state_dict(dist_workers, need_4_gpus, tmp_path, config, moe
)


def _test_te_grouped_vs_sequential_quantize_helper(tp_size, ep_size, etp_size, rank, size):
def _test_te_grouped_vs_sequential_quantize_helper(tp_size, ep_size, quant_cfg, rank, size):
"""Test that TEGrouped and sequential MoE models produce similar amax values."""
initialize_for_megatron(
tensor_model_parallel_size=tp_size,
expert_model_parallel_size=ep_size,
expert_tensor_parallel_size=etp_size,
seed=SEED,
)

# Create TEGrouped MoE model
te_grouped_moe_model = _gpt_model_provider(
tp_size=tp_size,
ep_size=ep_size,
etp_size=etp_size,
hidden_size=32,
moe_grouped_gemm=True,
use_te=True,
transformer_impl="transformer_engine",
num_moe_experts=4,
)

# Create forward function with cached inputs
forward = get_forward(te_grouped_moe_model)
forward = get_forward(te_grouped_moe_model, batch_size=8)

num_te_grouped_mlp = sum(
isinstance(module, TEGroupedMLP) for module in te_grouped_moe_model.modules()
)
assert num_te_grouped_mlp == 4, (
f"TEGrupedMoEModel has {num_te_grouped_mlp} TEGroupedMLP modules, it should have 4"
f"TEGroupedMoEModel has {num_te_grouped_mlp} TEGroupedMLP modules, it should have 4"
)

# Create sequential MoE model
sequential_moe_model = _gpt_model_provider(
tp_size=tp_size,
ep_size=ep_size,
etp_size=etp_size,
hidden_size=32,
moe_grouped_gemm=False,
num_moe_experts=4,
Expand All @@ -669,28 +661,37 @@ def _test_te_grouped_vs_sequential_quantize_helper(tp_size, ep_size, etp_size, r
assert torch.allclose(te_grouped_moe_output, sequential_moe_output, atol=1e-6, rtol=1e-6)

# Quantize grouped model
mtq.quantize(te_grouped_moe_model, mtq.FP8_DEFAULT_CFG, forward)
mtq.quantize(te_grouped_moe_model, quant_cfg, forward)

# Quantize non-grouped model
mtq.quantize(sequential_moe_model, mtq.FP8_DEFAULT_CFG, forward)
# Quantize non-grouped model with synced weight amax to match TEGroupedMLP behavior
seq_quant_cfg = copy.deepcopy(quant_cfg)
seq_quant_cfg["algorithm"] = {"method": "max", "sync_expert_weight_amax": True}
mtq.quantize(sequential_moe_model, seq_quant_cfg, forward)

# Compare model outputs after quantization
te_grouped_moe_quant_output = forward(te_grouped_moe_model)
sequential_moe_quant_output = forward(sequential_moe_model)

assert torch.allclose(
te_grouped_moe_quant_output, sequential_moe_quant_output, atol=1e-6, rtol=1e-6
)


def test_te_grouped_vs_sequential_quantize(dist_workers_size_4):
# TODO SequentialMLP local spec doesn't support EP and TP simultaneously yet
@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG])
def test_te_grouped_vs_sequential_quantize(dist_workers_size_4, quant_cfg):
"""Test that TEGrouped and sequential MoE models produce similar quantized models."""
pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently")
dist_workers_size_4.run(partial(_test_te_grouped_vs_sequential_quantize_helper, 1, 2, 2))
dist_workers_size_4.run(
partial(_test_te_grouped_vs_sequential_quantize_helper, 1, 2, quant_cfg)
)


@pytest.mark.parametrize("ep_size", [1, 2])
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
def test_layer_sync_moe_local_experts_amax(dist_workers, ep_size, moe_grouped_gemm):
@pytest.mark.parametrize("sync_weight_amax", [True, False])
def test_layer_sync_moe_local_experts_amax(
dist_workers, ep_size, moe_grouped_gemm, sync_weight_amax
):
"""Test expert model parallel synchronization."""
if torch.cuda.device_count() < ep_size:
pytest.skip(f"Requires at least {ep_size} GPUs for expert model parallel test")
Expand All @@ -700,11 +701,14 @@ def test_layer_sync_moe_local_experts_amax(dist_workers, ep_size, moe_grouped_ge
_test_layer_sync_moe_local_experts_amax,
ep_size,
moe_grouped_gemm,
sync_weight_amax,
),
)


def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, size):
def _test_layer_sync_moe_local_experts_amax(
ep_size, moe_grouped_gemm, sync_weight_amax, rank, size
):
initialize_for_megatron(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
Expand All @@ -718,21 +722,21 @@ def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, siz
etp_size=1,
hidden_size=256,
moe_grouped_gemm=moe_grouped_gemm,
use_te=moe_grouped_gemm,
num_moe_experts=8,
transformer_impl="modelopt",
)
# Make weight initialization different across experts, otherwise experts will have similar amax values
for layer in model.decoder.layers:
for i, expert in enumerate(layer.mlp.experts.local_experts):
expert.linear_fc1.weight.data.fill_(0.1 + i * 0.05)
expert.linear_fc2.weight.data.fill_(0.2 + i * 0.05)
if not moe_grouped_gemm:
# Make weight initialization different across experts, otherwise experts will have similar amax values
for layer in model.decoder.layers:
for i, expert in enumerate(layer.mlp.experts.local_experts):
expert.linear_fc1.weight.data.fill_(0.1 + i * 0.05)
expert.linear_fc2.weight.data.fill_(0.2 + i * 0.05)

quant_cfg = mtq.FP8_DEFAULT_CFG
model = mtq.quantize(model, quant_cfg, get_forward(model))

for layer in model.decoder.layers:
layer.mlp.experts.layer_sync_moe_local_experts_amax()
layer.mlp.experts.layer_sync_moe_local_experts_amax(sync_weight_amax=sync_weight_amax)

for layer in model.decoder.layers:
# Check input quantizer amax is synced across local experts
Expand All @@ -750,18 +754,22 @@ def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, siz
else:
assert torch.allclose(fc2_amax, expert.linear_fc2.input_quantizer.amax)

# Check weight quantizer amax is different across local experts
# Check weight quantizer amax
fc1_amax = None
fc2_amax = None
for expert in layer.mlp.experts.local_experts:
assert expert.linear_fc1.weight_quantizer.amax is not None
assert expert.linear_fc2.weight_quantizer.amax is not None
if fc1_amax is None:
fc1_amax = expert.linear_fc1.weight_quantizer.amax
elif sync_weight_amax:
assert torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax)
else:
assert not torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax)
if fc2_amax is None:
fc2_amax = expert.linear_fc2.weight_quantizer.amax
elif sync_weight_amax:
assert torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax)
else:
assert not torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax)

Expand All @@ -785,7 +793,6 @@ def _test_expert_model_parallel_amax_sync(
etp_size=etp_size,
hidden_size=256,
moe_grouped_gemm=moe_grouped_gemm,
use_te=moe_grouped_gemm,
num_moe_experts=8,
transformer_impl="modelopt",
)
Expand Down
Loading