Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 9344b72

Browse files
committed
format
1 parent 9584570 commit 9344b72

File tree

6 files changed

+29
-14
lines changed

6 files changed

+29
-14
lines changed

float8_experimental/float8_dynamic_linear.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
tensor_already_casted_to_fp8,
2323
to_fp8_no_autograd,
2424
)
25-
from float8_experimental.float8_utils import tensor_to_scale, e4m3_dtype, e5m2_dtype
25+
from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale
2626
from torch._prims_common import suggest_memory_format
2727

2828

@@ -106,9 +106,7 @@ def cast_to_float8_e4m3fn(
106106
if tensor_already_casted_to_fp8(inpt_tensor):
107107
return inpt_tensor
108108
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
109-
return Float8Tensor.to_float8(
110-
inpt_tensor, scale, e4m3_dtype, mm_config=mm_config
111-
)
109+
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)
112110

113111

114112
def cast_to_float8_e5m2_bw(

float8_experimental/float8_linear.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
to_fp8_no_autograd,
2222
)
2323

24-
from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax, e4m3_dtype, e5m2_dtype
24+
from float8_experimental.float8_utils import (
25+
amax_history_to_scale,
26+
e4m3_dtype,
27+
e5m2_dtype,
28+
tensor_to_amax,
29+
)
2530

2631

2732
def _maybe_initialize_amaxes_scales_for_float8_cast(

float8_experimental/float8_linear_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
1515
from float8_experimental.float8_linear import Float8Linear
1616

17-
from float8_experimental.float8_utils import amax_history_to_scale_stack, e4m3_dtype, e5m2_dtype
17+
from float8_experimental.float8_utils import (
18+
amax_history_to_scale_stack,
19+
e4m3_dtype,
20+
e5m2_dtype,
21+
)
1822
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor
1923

2024
log = logging.getLogger(__name__)

float8_experimental/float8_python_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
to simplify the product code.
1010
"""
1111

12-
1312
from typing import Optional, Tuple
1413

1514
import float8_experimental.float8_aten_api # noqa

float8_experimental/float8_tensor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
import torch
1010

1111
import torch.distributed._functional_collectives as funcol
12-
from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated, e4m3_dtype
12+
from float8_experimental.float8_utils import (
13+
e4m3_dtype,
14+
tensor_to_amax,
15+
to_fp8_saturated,
16+
)
1317
from torch.distributed._tensor import DTensor
1418

1519
aten = torch.ops.aten

test/test_base.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@
3131
from float8_experimental.float8_utils import (
3232
amax_to_scale,
3333
compute_error,
34+
e4m3_dtype,
35+
e5m2_dtype,
3436
fp8_tensor_statistics,
3537
FP8_TYPES,
3638
tensor_to_scale,
37-
e4m3_dtype,
38-
e5m2_dtype,
3939
)
4040

4141
random.seed(0)
@@ -397,10 +397,15 @@ def test_merge_configs(self):
397397

398398

399399
class TestNumerics:
400-
@pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn,
401-
torch.float8_e5m2,
402-
torch.float8_e4m3fnuz,
403-
torch.float8_e5m2fnuz])
400+
@pytest.mark.parametrize(
401+
"float8_dtype",
402+
[
403+
torch.float8_e4m3fn,
404+
torch.float8_e5m2,
405+
torch.float8_e4m3fnuz,
406+
torch.float8_e5m2fnuz,
407+
],
408+
)
404409
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
405410
def test_small_amax_float16(self, float8_dtype):
406411
# If we calculate scale naively with FP8_MAX_POS / amax,

0 commit comments

Comments
 (0)