Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
[3/x] clean up casting functions: delete to_fp8_no_autograd
Browse files Browse the repository at this point in the history
Summary:

`ToFloat8ConstrFunc` was just calling `to_fp8_no_autograd`,
unify them to reduce confusion.

We can rename the function in a future PR, keeping PRs small for now.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: f5235acc53b32879dca1e3411acb50b1f32b1fd6
Pull Request resolved: #347
  • Loading branch information
vkuzo committed Jul 26, 2024
1 parent 55dbc70 commit d6c3720
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 79 deletions.
13 changes: 6 additions & 7 deletions float8_experimental/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
LinearMMConfig,
ScaledMMConfig,
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
ToFloat8ConstrFunc,
)

Expand Down Expand Up @@ -146,12 +145,12 @@ def backward(ctx, go):

fp8_amax_grad_output.fill_(tensor_to_amax(go))

res = to_fp8_no_autograd(
res = ToFloat8ConstrFunc.apply(
go,
fp8_scale_grad_output,
e5m2_dtype,
linear_mm_config=ctx.linear_mm_config,
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
ctx.linear_mm_config,
GemmInputRole.GRAD_OUTPUT,
)
empty_grads = None, None, None, None, None, None
return res, *empty_grads
Expand All @@ -178,11 +177,11 @@ def backward(ctx, gradY):
if tensor_already_casted_to_fp8(gradY):
return gradY, None
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
fp8_tensor = to_fp8_no_autograd(
fp8_tensor = ToFloat8ConstrFunc.apply(
gradY,
gradY_scale,
e5m2_dtype,
linear_mm_config=ctx.linear_mm_config,
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
ctx.linear_mm_config,
GemmInputRole.GRAD_OUTPUT,
)
return fp8_tensor, None
117 changes: 45 additions & 72 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,71 +128,6 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
return False


# TODO: rename to hp_tensor_and_scale_to_float8_tensor
def to_fp8_no_autograd(
x: torch.Tensor,
x_scale: torch.Tensor,
float8_dtype: torch.dtype,
linear_mm_config: Optional[LinearMMConfig],
gemm_input_role: Optional[GemmInputRole],
) -> "Float8Tensor":
"""Convert a tensor to float8 without autograd
This is used in multiple places in the codebase to convert a tensor to float8
This function will apply the scaling, and then convert to a Float8Tensor
Note:
We will call this function with a DTensor subclass. Ideally this would be an aten OP
that DTensor could overload to ensure proper semantics. There are some techincal issues
with that composing with FakeTensor, so we special case here.
DTensor Invariant: DTensor must always be the outer most tensor subclass
Args:
x: the tensor to convert
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype to use
linear_mm_config: Defines the configuration for the scaled_mm for
the 3 fwd/bwd gemms of linear
gemm_input_role: Defines the role of this tensor (x, w or dL_dY) in
the 3 fwd/bwd gemms of linear
"""
x_scaled = x * x_scale
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)

if isinstance(bits_fp8, DTensor):
assert isinstance(
x, DTensor
), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor"
bits_mesh = bits_fp8.device_mesh
bits_placements = bits_fp8.placements
local_bits = bits_fp8.to_local()
local_scale = x_scale.to_local()
inner_float8_tensor = Float8Tensor(
local_bits,
local_scale,
x.dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)
return DTensor.from_local(
inner_float8_tensor,
bits_mesh,
bits_placements,
run_check=False,
shape=bits_fp8.size(),
stride=bits_fp8.stride(),
)

return Float8Tensor(
bits_fp8,
x_scale,
x.dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)


@torch._dynamo.allow_in_graph
class ToFloat8ConstrFunc(torch.autograd.Function):
"""
Expand All @@ -210,18 +145,56 @@ def forward(
linear_mm_config: Optional[LinearMMConfig] = None,
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
):
"""Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer.
Args
"""
This function will apply the scaling, and then convert to a Float8Tensor
Note:
We will call this function with a DTensor subclass. Ideally this would be an aten OP
that DTensor could overload to ensure proper semantics. There are some techincal issues
with that composing with FakeTensor, so we special case here.
DTensor Invariant: DTensor must always be the outer most tensor subclass
Args:
tensor: the tensor to convert
scale: the scale to use to convert the tensor
float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn
emulate: whether to emulate the matmuls in fp32
float8_dtype: the float8 dtype to use
linear_mm_config: Defines the configuration for the scaled_mm for
the 3 fwd/bwd gemms of linear
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
the 3 fwd/bwd gemms of linear
"""
tensor_scaled = tensor * scale
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)

if isinstance(bits_fp8, DTensor):
assert isinstance(
x, DTensor
), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor"
bits_mesh = bits_fp8.device_mesh
bits_placements = bits_fp8.placements
local_bits = bits_fp8.to_local()
local_scale = scale.to_local()
inner_float8_tensor = Float8Tensor(
local_bits,
local_scale,
tensor.dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)
return DTensor.from_local(
inner_float8_tensor,
bits_mesh,
bits_placements,
run_check=False,
shape=bits_fp8.size(),
stride=bits_fp8.stride(),
)

return to_fp8_no_autograd(
tensor,
return Float8Tensor(
bits_fp8,
scale,
float8_dtype,
tensor.dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)
Expand Down

0 comments on commit d6c3720

Please sign in to comment.