@@ -54,9 +54,16 @@ def split_param_groups(model) -> tuple[list, list, list]:
54
54
params_quant , params_embed , params_no_quant = [], [], []
55
55
56
56
def get_param_groups (model ):
57
+ seen_data_ptrs = set () # avoid duplicates in case of tied weights
57
58
for module in model .children ():
58
59
is_linear = _is_linear (module )
59
60
for n , p in module .named_parameters ():
61
+ if n == "weight" :
62
+ data_ptr = p .data_ptr ()
63
+ if data_ptr in seen_data_ptrs :
64
+ continue
65
+ seen_data_ptrs .add (data_ptr )
66
+
60
67
if is_linear and n == "weight" :
61
68
params_quant .append (p )
62
69
elif isinstance (module , nn .Embedding ) and n == "weight" :
@@ -152,7 +159,12 @@ def compare_parq_convert(
152
159
def check_torchao_tensor_subclass (
153
160
test_case : common_utils .TestCase , model : nn .Module , weight_only : bool = False
154
161
):
155
- for module in model .modules ():
162
+ for name , module in model .named_modules ():
163
+ if not hasattr (module , "weight" ) or f"{ name } .weight" in getattr (
164
+ model , "_tied_weights_keys" , []
165
+ ):
166
+ continue
167
+
156
168
if not weight_only and _is_linear (module ):
157
169
test_case .assertTrue (isinstance (module .weight , IntxUnpackedToInt8Tensor ))
158
170
test_case .assertTrue (
@@ -163,34 +175,58 @@ def check_torchao_tensor_subclass(
163
175
test_case .assertTrue (module .weight .activation_quantization is None )
164
176
165
177
178
+ def apply_activation_quantization (
179
+ model : nn .Module , optimizer : torch .optim .Optimizer , model_dtype : torch .dtype
180
+ ):
181
+ # apply torchao quantized activations on top
182
+ activation_config = IntxFakeQuantizeConfig (
183
+ torch .int8 , "per_token" , is_symmetric = False , scale_precision = model_dtype
184
+ )
185
+ qat_config = QATConfig (activation_config = activation_config , step = "prepare" )
186
+ for filter_fn in optimizer .get_filter_fns (model ):
187
+ try :
188
+ quantize_ (model , qat_config , filter_fn = filter_fn )
189
+ except ValueError as e :
190
+ if str (e ) == "Activation fake quantization is not supported for embedding" :
191
+ pass
192
+
193
+
166
194
class M (nn .Module ):
167
- def __init__ (self , m = 256 , n = 128 , k = 16 , bias = False , embedding = True ):
195
+ _tied_weights_keys : list [str ] = []
196
+
197
+ def __init__ (
198
+ self , m = 256 , n = 128 , k = 16 , bias = False , embedding = True , tied_weights = False
199
+ ):
168
200
super ().__init__ ()
169
- self .embedding = nn .Embedding (10 , m ) if embedding else nn .Identity ()
201
+ self .embedding = nn .Embedding (k , m ) if embedding else nn .Identity ()
170
202
self .linear1 = nn .Linear (m , n , bias = bias )
171
203
self .linear2 = nn .Linear (n , k , bias = bias )
172
204
self .relu = nn .ReLU ()
173
205
self .sigmoid = nn .Sigmoid ()
174
206
207
+ if embedding and tied_weights :
208
+ assert self .embedding .weight .shape == self .linear2 .weight .shape
209
+ self .linear2 .weight = self .embedding .weight
210
+ self ._tied_weights_keys .append ("linear2.weight" )
211
+
175
212
def reset_parameters (self ):
176
213
for module in (self .linear1 , self .linear2 ):
177
214
nn .init .xavier_uniform_ (module .weight )
178
215
if module .bias is not None :
179
216
nn .init .zeros_ (module .bias )
180
217
181
218
def example_inputs (self , device = None ):
182
- return (
183
- torch .randint (1 , 10 , (1 , self .linear1 .in_features ), device = device )
184
- if isinstance (self .embedding , nn .Embedding )
185
- else torch .randn (1 , self .linear1 .in_features , device = device )
186
- )
219
+ if isinstance (self .embedding , nn .Identity ):
220
+ inputs = torch .randn (1 , self .linear1 .in_features , device = device )
221
+ else :
222
+ k = self .embedding .num_embeddings
223
+ inputs = torch .randint (1 , k , (1 , self .linear1 .in_features ), device = device )
224
+ return inputs
187
225
188
226
def forward (self , x ):
189
227
x = self .embedding (x )
190
- x = self .linear1 (x )
191
- x = self .relu (x )
192
- x = self .linear2 (x )
193
- x = self .sigmoid (x )
228
+ x = self .relu (self .linear1 (x ))
229
+ x = self .sigmoid (self .linear2 (x ))
194
230
return x
195
231
196
232
@@ -297,7 +333,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
297
333
ProxHardQuant (),
298
334
quant_per_channel = True ,
299
335
)
300
- compare_parq_convert (model , m_ref , optimizer )
336
+ compare_parq_convert (model , m_ref , optimizer , weight_only = True )
301
337
302
338
@unittest .skipIf (_DEVICE == "cpu" , "Need GPU available" )
303
339
@common_utils .parametrize ("b" , [2 , 3 , 4 , 8 ])
@@ -399,6 +435,30 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
399
435
compare_parq_convert (model , m_ref , optimizer , weight_only = True )
400
436
check_torchao_tensor_subclass (self , model , weight_only = True )
401
437
438
+ @common_utils .parametrize ("b" , [2 , 3 ])
439
+ @common_utils .parametrize (
440
+ "model_dtype" , [torch .float16 , torch .float32 , torch .bfloat16 ]
441
+ )
442
+ def test_intx_weight_only_tied_embed_linear (
443
+ self , b : int = 2 , model_dtype : torch .dtype = torch .float32
444
+ ):
445
+ model = M (m = 256 , n = 256 , tied_weights = True ).to (_DEVICE )
446
+
447
+ quantizer = StretchedUnifTorchaoQuantizer (b )
448
+ base_optimizer = torch .optim .SGD (build_param_groups (model , b ))
449
+ optimizer = QuantOptimizer (
450
+ base_optimizer , quantizer , ProxHardQuant (), quant_per_channel = True
451
+ )
452
+ optimizer .zero_grad ()
453
+ optimizer .step ()
454
+
455
+ apply_activation_quantization (model , optimizer , model_dtype )
456
+ optimizer .torchao_convert (model )
457
+ check_torchao_tensor_subclass (self , model )
458
+ self .assertTrue (
459
+ torch .equal (model .embedding .weight .qdata , model .linear2 .weight .qdata )
460
+ )
461
+
402
462
403
463
class TestInt8DynamicActivationTorchaoQuantizer (common_utils .TestCase ):
404
464
def setUp (self ):
@@ -435,16 +495,12 @@ def test_int8_dynamic_activation_intx_e2e(
435
495
optimizer = QuantOptimizer (
436
496
base_optimizer , quantizer , ProxHardQuant (), quant_per_channel = True
437
497
)
498
+
438
499
optimizer .zero_grad ()
439
500
optimizer .step ()
440
501
441
- # apply torchao quantized activations on top
442
- activation_config = IntxFakeQuantizeConfig (
443
- torch .int8 , "per_token" , is_symmetric = False , scale_precision = model_dtype
444
- )
445
- qat_config = QATConfig (activation_config = activation_config , step = "prepare" )
446
- for filter_fn in optimizer .get_filter_fns (model ):
447
- quantize_ (model , qat_config , filter_fn = filter_fn )
502
+ apply_activation_quantization (model , optimizer , model_dtype )
503
+
448
504
out = model (x )
449
505
torch .testing .assert_close (out , ref_out , atol = 0 , rtol = 0 )
450
506
@@ -462,7 +518,10 @@ def test_int8_dynamic_activation_intx_e2e(
462
518
check_torchao_tensor_subclass (self , model )
463
519
464
520
if attach_hf_config :
465
- reg_param_names = {n for n , m in model .named_modules () if _is_linear (m )}
521
+ reg_param_names = {
522
+ n for n , m in model .named_modules () if isinstance (m , nn .Embedding )
523
+ }
524
+ reg_param_names .add ("_default" )
466
525
module_fqn_to_config = (
467
526
model .config .quantization_config .quant_type .module_fqn_to_config
468
527
)
0 commit comments