Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
performance boooooooooooost
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jan 19, 2024
1 parent 48f21f6 commit 20da1c0
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,19 @@ def backward(ctx, go_fp8: torch.Tensor):
x_fp8, w_fp8 = ctx.saved_tensors

# calculate dL/dX
go_fp8_reshaped = go_fp8.reshape(-1, go_fp8.size(-1))
go_fp8_reshaped = go_fp8.view(-1, go_fp8.size(-1))
w_fp8_t_c_t = w_fp8.t().contiguous().t()
dL_dX = float8_mm_helper(go_fp8_reshaped, w_fp8_t_c_t)
dL_dX = dL_dX.reshape(*go_fp8.shape[:-1], dL_dX.size(-1))
dL_dX = dL_dX.view(*go_fp8.shape[:-1], dL_dX.size(-1))

# calculate dL/dW
x_fp8_reshaped_t_c = x_fp8.reshape(-1, x_fp8.size(-1)).t().contiguous()
x_fp8_reshaped_t_c = x_fp8.view(-1, x_fp8.size(-1)).t().contiguous()
go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t()

dL_dW = float8_mm_helper(x_fp8_reshaped_t_c, go_fp8_reshaped_t_c_t)
dL_dW = dL_dW.t()
# The contiguous call is not needed for correctness, but allows for a faster backward
# pass in conjunction with compile for both single-gpu and fsdp.
dL_dW = dL_dW.t().contiguous()

empty_grads = None, None, None, None, None, None, None, None, None
return dL_dX, dL_dW, *empty_grads
Expand Down

0 comments on commit 20da1c0

Please sign in to comment.