Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[5/x] clean up casting: cast_to_float8_e4m3_dynamic -> cast_to_float8_dynamic #349

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from float8_experimental.float8_scaling_utils import (
_maybe_initialize_amaxes_scales_for_float8_cast,
cast_to_float8_delayed,
cast_to_float8_e4m3_dynamic,
cast_to_float8_dynamic,
NoopFwToFloat8E5M2BwDelayed,
NoopFwToFloat8E5M2BwDynamic,
)
Expand Down Expand Up @@ -270,7 +270,7 @@ def cast_input_to_float8(
)
else:
assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = cast_to_float8_e4m3_dynamic(input, self.linear_mm_config)
input_fp8 = cast_to_float8_dynamic(input, e4m3_dtype, self.linear_mm_config)
return input_fp8

def cast_weight_to_float8(
Expand Down Expand Up @@ -305,8 +305,9 @@ def cast_weight_to_float8(
if isinstance(self.weight, Float8Tensor): # cast by FSDP
weight_fp8 = self.weight
else:
weight_fp8 = cast_to_float8_e4m3_dynamic(
weight_fp8 = cast_to_float8_dynamic(
self.weight,
e4m3_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
Expand Down
8 changes: 4 additions & 4 deletions float8_experimental/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,25 @@
)


def cast_to_float8_e4m3_dynamic(
def cast_to_float8_dynamic(
inpt_tensor: torch.Tensor,
float8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
scale = tensor_to_scale(inpt_tensor, float8_dtype, reduce_amax)
return hp_tensor_and_scale_to_float8(
inpt_tensor,
scale,
e4m3_dtype,
float8_dtype,
linear_mm_config,
gemm_input_role,
)


# TODO(future PR): align name with cast_to_float8_e4m3_dynamic
def cast_to_float8_delayed(
tensor: torch.Tensor,
scale: torch.Tensor,
Expand Down
12 changes: 8 additions & 4 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import torch.nn as nn
from float8_experimental.config import ScalingType
from float8_experimental.float8_scaling_utils import (
cast_to_float8_e4m3_dynamic,
cast_to_float8_dynamic,
NoopFwToFloat8E5M2BwDynamic,
)
from float8_experimental.float8_tensor import GemmInputRole
from float8_experimental.float8_utils import e4m3_dtype
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import (
Expand Down Expand Up @@ -45,8 +46,9 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = cast_to_float8_e4m3_dynamic(
input_tensor = cast_to_float8_dynamic(
input_tensor,
e4m3_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)
Expand Down Expand Up @@ -98,8 +100,9 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = cast_to_float8_e4m3_dynamic(
input_tensor = cast_to_float8_dynamic(
input_tensor,
e4m3_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)
Expand Down Expand Up @@ -196,8 +199,9 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
input, mesh, (input_layout,), run_check=False
)

dt_inp = cast_to_float8_e4m3_dynamic(
dt_inp = cast_to_float8_dynamic(
dt_inp,
e4m3_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)
Expand Down
5 changes: 3 additions & 2 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.utils._pytree as pytree
from float8_experimental.float8_scaling_utils import (
cast_to_float8_delayed,
cast_to_float8_e4m3_dynamic,
cast_to_float8_dynamic,
)

from float8_experimental.float8_tensor import (
Expand Down Expand Up @@ -175,8 +175,9 @@ def fsdp_pre_all_gather(self, mesh):
GemmInputRole.WEIGHT,
)
else:
float8_tensor = cast_to_float8_e4m3_dynamic(
float8_tensor = cast_to_float8_dynamic(
self._tensor,
e4m3_dtype,
self._linear_mm_config,
reduce_amax=True,
gemm_input_role=GemmInputRole.WEIGHT,
Expand Down
1 change: 0 additions & 1 deletion test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_python_api import addmm_float8_unwrapped
from float8_experimental.float8_scaling_utils import cast_to_float8_e4m3_dynamic
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand Down
Loading