diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 021779f037..be62e08782 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -198,12 +198,6 @@ def _replace_with_custom_fn_if_matches_filter( Returns: None """ - if isinstance(model, Float8Linear): - with torch.device("meta"): - new_module = nn.Linear(model.in_features, model.out_features) - new_module.weight = model.weight - new_module.bias = model.bias - model = new_module if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization @@ -244,12 +238,6 @@ def _replace_with_custom_fn_if_matches_filter_with_name( Returns: None """ - if isinstance(model, Float8Linear): - with torch.device("meta"): - new_module = nn.Linear(model.in_features, model.out_features) - new_module.weight = model.weight - new_module.bias = model.bias - model = new_module if filter_fn(model, cur_fqn[:-1]): if device is not None: model.to(device=device) # move to device before quantization @@ -1673,6 +1661,10 @@ def _float8_weight_only_transform( "applying int8 weight only quant requires module to have weight attribute" + " but {module} does not have one" ) + + if isinstance(module, Float8Linear): + module = _unwrap_float8_linear(module) + new_weight = _float8_weight_only_quant_tensor(module.weight, config) module.weight = torch.nn.Parameter(new_weight, requires_grad=False) @@ -1882,6 +1874,9 @@ def _float8_dynamic_activation_float8_weight_transform( "applying float8 dynamic activation quant requires module to have weight attribute" + f"but {module} does not have one" ) + if isinstance(module, Float8Linear): + module = _unwrap_float8_linear(module) + quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( module.weight, config ) @@ -1917,6 +1912,9 @@ def _float8_dynamic_activation_float8_semi_sparse_weight_transform( ): assert is_sm_at_least_90(), "Float8 quantization is only supported on CUDA>=9.0" + if isinstance(module, Float8Linear): + module = _unwrap_float8_linear(module) + weight = module.weight weight_dtype = config.weight_dtype activation_dtype = config.activation_dtype @@ -1981,6 +1979,9 @@ def _float8_static_activation_float8_weight_transform( "Float8 static activation quantization is only supported on CUDA 8.9 and above" ) + if isinstance(module, Float8Linear): + module = _unwrap_float8_linear(module) + scale = config.scale activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype @@ -2340,6 +2341,9 @@ def _fpx_weight_only_transform( if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() + if isinstance(module, Float8Linear): + module = _unwrap_float8_linear(module) + from torchao.dtypes import to_affine_quantized_fpx from torchao.dtypes.floatx import FloatxTensorCoreLayout @@ -2398,6 +2402,21 @@ def _module_fqn_to_config_handler( return module +def _unwrap_float8_linear(module: Float8Linear) -> nn.Linear: + """ + Unwrap a torchao Float8Linear by returning a nn.Linear with the same weights and bias. + + Torchao inference quantization techniques are generally only applicable to nn.Linear + layers, so this helper is useful for unwrapping models trained with torchao float8 training, + which replaces nn.Linear layers with Float8Linear layers. + """ + with torch.device("meta"): + new_module = nn.Linear(module.in_features, module.out_features) + new_module.weight = module.weight + new_module.bias = module.bias + return new_module + + torch.serialization.add_safe_globals( [ _int8_asymm_per_token_quant,