14
14
15
15
import torch
16
16
17
- from float8_experimental .config import Float8LinearConfig
17
+ from float8_experimental .config import Float8LinearConfig , TensorScalingType
18
18
19
19
from float8_experimental .float8_dynamic_utils import (
20
20
cast_to_float8_e4m3_dynamic ,
@@ -148,18 +148,6 @@ def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):
148
148
), f"{ self .scale_fn_name } is not implemented yet. Only max is supported for now."
149
149
150
150
151
- class TensorScalingType (enum .Enum ):
152
- DELAYED = "delayed"
153
- DYNAMIC = "dynamic"
154
-
155
- def short_str (self ):
156
- if self is TensorScalingType .DELAYED :
157
- return "del"
158
- else :
159
- assert self is TensorScalingType .DYNAMIC
160
- return "dyn"
161
-
162
-
163
151
class Float8Linear (torch .nn .Linear ):
164
152
"""
165
153
A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks
@@ -170,9 +158,6 @@ def __init__(self, *args, **kwargs):
170
158
"""
171
159
Additional arguments on top of `torch.nn.Linear`'s arguments:
172
160
* `delayed_scaling_recipe`: configuration for delayed scaling
173
- * `scaling_type_input`: delayed vs dynamic scaling for `input`
174
- * `scaling_type_weight`: delayed vs dynamic scaling for `weight`
175
- * `scaling_type_grad_output`: delayed vs dynamic scaling for `grad_output`
176
161
* `config`: Float8LinearConfig
177
162
"""
178
163
@@ -182,20 +167,13 @@ def __init__(self, *args, **kwargs):
182
167
# Amax scales should always be kept as float32.
183
168
self .always_float32_buffers = set ()
184
169
emulate = kwargs .pop ("emulate" , False )
185
- scaling_type_input = kwargs .pop ("scaling_type_input" , TensorScalingType .DYNAMIC )
186
- scaling_type_weight = kwargs .pop (
187
- "scaling_type_weight" , TensorScalingType .DYNAMIC
188
- )
189
- scaling_type_grad_output = kwargs .pop (
190
- "scaling_type_grad_output" , TensorScalingType .DYNAMIC
191
- )
192
170
config = kwargs .pop ("config" )
193
171
super ().__init__ (* args , ** kwargs )
194
172
195
173
# Defines the scaling behavior of input, weight, grad_output
196
- self .scaling_type_input = scaling_type_input
197
- self .scaling_type_weight = scaling_type_weight
198
- self .scaling_type_grad_output = scaling_type_grad_output
174
+ self .scaling_type_input = config . cast_config_input . scaling_type
175
+ self .scaling_type_weight = config . cast_config_weight . scaling_type
176
+ self .scaling_type_grad_output = config . cast_config_grad_output . scaling_type
199
177
# Convenience flag to skip code related to delayed scaling
200
178
self .has_any_delayed_scaling = (
201
179
self .scaling_type_input is TensorScalingType .DELAYED
@@ -457,9 +435,6 @@ def from_float(
457
435
cls ,
458
436
mod ,
459
437
emulate : bool = False ,
460
- scaling_type_input = TensorScalingType .DYNAMIC ,
461
- scaling_type_weight = TensorScalingType .DYNAMIC ,
462
- scaling_type_grad_output = TensorScalingType .DYNAMIC ,
463
438
config : Optional [Float8LinearConfig ] = None ,
464
439
):
465
440
"""
@@ -477,9 +452,6 @@ def from_float(
477
452
mod .in_features ,
478
453
mod .out_features ,
479
454
bias = False ,
480
- scaling_type_input = scaling_type_input ,
481
- scaling_type_weight = scaling_type_weight ,
482
- scaling_type_grad_output = scaling_type_grad_output ,
483
455
emulate = emulate ,
484
456
config = config ,
485
457
)
@@ -495,15 +467,17 @@ def from_float(
495
467
# 2. buffers need to be already created for the delayed scaling version
496
468
# of the weight wrapper to be initialized
497
469
if config .enable_fsdp_fp8_all_gather :
498
- if scaling_type_weight is TensorScalingType .DYNAMIC :
470
+ if config . cast_config_weight . scaling_type is TensorScalingType .DYNAMIC :
499
471
new_mod .weight = torch .nn .Parameter (
500
472
WeightWithDynamicFloat8CastTensor (
501
473
new_mod .weight ,
502
474
new_mod .linear_mm_config ,
503
475
)
504
476
)
505
477
else :
506
- assert scaling_type_weight is TensorScalingType .DELAYED
478
+ assert (
479
+ config .cast_config_weight .scaling_type is TensorScalingType .DELAYED
480
+ )
507
481
new_mod .weight = torch .nn .Parameter (
508
482
WeightWithDelayedFloat8CastTensor (
509
483
new_mod .weight ,
0 commit comments