diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index bbe8e73c..2f65abec 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -49,9 +49,9 @@ def to_fp8_no_autograd( if x.dim() in {3, 4}: prev_x_shape = x.shape - x = x.view(-1, x.size(-1)) + x = x.reshape(-1, x.size(-1)) bits_fp8 = saturated_cast(x, x_scale, float8_dtype) - bits_fp8 = bits_fp8.view(prev_x_shape) + bits_fp8 = bits_fp8.reshape(prev_x_shape) else: bits_fp8 = saturated_cast(x, x_scale, float8_dtype) else: diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 438d1b4c..21574fba 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -70,14 +70,10 @@ def amax_history_to_scale_stack( @torch.no_grad() def tensor_to_amax(x, distributed_reduction=False): - if False and float8_experimental.config.use_fused_cast and x.is_cuda: + if float8_experimental.config.use_fused_cast and x.is_cuda: from float8_experimental.fused_kernels.fused_casting_kernels import abs_max amax = abs_max(x) - diff = abs_max(x) - x.abs().max().to(torch.float32) - assert ( - diff.item() == 0 - ), f"Expected {amax} to be equal to {x.abs().max().to(torch.float32)} but got {diff}" else: amax = x.abs().max().to(torch.float32) diff --git a/float8_experimental/fused_kernels/fused_casting_kernels.py b/float8_experimental/fused_kernels/fused_casting_kernels.py index 83eb646d..06bb41d8 100644 --- a/float8_experimental/fused_kernels/fused_casting_kernels.py +++ b/float8_experimental/fused_kernels/fused_casting_kernels.py @@ -61,13 +61,21 @@ def abs_max_kernel( def abs_max(x: torch.Tensor) -> torch.Tensor: - "Calculates the global max of the absolute values of a tensor" + """Calculates the global max of the absolute values of a tensor + + This kernel launches a grid of 512 threads, each thread calculates the + maximum of x.numel // 512 elements. The results are then reduced to a single + value in a follow up kernel. + + Args: + x: Input tensor to calculate the abs_max for + """ x = x.contiguous() if x.numel() % 512 == 0: output = torch.full( (512, 1), -float("inf"), device=x.device, dtype=torch.float32 ) - grid = lambda meta: (meta["X_BLOCK_SIZE"],) + grid = lambda meta: (512,) X_BLOCK_SIZE = 1 R_BLOCK_SIZE = 1024 r_numel = x.numel() // 512