Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Update on "[2/x] clean up casting functions: delayed scaling"
Browse files Browse the repository at this point in the history
Summary:

Removes delayed scaling from `float8_tensor.py`. After this PR, the
invariant is that everything in `float8_tensor.py` requires the scale to
be calculated elsewhere. This moves the codebase towards separation of
concerns for calculating the scale (via various scaling strategies),
separated from creating an instance of `Float8Tensor`.

Note that stateful delayed scaling is the reason we need this separation.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jul 26, 2024
1 parent 08f4052 commit 24f09e4
Showing 1 changed file with 0 additions and 4 deletions.
4 changes: 0 additions & 4 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def forward(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype=e4m3_dtype,
# amax_buffer: Optional[torch.Tensor] = None,
linear_mm_config: Optional[LinearMMConfig] = None,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
Expand All @@ -216,11 +215,8 @@ def forward(
tensor: the tensor to convert
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn
amax_buffer: an Optional buffer buffer to store the amax value in prior to conversion
emulate: whether to emulate the matmuls in fp32
"""
# if amax_buffer is not None:
# amax_buffer.fill_(tensor_to_amax(tensor))

return to_fp8_no_autograd(
tensor,
Expand Down

0 comments on commit 24f09e4

Please sign in to comment.