Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
cast to fp32 in amax
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Mar 5, 2024
1 parent 1dd4573 commit 110ec4b
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@


@torch.no_grad()
def amax_to_scale(amax, float8_dtype, orig_dtype):
scale = torch.empty_like(amax, dtype=torch.float32)
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
):
assert amax.dtype == torch.float32, "amax must be a float32 tensor"
if float8_dtype == torch.float8_e4m3fn:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else: # e5m2
Expand All @@ -35,16 +37,15 @@ def amax_to_scale(amax, float8_dtype, orig_dtype):
# to care about this for float32/bfloat16.
if orig_dtype is torch.float16:
res = torch.clamp(res, max=FP16_MAX_POS)
scale.copy_(res)
return scale
return res


@torch.no_grad()
def amax_history_to_scale(
amax_history,
float8_dtype,
orig_dtype,
history_to_scale_fn_type,
amax_history: torch.Tensor,
float8_dtype: torch.dtype,
orig_dtype: torch.dtype,
history_to_scale_fn_type: str,
):
if history_to_scale_fn_type == "max":
amax = torch.max(amax_history)
Expand Down Expand Up @@ -87,7 +88,7 @@ def tensor_to_amax(x, distributed_reduction=False):


@torch.no_grad()
def tensor_to_scale(x, float8_dtype):
def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype):
amax = tensor_to_amax(x)
if float8_experimental.config.use_fused_cast and x.is_cuda:
from float8_experimental.fused_kernels.fused_casting_kernels import (
Expand Down

0 comments on commit 110ec4b

Please sign in to comment.