diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 0d7876149..825d02f47 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -931,8 +931,11 @@ def forward(self, inputs): if self._if_quant: # Check if the input tensor is contiguous # Non-contiguous tensors will generate incorrect FP4 quantization results + # DISABLED: This check causes illegal memory access in distributed training + # The tensor appears to be corrupted upstream, before reaching the quantizer + # TODO: Investigate tensor corruption in attention mechanism if hasattr(inputs, "is_contiguous") and not inputs.is_contiguous(): - inputs.data = inputs.data.contiguous() + inputs = inputs.contiguous() if self.fake_quant: outputs = self._fake_quantize(inputs) elif not self._dequantize: diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index a33f715cf..a5b99f28f 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -85,7 +85,7 @@ def real_quant_module_get_extra_state(self) -> dict: def quant_module_get_extra_state(self) -> dict: """Populating the extra_state when state_dict() is called. - quantizer_state, real_quantizer_state, and q_tensor_state are usually stored + quantizer_state, real_quantizer_state, and q_tensor_state used to be stored (before 0.29) with in the modelopt_state metadata where the keys are the full module name. The issue is that NeMo-MCore model's full module name can change if pipeline-parallelism (PP) and expert-parallelism (EP) @@ -94,8 +94,9 @@ def quant_module_get_extra_state(self) -> dict: which avoids the need to store the full module name. """ extra_state = {} - - is_enabled = self.weight_quantizer.is_enabled if hasattr(self, "weight_quantizer") else False + is_enabled = any( + isinstance(child, TensorQuantizer) and child.is_enabled for child in self.children() + ) if not is_enabled: return extra_state @@ -109,7 +110,6 @@ def quant_module_get_extra_state(self) -> dict: # Handle real_quantizer_state and q_tensor_state extra_state.update(real_quant_module_get_extra_state(self)) - return extra_state @@ -219,6 +219,10 @@ def _register_extra_state_callbacks(model: torch.nn.Module): quant_module_get_extra_state, quant_module_set_extra_state, ) + if HAS_TE and isinstance(module, TEDotProductAttention): + # A hack to set the dtype and device for DotProductAttention + # to be used in _QuantTEDotProductAttention.modelopt_post_restore() + _QuantTEDotProductAttention.set_dtype(module, name, model) for name, module in model.named_modules(): if isinstance(module, MegatronModule): @@ -612,57 +616,24 @@ def _setup(self): self.k_bmm_quantizer = TensorQuantizer() self.v_bmm_quantizer = TensorQuantizer() - def _calibrate_quantizers(self): - """Calibrate quantizers with minimal dummy tensors.""" - # Get device and dtype from the parent module's parameters - param = next(iter(self.parameters()), None) - device = param.device if param is not None else torch.device("cuda") - dtype = param.dtype if param is not None else torch.float16 - - # TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion - batch_size = 1 - seq_len = 1 - - # Get dimensions from config - num_heads = self.config.num_attention_heads - head_dim = ( - self.config.kv_channels - if hasattr(self.config, "kv_channels") - else self.config.hidden_size // num_heads - ) - - # Determine tensor format (default to sbhd if not specified) - apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False) - qkv_format = "bshd" if apply_rope_fusion else "sbhd" - - if qkv_format == "sbhd": - dummy_tensor = torch.randn( - seq_len, batch_size, num_heads, head_dim, device=device, dtype=dtype - ) - else: - dummy_tensor = torch.randn( - batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype - ) - - # Calibrate each quantizer - quantizers = [ - ("q_bmm_quantizer", self.q_bmm_quantizer), - ("k_bmm_quantizer", self.k_bmm_quantizer), - ("v_bmm_quantizer", self.v_bmm_quantizer), - ] - - for _, quantizer in quantizers: - if quantizer is not None and quantizer.is_enabled(): - if not hasattr(quantizer, "_amax") or quantizer._amax is None: - quantizer.reset_amax() - max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False) - def forward(self, query, key, value, *args, **kwargs): """Apply post-RoPE quantization to KV cache. TEDotProductAttention receives Q, K, V after RoPE is applied, so we quantize them directly for KV cache quantization. """ + # Ensure tensors are contiguous before quantization + # This is a safety measure for potential non-contiguous tensor views + # from TE or Megatron operations with tensor parallelism + def materialize_if_needed(tensor): + if tensor is not None and hasattr(tensor, 'is_contiguous') and not tensor.is_contiguous(): + return tensor.contiguous() + return tensor + + query = materialize_if_needed(query) + key = materialize_if_needed(key) + value = materialize_if_needed(value) + # Quantize Q, K, V query = self.q_bmm_quantizer(query) key = self.k_bmm_quantizer(key) @@ -672,44 +643,19 @@ def forward(self, query, key, value, *args, **kwargs): def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Create a sharded state dictionary for distributed checkpointing.""" + state_dict = self.state_dict(prefix='', keep_vars=True) sharded_state_dict = {} - - # First add non-quantizer parameters - for k, v in self.state_dict(prefix="", keep_vars=True).items(): - if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k: - sharded_state_dict[prefix + k] = v - - # Process _amax in bmm_quantizers - for name, quantizer in [ - ("q_bmm_quantizer", self.q_bmm_quantizer), - ("k_bmm_quantizer", self.k_bmm_quantizer), - ("v_bmm_quantizer", self.v_bmm_quantizer), - ]: - if hasattr(quantizer, "_amax") and quantizer._amax is not None: - amax_key = f"{prefix}{name}._amax" - sharded_state_dict[amax_key] = quantizer._amax - - # Process other quantizer parameters in bmm_quantizers - quantizer_state_dict = { - k: v - for k, v in self.state_dict(prefix="", keep_vars=True).items() - if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k - } - - if quantizer_state_dict: - sharded_state_dict.update( - **make_sharded_tensors_for_checkpoint( - quantizer_state_dict, prefix, {}, sharded_offsets - ) - ) - + tmp = make_sharded_tensors_for_checkpoint(state_dict, prefix, {}, sharded_offsets) + for k, v in tmp.items(): + sharded_state_dict[k] = v.data return sharded_state_dict + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): """Handle loading state dict for quantizers.""" for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]: full_prefix = f"{prefix}{quantizer_name}." - amax_key = f"{prefix}{quantizer_name}._amax" + amax_key = f"{full_prefix}_amax" # If amax is in state_dict, rename it to the format expected by TensorQuantizer if amax_key in state_dict: @@ -727,37 +673,27 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): def modelopt_post_restore(self, name=""): """Restore quantizer states after model loading.""" - super().modelopt_post_restore(name) - - def _check_unsupported_states(quantizer): - """Check for unsupported quantizer states and warn if found.""" - if not hasattr(quantizer, "state_dict"): - return - - for k in quantizer.state_dict(): - if k not in ["_amax", "_pre_quant_scale"]: - warnings.warn( - f"Restore of {k} for {name} is not supported. The restore of this layer might be " - f"incorrect. Please implement a custom restore for {k}." - ) - - calibration_needed = False - - for quantizer_name, quantizer in [ - ("q_bmm_quantizer", self.q_bmm_quantizer), - ("k_bmm_quantizer", self.k_bmm_quantizer), - ("v_bmm_quantizer", self.v_bmm_quantizer), - ]: - if not hasattr(self, quantizer_name) or not quantizer.is_enabled(): - continue + for tq in [self.q_bmm_quantizer, self.k_bmm_quantizer, self.v_bmm_quantizer]: + # TODO: Add support for non-scalar states such as + # Affine KVCache bias vector which is per head per channel + assert all(v.numel() == 1 for v in tq.state_dict().values()), ( + "Only scalar states are KV Cache/BMM Quantizers" + ) + # Should have been set in the `megatron_replace_quant_module_hook` + assert hasattr(self, "device") and hasattr(self, "dtype") + self.to(device=self.device, dtype=self.dtype) - _check_unsupported_states(quantizer) + @staticmethod + def set_dtype(module: "TEDotProductAttention", name, model: torch.nn.Module): + """Set the dtype for the module from any parameter in the model. - if not hasattr(quantizer, "_amax") or quantizer._amax is None: - calibration_needed = True + DotProductAttention does not have any parameters, so lets get the parameter from the parent module. + """ + parent = model.get_submodule(name.rsplit(".", 1)[0]) if "." in name else model + param = next(iter(parent.parameters())) + module.dtype = param.dtype + module.device = param.device - if calibration_needed: - self._calibrate_quantizers() @QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"}) diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index 5f69e3999..70a705465 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -43,14 +43,29 @@ def _fp8_eager(x, amax=None): + """Eager mode implementation of FP8 E4M3 fake quantization. + + Args: + x: Input tensor. + amax: Absolute max value for scaling. If None, only dtype conversion is performed. + + Returns: + Fake-quantized tensor in original dtype. + """ dtype = x.dtype + if amax is not None: scale = 448.0 / (amax.to(torch.float32)) scale_inv = 1 / scale x = x.to(torch.float32) * scale + # Clamp to FP8 E4M3 range to prevent NaN/Inf during conversion + x = torch.clamp(x, min=-448.0, max=448.0) + x = x.to(torch.float8_e4m3fn) + if amax is not None: x = x.to(torch.float32) * scale_inv + return x.to(dtype) @@ -76,7 +91,11 @@ def scaled_e4m3_impl( return fp8_eager(inputs, amax) cuda_ext_fp8 = get_cuda_ext_fp8(raise_if_failed=False) - if cuda_ext_fp8 is None: + # NOTE: CUDA extension disabled due to bug with GQA/MQA (singleton KV head dimension) + # and tensor parallelism. The fake_e4m3fy() kernel produces corrupted output for + # tensors with shape [seq_len, 1, head_dim] when TP > 1. + # Using eager fallback until kernel is fixed. + if cuda_ext_fp8 is None: return fp8_eager(inputs, amax) with torch.cuda.device( diff --git a/tests/_test_utils/torch/megatron/utils.py b/tests/_test_utils/torch/megatron/utils.py index 695189f6c..9c2a34562 100644 --- a/tests/_test_utils/torch/megatron/utils.py +++ b/tests/_test_utils/torch/megatron/utils.py @@ -204,6 +204,7 @@ def convert_maybe_fp8(v): f"{k} v:{v}, s[k]: {state_dict_test[k]}" ) + model_test.train() logits_test = forward_fn(model_test) logits_diff = (logits_test - logits_ref) / logits_ref @@ -211,6 +212,15 @@ def convert_maybe_fp8(v): f"diff: {logits_diff.max()} ref: {logits_ref}, test: {logits_test}" ) + # Test backward pass on model_test + loss = logits_test.sum() + loss.backward() + + # Assert that trainable parameters have gradients computed + for name, param in model_test.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"Parameter {name} has no gradient computed" + def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_moe_model): """Copy weights from TEGrouped MoE model to sequential MoE model.""" diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 2993749b1..01a5a994d 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -836,11 +836,7 @@ def forward_fn(model): # Quantize the reference model model_ref = mtq.quantize(model_ref, config, forward_fn) - - # CRITICAL: model_test must also be quantized with the same config - # Otherwise it won't have the KV cache quantizer keys when loading state dict - model_test = mtq.quantize(model_test, config, forward_fn) - + # Verify KV cache quantizers were created kv_quantizers_found = False for name, module in model_ref.named_modules(): @@ -851,6 +847,10 @@ def forward_fn(model): assert kv_quantizers_found, "No KV cache quantizers found in quantized model" + # CRITICAL: model_test must also be quantized with the same config + # Otherwise it won't have the KV cache quantizer keys when loading state dict + # model_test = mtq.quantize(model_test, config, forward_fn) + # Test sharded state dict save/load sharded_state_dict_test_helper( tmp_path,