diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index cf2336bf4a..2f5df5399a 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1060,6 +1060,17 @@ class MaxCalibConfig(QuantizeAlgorithmConfig): description="If True, the amax will be synced across the distributed processes.", ) + sync_expert_weight_amax: bool = ModeloptField( + default=False, + title="Sync weight quantizer amax across MoE experts", + description=( + "If True, the weight quantizer amax values are synchronized (max) across " + "local experts in SequentialMLP layers during calibration. This matches " + "TEGroupedMLP behavior where all experts share a single weight quantizer. " + "Only affects MoE models with SequentialMLP experts." + ), + ) + class MseCalibConfig(QuantizeAlgorithmConfig): """Configuration for per-tensor MSE calibration. diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 89097fd32c..82be0e3450 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -109,6 +109,7 @@ def max_calibrate( model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True, + sync_expert_weight_amax=False, ): """Calibrate the model using max. @@ -117,6 +118,7 @@ def max_calibrate( forward_loop: A callable which takes the model as argument and forwards calibration data through the model. distributed_sync: Whether to sync input_quantizer amax across distributed processes. + sync_expert_weight_amax: Whether to sync weight quantizer amax across MoE experts. See :class:`MaxCalibConfig ` for details on the remaining arguments. @@ -128,10 +130,10 @@ def max_calibrate( forward_loop(model) finish_stats_collection(model) - # Sync input_quantizer amax across local experts within each rank (for SequentialMLP) + # Sync quantizer amax across local experts within each rank (for SequentialMLP) for name, module in model.named_modules(): if hasattr(module, "layer_sync_moe_local_experts_amax"): - module.layer_sync_moe_local_experts_amax() + module.layer_sync_moe_local_experts_amax(sync_weight_amax=sync_expert_weight_amax) if not distributed_sync: return diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 0d02716a6e..812550e4f4 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -553,11 +553,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: self._count_expert_tokens = False return output - def layer_sync_moe_local_experts_amax(self): + def layer_sync_moe_local_experts_amax(self, sync_weight_amax=False): """Sync input_quantizer amax across experts so all share the same amax per quantizer. Skipped when _moe_calib_experts_ratio is set, as each expert is calibrated independently. Also skipped when experts is a fused module (e.g. Llama4TextExperts) with shared quantizers. + + Args: + sync_weight_amax: If True, also sync weight quantizer amax across experts. + """ if self._moe_calib_experts_ratio is not None: return @@ -565,7 +569,7 @@ def layer_sync_moe_local_experts_amax(self): iter(self.experts) except TypeError: return - sync_moe_expert_amax(self.experts) + sync_moe_expert_amax(self.experts, sync_weight_amax=sync_weight_amax) class _QuantLlama4TextExperts(QuantModule): diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index f08dc8275f..0b50fd937a 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -575,15 +575,20 @@ def _setup(self): expert.linear_fc1.parallel_state = self.parallel_state expert.linear_fc2.parallel_state = self.parallel_state - def layer_sync_moe_local_experts_amax(self): - """Sync input quantizer amax across local experts in a SequentialMLP. + def layer_sync_moe_local_experts_amax(self, sync_weight_amax=False): + """Sync quantizer amax across local experts in a SequentialMLP. - Ensures all experts have the same input quantizer amax. This function operates - on a single rank and does not require distributed sync. + Always syncs input quantizer amax across experts. Optionally syncs weight + quantizer amax as well, which matches TEGroupedMLP behavior where all + experts are fused into a single GEMM with one quantizer per linear layer. - Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate(). - This function should be called before the distributed sync to ensure the amax values - are synchronized across the layer first. + Args: + sync_weight_amax: If True, also sync weight quantizer amax across experts. + + This function operates on a single rank and does not require distributed sync. + Distributed amax sync across EP and ETP (for RowParallel) happens in + model_calib.max_calibrate(). This function should be called before the + distributed sync to ensure the amax values are synchronized across the layer first. Note: Because there are logic which calls collective communication based on whether amax is not None, @@ -591,7 +596,7 @@ def layer_sync_moe_local_experts_amax(self): when synchronizing over EP since some ranks may have amax None and not calling the collective communication. """ - sync_moe_expert_amax(self.local_experts) + sync_moe_expert_amax(self.local_experts, sync_weight_amax=sync_weight_amax) def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Override the default to enable singleton_local_shards. diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index 4340b8dc1f..0fd6220185 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -516,12 +516,15 @@ def set_quantizer_state_dict(model: nn.Module, quantizer_state_dict: dict): module.load_state_dict(quantizer_state_dict[key]) -def sync_moe_expert_amax(experts): - """Sync input_quantizer amax across MoE experts and fix missing weight amax. +def sync_moe_expert_amax(experts, sync_weight_amax=False): + """Sync quantizer amax across MoE experts and fix missing weight amax. 1. Takes the element-wise max of each ``input_quantizer`` amax across all experts and writes it back, so every expert shares the same input amax. - 2. For any ``weight_quantizer`` that is enabled but has ``amax is None`` (expert + 2. If ``sync_weight_amax`` is True, also syncs ``weight_quantizer`` amax across + experts (max across experts). This matches TEGroupedMLP behavior where all + experts share a single weight quantizer. + 3. For any ``weight_quantizer`` that is enabled but has ``amax is None`` (expert received no tokens during calibration), runs a weight-only ``max_calibrate`` to populate the missing amax. """ @@ -530,11 +533,9 @@ def sync_moe_expert_amax(experts): amax_dict: dict[str, torch.Tensor] = {} for expert in experts: for name, module in expert.named_modules(): - if ( - isinstance(module, TensorQuantizer) - and module.amax is not None - and "input_quantizer" in name - ): + if not isinstance(module, TensorQuantizer) or module.amax is None: + continue + if "input_quantizer" in name or (sync_weight_amax and "weight_quantizer" in name): stored_amax = amax_dict.get(name) amax_tensor = module.amax.detach().clone() amax_dict[name] = ( diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index d8ba6fbed7..aaaa3bcae6 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -245,7 +245,6 @@ def _gpt_model_provider( meta_device=False, ep_size=1, etp_size=None, - use_te=False, transformer_impl="local", # Hybrid mamba MOE parameters is_hybrid=False, @@ -291,7 +290,6 @@ def _gpt_model_provider( use_cpu_initialization=meta_device, num_moe_experts=num_moe_experts, moe_grouped_gemm=moe_grouped_gemm, - use_te=use_te, ) if not meta_device: @@ -315,7 +313,6 @@ def _test_sharded_state_dict( etp_size = model_config.get("etp_size", None) num_moe_experts = model_config.get("num_moe_experts", None) moe_grouped_gemm = model_config.get("moe_grouped_gemm", False) - use_te = model_config.get("use_te", False) transformer_impl = model_config.get("transformer_impl", "local") # Hybrid mamba MOE parameters is_hybrid = model_config.get("is_hybrid", False) @@ -334,7 +331,6 @@ def _test_sharded_state_dict( vocab_size=256, num_moe_experts=num_moe_experts, moe_grouped_gemm=moe_grouped_gemm, - use_te=use_te, ep_size=ep_size, etp_size=etp_size, transformer_impl=transformer_impl, @@ -347,7 +343,6 @@ def _test_sharded_state_dict( vocab_size=256, num_moe_experts=num_moe_experts, moe_grouped_gemm=moe_grouped_gemm, - use_te=use_te, meta_device=meta_device, ep_size=ep_size, etp_size=etp_size, @@ -445,8 +440,6 @@ def test_homogeneous_sharded_state_dict( ) model_config = {"transformer_impl": transformer_impl} - if transformer_impl == "modelopt": - model_config["use_te"] = True dist_workers.run( partial( _test_sharded_state_dict, @@ -581,25 +574,27 @@ def test_fp8_real_quantize(dist_workers): dist_workers.run(_test_fp8_real_quantize_helper) -@pytest.mark.skip(reason="TODO: etp requires sequence parallelism now in Megatron due to a bug;") +# TODO: etp requires sequence parallelism now in Megatron due to a bug @pytest.mark.parametrize( "config", [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG], ) @pytest.mark.parametrize("moe_grouped_gemm", [True, False]) def test_moe_sharded_state_dict(dist_workers, need_4_gpus, tmp_path, config, moe_grouped_gemm): - if moe_grouped_gemm: - pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently") + if config == mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG and moe_grouped_gemm: + pytest.skip("TEGroupedMLP only supports per-tensor quantization.") + # TODO: Add support for compress=True for TEGroupedMLP moe_config = { "tp_size": 2, "ep_size": 2, - "etp_size": 2, + "etp_size": 1, "num_moe_experts": 4, "moe_grouped_gemm": moe_grouped_gemm, - "use_te": moe_grouped_gemm, - "transformer_impl": "modelopt", + "transformer_impl": "transformer_engine" if moe_grouped_gemm else "modelopt", } + if not moe_grouped_gemm: + moe_config["tp_size"] = 1 # TODO: TP+EP is not supported by QuantSequentialMLP dist_workers.run( partial( _test_sharded_state_dict, @@ -614,12 +609,11 @@ def test_moe_sharded_state_dict(dist_workers, need_4_gpus, tmp_path, config, moe ) -def _test_te_grouped_vs_sequential_quantize_helper(tp_size, ep_size, etp_size, rank, size): +def _test_te_grouped_vs_sequential_quantize_helper(tp_size, ep_size, quant_cfg, rank, size): """Test that TEGrouped and sequential MoE models produce similar amax values.""" initialize_for_megatron( tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size, - expert_tensor_parallel_size=etp_size, seed=SEED, ) @@ -627,28 +621,26 @@ def _test_te_grouped_vs_sequential_quantize_helper(tp_size, ep_size, etp_size, r te_grouped_moe_model = _gpt_model_provider( tp_size=tp_size, ep_size=ep_size, - etp_size=etp_size, hidden_size=32, moe_grouped_gemm=True, - use_te=True, + transformer_impl="transformer_engine", num_moe_experts=4, ) # Create forward function with cached inputs - forward = get_forward(te_grouped_moe_model) + forward = get_forward(te_grouped_moe_model, batch_size=8) num_te_grouped_mlp = sum( isinstance(module, TEGroupedMLP) for module in te_grouped_moe_model.modules() ) assert num_te_grouped_mlp == 4, ( - f"TEGrupedMoEModel has {num_te_grouped_mlp} TEGroupedMLP modules, it should have 4" + f"TEGroupedMoEModel has {num_te_grouped_mlp} TEGroupedMLP modules, it should have 4" ) # Create sequential MoE model sequential_moe_model = _gpt_model_provider( tp_size=tp_size, ep_size=ep_size, - etp_size=etp_size, hidden_size=32, moe_grouped_gemm=False, num_moe_experts=4, @@ -669,28 +661,34 @@ def _test_te_grouped_vs_sequential_quantize_helper(tp_size, ep_size, etp_size, r assert torch.allclose(te_grouped_moe_output, sequential_moe_output, atol=1e-6, rtol=1e-6) # Quantize grouped model - mtq.quantize(te_grouped_moe_model, mtq.FP8_DEFAULT_CFG, forward) + mtq.quantize(te_grouped_moe_model, quant_cfg, forward) - # Quantize non-grouped model - mtq.quantize(sequential_moe_model, mtq.FP8_DEFAULT_CFG, forward) + # Quantize non-grouped model with synced weight amax to match TEGroupedMLP behavior + seq_quant_cfg = copy.deepcopy(quant_cfg) + seq_quant_cfg["algorithm"] = {"method": "max", "sync_expert_weight_amax": True} + mtq.quantize(sequential_moe_model, seq_quant_cfg, forward) # Compare model outputs after quantization te_grouped_moe_quant_output = forward(te_grouped_moe_model) sequential_moe_quant_output = forward(sequential_moe_model) + assert torch.allclose( te_grouped_moe_quant_output, sequential_moe_quant_output, atol=1e-6, rtol=1e-6 ) -def test_te_grouped_vs_sequential_quantize(dist_workers_size_4): +# TODO SequentialMLP local spec doesn't support EP and TP simultaneously yet +@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]) +def test_te_grouped_vs_sequential_quantize(dist_workers_size_4, quant_cfg): """Test that TEGrouped and sequential MoE models produce similar quantized models.""" - pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently") - dist_workers_size_4.run(partial(_test_te_grouped_vs_sequential_quantize_helper, 1, 2, 2)) + dist_workers_size_4.run( + partial(_test_te_grouped_vs_sequential_quantize_helper, 1, 2, quant_cfg) + ) @pytest.mark.parametrize("ep_size", [1, 2]) -@pytest.mark.parametrize("moe_grouped_gemm", [True, False]) -def test_layer_sync_moe_local_experts_amax(dist_workers, ep_size, moe_grouped_gemm): +@pytest.mark.parametrize("sync_weight_amax", [True, False]) +def test_layer_sync_moe_local_experts_amax(dist_workers, ep_size, sync_weight_amax): """Test expert model parallel synchronization.""" if torch.cuda.device_count() < ep_size: pytest.skip(f"Requires at least {ep_size} GPUs for expert model parallel test") @@ -699,12 +697,12 @@ def test_layer_sync_moe_local_experts_amax(dist_workers, ep_size, moe_grouped_ge partial( _test_layer_sync_moe_local_experts_amax, ep_size, - moe_grouped_gemm, + sync_weight_amax, ), ) -def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, size): +def _test_layer_sync_moe_local_experts_amax(ep_size, sync_weight_amax, rank, size): initialize_for_megatron( tensor_model_parallel_size=1, pipeline_model_parallel_size=1, @@ -717,8 +715,6 @@ def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, siz ep_size=ep_size, etp_size=1, hidden_size=256, - moe_grouped_gemm=moe_grouped_gemm, - use_te=moe_grouped_gemm, num_moe_experts=8, transformer_impl="modelopt", ) @@ -732,7 +728,7 @@ def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, siz model = mtq.quantize(model, quant_cfg, get_forward(model)) for layer in model.decoder.layers: - layer.mlp.experts.layer_sync_moe_local_experts_amax() + layer.mlp.experts.layer_sync_moe_local_experts_amax(sync_weight_amax=sync_weight_amax) for layer in model.decoder.layers: # Check input quantizer amax is synced across local experts @@ -750,7 +746,7 @@ def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, siz else: assert torch.allclose(fc2_amax, expert.linear_fc2.input_quantizer.amax) - # Check weight quantizer amax is different across local experts + # Check weight quantizer amax fc1_amax = None fc2_amax = None for expert in layer.mlp.experts.local_experts: @@ -758,10 +754,14 @@ def _test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, rank, siz assert expert.linear_fc2.weight_quantizer.amax is not None if fc1_amax is None: fc1_amax = expert.linear_fc1.weight_quantizer.amax + elif sync_weight_amax: + assert torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax) else: assert not torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax) if fc2_amax is None: fc2_amax = expert.linear_fc2.weight_quantizer.amax + elif sync_weight_amax: + assert torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax) else: assert not torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax) @@ -785,7 +785,6 @@ def _test_expert_model_parallel_amax_sync( etp_size=etp_size, hidden_size=256, moe_grouped_gemm=moe_grouped_gemm, - use_te=moe_grouped_gemm, num_moe_experts=8, transformer_impl="modelopt", )