Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[6/x] clean up casting: rename delayed and dynamic casting functions #350

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

from float8_experimental.float8_scaling_utils import (
_maybe_initialize_amaxes_scales_for_float8_cast,
cast_to_float8_delayed,
cast_to_float8_dynamic,
hp_tensor_to_float8_delayed,
hp_tensor_to_float8_dynamic,
NoopFwToFloat8E5M2BwDelayed,
NoopFwToFloat8E5M2BwDynamic,
)
Expand Down Expand Up @@ -260,7 +260,7 @@ def cast_input_to_float8(
is_amax_initialized,
reduce_amax=True,
)
input_fp8 = cast_to_float8_delayed(
input_fp8 = hp_tensor_to_float8_delayed(
input,
self.fp8_scale_input,
e4m3_dtype,
Expand All @@ -270,7 +270,9 @@ def cast_input_to_float8(
)
else:
assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = cast_to_float8_dynamic(input, e4m3_dtype, self.linear_mm_config)
input_fp8 = hp_tensor_to_float8_dynamic(
input, e4m3_dtype, self.linear_mm_config
)
return input_fp8

def cast_weight_to_float8(
Expand All @@ -292,7 +294,7 @@ def cast_weight_to_float8(
reduce_amax=False,
)

weight_fp8 = cast_to_float8_delayed(
weight_fp8 = hp_tensor_to_float8_delayed(
weight,
self.fp8_scale_weight,
e4m3_dtype,
Expand All @@ -305,7 +307,7 @@ def cast_weight_to_float8(
if isinstance(self.weight, Float8Tensor): # cast by FSDP
weight_fp8 = self.weight
else:
weight_fp8 = cast_to_float8_dynamic(
weight_fp8 = hp_tensor_to_float8_dynamic(
self.weight,
e4m3_dtype,
self.linear_mm_config,
Expand Down
55 changes: 42 additions & 13 deletions float8_experimental/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,66 @@
)


def cast_to_float8_dynamic(
inpt_tensor: torch.Tensor,
def hp_tensor_to_float8_dynamic(
hp_tensor: torch.Tensor,
float8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, float8_dtype, reduce_amax)
"""
Given a high precision tensor `hp_tensor`,
scales `hp_tensor` dynamically and returns a `Float8Tensor` of the result.

Args:
hp_tensor: the tensor to convert
float8_dtype: the float8 dtype to use
linear_mm_config: Defines the configuration for the scaled_mm for
the 3 fwd/bwd gemms of linear
reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
the 3 fwd/bwd gemms of linear
"""
if tensor_already_casted_to_fp8(hp_tensor):
return hp_tensor
scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax)
return hp_tensor_and_scale_to_float8(
inpt_tensor,
hp_tensor,
scale,
float8_dtype,
linear_mm_config,
gemm_input_role,
)


def cast_to_float8_delayed(
tensor: torch.Tensor,
scale: torch.Tensor,
def hp_tensor_to_float8_delayed(
hp_tensor: torch.Tensor,
s: torch.Tensor,
float8_dtype: torch.dtype,
amax_buffer: torch.Tensor,
linear_mm_config: Optional[LinearMMConfig] = None,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
amax_buffer.fill_(tensor_to_amax(tensor))
) -> Float8Tensor:
"""
Given a high precision tensor `hp_tensor` and relevant metadata, scales it using
delayed scaling and returns a `Float8Tensor` of the result. Specifically:
1. calculates max(abs(hp_tensor)) and stores the result in `amax_buffer`, inplace
2. scales `hp_tensor` by `s` and returns the result wrapped in Float8Tensor

Args:
hp_tensor: the tensor to convert
s: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
amax_buffer: the buffer to modify inplace with max(abs(hp_tensor))
linear_mm_config: Defines the configuration for the scaled_mm for
the 3 fwd/bwd gemms of linear
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
the 3 fwd/bwd gemms of linear
"""
amax_buffer.fill_(tensor_to_amax(hp_tensor))
return hp_tensor_and_scale_to_float8(
tensor,
scale,
hp_tensor,
s,
float8_dtype,
linear_mm_config,
gemm_input_role,
Expand Down
8 changes: 4 additions & 4 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
from float8_experimental.config import ScalingType
from float8_experimental.float8_scaling_utils import (
cast_to_float8_dynamic,
hp_tensor_to_float8_dynamic,
NoopFwToFloat8E5M2BwDynamic,
)
from float8_experimental.float8_tensor import GemmInputRole
Expand Down Expand Up @@ -46,7 +46,7 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = cast_to_float8_dynamic(
input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
e4m3_dtype,
mod.linear_mm_config,
Expand Down Expand Up @@ -100,7 +100,7 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = cast_to_float8_dynamic(
input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
e4m3_dtype,
mod.linear_mm_config,
Expand Down Expand Up @@ -199,7 +199,7 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
input, mesh, (input_layout,), run_check=False
)

dt_inp = cast_to_float8_dynamic(
dt_inp = hp_tensor_to_float8_dynamic(
dt_inp,
e4m3_dtype,
self.linear_mm_config,
Expand Down
8 changes: 4 additions & 4 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import torch.nn as nn
import torch.utils._pytree as pytree
from float8_experimental.float8_scaling_utils import (
cast_to_float8_delayed,
cast_to_float8_dynamic,
hp_tensor_to_float8_delayed,
hp_tensor_to_float8_dynamic,
)

from float8_experimental.float8_tensor import (
Expand Down Expand Up @@ -175,7 +175,7 @@ def fsdp_pre_all_gather(self, mesh):
GemmInputRole.WEIGHT,
)
else:
float8_tensor = cast_to_float8_dynamic(
float8_tensor = hp_tensor_to_float8_dynamic(
self._tensor,
e4m3_dtype,
self._linear_mm_config,
Expand Down Expand Up @@ -355,7 +355,7 @@ def fsdp_pre_all_gather(self, mesh):
)
self.is_amax_initialized = True

float8_tensor = cast_to_float8_delayed(
float8_tensor = hp_tensor_to_float8_delayed(
self._tensor,
self._scale_buffer,
e4m3_dtype,
Expand Down
4 changes: 2 additions & 2 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
get_float8_layers,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_scaling_utils import cast_to_float8_delayed
from float8_experimental.float8_scaling_utils import hp_tensor_to_float8_delayed
from float8_experimental.float8_tensor import LinearMMConfig
from float8_experimental.float8_utils import e4m3_dtype

Expand Down Expand Up @@ -179,7 +179,7 @@ def __init__(self, graph_break: bool):
self.graph_break = graph_break

def forward(self, x):
x_fp8 = cast_to_float8_delayed(
x_fp8 = hp_tensor_to_float8_delayed(
x,
self.fp8_scale_x,
e4m3_dtype,
Expand Down
Loading