diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index d4cf249fe..e36a9d5e5 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -184,6 +184,20 @@ def sync_quantizer_amax_across_tp( if hasattr(module, "sync_moe_local_experts_amax"): module.sync_moe_local_experts_amax() + # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) + # So we should sync amax across DP and TP for these quantizers + for name, module in model.named_modules(): + if not (hasattr(module, "k_bmm_quantizer") and hasattr(module, "parallel_state")): + continue + for quantizer in [module.k_bmm_quantizer, module.v_bmm_quantizer]: + if isinstance(quantizer, TensorQuantizer) and quantizer.amax is not None: + quantizer.sync_amax_across_distributed_group( + [ + module.parallel_state.data_parallel_group, + module.parallel_state.tensor_parallel_group, + ] + ) + @torch.no_grad() def mse_calibrate( diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index de7407f70..e7f2f2e49 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -18,7 +18,7 @@ import contextlib import math import warnings -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import Any, Protocol import torch @@ -1155,11 +1155,23 @@ def set_from_modelopt_state(self, modelopt_state, properties_only: bool = False) modelopt_state.get("_pytorch_state_metadata", {}) ) - def sync_amax_across_distributed_group(self, parallel_group: DistributedProcessGroup): - """Synchronize the amax across all ranks in the given group.""" - if parallel_group.is_initialized() and getattr(self, "_amax", None) is not None: + def sync_amax_across_distributed_group( + self, + parallel_group: DistributedProcessGroup | Iterable[DistributedProcessGroup], + ): + """Synchronize the amax across all ranks in the given group(s).""" + if getattr(self, "_amax", None) is None: + return + + for pg in ( + [parallel_group] + if isinstance(parallel_group, DistributedProcessGroup) + else list(parallel_group) + ): + if not pg.is_initialized(): + continue try: - dist.all_reduce(self._amax, op=dist.ReduceOp.MAX, group=parallel_group.group) + dist.all_reduce(self._amax, op=dist.ReduceOp.MAX, group=pg.group) except RuntimeError as e: # This error happens if the distributed backend is using GPU and # the tensor is not on GPU (or vice versa). diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 95e8651aa..ab806bc00 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -16,6 +16,7 @@ """Support quantization for megatron linear layers.""" import logging +import types import warnings from typing import Any @@ -28,6 +29,7 @@ from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule +from megatron.core.transformer.attention import Attention from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from megatron.core.utils import get_tensor_model_parallel_group_if_none @@ -38,7 +40,6 @@ ) from modelopt.torch.utils.distributed import ParallelState -from ..model_calib import max_calibrate from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper @@ -98,7 +99,9 @@ def quant_module_get_extra_state(self) -> dict: """ extra_state = {} - is_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False + is_enabled = any( + isinstance(child, TensorQuantizer) and child.is_enabled for child in self.children() + ) if not is_enabled: return extra_state @@ -201,6 +204,19 @@ def quant_module_set_extra_state(self, state: Any): self.allow_post_restore = False +def _create_incompatible_method(method_name: str): + """Create a method that raises an error for incompatible flash decode methods.""" + + def _incompatible_method(self, *args, **kwargs): + raise NotImplementedError( + f"{method_name} is not compatible with ModelOpt KV cache quantization. " + f"KV cache quantization requires core_attention to be called. " + f"Please raise an issue at https://github.com/NVIDIA/Model-Optimizer if you need this feature." + ) + + return _incompatible_method + + def megatron_replace_quant_module_hook(model: torch.nn.Module): """Configure Megatron-Core model quantization support. @@ -211,8 +227,32 @@ def megatron_replace_quant_module_hook(model: torch.nn.Module): 1. We change TransformerConfig to enable heterogenous distributed checkpointing. 2. We enable all sub- QuantModule to store quantizer_state as extra_state by typing-matching the QuantModuleRegistry. + 3. For Attention modules, we configure them to use core_attention path for KV cache quantization. """ + def _configure_attention_for_kv_cache_quant(module: Attention): + """Configure Attention module for KV cache quantization compatibility.""" + # Disable flash_decode if enabled (it bypasses core_attention) + if getattr(module.config, "flash_decode", False): + warnings.warn( + "flash_decode=True is incompatible with ModelOpt KV cache quantization. " + "Setting flash_decode=False. Flash decode bypasses core_attention during decode phase." + ) + module.config.flash_decode = False + + # Set dtype and device for core_attention (needed for modelopt_post_restore) + assert hasattr(module, "core_attention"), "Attention module must have core_attention" + param = next(iter(module.parameters()), None) + if param is not None: + module.core_attention.dtype = param.dtype + module.core_attention.device = param.device + + # Patch flash_decode and flash_decode_and_prefill to raise errors + module.flash_decode = types.MethodType(_create_incompatible_method("flash_decode"), module) + module.flash_decode_and_prefill = types.MethodType( + _create_incompatible_method("flash_decode_and_prefill"), module + ) + def _register_extra_state_callbacks(model: torch.nn.Module): for name, module in model.named_modules(): if type(module) in QuantModuleRegistry: @@ -223,6 +263,10 @@ def _register_extra_state_callbacks(model: torch.nn.Module): quant_module_set_extra_state, ) + # Configure Attention modules for KV cache quantization + if isinstance(module, Attention): + _configure_attention_for_kv_cache_quant(module) + for name, module in model.named_modules(): if isinstance(module, MegatronModule): if "vision_model" not in name: @@ -632,152 +676,37 @@ def _setup(self): self.k_bmm_quantizer = TensorQuantizer() self.v_bmm_quantizer = TensorQuantizer() - def _calibrate_quantizers(self): - """Calibrate quantizers with minimal dummy tensors.""" - # Get device and dtype from the parent module's parameters - param = next(iter(self.parameters()), None) - device = param.device if param is not None else torch.device("cuda") - dtype = param.dtype if param is not None else torch.float16 - - # TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion - batch_size = 1 - seq_len = 1 - - # Get dimensions from config - num_heads = self.config.num_attention_heads - head_dim = ( - self.config.kv_channels - if hasattr(self.config, "kv_channels") - else self.config.hidden_size // num_heads + # Set parallel_state for distributed sync of BMM quantizers + try: + data_parallel_group = get_data_parallel_group(with_context_parallel=True) + except AssertionError: + data_parallel_group = get_data_parallel_group() + self.parallel_state = ParallelState( + data_parallel_group, + mcore_parallel.get_tensor_model_parallel_group(), ) - # Determine tensor format (default to sbhd if not specified) - apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False) - qkv_format = "bshd" if apply_rope_fusion else "sbhd" - - if qkv_format == "sbhd": - dummy_tensor = torch.randn( - seq_len, batch_size, num_heads, head_dim, device=device, dtype=dtype - ) - else: - dummy_tensor = torch.randn( - batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype - ) - - # Calibrate each quantizer - quantizers = [ - ("q_bmm_quantizer", self.q_bmm_quantizer), - ("k_bmm_quantizer", self.k_bmm_quantizer), - ("v_bmm_quantizer", self.v_bmm_quantizer), - ] - - for _, quantizer in quantizers: - if quantizer is not None and quantizer.is_enabled(): - if not hasattr(quantizer, "_amax") or quantizer._amax is None: - quantizer.reset_amax() - max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False) - def forward(self, query, key, value, *args, **kwargs): - """Apply post-RoPE quantization to KV cache. - - TEDotProductAttention receives Q, K, V after RoPE is applied, - so we quantize them directly for KV cache quantization. - """ + """Apply post-RoPE quantization to KV cache.""" # Quantize Q, K, V query = self.q_bmm_quantizer(query) key = self.k_bmm_quantizer(key) value = self.v_bmm_quantizer(value) - return super().forward(query, key, value, *args, **kwargs) - def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): - """Create a sharded state dictionary for distributed checkpointing.""" - sharded_state_dict = {} - - # First add non-quantizer parameters - for k, v in self.state_dict(prefix="", keep_vars=True).items(): - if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k: - sharded_state_dict[prefix + k] = v - - # Process _amax in bmm_quantizers - for name, quantizer in [ - ("q_bmm_quantizer", self.q_bmm_quantizer), - ("k_bmm_quantizer", self.k_bmm_quantizer), - ("v_bmm_quantizer", self.v_bmm_quantizer), - ]: - if hasattr(quantizer, "_amax") and quantizer._amax is not None: - amax_key = f"{prefix}{name}._amax" - sharded_state_dict[amax_key] = quantizer._amax - - # Process other quantizer parameters in bmm_quantizers - quantizer_state_dict = { - k: v - for k, v in self.state_dict(prefix="", keep_vars=True).items() - if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k - } - - if quantizer_state_dict: - sharded_state_dict.update( - **make_sharded_tensors_for_checkpoint( - quantizer_state_dict, prefix, {}, sharded_offsets - ) - ) - - return sharded_state_dict - - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): - """Handle loading state dict for quantizers.""" - for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]: - full_prefix = f"{prefix}{quantizer_name}." - amax_key = f"{prefix}{quantizer_name}._amax" - - # If amax is in state_dict, rename it to the format expected by TensorQuantizer - if amax_key in state_dict: - expected_amax_key = f"{full_prefix}_amax" - state_dict[expected_amax_key] = state_dict.pop(amax_key) - - # Handle other quantizer states - for k in list(state_dict.keys()): - if "_quantizer" in k and "_amax" not in k: - name = k.split(prefix)[-1] if prefix else k - if name in self.state_dict(): - state_dict[k] = state_dict[k].view_as(self.state_dict()[name]) - - super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) - def modelopt_post_restore(self, name=""): """Restore quantizer states after model loading.""" - super().modelopt_post_restore(name) - - def _check_unsupported_states(quantizer): - """Check for unsupported quantizer states and warn if found.""" - if not hasattr(quantizer, "state_dict"): - return - - for k in quantizer.state_dict(): - if k not in ["_amax", "_pre_quant_scale"]: - warnings.warn( - f"Restore of {k} for {name} is not supported. The restore of this layer might be " - f"incorrect. Please implement a custom restore for {k}." - ) - - calibration_needed = False - - for quantizer_name, quantizer in [ - ("q_bmm_quantizer", self.q_bmm_quantizer), - ("k_bmm_quantizer", self.k_bmm_quantizer), - ("v_bmm_quantizer", self.v_bmm_quantizer), - ]: - if not hasattr(self, quantizer_name) or not quantizer.is_enabled(): - continue - - _check_unsupported_states(quantizer) - - if not hasattr(quantizer, "_amax") or quantizer._amax is None: - calibration_needed = True - - if calibration_needed: - self._calibrate_quantizers() + for tq in [self.q_bmm_quantizer, self.k_bmm_quantizer, self.v_bmm_quantizer]: + # TODO: Add support for non-scalar states such as + # Affine KVCache bias vector which is per head per channel + if not all(v.numel() == 1 for v in tq.state_dict().values()): + raise NotImplementedError( + "Only scalar states are supported for KV Cache/BMM Quantizers" + ) + # dtype and device should have been set in `megatron_replace_quant_module_hook` + # via `_configure_attention_for_kv_cache_quant` + assert hasattr(self, "device") and hasattr(self, "dtype") + self.to(device=self.device, dtype=self.dtype) @QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"}) diff --git a/tests/_test_utils/torch/megatron/models.py b/tests/_test_utils/torch/megatron/models.py index ed69fd496..9509358b4 100644 --- a/tests/_test_utils/torch/megatron/models.py +++ b/tests/_test_utils/torch/megatron/models.py @@ -311,6 +311,7 @@ def get_mcore_mamba_hybrid_model( max_sequence_length: int = 4, vocab_size: int = 64, bf16: bool = True, + sequence_parallel: bool = False, # Mamba-specific parameters mamba_state_dim: int = 32, mamba_head_dim: int = 16, @@ -337,7 +338,7 @@ def get_mcore_mamba_hybrid_model( config = TransformerConfig( tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, - sequence_parallel=False, + sequence_parallel=sequence_parallel, num_layers=num_layers, num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, diff --git a/tests/_test_utils/torch/megatron/utils.py b/tests/_test_utils/torch/megatron/utils.py index 5ca0cf14c..bb91f83cd 100644 --- a/tests/_test_utils/torch/megatron/utils.py +++ b/tests/_test_utils/torch/megatron/utils.py @@ -214,6 +214,16 @@ def convert_maybe_fp8(v): f"diff: {logits_diff.max()} ref: {logits_ref}, test: {logits_test}" ) + # Test backward pass on model_test + model_test.train() + loss = forward_fn(model_test).sum() + loss.backward() + + # Assert that trainable parameters have gradients computed + for name, param in model_test.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"Parameter {name} has no gradient computed" + def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_moe_model): """Copy weights from TEGrouped MoE model to sequential MoE model.""" diff --git a/tests/_test_utils/torch/quantization/quantize_common.py b/tests/_test_utils/torch/quantization/quantize_common.py index 8647aaa00..f62d2d991 100644 --- a/tests/_test_utils/torch/quantization/quantize_common.py +++ b/tests/_test_utils/torch/quantization/quantize_common.py @@ -127,6 +127,23 @@ def _distributed_attr_check(quantizer, attr: str, op=dist.ReduceOp.MAX, groups=[ assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) +def verify_kv_cache_amax_sync(model, group=None): + kv_quantizers_found = False + for name, module in model.named_modules(): + if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): + kv_quantizers_found = True + + for quantizer in [module.k_bmm_quantizer, module.v_bmm_quantizer]: + if quantizer.amax is not None: + quantizer_amax = quantizer.amax.clone() + dist.all_reduce(quantizer_amax, op=dist.ReduceOp.MAX, group=group) + assert torch.allclose(quantizer_amax, quantizer.amax), ( + f"KV cache quantizer amax not synced across distributed group for {name}" + ) + + return kv_quantizers_found + + original_awq_lite = model_calib_module.awq_lite diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 5b2a8cc0a..d02b02c18 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -20,7 +20,12 @@ import torch from _test_utils.import_helper import skip_if_no_megatron from _test_utils.torch.distributed.utils import spawn_multiprocess_job -from _test_utils.torch.megatron.models import MegatronModel, get_mcore_gpt_model +from _test_utils.torch.megatron.models import ( + MambaModel, + MegatronModel, + get_mcore_gpt_model, + get_mcore_mamba_hybrid_model, +) from _test_utils.torch.megatron.utils import ( compare_amax_sync_across_expert_parallel, copy_weights_from_grouped_to_non_grouped, @@ -34,6 +39,7 @@ from _test_utils.torch.quantization.quantize_common import ( auto_quantize_helper, data_tensor_context_parallel_test_helper, + verify_kv_cache_amax_sync, ) skip_if_no_megatron() @@ -52,7 +58,6 @@ import modelopt.torch.quantization as mtq from modelopt.torch.quantization.nn import QuantModuleRegistry from modelopt.torch.quantization.plugins.megatron import _QuantTEMCoreRowParallelLinear -from modelopt.torch.utils.plugins import megatron_prefill try: from megatron.core.extensions.transformer_engine import TERowParallelLinear @@ -81,6 +86,30 @@ def get_batch(model, batch_size=2): return input_ids, labels, position_ids, attention_mask, loss_mask +def get_forward(model, batch_size=2): + """Return a forward function with cached batch inputs.""" + input_ids, labels, position_ids, attention_mask, loss_mask = get_batch(model, batch_size) + + def forward(model): + # MambaModel doesn't accept loss_mask argument + if isinstance(model, MambaModel): + return model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + ) + return model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + ) + + return forward + + def test_convert_megatron_parallel_linear(distributed_setup_size_1): initialize_for_megatron(seed=SEED) set_seed(SEED) @@ -267,12 +296,37 @@ def _gpt_model_provider( etp_size=None, use_te=False, transformer_impl="local", + # Hybrid mamba MOE parameters + is_hybrid=False, + hybrid_override_pattern=None, + mamba_head_dim=16, ): - """Build the model.""" + from contextlib import nullcontext + + device_ctx = torch.device("meta") if meta_device else nullcontext() - if meta_device: - with torch.device("meta"): - gpt_model = get_mcore_gpt_model( + with device_ctx: + if is_hybrid: + # Derive num_layers from pattern length, default to 4 + num_layers = len(hybrid_override_pattern) if hybrid_override_pattern else 4 + model = get_mcore_mamba_hybrid_model( + tensor_model_parallel_size=tp_size, + num_layers=num_layers, + hidden_size=hidden_size, + vocab_size=vocab_size, + num_attention_heads=8, + ffn_hidden_size=None, + hybrid_override_pattern=hybrid_override_pattern, + mamba_head_dim=mamba_head_dim, + mamba_num_groups=tp_size, # Must be divisible by tp_size + num_moe_experts=num_moe_experts, + sequence_parallel=True, # Required for MoE + TP + # EP/ETP passed via config_kwargs + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + ) + else: + model = get_mcore_gpt_model( tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size, expert_tensor_parallel_size=etp_size, @@ -288,27 +342,14 @@ def _gpt_model_provider( moe_grouped_gemm=moe_grouped_gemm, use_te=use_te, ) - else: - gpt_model = get_mcore_gpt_model( - tensor_model_parallel_size=tp_size, - expert_model_parallel_size=ep_size, - expert_tensor_parallel_size=etp_size, - num_layers=4, - ffn_hidden_size=None, - num_attention_heads=8, - activation_func="squared_relu", - transformer_impl=transformer_impl, - hidden_size=hidden_size, - vocab_size=vocab_size, - num_moe_experts=num_moe_experts, - moe_grouped_gemm=moe_grouped_gemm, - use_te=use_te, - ).cuda() - return gpt_model.eval() + + if not meta_device: + model = model.cuda() + return model.eval() def _test_sharded_state_dict( - tmp_path, config, hidden_size, modelopt_version, compress, meta_device, moe_config, rank, size + tmp_path, config, hidden_size, modelopt_version, compress, meta_device, model_config, rank, size ): # Must disable output_layer quantization since output_layer amax cannot be restore via # sharded_state_dict. All output_layer quantizers state are removed. @@ -318,13 +359,16 @@ def _test_sharded_state_dict( mto.conversion.__version__ = modelopt_version mtq.plugins.megatron.__version__ = modelopt_version - tp_size = moe_config.get("tp_size", size) - ep_size = moe_config.get("ep_size", 1) - etp_size = moe_config.get("etp_size", None) - num_moe_experts = moe_config.get("num_moe_experts", None) - moe_grouped_gemm = moe_config.get("moe_grouped_gemm", False) - use_te = moe_config.get("use_te", False) - transformer_impl = moe_config.get("transformer_impl", "local") + tp_size = model_config.get("tp_size", size) + ep_size = model_config.get("ep_size", 1) + etp_size = model_config.get("etp_size", None) + num_moe_experts = model_config.get("num_moe_experts", None) + moe_grouped_gemm = model_config.get("moe_grouped_gemm", False) + use_te = model_config.get("use_te", False) + transformer_impl = model_config.get("transformer_impl", "local") + # Hybrid mamba MOE parameters + is_hybrid = model_config.get("is_hybrid", False) + hybrid_override_pattern = model_config.get("hybrid_override_pattern", None) initialize_for_megatron( tensor_model_parallel_size=tp_size, @@ -343,6 +387,8 @@ def _test_sharded_state_dict( ep_size=ep_size, etp_size=etp_size, transformer_impl=transformer_impl, + is_hybrid=is_hybrid, + hybrid_override_pattern=hybrid_override_pattern, ) model_test = _gpt_model_provider( tp_size, @@ -355,16 +401,12 @@ def _test_sharded_state_dict( ep_size=ep_size, etp_size=etp_size, transformer_impl=transformer_impl, + is_hybrid=is_hybrid, + hybrid_override_pattern=hybrid_override_pattern, ) - prompt_tokens = torch.randint( - 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) - ).cuda() - - def forward_fn(model): - return megatron_prefill(model, prompt_tokens) - - model_ref = mtq.quantize(model_ref, config, forward_fn) + forward = get_forward(model_ref) + model_ref = mtq.quantize(model_ref, config, forward) if compress: mtq.compress(model_ref) @@ -376,7 +418,7 @@ def forward_fn(model): tmp_path, model_ref, model_test, - forward_fn, + forward, meta_device=meta_device, version=modelopt_version, ) @@ -413,6 +455,14 @@ def forward_fn(model): } ) +# Combined NVFP4 GEMM + KV cache quantization config +NVFP4_GEMM_KV_CFG = copy.deepcopy(mtq.NVFP4_DEFAULT_CFG) +NVFP4_GEMM_KV_CFG["quant_cfg"].update(mtq.NVFP4_KV_CFG["quant_cfg"]) + +# Combined FP8 GEMM + KV cache quantization config +FP8_GEMM_KV_CFG = copy.deepcopy(mtq.FP8_DEFAULT_CFG) +FP8_GEMM_KV_CFG["quant_cfg"].update(mtq.FP8_KV_CFG["quant_cfg"]) + @pytest.mark.parametrize( "config", @@ -424,22 +474,75 @@ def forward_fn(model): mtq.W4A8_AWQ_BETA_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG, - # Note: KV cache configs (FP8_KV_CFG, NVFP4_KV_CFG) are tested separately in test_kv_cache_quant - # They require TEDotProductAttention which needs transformer_impl="modelopt", not "local" + mtq.FP8_KV_CFG, + mtq.NVFP4_KV_CFG, ], ) @pytest.mark.parametrize("compress", [False, True]) @pytest.mark.parametrize("meta_device", [False, True]) -def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device): +@pytest.mark.parametrize("transformer_impl", ["local", "modelopt"]) +def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device, transformer_impl): if compress and config is mtq.W4A8_AWQ_BETA_CFG: pytest.skip("W4A8_AWQ_BETA_CFG is not supported for compress") + if config in (mtq.FP8_KV_CFG, mtq.NVFP4_KV_CFG): + if transformer_impl != "modelopt" or compress or meta_device: + pytest.skip( + "KV cache configs require transformer_impl='modelopt' and no compress/meta_device" + ) + size = torch.cuda.device_count() + model_config = {"transformer_impl": transformer_impl} + if transformer_impl == "modelopt": + model_config["use_te"] = True spawn_multiprocess_job( size=size, job=partial( - _test_sharded_state_dict, tmp_path, config, 256, None, compress, meta_device, {} + _test_sharded_state_dict, + tmp_path, + config, + 256, + None, + compress, + meta_device, + model_config, + ), + backend="nccl", + ) + + +@pytest.mark.parametrize( + "config", + [ + NVFP4_GEMM_KV_CFG, + FP8_GEMM_KV_CFG, + ], +) +def test_homogeneous_sharded_state_dict_hybrid(tmp_path, config): + """Test sharded state dict for hybrid Mamba MOE models.""" + if torch.cuda.device_count() < 4: + pytest.skip("Hybrid MOE test requires at least 4 GPUs") + + model_config = { + "is_hybrid": True, + "hybrid_override_pattern": "MEM*E", # 5 layers: Mamba → MoE → Mamba → Attention → MoE + "num_moe_experts": 8, + "tp_size": 2, + "ep_size": 2, + "etp_size": 2, + } + spawn_multiprocess_job( + size=4, + job=partial( + _test_sharded_state_dict, + tmp_path, + config, + 256, + None, + False, # compress + False, # meta_device + model_config, ), backend="nccl", ) @@ -534,16 +637,13 @@ def _test_fp8_real_quantize_helper(rank, size): config = mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG model = _gpt_model_provider(tp_size=1, hidden_size=hidden_size) - prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() - def forward_fn(model): - return megatron_prefill(model, prompt_tokens) - - forward_fn(model) + forward = get_forward(model) + forward(model) # real quant the model cur_mem = get_model_size(model) - real_quant_model = mtq.quantize(model, config, forward_fn) + real_quant_model = mtq.quantize(model, config, forward) mtq.compress(real_quant_model) real_quant_mem = get_model_size(real_quant_model) @@ -551,7 +651,7 @@ def forward_fn(model): assert real_quant_mem < (cur_mem / 2) * 1.1, "Memory after real quantization is not reduced." # check forward works after real quantization - forward_fn(real_quant_model) + forward(real_quant_model) assert real_quant_mem < cur_mem @@ -606,12 +706,6 @@ def _test_te_grouped_vs_sequential_quantize_helper(tp_size, ep_size, etp_size, r seed=SEED, ) - # Create input - prompt_tokens = torch.randint(0, 64, (2, 16)).cuda() - - def forward_fn(model): - return megatron_prefill(model, prompt_tokens) - # Create TEGrouped MoE model te_grouped_moe_model = _gpt_model_provider( tp_size=tp_size, @@ -622,6 +716,10 @@ def forward_fn(model): use_te=True, num_moe_experts=4, ) + + # Create forward function with cached inputs + forward = get_forward(te_grouped_moe_model) + num_te_grouped_mlp = sum( isinstance(module, TEGroupedMLP) for module in te_grouped_moe_model.modules() ) @@ -649,19 +747,19 @@ def forward_fn(model): copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_moe_model) # Compare model outputs before quantization - te_grouped_moe_output = forward_fn(te_grouped_moe_model) - sequential_moe_output = forward_fn(sequential_moe_model) + te_grouped_moe_output = forward(te_grouped_moe_model) + sequential_moe_output = forward(sequential_moe_model) assert torch.allclose(te_grouped_moe_output, sequential_moe_output, atol=1e-6, rtol=1e-6) # Quantize grouped model - mtq.quantize(te_grouped_moe_model, mtq.FP8_DEFAULT_CFG, forward_fn) + mtq.quantize(te_grouped_moe_model, mtq.FP8_DEFAULT_CFG, forward) # Quantize non-grouped model - mtq.quantize(sequential_moe_model, mtq.FP8_DEFAULT_CFG, forward_fn) + mtq.quantize(sequential_moe_model, mtq.FP8_DEFAULT_CFG, forward) # Compare model outputs after quantization - te_grouped_moe_quant_output = forward_fn(te_grouped_moe_model) - sequential_moe_quant_output = forward_fn(sequential_moe_model) + te_grouped_moe_quant_output = forward(te_grouped_moe_model) + sequential_moe_quant_output = forward(sequential_moe_model) assert torch.allclose( te_grouped_moe_quant_output, sequential_moe_quant_output, atol=1e-6, rtol=1e-6 ) @@ -716,18 +814,15 @@ def _test_expert_model_parallel_amax_sync( param.data.fill_(const_val) weight_idx += 1 - prompt_tokens = (torch.ones((2, model.max_sequence_length)) * 0.05 + rank * 0.5).cuda().long() - # force all expert routing for module in model.modules(): if isinstance(module, TopKRouter): module.topk = module.num_experts - def forward_fn(model): - return megatron_prefill(model, prompt_tokens) + forward = get_forward(model) # quantize the model - model = mtq.quantize(model, config, forward_fn) + model = mtq.quantize(model, config, forward) # Check initial sync status initial_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model) assert initial_sync, ( @@ -735,7 +830,7 @@ def forward_fn(model): ) # Test if the amax values are inconsistent when distributed sync is disabled - mtq.model_calib.max_calibrate(model, forward_fn, distributed_sync=False) + mtq.model_calib.max_calibrate(model, forward, distributed_sync=False) inconsistent_amax, _, _ = compare_amax_sync_across_expert_parallel( model, compare_across_experts=False ) @@ -745,7 +840,7 @@ def forward_fn(model): "Amax should not be synchronized across expert parallel ranks since expert parallel is disabled" ) # calibrate the model with distributed sync and test synchronization - mtq.model_calib.max_calibrate(model, forward_fn, distributed_sync=True) + mtq.model_calib.max_calibrate(model, forward, distributed_sync=True) for module in model.modules(): if hasattr(module, "sync_moe_local_experts_amax"): module.sync_moe_local_experts_amax() @@ -798,14 +893,11 @@ def _test_kv_cache_quant_helper(config, rank, size): transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec ).cuda() - # Create dummy input for calibration - prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() - - def forward_fn(model): - return megatron_prefill(model, prompt_tokens) + # Create forward function with cached inputs + forward = get_forward(model) # Test KV cache quantization with the given config - quantized_model = mtq.quantize(model, config, forward_fn) + quantized_model = mtq.quantize(model, config, forward) # Find TEDotProductAttention modules and verify they have KV cache quantizers te_attention_found = False @@ -823,104 +915,10 @@ def forward_fn(model): assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model" # Quick smoke test that forward still works - output = forward_fn(quantized_model) + output = forward(quantized_model) assert output is not None, "Forward pass failed" -def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): - """Helper for testing KV cache quantization with sharded state dict save/load.""" - # Disable output_layer quantization (same as other sharded state dict tests) - config["quant_cfg"]["*output_layer*"] = {"enable": False} - - initialize_for_megatron( - tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED - ) - - # Create GPT models with TEDotProductAttention (transformer_impl="modelopt") - model_ref = get_mcore_gpt_model( - tensor_model_parallel_size=size, - num_layers=2, # At least 2 layers to test multiple attention modules - hidden_size=64, - num_attention_heads=4, - vocab_size=64, - transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention - ).cuda() - - model_test = get_mcore_gpt_model( - tensor_model_parallel_size=size, - num_layers=2, - hidden_size=64, - num_attention_heads=4, - vocab_size=64, - transformer_impl="modelopt", - ).cuda() - - prompt_tokens = torch.randint( - 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) - ).cuda() - - def forward_fn(model): - return megatron_prefill(model, prompt_tokens) - - # Quantize the reference model - model_ref = mtq.quantize(model_ref, config, forward_fn) - - # CRITICAL: model_test must also be quantized with the same config - # Otherwise it won't have the KV cache quantizer keys when loading state dict - model_test = mtq.quantize(model_test, config, forward_fn) - - # Verify KV cache quantizers were created - kv_quantizers_found = False - for name, module in model_ref.named_modules(): - if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): - kv_quantizers_found = True - assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" - assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" - - assert kv_quantizers_found, "No KV cache quantizers found in quantized model" - - # Test sharded state dict save/load - sharded_state_dict_test_helper( - tmp_path, - model_ref, - model_test, - forward_fn, - meta_device=False, - version=None, - ) - - # Verify KV cache quantizers are restored correctly in model_test - for (name_ref, module_ref), (name_test, module_test) in zip( - model_ref.named_modules(), model_test.named_modules() - ): - if hasattr(module_ref, "k_bmm_quantizer"): - assert hasattr(module_test, "k_bmm_quantizer"), ( - f"K quantizer missing after restore in {name_test}" - ) - assert hasattr(module_test, "v_bmm_quantizer"), ( - f"V quantizer missing after restore in {name_test}" - ) - - # Check that quantizer states match - if hasattr(module_ref.k_bmm_quantizer, "_amax"): - assert hasattr(module_test.k_bmm_quantizer, "_amax"), ( - f"K quantizer _amax missing in {name_test}" - ) - if module_ref.k_bmm_quantizer._amax is not None: - assert torch.allclose( - module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax - ), f"K quantizer _amax mismatch in {name_test}" - - if hasattr(module_ref.v_bmm_quantizer, "_amax"): - assert hasattr(module_test.v_bmm_quantizer, "_amax"), ( - f"V quantizer _amax missing in {name_test}" - ) - if module_ref.v_bmm_quantizer._amax is not None: - assert torch.allclose( - module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax - ), f"V quantizer _amax mismatch in {name_test}" - - @pytest.mark.parametrize( "config", [ @@ -940,24 +938,42 @@ def test_kv_cache_quant(config): spawn_multiprocess_job(size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl") -@pytest.mark.parametrize( - "config", - [ - mtq.FP8_KV_CFG, - mtq.NVFP4_KV_CFG, - ], -) -def test_kv_cache_sharded_state_dict(tmp_path, config): - """Test KV cache quantization with sharded state dict save/load. +def _test_kv_cache_amax_sync_helper(config, rank, size, tensor_model_parallel_size=1): + """Helper function for testing KV cache quantizer amax sync across distributed world.""" + # Use rank in seed to produce different amax values across ranks + seed = SEED + rank + initialize_for_megatron( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=1, + seed=seed, + ) - This test verifies the complete workflow of saving and loading KV cache quantized - models with distributed checkpointing, ensuring quantizer states are properly - preserved across the save/load cycle. - """ - size = min(2, torch.cuda.device_count()) # Use 2 GPUs if available, else 1 + model = get_mcore_gpt_model( + tensor_model_parallel_size=tensor_model_parallel_size, + num_layers=1, + hidden_size=64, + num_attention_heads=4, + vocab_size=32, + transformer_impl="modelopt", + ).cuda() + + forward = get_forward(model) + + # Quantize with KV cache config + quantized_model = mtq.quantize(model, config, forward) + + # Verify KV cache quantizer amax is synced across the whole world + kv_quantizers_found = verify_kv_cache_amax_sync(quantized_model) + assert kv_quantizers_found, "No KV cache quantizers found in model" + + +def test_kv_cache_amax_sync(need_2_gpus): + """Test KV cache quantizer amax is synced across the distributed world.""" spawn_multiprocess_job( - size=size, - job=partial(_test_kv_cache_sharded_state_dict_helper, tmp_path, config), + size=2, + job=partial( + _test_kv_cache_amax_sync_helper, NVFP4_GEMM_KV_CFG, tensor_model_parallel_size=2 + ), backend="nccl", ) @@ -968,16 +984,7 @@ def test_convert_mcore_te_gpt_model(distributed_setup_size_1): initialize_for_megatron(tensor_model_parallel_size=1, seed=SEED) model = get_mcore_gpt_model(tensor_model_parallel_size=1, transformer_impl="transformer_engine") - input_ids, labels, position_ids, attention_mask, loss_mask = get_batch(model) - - def forward(model): - return model.forward( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - labels=labels, - loss_mask=loss_mask, - ) + forward = get_forward(model) for name, param in model.named_parameters(): param.requires_grad = True