Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
clean up casting: cast_to_float8_e4m3_dynamic -> cast_to_float8_dynam…
Browse files Browse the repository at this point in the history
…ic (#349)

Summary:
Pull Request resolved: #349

Moves the dtype from function name to argument, to match delayed scaling
version.

Reviewed By: drisspg

Differential Revision: D60310239

fbshipit-source-id: d266f8d9a17ed3170176c058e9960541a1d3946b
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 26, 2024
1 parent 6ac2f82 commit 4fb2877
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 14 deletions.
7 changes: 4 additions & 3 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from float8_experimental.float8_scaling_utils import (
_maybe_initialize_amaxes_scales_for_float8_cast,
cast_to_float8_delayed,
cast_to_float8_e4m3_dynamic,
cast_to_float8_dynamic,
NoopFwToFloat8E5M2BwDelayed,
NoopFwToFloat8E5M2BwDynamic,
)
Expand Down Expand Up @@ -270,7 +270,7 @@ def cast_input_to_float8(
)
else:
assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = cast_to_float8_e4m3_dynamic(input, self.linear_mm_config)
input_fp8 = cast_to_float8_dynamic(input, e4m3_dtype, self.linear_mm_config)
return input_fp8

def cast_weight_to_float8(
Expand Down Expand Up @@ -305,8 +305,9 @@ def cast_weight_to_float8(
if isinstance(self.weight, Float8Tensor): # cast by FSDP
weight_fp8 = self.weight
else:
weight_fp8 = cast_to_float8_e4m3_dynamic(
weight_fp8 = cast_to_float8_dynamic(
self.weight,
e4m3_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
Expand Down
8 changes: 4 additions & 4 deletions float8_experimental/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,25 @@
)


def cast_to_float8_e4m3_dynamic(
def cast_to_float8_dynamic(
inpt_tensor: torch.Tensor,
float8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
scale = tensor_to_scale(inpt_tensor, float8_dtype, reduce_amax)
return hp_tensor_and_scale_to_float8(
inpt_tensor,
scale,
e4m3_dtype,
float8_dtype,
linear_mm_config,
gemm_input_role,
)


# TODO(future PR): align name with cast_to_float8_e4m3_dynamic
def cast_to_float8_delayed(
tensor: torch.Tensor,
scale: torch.Tensor,
Expand Down
12 changes: 8 additions & 4 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import torch.nn as nn
from float8_experimental.config import ScalingType
from float8_experimental.float8_scaling_utils import (
cast_to_float8_e4m3_dynamic,
cast_to_float8_dynamic,
NoopFwToFloat8E5M2BwDynamic,
)
from float8_experimental.float8_tensor import GemmInputRole
from float8_experimental.float8_utils import e4m3_dtype
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import (
Expand Down Expand Up @@ -45,8 +46,9 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = cast_to_float8_e4m3_dynamic(
input_tensor = cast_to_float8_dynamic(
input_tensor,
e4m3_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)
Expand Down Expand Up @@ -98,8 +100,9 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = cast_to_float8_e4m3_dynamic(
input_tensor = cast_to_float8_dynamic(
input_tensor,
e4m3_dtype,
mod.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)
Expand Down Expand Up @@ -196,8 +199,9 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
input, mesh, (input_layout,), run_check=False
)

dt_inp = cast_to_float8_e4m3_dynamic(
dt_inp = cast_to_float8_dynamic(
dt_inp,
e4m3_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
) # DTensor(Float8Tensor)
Expand Down
5 changes: 3 additions & 2 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch.utils._pytree as pytree
from float8_experimental.float8_scaling_utils import (
cast_to_float8_delayed,
cast_to_float8_e4m3_dynamic,
cast_to_float8_dynamic,
)

from float8_experimental.float8_tensor import (
Expand Down Expand Up @@ -175,8 +175,9 @@ def fsdp_pre_all_gather(self, mesh):
GemmInputRole.WEIGHT,
)
else:
float8_tensor = cast_to_float8_e4m3_dynamic(
float8_tensor = cast_to_float8_dynamic(
self._tensor,
e4m3_dtype,
self._linear_mm_config,
reduce_amax=True,
gemm_input_role=GemmInputRole.WEIGHT,
Expand Down
1 change: 0 additions & 1 deletion test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_python_api import addmm_float8_unwrapped
from float8_experimental.float8_scaling_utils import cast_to_float8_e4m3_dynamic
from float8_experimental.float8_tensor import (
Float8Tensor,
GemmInputRole,
Expand Down

0 comments on commit 4fb2877

Please sign in to comment.