@@ -128,71 +128,6 @@ def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
128
128
return False
129
129
130
130
131
- # TODO: rename to hp_tensor_and_scale_to_float8_tensor
132
- def to_fp8_no_autograd (
133
- x : torch .Tensor ,
134
- x_scale : torch .Tensor ,
135
- float8_dtype : torch .dtype ,
136
- linear_mm_config : Optional [LinearMMConfig ],
137
- gemm_input_role : Optional [GemmInputRole ],
138
- ) -> "Float8Tensor" :
139
- """Convert a tensor to float8 without autograd
140
- This is used in multiple places in the codebase to convert a tensor to float8
141
-
142
- This function will apply the scaling, and then convert to a Float8Tensor
143
-
144
- Note:
145
- We will call this function with a DTensor subclass. Ideally this would be an aten OP
146
- that DTensor could overload to ensure proper semantics. There are some techincal issues
147
- with that composing with FakeTensor, so we special case here.
148
-
149
- DTensor Invariant: DTensor must always be the outer most tensor subclass
150
-
151
- Args:
152
- x: the tensor to convert
153
- scale: the scale to use to convert the tensor
154
- float8_dtype: the float8 dtype to use
155
- linear_mm_config: Defines the configuration for the scaled_mm for
156
- the 3 fwd/bwd gemms of linear
157
- gemm_input_role: Defines the role of this tensor (x, w or dL_dY) in
158
- the 3 fwd/bwd gemms of linear
159
- """
160
- x_scaled = x * x_scale
161
- bits_fp8 = to_fp8_saturated (x_scaled , float8_dtype )
162
-
163
- if isinstance (bits_fp8 , DTensor ):
164
- assert isinstance (
165
- x , DTensor
166
- ), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor"
167
- bits_mesh = bits_fp8 .device_mesh
168
- bits_placements = bits_fp8 .placements
169
- local_bits = bits_fp8 .to_local ()
170
- local_scale = x_scale .to_local ()
171
- inner_float8_tensor = Float8Tensor (
172
- local_bits ,
173
- local_scale ,
174
- x .dtype ,
175
- linear_mm_config = linear_mm_config ,
176
- gemm_input_role = gemm_input_role ,
177
- )
178
- return DTensor .from_local (
179
- inner_float8_tensor ,
180
- bits_mesh ,
181
- bits_placements ,
182
- run_check = False ,
183
- shape = bits_fp8 .size (),
184
- stride = bits_fp8 .stride (),
185
- )
186
-
187
- return Float8Tensor (
188
- bits_fp8 ,
189
- x_scale ,
190
- x .dtype ,
191
- linear_mm_config = linear_mm_config ,
192
- gemm_input_role = gemm_input_role ,
193
- )
194
-
195
-
196
131
@torch ._dynamo .allow_in_graph
197
132
class ToFloat8ConstrFunc (torch .autograd .Function ):
198
133
"""
@@ -210,18 +145,56 @@ def forward(
210
145
linear_mm_config : Optional [LinearMMConfig ] = None ,
211
146
gemm_input_role : Optional [GemmInputRole ] = GemmInputRole .INPUT ,
212
147
):
213
- """Autograd enabled wrapper around to_fp8_no_autograd that will also populate the amax buffer.
214
- Args
148
+ """
149
+ This function will apply the scaling, and then convert to a Float8Tensor
150
+
151
+ Note:
152
+ We will call this function with a DTensor subclass. Ideally this would be an aten OP
153
+ that DTensor could overload to ensure proper semantics. There are some techincal issues
154
+ with that composing with FakeTensor, so we special case here.
155
+
156
+ DTensor Invariant: DTensor must always be the outer most tensor subclass
157
+
158
+ Args:
215
159
tensor: the tensor to convert
216
160
scale: the scale to use to convert the tensor
217
- float8_dtype: the float8 dtype either, torch.float8_e4m3fn or torch.float8_e5m2fn
218
- emulate: whether to emulate the matmuls in fp32
161
+ float8_dtype: the float8 dtype to use
162
+ linear_mm_config: Defines the configuration for the scaled_mm for
163
+ the 3 fwd/bwd gemms of linear
164
+ gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
165
+ the 3 fwd/bwd gemms of linear
219
166
"""
167
+ tensor_scaled = tensor * scale
168
+ bits_fp8 = to_fp8_saturated (tensor_scaled , float8_dtype )
169
+
170
+ if isinstance (bits_fp8 , DTensor ):
171
+ assert isinstance (
172
+ scale , DTensor
173
+ ), "Expected Float8 scale to be a DTensor if bits_fp8 is a DTensor"
174
+ bits_mesh = bits_fp8 .device_mesh
175
+ bits_placements = bits_fp8 .placements
176
+ local_bits = bits_fp8 .to_local ()
177
+ local_scale = scale .to_local ()
178
+ inner_float8_tensor = Float8Tensor (
179
+ local_bits ,
180
+ local_scale ,
181
+ tensor .dtype ,
182
+ linear_mm_config = linear_mm_config ,
183
+ gemm_input_role = gemm_input_role ,
184
+ )
185
+ return DTensor .from_local (
186
+ inner_float8_tensor ,
187
+ bits_mesh ,
188
+ bits_placements ,
189
+ run_check = False ,
190
+ shape = bits_fp8 .size (),
191
+ stride = bits_fp8 .stride (),
192
+ )
220
193
221
- return to_fp8_no_autograd (
222
- tensor ,
194
+ return Float8Tensor (
195
+ bits_fp8 ,
223
196
scale ,
224
- float8_dtype ,
197
+ tensor . dtype ,
225
198
linear_mm_config = linear_mm_config ,
226
199
gemm_input_role = gemm_input_role ,
227
200
)
0 commit comments