Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
fix kernel args
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Feb 28, 2024
1 parent 3932fab commit 4afcd0b
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
4 changes: 2 additions & 2 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def to_fp8_no_autograd(

if x.dim() in {3, 4}:
prev_x_shape = x.shape
x = x.view(-1, x.size(-1))
x = x.reshape(-1, x.size(-1))
bits_fp8 = saturated_cast(x, x_scale, float8_dtype)
bits_fp8 = bits_fp8.view(prev_x_shape)
bits_fp8 = bits_fp8.reshape(prev_x_shape)
else:
bits_fp8 = saturated_cast(x, x_scale, float8_dtype)
else:
Expand Down
6 changes: 1 addition & 5 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,10 @@ def amax_history_to_scale_stack(

@torch.no_grad()
def tensor_to_amax(x, distributed_reduction=False):
if False and float8_experimental.config.use_fused_cast and x.is_cuda:
if float8_experimental.config.use_fused_cast and x.is_cuda:
from float8_experimental.fused_kernels.fused_casting_kernels import abs_max

amax = abs_max(x)
diff = abs_max(x) - x.abs().max().to(torch.float32)
assert (
diff.item() == 0
), f"Expected {amax} to be equal to {x.abs().max().to(torch.float32)} but got {diff}"
else:
amax = x.abs().max().to(torch.float32)

Expand Down
12 changes: 10 additions & 2 deletions float8_experimental/fused_kernels/fused_casting_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,21 @@ def abs_max_kernel(


def abs_max(x: torch.Tensor) -> torch.Tensor:
"Calculates the global max of the absolute values of a tensor"
"""Calculates the global max of the absolute values of a tensor
This kernel launches a grid of 512 threads, each thread calculates the
maximum of x.numel // 512 elements. The results are then reduced to a single
value in a follow up kernel.
Args:
x: Input tensor to calculate the abs_max for
"""
x = x.contiguous()
if x.numel() % 512 == 0:
output = torch.full(
(512, 1), -float("inf"), device=x.device, dtype=torch.float32
)
grid = lambda meta: (meta["X_BLOCK_SIZE"],)
grid = lambda meta: (512,)
X_BLOCK_SIZE = 1
R_BLOCK_SIZE = 1024
r_numel = x.numel() // 512
Expand Down

0 comments on commit 4afcd0b

Please sign in to comment.