Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[2/x] clean up casting functions: delayed scaling #346

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions benchmarks/bench_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,13 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
A,
scale_a,
fp8_dtype,
None, # amax_buffer
a_config,
GemmInputRole.INPUT,
)
b_fp8 = ToFloat8ConstrFunc.apply(
B,
scale_b,
fp8_dtype,
None, # amax_buffer
b_config,
GemmInputRole.WEIGHT,
)
Expand Down
3 changes: 1 addition & 2 deletions float8_experimental/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def cast_to_float8_e4m3_dynamic(
inpt_tensor,
scale,
e4m3_dtype,
None, # amax_buffer
linear_mm_config,
gemm_input_role,
)
Expand All @@ -59,11 +58,11 @@ def cast_to_float8_delayed(
linear_mm_config: Optional[LinearMMConfig] = None,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
amax_buffer.fill_(tensor_to_amax(tensor))
return ToFloat8ConstrFunc.apply(
tensor,
scale,
float8_dtype,
amax_buffer,
linear_mm_config,
gemm_input_role,
)
Expand Down
4 changes: 0 additions & 4 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def forward(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype=e4m3_dtype,
amax_buffer: Optional[torch.Tensor] = None,
linear_mm_config: Optional[LinearMMConfig] = None,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
Expand All @@ -216,11 +215,8 @@ def forward(
tensor: the tensor to convert
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn
amax_buffer: an Optional buffer buffer to store the amax value in prior to conversion
emulate: whether to emulate the matmuls in fp32
"""
if amax_buffer is not None:
amax_buffer.fill_(tensor_to_amax(tensor))

return to_fp8_no_autograd(
tensor,
Expand Down
13 changes: 5 additions & 8 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
import torch
import torch.nn as nn
import torch.utils._pytree as pytree
from float8_experimental.float8_scaling_utils import cast_to_float8_e4m3_dynamic
from float8_experimental.float8_scaling_utils import (
cast_to_float8_delayed,
cast_to_float8_e4m3_dynamic,
)

from float8_experimental.float8_tensor import (
Float8Tensor,
Expand Down Expand Up @@ -168,7 +171,6 @@ def fsdp_pre_all_gather(self, mesh):
self._tensor,
self._precomputed_scale,
torch.float8_e4m3fn,
None, # amax_buffer
self._linear_mm_config,
GemmInputRole.WEIGHT,
)
Expand Down Expand Up @@ -352,12 +354,7 @@ def fsdp_pre_all_gather(self, mesh):
)
self.is_amax_initialized = True

# this will:
# 1. cast the tensor to float8 using `_scale_buffer`
# 2. populate `_amax_buffer` inplace
# TODO(future PR): clean up all the casting functions and clearly
# separate dynamic vs delayed, tech debt has accumulated
float8_tensor = ToFloat8ConstrFunc.apply(
float8_tensor = cast_to_float8_delayed(
self._tensor,
self._scale_buffer,
e4m3_dtype,
Expand Down
2 changes: 0 additions & 2 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
self.weight,
scale,
dtype,
None, # amax_buffer
self.linear_mm_config,
GemmInputRole.WEIGHT,
)
Expand Down Expand Up @@ -205,7 +204,6 @@ def cast_to_float8_e4m3_inference(
inpt_tensor,
scale,
e4m3_dtype,
None, # amax_buffer
linear_mm_config,
GemmInputRole.INPUT,
)
Expand Down
10 changes: 2 additions & 8 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,15 +451,13 @@ def test_different_configs_error(self):
x_fp32,
x_scale,
fp8_dtype,
None, # amax_buffer
linear_config_a,
GemmInputRole.INPUT,
)
b = ToFloat8ConstrFunc.apply(
x_fp32,
x_scale,
fp8_dtype,
None, # amax_buffer
linear_config_b,
GemmInputRole.WEIGHT,
)
Expand Down Expand Up @@ -489,10 +487,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
b_scale = tensor_to_scale(b, input_dtype).float()

a_fp8 = ToFloat8ConstrFunc.apply(
a, a_scale, input_dtype, None, None, GemmInputRole.INPUT
a, a_scale, input_dtype, None, GemmInputRole.INPUT
)
b_fp8 = ToFloat8ConstrFunc.apply(
b, b_scale, input_dtype, None, None, GemmInputRole.WEIGHT
b, b_scale, input_dtype, None, GemmInputRole.WEIGHT
)

with pytest.raises(
Expand All @@ -512,15 +510,13 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
a,
a_scale,
input_dtype,
None, # amax_buffer
pad_config,
GemmInputRole.INPUT,
)
b_fp8 = ToFloat8ConstrFunc.apply(
b,
b_scale,
input_dtype,
None, # amax_buffer
pad_config,
GemmInputRole.WEIGHT,
)
Expand All @@ -537,15 +533,13 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
a,
a_scale,
input_dtype,
None, # amax_buffer
emulated_config,
GemmInputRole.INPUT,
)
b_fp8 = ToFloat8ConstrFunc.apply(
b,
b_scale,
input_dtype,
None, # amax_buffer
emulated_config,
GemmInputRole.WEIGHT,
)
Expand Down
5 changes: 3 additions & 2 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
get_float8_layers,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_tensor import LinearMMConfig, ToFloat8ConstrFunc
from float8_experimental.float8_scaling_utils import cast_to_float8_delayed
from float8_experimental.float8_tensor import LinearMMConfig
from float8_experimental.float8_utils import e4m3_dtype

from torch._dynamo.test_case import TestCase as DynamoTestCase
Expand Down Expand Up @@ -178,7 +179,7 @@ def __init__(self, graph_break: bool):
self.graph_break = graph_break

def forward(self, x):
x_fp8 = ToFloat8ConstrFunc.apply(
x_fp8 = cast_to_float8_delayed(
x,
self.fp8_scale_x,
e4m3_dtype,
Expand Down
6 changes: 2 additions & 4 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ def test_scaled_mm(mesh: DeviceMesh, size=16):
y_scale = tensor_to_scale(y_fp32, fp8_dtype).float()

x_fp8 = ToFloat8ConstrFunc.apply(
x_fp32, x_scale, fp8_dtype, None, None, GemmInputRole.INPUT
x_fp32, x_scale, fp8_dtype, None, GemmInputRole.INPUT
)
y_fp8 = ToFloat8ConstrFunc.apply(
y_fp32, y_scale, fp8_dtype, None, None, GemmInputRole.WEIGHT
y_fp32, y_scale, fp8_dtype, None, GemmInputRole.WEIGHT
)

dist_x_fp8 = DTensor.from_local(x_fp8, mesh, [lhs_placement], run_check=False)
Expand Down Expand Up @@ -169,15 +169,13 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
dist_x_scale,
fp8_dtype,
None,
None,
GemmInputRole.INPUT,
)
dist_weight_fp8 = ToFloat8ConstrFunc.apply(
dist_wight_fp32,
dist_weight_scale,
fp8_dtype,
None,
None,
GemmInputRole.WEIGHT,
)

Expand Down
Loading