Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[6/x] clean up casting: rename delayed and dynamic casting functions
Browse files Browse the repository at this point in the history
Summary:

Renames the delayed and dynamic casting functions to
`hp_tensor_to_float8_delayed` and `hp_tensor_to_float8_dynamic`
to clarify what they are doing and how they are different from
`hp_tensor_and_scale_to_float8`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jul 26, 2024
1 parent b1dfe2b commit 3b786be
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 29 deletions.
14 changes: 8 additions & 6 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

from float8_experimental.float8_scaling_utils import (
_maybe_initialize_amaxes_scales_for_float8_cast,
cast_to_float8_delayed,
cast_to_float8_dynamic,
hp_tensor_to_float8_delayed,
hp_tensor_to_float8_dynamic,
NoopFwToFloat8E5M2BwDelayed,
NoopFwToFloat8E5M2BwDynamic,
)
Expand Down Expand Up @@ -260,7 +260,7 @@ def cast_input_to_float8(
is_amax_initialized,
reduce_amax=True,
)
input_fp8 = cast_to_float8_delayed(
input_fp8 = hp_tensor_to_float8_delayed(
input,
self.fp8_scale_input,
e4m3_dtype,
Expand All @@ -270,7 +270,9 @@ def cast_input_to_float8(
)
else:
assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = cast_to_float8_dynamic(input, e4m3_dtype, self.linear_mm_config)
input_fp8 = hp_tensor_to_float8_dynamic(
input, e4m3_dtype, self.linear_mm_config
)
return input_fp8

def cast_weight_to_float8(
Expand All @@ -292,7 +294,7 @@ def cast_weight_to_float8(
reduce_amax=False,
)

weight_fp8 = cast_to_float8_delayed(
weight_fp8 = hp_tensor_to_float8_delayed(
weight,
self.fp8_scale_weight,
e4m3_dtype,
Expand All @@ -305,7 +307,7 @@ def cast_weight_to_float8(
if isinstance(self.weight, Float8Tensor): # cast by FSDP
weight_fp8 = self.weight
else:
weight_fp8 = cast_to_float8_dynamic(
weight_fp8 = hp_tensor_to_float8_dynamic(
self.weight,
e4m3_dtype,
self.linear_mm_config,
Expand Down
55 changes: 42 additions & 13 deletions float8_experimental/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,66 @@
)


def cast_to_float8_dynamic(
inpt_tensor: torch.Tensor,
def hp_tensor_to_float8_dynamic(
hp_tensor: torch.Tensor,
float8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = tensor_to_scale(inpt_tensor, float8_dtype, reduce_amax)
"""
Given a high precision tensor `hp_tensor`,
scales `hp_tensor` dynamically and returns a `Float8Tensor` of the result.
Args:
hp_tensor: the tensor to convert
float8_dtype: the float8 dtype to use
linear_mm_config: Defines the configuration for the scaled_mm for
the 3 fwd/bwd gemms of linear
reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
the 3 fwd/bwd gemms of linear
"""
if tensor_already_casted_to_fp8(hp_tensor):
return hp_tensor
scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax)
return hp_tensor_and_scale_to_float8(
inpt_tensor,
hp_tensor,
scale,
float8_dtype,
linear_mm_config,
gemm_input_role,
)


def cast_to_float8_delayed(
tensor: torch.Tensor,
scale: torch.Tensor,
def hp_tensor_to_float8_delayed(
hp_tensor: torch.Tensor,
s: torch.Tensor,
float8_dtype: torch.dtype,
amax_buffer: torch.Tensor,
linear_mm_config: Optional[LinearMMConfig] = None,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
amax_buffer.fill_(tensor_to_amax(tensor))
) -> Float8Tensor:
"""
Given a high precision tensor `hp_tensor` and relevant metadata, scales it using
delayed scaling and returns a `Float8Tensor` of the result. Specifically:
1. calculates max(abs(hp_tensor)) and stores the result in `amax_buffer`, inplace
2. scales `hp_tensor` by `s` and returns the result wrapped in Float8Tensor
Args:
hp_tensor: the tensor to convert
s: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
amax_buffer: the buffer to modify inplace with max(abs(hp_tensor))
linear_mm_config: Defines the configuration for the scaled_mm for
the 3 fwd/bwd gemms of linear
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
the 3 fwd/bwd gemms of linear
"""
amax_buffer.fill_(tensor_to_amax(hp_tensor))
return hp_tensor_and_scale_to_float8(
tensor,
scale,
hp_tensor,
s,
float8_dtype,
linear_mm_config,
gemm_input_role,
Expand Down
8 changes: 4 additions & 4 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
from float8_experimental.config import ScalingType
from float8_experimental.float8_scaling_utils import (
cast_to_float8_dynamic,
hp_tensor_to_float8_dynamic,
NoopFwToFloat8E5M2BwDynamic,
)
from float8_experimental.float8_tensor import GemmInputRole
Expand Down Expand Up @@ -46,7 +46,7 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = cast_to_float8_dynamic(
input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
e4m3_dtype,
mod.linear_mm_config,
Expand Down Expand Up @@ -100,7 +100,7 @@ def _prepare_input_fn(
input_tensor, device_mesh, input_layouts, run_check=False
)

input_tensor = cast_to_float8_dynamic(
input_tensor = hp_tensor_to_float8_dynamic(
input_tensor,
e4m3_dtype,
mod.linear_mm_config,
Expand Down Expand Up @@ -199,7 +199,7 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
input, mesh, (input_layout,), run_check=False
)

dt_inp = cast_to_float8_dynamic(
dt_inp = hp_tensor_to_float8_dynamic(
dt_inp,
e4m3_dtype,
self.linear_mm_config,
Expand Down
8 changes: 4 additions & 4 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import torch.nn as nn
import torch.utils._pytree as pytree
from float8_experimental.float8_scaling_utils import (
cast_to_float8_delayed,
cast_to_float8_dynamic,
hp_tensor_to_float8_delayed,
hp_tensor_to_float8_dynamic,
)

from float8_experimental.float8_tensor import (
Expand Down Expand Up @@ -175,7 +175,7 @@ def fsdp_pre_all_gather(self, mesh):
GemmInputRole.WEIGHT,
)
else:
float8_tensor = cast_to_float8_dynamic(
float8_tensor = hp_tensor_to_float8_dynamic(
self._tensor,
e4m3_dtype,
self._linear_mm_config,
Expand Down Expand Up @@ -355,7 +355,7 @@ def fsdp_pre_all_gather(self, mesh):
)
self.is_amax_initialized = True

float8_tensor = cast_to_float8_delayed(
float8_tensor = hp_tensor_to_float8_delayed(
self._tensor,
self._scale_buffer,
e4m3_dtype,
Expand Down
4 changes: 2 additions & 2 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
get_float8_layers,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_scaling_utils import cast_to_float8_delayed
from float8_experimental.float8_scaling_utils import hp_tensor_to_float8_delayed
from float8_experimental.float8_tensor import LinearMMConfig
from float8_experimental.float8_utils import e4m3_dtype

Expand Down Expand Up @@ -179,7 +179,7 @@ def __init__(self, graph_break: bool):
self.graph_break = graph_break

def forward(self, x):
x_fp8 = cast_to_float8_delayed(
x_fp8 = hp_tensor_to_float8_delayed(
x,
self.fp8_scale_x,
e4m3_dtype,
Expand Down

0 comments on commit 3b786be

Please sign in to comment.