Skip to content

Commit 122b307

Browse files
authored
Fix torchao_convert, remove StretchedAffineQuantizedTensor (#3015)
1 parent c801f10 commit 122b307

File tree

5 files changed

+152
-162
lines changed

5 files changed

+152
-162
lines changed

test/prototype/test_parq.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,16 @@ def split_param_groups(model) -> tuple[list, list, list]:
5454
params_quant, params_embed, params_no_quant = [], [], []
5555

5656
def get_param_groups(model):
57+
seen_data_ptrs = set() # avoid duplicates in case of tied weights
5758
for module in model.children():
5859
is_linear = _is_linear(module)
5960
for n, p in module.named_parameters():
61+
if n == "weight":
62+
data_ptr = p.data_ptr()
63+
if data_ptr in seen_data_ptrs:
64+
continue
65+
seen_data_ptrs.add(data_ptr)
66+
6067
if is_linear and n == "weight":
6168
params_quant.append(p)
6269
elif isinstance(module, nn.Embedding) and n == "weight":
@@ -152,7 +159,12 @@ def compare_parq_convert(
152159
def check_torchao_tensor_subclass(
153160
test_case: common_utils.TestCase, model: nn.Module, weight_only: bool = False
154161
):
155-
for module in model.modules():
162+
for name, module in model.named_modules():
163+
if not hasattr(module, "weight") or f"{name}.weight" in getattr(
164+
model, "_tied_weights_keys", []
165+
):
166+
continue
167+
156168
if not weight_only and _is_linear(module):
157169
test_case.assertTrue(isinstance(module.weight, IntxUnpackedToInt8Tensor))
158170
test_case.assertTrue(
@@ -163,34 +175,58 @@ def check_torchao_tensor_subclass(
163175
test_case.assertTrue(module.weight.activation_quantization is None)
164176

165177

178+
def apply_activation_quantization(
179+
model: nn.Module, optimizer: torch.optim.Optimizer, model_dtype: torch.dtype
180+
):
181+
# apply torchao quantized activations on top
182+
activation_config = IntxFakeQuantizeConfig(
183+
torch.int8, "per_token", is_symmetric=False, scale_precision=model_dtype
184+
)
185+
qat_config = QATConfig(activation_config=activation_config, step="prepare")
186+
for filter_fn in optimizer.get_filter_fns(model):
187+
try:
188+
quantize_(model, qat_config, filter_fn=filter_fn)
189+
except ValueError as e:
190+
if str(e) == "Activation fake quantization is not supported for embedding":
191+
pass
192+
193+
166194
class M(nn.Module):
167-
def __init__(self, m=256, n=128, k=16, bias=False, embedding=True):
195+
_tied_weights_keys: list[str] = []
196+
197+
def __init__(
198+
self, m=256, n=128, k=16, bias=False, embedding=True, tied_weights=False
199+
):
168200
super().__init__()
169-
self.embedding = nn.Embedding(10, m) if embedding else nn.Identity()
201+
self.embedding = nn.Embedding(k, m) if embedding else nn.Identity()
170202
self.linear1 = nn.Linear(m, n, bias=bias)
171203
self.linear2 = nn.Linear(n, k, bias=bias)
172204
self.relu = nn.ReLU()
173205
self.sigmoid = nn.Sigmoid()
174206

207+
if embedding and tied_weights:
208+
assert self.embedding.weight.shape == self.linear2.weight.shape
209+
self.linear2.weight = self.embedding.weight
210+
self._tied_weights_keys.append("linear2.weight")
211+
175212
def reset_parameters(self):
176213
for module in (self.linear1, self.linear2):
177214
nn.init.xavier_uniform_(module.weight)
178215
if module.bias is not None:
179216
nn.init.zeros_(module.bias)
180217

181218
def example_inputs(self, device=None):
182-
return (
183-
torch.randint(1, 10, (1, self.linear1.in_features), device=device)
184-
if isinstance(self.embedding, nn.Embedding)
185-
else torch.randn(1, self.linear1.in_features, device=device)
186-
)
219+
if isinstance(self.embedding, nn.Identity):
220+
inputs = torch.randn(1, self.linear1.in_features, device=device)
221+
else:
222+
k = self.embedding.num_embeddings
223+
inputs = torch.randint(1, k, (1, self.linear1.in_features), device=device)
224+
return inputs
187225

188226
def forward(self, x):
189227
x = self.embedding(x)
190-
x = self.linear1(x)
191-
x = self.relu(x)
192-
x = self.linear2(x)
193-
x = self.sigmoid(x)
228+
x = self.relu(self.linear1(x))
229+
x = self.sigmoid(self.linear2(x))
194230
return x
195231

196232

@@ -297,7 +333,7 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
297333
ProxHardQuant(),
298334
quant_per_channel=True,
299335
)
300-
compare_parq_convert(model, m_ref, optimizer)
336+
compare_parq_convert(model, m_ref, optimizer, weight_only=True)
301337

302338
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
303339
@common_utils.parametrize("b", [2, 3, 4, 8])
@@ -399,6 +435,30 @@ def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
399435
compare_parq_convert(model, m_ref, optimizer, weight_only=True)
400436
check_torchao_tensor_subclass(self, model, weight_only=True)
401437

438+
@common_utils.parametrize("b", [2, 3])
439+
@common_utils.parametrize(
440+
"model_dtype", [torch.float16, torch.float32, torch.bfloat16]
441+
)
442+
def test_intx_weight_only_tied_embed_linear(
443+
self, b: int = 2, model_dtype: torch.dtype = torch.float32
444+
):
445+
model = M(m=256, n=256, tied_weights=True).to(_DEVICE)
446+
447+
quantizer = StretchedUnifTorchaoQuantizer(b)
448+
base_optimizer = torch.optim.SGD(build_param_groups(model, b))
449+
optimizer = QuantOptimizer(
450+
base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True
451+
)
452+
optimizer.zero_grad()
453+
optimizer.step()
454+
455+
apply_activation_quantization(model, optimizer, model_dtype)
456+
optimizer.torchao_convert(model)
457+
check_torchao_tensor_subclass(self, model)
458+
self.assertTrue(
459+
torch.equal(model.embedding.weight.qdata, model.linear2.weight.qdata)
460+
)
461+
402462

403463
class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase):
404464
def setUp(self):
@@ -435,16 +495,12 @@ def test_int8_dynamic_activation_intx_e2e(
435495
optimizer = QuantOptimizer(
436496
base_optimizer, quantizer, ProxHardQuant(), quant_per_channel=True
437497
)
498+
438499
optimizer.zero_grad()
439500
optimizer.step()
440501

441-
# apply torchao quantized activations on top
442-
activation_config = IntxFakeQuantizeConfig(
443-
torch.int8, "per_token", is_symmetric=False, scale_precision=model_dtype
444-
)
445-
qat_config = QATConfig(activation_config=activation_config, step="prepare")
446-
for filter_fn in optimizer.get_filter_fns(model):
447-
quantize_(model, qat_config, filter_fn=filter_fn)
502+
apply_activation_quantization(model, optimizer, model_dtype)
503+
448504
out = model(x)
449505
torch.testing.assert_close(out, ref_out, atol=0, rtol=0)
450506

@@ -462,7 +518,10 @@ def test_int8_dynamic_activation_intx_e2e(
462518
check_torchao_tensor_subclass(self, model)
463519

464520
if attach_hf_config:
465-
reg_param_names = {n for n, m in model.named_modules() if _is_linear(m)}
521+
reg_param_names = {
522+
n for n, m in model.named_modules() if isinstance(m, nn.Embedding)
523+
}
524+
reg_param_names.add("_default")
466525
module_fqn_to_config = (
467526
model.config.quantization_config.quant_type.module_fqn_to_config
468527
)

torchao/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]:
196196
"torchao.prototype.parq",
197197
"torchao.dtypes",
198198
"torchao.prototype.awq",
199+
"torchao.prototype.parq.quant",
199200
"torchao.quantization.quantize_.common",
200201
"torchao.quantization.quantize_.workflows",
201202
}

torchao/prototype/parq/optim/quantopt.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from torch.optim import Optimizer
1515

1616
from torchao.quantization import quantize_
17+
from torchao.quantization.quant_api import _is_linear
1718

1819
from ..quant import Quantizer, UnifTorchaoQuantizer
1920
from ..quant.config_torchao import (
@@ -158,24 +159,30 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None:
158159
self.restore_latent_params()
159160

160161
# TODO(lvj): find more robust way to identify embedding layers
161-
embed_data_ptrs = {
162-
module.weight.data_ptr()
163-
for module in model.modules()
164-
if isinstance(module, nn.Embedding)
165-
}
162+
embed_data_ptrs = set()
163+
linear_data_ptrs = set()
164+
for module in model.modules():
165+
if isinstance(module, nn.Embedding):
166+
embed_data_ptrs.add(module.weight.data_ptr())
167+
elif _is_linear(module) and module.weight.data_ptr() not in embed_data_ptrs:
168+
linear_data_ptrs.add(module.weight.data_ptr())
166169

167170
filter_fns = []
168171
configs = []
169172
attach_hf_config = _is_hf_model(model)
170-
for group, filter_fn in zip(
171-
self.regularized_param_groups(), self.get_filter_fns(model)
173+
all_linear_layers_idx = -1
174+
for i, (group, filter_fn) in enumerate(
175+
zip(self.regularized_param_groups(), self.get_filter_fns(model))
172176
):
173177
filter_fns.append(filter_fn)
174178
quantizer = group.get("quantizer", self.quantizer)
175179
if not isinstance(quantizer, UnifTorchaoQuantizer) or not group["params"]:
176180
configs.append(None)
177181
continue
178182

183+
if set((p.data_ptr() for p in group["params"])) == linear_data_ptrs:
184+
all_linear_layers_idx = i
185+
179186
device = group["params"][0].device
180187
any_embed = any(p.data_ptr() in embed_data_ptrs for p in group["params"])
181188
config = _get_config_from_quantizer(
@@ -187,10 +194,21 @@ def torchao_convert(self, model: nn.Module, weight_only: bool = False) -> None:
187194
)
188195
configs.append(config)
189196

197+
filter_fns_orig = filter_fns[:]
198+
configs_orig = configs[:]
199+
200+
# If one group has all the linear layers, then set its config as default
201+
if all_linear_layers_idx > -1:
202+
module_to_config = {"_default": configs[all_linear_layers_idx]}
203+
del filter_fns[all_linear_layers_idx]
204+
del configs[all_linear_layers_idx]
205+
else:
206+
module_to_config = None
207+
190208
if attach_hf_config:
191-
_attach_hf_quantization_config(model, filter_fns, configs)
209+
_attach_hf_quantization_config(model, filter_fns, configs, module_to_config)
192210

193-
for config, filter_fn in zip(configs, filter_fns):
211+
for config, filter_fn in zip(configs_orig, filter_fns_orig):
194212
quantize_(model, config, filter_fn=filter_fn)
195213

196214
@torch._disable_dynamo

0 commit comments

Comments
 (0)