Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
When running with compile idnuctor throws error
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Feb 15, 2024
1 parent 956195b commit 97a509e
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,16 +239,17 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor):
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()

(
reduced_fp8_amax_tensor,
reduced_fp8_amax_w_tensor,
reduced_fp8_amax_dL_dY_tensor,
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))

for idx, child in enumerate(fp8_layers):
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])
# Split the reduced tensor into single element tensors
# [x1, x2, x3, w1, w2, w3, dL_dY1, dL_dY2, dL_dY3] -> [[x1], [x2], [x3], [w1], [w2], [w3], [dL_dY1], [dL_dY2], [dL_dY3]]
splits = torch.split(all_reduced_amax_tensor, 1)

# Then foreach_copy the split tensors back into the original tensors
torch._foreach_copy_(
fp8_amax_x_tensor_list
+ fp8_amax_w_tensor_list
+ fp8_amax_dL_dY_tensor_list,
splits,
)

# We create two stacked tensor groups, one for the amax history and one for the current scales
fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list)
Expand Down

0 comments on commit 97a509e

Please sign in to comment.