14
14
15
15
import torch
16
16
17
- from float8_experimental .config import Float8LinearConfig , TensorScalingType
17
+ from float8_experimental .config import Float8LinearConfig , ScalingType
18
18
19
19
from float8_experimental .float8_dynamic_utils import (
20
20
cast_to_float8_e4m3_dynamic ,
@@ -159,9 +159,9 @@ def __init__(self, *args, **kwargs):
159
159
self .scaling_type_grad_output = config .cast_config_grad_output .scaling_type
160
160
# Convenience flag to skip code related to delayed scaling
161
161
self .has_any_delayed_scaling = (
162
- self .scaling_type_input is TensorScalingType .DELAYED
163
- or self .scaling_type_weight is TensorScalingType .DELAYED
164
- or self .scaling_type_grad_output is TensorScalingType .DELAYED
162
+ self .scaling_type_input is ScalingType .DELAYED
163
+ or self .scaling_type_weight is ScalingType .DELAYED
164
+ or self .scaling_type_grad_output is ScalingType .DELAYED
165
165
)
166
166
167
167
self .config = config
@@ -284,7 +284,7 @@ def cast_input_to_float8(
284
284
autocast_dtype = torch .get_autocast_gpu_dtype ()
285
285
input = input .to (autocast_dtype )
286
286
287
- if self .scaling_type_input is TensorScalingType .DELAYED :
287
+ if self .scaling_type_input is ScalingType .DELAYED :
288
288
scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
289
289
_maybe_initialize_amaxes_scales_for_float8_cast (
290
290
input ,
@@ -305,14 +305,14 @@ def cast_input_to_float8(
305
305
gemm_input_role = GemmInputRole .INPUT ,
306
306
)
307
307
else :
308
- assert self .scaling_type_input is TensorScalingType .DYNAMIC
308
+ assert self .scaling_type_input is ScalingType .DYNAMIC
309
309
input_fp8 = cast_to_float8_e4m3_dynamic (input , self .linear_mm_config )
310
310
return input_fp8
311
311
312
312
def cast_weight_to_float8 (
313
313
self , weight : torch .Tensor , is_amax_initialized : bool
314
314
) -> torch .Tensor :
315
- if self .scaling_type_weight is TensorScalingType .DELAYED :
315
+ if self .scaling_type_weight is ScalingType .DELAYED :
316
316
if isinstance (self .weight , Float8Tensor ): # cast by FSDP
317
317
weight_fp8 = self .weight
318
318
else :
@@ -337,7 +337,7 @@ def cast_weight_to_float8(
337
337
gemm_input_role = GemmInputRole .WEIGHT ,
338
338
)
339
339
else :
340
- assert self .scaling_type_weight is TensorScalingType .DYNAMIC
340
+ assert self .scaling_type_weight is ScalingType .DYNAMIC
341
341
if isinstance (self .weight , Float8Tensor ): # cast by FSDP
342
342
weight_fp8 = self .weight
343
343
else :
@@ -349,7 +349,7 @@ def cast_weight_to_float8(
349
349
return weight_fp8
350
350
351
351
def cast_output_to_float8_in_bw (self , output : torch .Tensor ) -> torch .Tensor :
352
- if self .scaling_type_grad_output is TensorScalingType .DELAYED :
352
+ if self .scaling_type_grad_output is ScalingType .DELAYED :
353
353
scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
354
354
output = NoopFwToFloat8E5M2Bw .apply (
355
355
output ,
@@ -361,7 +361,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
361
361
self .linear_mm_config ,
362
362
)
363
363
else :
364
- assert self .scaling_type_grad_output is TensorScalingType .DYNAMIC
364
+ assert self .scaling_type_grad_output is ScalingType .DYNAMIC
365
365
output = cast_to_float8_e5m2_dynamic_bw (output , self .linear_mm_config )
366
366
return output
367
367
@@ -448,17 +448,15 @@ def from_float(
448
448
# 2. buffers need to be already created for the delayed scaling version
449
449
# of the weight wrapper to be initialized
450
450
if config .enable_fsdp_float8_all_gather :
451
- if config .cast_config_weight .scaling_type is TensorScalingType .DYNAMIC :
451
+ if config .cast_config_weight .scaling_type is ScalingType .DYNAMIC :
452
452
new_mod .weight = torch .nn .Parameter (
453
453
WeightWithDynamicFloat8CastTensor (
454
454
new_mod .weight ,
455
455
new_mod .linear_mm_config ,
456
456
)
457
457
)
458
458
else :
459
- assert (
460
- config .cast_config_weight .scaling_type is TensorScalingType .DELAYED
461
- )
459
+ assert config .cast_config_weight .scaling_type is ScalingType .DELAYED
462
460
new_mod .weight = torch .nn .Parameter (
463
461
WeightWithDelayedFloat8CastTensor (
464
462
new_mod .weight ,
0 commit comments