Skip to content

Commit cfaa23f

Browse files
AWQ Apply Scales Bugfix when smooth layer output length doesn't match balance layer input length (#1451)
### Summary We are hitting an edge case in AWQ we had not previously hit with the initial Llama/Qwen testing models. When a smooth layer's # of output_features does not match a balance layer's # of input_features, the code as it is currently will error out when trying to update the smooth layer's weights with `weights.div(scales)`, due to a shape mismatch error. We are hitting this in #1440 for Phi3 models, which include a mapping between the fused `qkv_proj` smooth layer and `o_proj` balance layer in AutoAWQ (see [here](https://github.com/casper-hansen/AutoAWQ/blob/main/awq/models/phi3.py#L51-L57)). The resolution in AutoAWQ is to only use the last rows of the smooth layer so that the shapes line up, as shown [here](https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/scale.py#L123). This PR includes that update, and with #1440 will allow Phi3 models to be quantizable with AWQModifier. Like with v_proj -> o_proj, if shapes don't match up, they will be excluded from resolved mappings. This allows [phi-3-mini](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/tree/main?show_file_info=model-00001-of-00002.safetensors) to include the mapping because `qkv_proj out_features == 3*o_proj in_features == 9216`, but excludes it from [phi-3-medium](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct/tree/main?show_file_info=model-00001-of-00006.safetensors) which has `qkv_proj out_features == 7680` and `o_proj in_features==5120`. If the mapping is included for phi-3-medium, the model blows up with wikitext eval perplexities >2000. This implementation was agreed upon with @anmarques . PS: I also moved `mul` & `div` to `mul_` & `div_`, to avoid unnecessary memory allocation. ------------- ### Test Plan With these changes and with #1440 , `examples/awq/llama_example.py` works with `"microsoft/Phi-3-mini-128k-instruct"` and produces similar results as when qkv_proj to o_proj mapping is included Without mapping: | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|-----:|---------------|---|------:|---|------| |wikitext| 2|none | 5|bits_per_byte |↓ | 0.6474|± | N/A| | | |none | 5|byte_perplexity|↓ | 1.5664|± | N/A| | | |none | 5|word_perplexity|↓ |11.0201|± | N/A| With mapping: | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|-----:|---------------|---|------:|---|------| |wikitext| 2|none | 5|bits_per_byte |↓ | 0.6482|± | N/A| | | |none | 5|byte_perplexity|↓ | 1.5672|± | N/A| | | |none | 5|word_perplexity|↓ |11.0527|± | N/A| I also confirmed re-running with `meta-llama/Llama-3.2-3B-Instruct` and `meta-llama/Llama-2-7b-hf` does not deviate in PPL scores from what is currently on `main` --------- Signed-off-by: Brian Dellabetta <[email protected]>
1 parent cc2b5d9 commit cfaa23f

File tree

1 file changed

+31
-12
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+31
-12
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -310,14 +310,24 @@ def _set_resolved_mappings(self, model: Module) -> None:
310310
if not balance_layer:
311311
continue
312312

313-
# exclude v_proj/o_proj mappings whose shapes are incompatible
313+
# exclude v_proj->o_proj mappings whose shapes are incompatible
314314
# https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777
315315
if (
316-
".v_proj" in layer_name
317-
and ".o_proj" in balance_name
318-
and isinstance(smooth_layer, torch.nn.Linear)
316+
isinstance(smooth_layer, torch.nn.Linear)
319317
and isinstance(balance_layer, torch.nn.Linear)
320-
and smooth_layer.weight.shape != balance_layer.weight.shape
318+
and ".o_proj" in balance_name
319+
and (
320+
(
321+
".v_proj" in layer_name
322+
and smooth_layer.out_features
323+
!= balance_layer.in_features
324+
)
325+
or (
326+
".qkv_proj" in layer_name
327+
and smooth_layer.out_features
328+
!= 3 * balance_layer.in_features
329+
)
330+
)
321331
):
322332
num_skipped_oproj_mappings += 1
323333
continue
@@ -466,33 +476,42 @@ def _apply_smoothing(self, model: Module) -> None:
466476
inp, w_mean, x_mean, module2inspect, balance_layers, fp16_output
467477
)
468478

469-
scales = best_scales
470-
471479
@torch.no_grad()
472480
def smooth(module):
473481
with align_module_device(module):
482+
scales = best_scales.to(module.weight.device)
474483
if module in balance_layers:
475-
module.weight.mul_(scales.view(1, -1).to(module.weight.device))
484+
update_offload_parameter(
485+
module,
486+
"weight",
487+
module.weight.mul_(scales.view(1, -1)),
488+
)
476489
elif module == smooth_layer:
477490
if module.weight.ndim == 1:
478491
update_offload_parameter(
479492
module,
480493
"weight",
481-
module.weight.div(scales.to(module.weight.device)),
494+
module.weight.div_(scales),
482495
)
483496
else:
497+
# NOTE: edge case when smooth layer number of out_features
498+
# is not equal to balance layer number of in_features
499+
# e.g. when fused qkv_proj is used to smooth o_proj
500+
# in this case, default to scaling the last output features
501+
# because the desired smooth layer is v_proj
502+
# https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/scale.py#L123
484503
update_offload_parameter(
485504
module,
486505
"weight",
487-
module.weight.div(
488-
scales.view(-1, 1).to(module.weight.device)
506+
module.weight[-scales.size(0) :].div_(
507+
scales.view(-1, 1)
489508
),
490509
)
491510
if hasattr(module, "bias") and module.bias is not None:
492511
update_offload_parameter(
493512
module,
494513
"bias",
495-
module.bias.div(scales.to(module.bias.device)),
514+
module.bias.div_(scales),
496515
)
497516

498517
parent = get_fsdp_parent(mapping.smooth_name, model)

0 commit comments

Comments
 (0)