diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 339e9d0bb9..7907c79bd6 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -763,7 +763,8 @@ def run_search(self): def _get_auto_quantize_score(grad_output, output_diff): - return ((grad_output.float() ** 2) * (output_diff.float() ** 2)).sum() + x = grad_output.float() * output_diff.float() + return x.to(torch.float64).square().sum() def _add_auto_quantize_score(grad_output, output_diff, score_tensor):