Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 7e0182f

Browse files
vkuzofacebook-github-bot
authored andcommitted
clean up casting functions: delete to_fp8_no_autograd (#347)
Summary: Pull Request resolved: #347 `ToFloat8ConstrFunc` was just calling `to_fp8_no_autograd`, unify them to reduce confusion. We can rename the function in a future PR, keeping PRs small for now. Reviewed By: drisspg Differential Revision: D60292694 fbshipit-source-id: e46e47368f786b158f5b81a17830b251d8f4b586
1 parent ab2b828 commit 7e0182f

File tree

2 files changed

+51
-79
lines changed

2 files changed

+51
-79
lines changed

float8_experimental/float8_scaling_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
LinearMMConfig,
1919
ScaledMMConfig,
2020
tensor_already_casted_to_fp8,
21-
to_fp8_no_autograd,
2221
ToFloat8ConstrFunc,
2322
)
2423

@@ -146,12 +145,12 @@ def backward(ctx, go):
146145

147146
fp8_amax_grad_output.fill_(tensor_to_amax(go))
148147

149-
res = to_fp8_no_autograd(
148+
res = ToFloat8ConstrFunc.apply(
150149
go,
151150
fp8_scale_grad_output,
152151
e5m2_dtype,
153-
linear_mm_config=ctx.linear_mm_config,
154-
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
152+
ctx.linear_mm_config,
153+
GemmInputRole.GRAD_OUTPUT,
155154
)
156155
empty_grads = None, None, None, None, None, None
157156
return res, *empty_grads
@@ -178,11 +177,11 @@ def backward(ctx, gradY):
178177
if tensor_already_casted_to_fp8(gradY):
179178
return gradY, None
180179
gradY_scale = tensor_to_scale(gradY, e5m2_dtype)
181-
fp8_tensor = to_fp8_no_autograd(
180+
fp8_tensor = ToFloat8ConstrFunc.apply(
182181
gradY,
183182
gradY_scale,
184183
e5m2_dtype,
185-
linear_mm_config=ctx.linear_mm_config,
186-
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
184+
ctx.linear_mm_config,
185+
GemmInputRole.GRAD_OUTPUT,
187186
)
188187
return fp8_tensor, None

float8_experimental/float8_tensor.py

Lines changed: 45 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -128,71 +128,6 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
128128
return False
129129

130130

131-
# TODO: rename to hp_tensor_and_scale_to_float8_tensor
132-
def to_fp8_no_autograd(
133-
x: torch.Tensor,
134-
x_scale: torch.Tensor,
135-
float8_dtype: torch.dtype,
136-
linear_mm_config: Optional[LinearMMConfig],
137-
gemm_input_role: Optional[GemmInputRole],
138-
) -> "Float8Tensor":
139-
"""Convert a tensor to float8 without autograd
140-
This is used in multiple places in the codebase to convert a tensor to float8
141-
142-
This function will apply the scaling, and then convert to a Float8Tensor
143-
144-
Note:
145-
We will call this function with a DTensor subclass. Ideally this would be an aten OP
146-
that DTensor could overload to ensure proper semantics. There are some techincal issues
147-
with that composing with FakeTensor, so we special case here.
148-
149-
DTensor Invariant: DTensor must always be the outer most tensor subclass
150-
151-
Args:
152-
x: the tensor to convert
153-
scale: the scale to use to convert the tensor
154-
float8_dtype: the float8 dtype to use
155-
linear_mm_config: Defines the configuration for the scaled_mm for
156-
the 3 fwd/bwd gemms of linear
157-
gemm_input_role: Defines the role of this tensor (x, w or dL_dY) in
158-
the 3 fwd/bwd gemms of linear
159-
"""
160-
x_scaled = x * x_scale
161-
bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype)
162-
163-
if isinstance(bits_fp8, DTensor):
164-
assert isinstance(
165-
x, DTensor
166-
), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor"
167-
bits_mesh = bits_fp8.device_mesh
168-
bits_placements = bits_fp8.placements
169-
local_bits = bits_fp8.to_local()
170-
local_scale = x_scale.to_local()
171-
inner_float8_tensor = Float8Tensor(
172-
local_bits,
173-
local_scale,
174-
x.dtype,
175-
linear_mm_config=linear_mm_config,
176-
gemm_input_role=gemm_input_role,
177-
)
178-
return DTensor.from_local(
179-
inner_float8_tensor,
180-
bits_mesh,
181-
bits_placements,
182-
run_check=False,
183-
shape=bits_fp8.size(),
184-
stride=bits_fp8.stride(),
185-
)
186-
187-
return Float8Tensor(
188-
bits_fp8,
189-
x_scale,
190-
x.dtype,
191-
linear_mm_config=linear_mm_config,
192-
gemm_input_role=gemm_input_role,
193-
)
194-
195-
196131
@torch._dynamo.allow_in_graph
197132
class ToFloat8ConstrFunc(torch.autograd.Function):
198133
"""
@@ -210,18 +145,56 @@ def forward(
210145
linear_mm_config: Optional[LinearMMConfig] = None,
211146
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
212147
):
213-
"""Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer.
214-
Args
148+
"""
149+
This function will apply the scaling, and then convert to a Float8Tensor
150+
151+
Note:
152+
We will call this function with a DTensor subclass. Ideally this would be an aten OP
153+
that DTensor could overload to ensure proper semantics. There are some techincal issues
154+
with that composing with FakeTensor, so we special case here.
155+
156+
DTensor Invariant: DTensor must always be the outer most tensor subclass
157+
158+
Args:
215159
tensor: the tensor to convert
216160
scale: the scale to use to convert the tensor
217-
float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn
218-
emulate: whether to emulate the matmuls in fp32
161+
float8_dtype: the float8 dtype to use
162+
linear_mm_config: Defines the configuration for the scaled_mm for
163+
the 3 fwd/bwd gemms of linear
164+
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
165+
the 3 fwd/bwd gemms of linear
219166
"""
167+
tensor_scaled = tensor * scale
168+
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
169+
170+
if isinstance(bits_fp8, DTensor):
171+
assert isinstance(
172+
scale, DTensor
173+
), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor"
174+
bits_mesh = bits_fp8.device_mesh
175+
bits_placements = bits_fp8.placements
176+
local_bits = bits_fp8.to_local()
177+
local_scale = scale.to_local()
178+
inner_float8_tensor = Float8Tensor(
179+
local_bits,
180+
local_scale,
181+
tensor.dtype,
182+
linear_mm_config=linear_mm_config,
183+
gemm_input_role=gemm_input_role,
184+
)
185+
return DTensor.from_local(
186+
inner_float8_tensor,
187+
bits_mesh,
188+
bits_placements,
189+
run_check=False,
190+
shape=bits_fp8.size(),
191+
stride=bits_fp8.stride(),
192+
)
220193

221-
return to_fp8_no_autograd(
222-
tensor,
194+
return Float8Tensor(
195+
bits_fp8,
223196
scale,
224-
float8_dtype,
197+
tensor.dtype,
225198
linear_mm_config=linear_mm_config,
226199
gemm_input_role=gemm_input_role,
227200
)

0 commit comments

Comments
 (0)