Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
my abs_max is busted
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Feb 28, 2024
1 parent f430039 commit 3932fab
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 20 deletions.
9 changes: 7 additions & 2 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,13 @@ def to_fp8_no_autograd(
):
from driss_torch import saturated_cast

bits_fp8 = saturated_cast(x, x_scale, float8_dtype)

if x.dim() in {3, 4}:
prev_x_shape = x.shape
x = x.view(-1, x.size(-1))
bits_fp8 = saturated_cast(x, x_scale, float8_dtype)
bits_fp8 = bits_fp8.view(prev_x_shape)
else:
bits_fp8 = saturated_cast(x, x_scale, float8_dtype)
else:
x_scaled = x * x_scale
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)
Expand Down
9 changes: 7 additions & 2 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,16 @@ def amax_history_to_scale_stack(

@torch.no_grad()
def tensor_to_amax(x, distributed_reduction=False):
if float8_experimental.config.use_fused_cast and x.is_cuda:
if False and float8_experimental.config.use_fused_cast and x.is_cuda:
from float8_experimental.fused_kernels.fused_casting_kernels import abs_max

amax = abs_max(x)
diff = abs_max(x) - x.abs().max().to(torch.float32)
assert (
diff.item() == 0
), f"Expected {amax} to be equal to {x.abs().max().to(torch.float32)} but got {diff}"
else:
amax = x.abs().max()
amax = x.abs().max().to(torch.float32)

# If the user asked for distributed reduction, do it.
# If the user did not ask for it, assume that it will
Expand Down
37 changes: 21 additions & 16 deletions float8_experimental/fused_kernels/fused_casting_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def abs_max_kernel(
r_mask = r_index < r_numel
values = tl.load(
x_ptr + (r_index + (r_numel * x_index)),
x_mask & r_mask,
r_mask,
eviction_policy="evict_last",
other=0.0,
).to(tl.float32)
Expand All @@ -62,21 +62,26 @@ def abs_max_kernel(

def abs_max(x: torch.Tensor) -> torch.Tensor:
"Calculates the global max of the absolute values of a tensor"
output = torch.empty((512, 1), device=x.device, dtype=torch.float32)
n_elements = x.numel()
grid = lambda meta: (meta["X_BLOCK_SIZE"],)
X_BLOCK_SIZE = 1
R_BLOCK_SIZE = 1024
r_numel = n_elements // 512
abs_max_kernel[grid](
x,
output,
x_numel=512,
r_numel=r_numel,
X_BLOCK_SIZE=X_BLOCK_SIZE,
R_BLOCK_SIZE=R_BLOCK_SIZE,
)
return output.max()
x = x.contiguous()
if x.numel() % 512 == 0:
output = torch.full(
(512, 1), -float("inf"), device=x.device, dtype=torch.float32
)
grid = lambda meta: (meta["X_BLOCK_SIZE"],)
X_BLOCK_SIZE = 1
R_BLOCK_SIZE = 1024
r_numel = x.numel() // 512
abs_max_kernel[grid](
x,
output,
x_numel=512,
r_numel=r_numel,
X_BLOCK_SIZE=X_BLOCK_SIZE,
R_BLOCK_SIZE=R_BLOCK_SIZE,
)
return output.max()
else:
return x.abs().max().to(torch.float32)


@triton.jit
Expand Down

0 comments on commit 3932fab

Please sign in to comment.