diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 1d7aaa49..8093c271 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -223,17 +223,19 @@ def backward(ctx, go_fp8: torch.Tensor): x_fp8, w_fp8 = ctx.saved_tensors # calculate dL/dX - go_fp8_reshaped = go_fp8.reshape(-1, go_fp8.size(-1)) + go_fp8_reshaped = go_fp8.view(-1, go_fp8.size(-1)) w_fp8_t_c_t = w_fp8.t().contiguous().t() dL_dX = float8_mm_helper(go_fp8_reshaped, w_fp8_t_c_t) - dL_dX = dL_dX.reshape(*go_fp8.shape[:-1], dL_dX.size(-1)) + dL_dX = dL_dX.view(*go_fp8.shape[:-1], dL_dX.size(-1)) # calculate dL/dW - x_fp8_reshaped_t_c = x_fp8.reshape(-1, x_fp8.size(-1)).t().contiguous() + x_fp8_reshaped_t_c = x_fp8.view(-1, x_fp8.size(-1)).t().contiguous() go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t() dL_dW = float8_mm_helper(x_fp8_reshaped_t_c, go_fp8_reshaped_t_c_t) - dL_dW = dL_dW.t() + # The contiguous call is not needed for correctness, but allows for a faster backward + # pass in conjunction with compile for both single-gpu and fsdp. + dL_dW = dL_dW.t().contiguous() empty_grads = None, None, None, None, None, None, None, None, None return dL_dX, dL_dW, *empty_grads