From 110ec4b07adf838fd6ca4289349683204631e203 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 5 Mar 2024 11:01:54 -0800 Subject: [PATCH] cast to fp32 in amax --- float8_experimental/float8_utils.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 21574fba..00f0ed03 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -23,8 +23,10 @@ @torch.no_grad() -def amax_to_scale(amax, float8_dtype, orig_dtype): - scale = torch.empty_like(amax, dtype=torch.float32) +def amax_to_scale( + amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype +): + assert amax.dtype == torch.float32, "amax must be a float32 tensor" if float8_dtype == torch.float8_e4m3fn: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) else: # e5m2 @@ -35,16 +37,15 @@ def amax_to_scale(amax, float8_dtype, orig_dtype): # to care about this for float32/bfloat16. if orig_dtype is torch.float16: res = torch.clamp(res, max=FP16_MAX_POS) - scale.copy_(res) - return scale + return res @torch.no_grad() def amax_history_to_scale( - amax_history, - float8_dtype, - orig_dtype, - history_to_scale_fn_type, + amax_history: torch.Tensor, + float8_dtype: torch.dtype, + orig_dtype: torch.dtype, + history_to_scale_fn_type: str, ): if history_to_scale_fn_type == "max": amax = torch.max(amax_history) @@ -87,7 +88,7 @@ def tensor_to_amax(x, distributed_reduction=False): @torch.no_grad() -def tensor_to_scale(x, float8_dtype): +def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype): amax = tensor_to_amax(x) if float8_experimental.config.use_fused_cast and x.is_cuda: from float8_experimental.fused_kernels.fused_casting_kernels import (