Skip to content

Commit 81caded

Browse files
authored
update gguf alg ext (#1026)
1 parent 6d3d87d commit 81caded

File tree

7 files changed

+66
-68
lines changed

7 files changed

+66
-68
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ See our [paper](https://arxiv.org/pdf/2309.05516) for more details. For usage in
3030

3131

3232
## 🆕 What's New
33+
[2025/11] AutoRound now offers preliminary support for an **enhanced GGUF quantization algorithm** via `--enable_alg_ext`. For detailed accuracy benchmarks, please refer to the accompanying [documentation](./docs/gguf_alg_ext_acc.md).
34+
3335
[2025/10] AutoRound has been integrated into **SGLang**. You can now run models in the AutoRound format directly using the latest SGLang later than v0.5.4.
3436

3537
[2025/10] We enhanced the RTN mode (--iters 0) to significantly reduce quantization cost compared to the default tuning mode. Check out [this doc](./docs/opt_rtn.md) for some accuracy results. If you don’t have sufficient resources, you can use this mode for 4-bit quantization.

auto_round/alg_ext.abi3.so

-293 KB
Binary file not shown.

auto_round/compressors/base.py

Lines changed: 42 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,16 @@ def __init__(
381381

382382
self.attention_mask = []
383383

384+
self.wrapper_block = wrapper_block
385+
if self.enable_alg_ext:
386+
try:
387+
logger.warning_once("using algorithm extension for quantization.")
388+
from auto_round.alg_ext import wrapper_autoround
389+
390+
wrapper_autoround(self)
391+
except (ImportError, ModuleNotFoundError):
392+
logger.error("algorithm extension import error, fallback to default mode")
393+
384394
def _gen_auto_scheme(
385395
self, model: torch.nn.Module, scheme: AutoScheme, dataset: str, device_map: Union[str, int, dict, torch.device]
386396
) -> dict[str, dict]:
@@ -2516,6 +2526,32 @@ def quantize_block(
25162526
input_ids, input_others = normalize_input(inputs)
25172527
return self._quantize_block(block, input_ids, input_others, q_input, device, auto_offload)
25182528

2529+
def _get_loss(
2530+
self,
2531+
output_q: torch.Tensor,
2532+
current_output: torch.Tensor,
2533+
indices: torch.Tensor,
2534+
mse_loss: Callable,
2535+
device: Union[str, torch.device] = "cpu",
2536+
):
2537+
if self.attention_mask:
2538+
tmp_attention_mask = [self.attention_mask[i] for i in indices]
2539+
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device)
2540+
tmp_attention_mask.unsqueeze_(-1)
2541+
else:
2542+
tmp_attention_mask = 1.0
2543+
if self.amp:
2544+
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
2545+
loss = mse_loss( # pylint: disable=not-callable
2546+
output_q * tmp_attention_mask, current_output * tmp_attention_mask
2547+
)
2548+
else:
2549+
loss = mse_loss( # pylint: disable=not-callable
2550+
output_q.to(torch.float32) * tmp_attention_mask,
2551+
current_output.to(torch.float32) * tmp_attention_mask,
2552+
)
2553+
return loss
2554+
25192555
def _quantize_block(
25202556
self,
25212557
block: torch.nn.Module,
@@ -2600,7 +2636,7 @@ def _quantize_block(
26002636
clear_memory(device_list=self.device_list)
26012637
input_ids = q_input
26022638

2603-
quantized_layer_names, unquantized_layer_names = wrapper_block(
2639+
quantized_layer_names, unquantized_layer_names = self.wrapper_block(
26042640
block,
26052641
self.enable_minmax_tuning,
26062642
self.enable_norm_bias_tuning,
@@ -2675,6 +2711,9 @@ def _quantize_block(
26752711
num_elm = self._get_current_num_elm(input_ids, whole_indices)
26762712

26772713
for i in range(self.iters):
2714+
if self.enable_alg_ext and self.data_type.endswith("dq"):
2715+
for n, m in block.named_modules():
2716+
m.cur_iter = i
26782717
total_loss = 0
26792718
if self.sampler == "rand":
26802719
whole_indices = torch.randperm(nsamples)[:global_batch_size]
@@ -2688,25 +2727,7 @@ def _quantize_block(
26882727

26892728
output_q = self._get_current_q_output(block, input_ids, input_others, indices, device, loss_device)
26902729

2691-
if self.attention_mask:
2692-
tmp_attention_mask = [self.attention_mask[i] for i in indices]
2693-
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(loss_device)
2694-
tmp_attention_mask.unsqueeze_(-1)
2695-
num_elm = torch.sum(tmp_attention_mask).item()
2696-
if num_elm == 0:
2697-
num_elm = 1
2698-
else:
2699-
tmp_attention_mask = 1.0
2700-
if self.amp:
2701-
with autocast(device_type=str(loss_device).split(":")[0], dtype=self.amp_dtype):
2702-
loss = mse_loss( # pylint: disable=not-callable
2703-
output_q * tmp_attention_mask, current_output * tmp_attention_mask
2704-
)
2705-
else:
2706-
loss = mse_loss( # pylint: disable=not-callable
2707-
output_q.to(torch.float32) * tmp_attention_mask,
2708-
current_output.to(torch.float32) * tmp_attention_mask,
2709-
)
2730+
loss = self._get_loss(output_q, current_output, indices, mse_loss, device)
27102731

27112732
total_loss += loss.item() / num_elm
27122733

@@ -2836,44 +2857,6 @@ def _quantize_blocks(
28362857
for i in range(len(input_others[key])):
28372858
to_dtype(input_others[key][i], tmp_dtype)
28382859

2839-
if (
2840-
self.sym
2841-
and self.enable_alg_ext
2842-
and self.super_group_size is None
2843-
and (
2844-
(self.data_type.startswith("int") and self.act_bits >= 8)
2845-
or self.data_type.startswith("mx")
2846-
or self.data_type.startswith("nv")
2847-
)
2848-
):
2849-
try:
2850-
from auto_round.alg_ext import quantize_block_ext
2851-
2852-
BaseCompressor.quantize_block_ext = quantize_block_ext
2853-
quantize_block = self.quantize_block_ext # must use self.quantize_block_ext
2854-
if self.bits > 2 and (not self.data_type.startswith("mx") or not self.data_type.startswith("nv")):
2855-
logger.warning(
2856-
"algorithm extension has only undergone limited validation on "
2857-
"INT2,mxfp4 and nvfp4; use with caution."
2858-
)
2859-
else:
2860-
logger.info("using algorithm extension for quantization.")
2861-
except (ImportError, ModuleNotFoundError):
2862-
logger.error("algorithm extension import error, fallback to default mode")
2863-
quantize_block = self._quantize_block
2864-
elif self.enable_alg_ext and self.data_type.endswith("dq"):
2865-
try:
2866-
from auto_round.alg_ext import dq_quantize_block_ext
2867-
2868-
BaseCompressor.dq_quantize_block_ext = dq_quantize_block_ext
2869-
quantize_block = self.dq_quantize_block_ext
2870-
logger.info("using algorithm extension for quantization.")
2871-
except (ImportError, ModuleNotFoundError):
2872-
logger.error("algorithm extension import error, fallback to default mode")
2873-
quantize_block = self._quantize_block
2874-
else:
2875-
quantize_block = self._quantize_block
2876-
28772860
if pbar is None:
28782861
pbar = tqdm(range(0, len(block_names), nblocks))
28792862

@@ -2891,7 +2874,7 @@ def _quantize_blocks(
28912874
m = WrapperMultiblock(modules)
28922875

28932876
m.config = model.config if hasattr(model, "config") else None
2894-
q_input, input_ids = quantize_block(
2877+
q_input, input_ids = self._quantize_block(
28952878
m,
28962879
input_ids,
28972880
input_others,

auto_round/data_type/int.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def quant_tensor_rtn_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5
7272
else:
7373
imatrix = imatrix.reshape(1, -1)
7474

75+
imatrix = reshape_pad_tensor_by_group_size(imatrix, group_size, val=1e-5)[0].view(1, -1)
7576
imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1)
7677
imatrix = imatrix.reshape(tensor.shape)
7778

auto_round/data_type/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from auto_round.utils import logger
2424

2525

26-
def reshape_pad_tensor_by_group_size(data: torch.Tensor, group_size: int):
26+
def reshape_pad_tensor_by_group_size(data: torch.Tensor, group_size: int, val: float = 0.0):
2727
"""Reshapes and pads the tensor to ensure that it can be quantized in groups of `group_size`.
2828
2929
This function adjusts the
@@ -55,7 +55,7 @@ def reshape_pad_tensor_by_group_size(data: torch.Tensor, group_size: int):
5555
return data, orig_shape, pad_len
5656
else:
5757
pad_len = (data.shape[1] + group_size - 1) // group_size * group_size - data.shape[1]
58-
data_new = torch.nn.functional.pad(data, (0, pad_len))
58+
data_new = torch.nn.functional.pad(data, (0, pad_len), value=val)
5959
data_new = data_new.reshape(-1, group_size)
6060
return data_new, orig_shape, pad_len
6161

docs/gguf_alg_ext_acc.md

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,21 @@ to stabilize accuracy during evaluation. All other settings follow the default c
88
|method|scheme|Llama-3.1-8B|Qwen2.5-7B-Instruct|Qwen3-8b|Qwen3-30B-A3B-Instruct-2507|
99
|:-----|:-----|:-----------|:------------------|:-------|:--------------------------|
1010
|**BF16** | - |0.6295(100%)|0.6571(100%) |0.6322(100%)|0.6746(100%) |
11-
| **original** | q2_k_s | 0.5535(87.92%)| 0.6266(95.35%)|0.5901(93.35%)|0.6386(94.66%)|
12-
| **enable_alg_ext** |q2_k_s|0.5740(91.18%)|0.6349(96.62%)|0.5962(94.31%)|0.6460(95.77%)|
13-
| **original** | q3_k_s | 0.6040(95.95%)|0.6382(97.12%)|0.6128(96.94%)|0.6598(97.82%)|
14-
| **enable_alg_ext** |q3_k_s|0.6081(96.59%)|0.6503(98.97%)|0.6252(98.89%)|0.6622(98.17%)|
15-
| **original** | q4_k_s | 0.6228(98.94%)|0.6560(99.83%)|0.6303(99.70%)|0.6762(100.24%)|
16-
| **enable_alg_ext** |q4_k_s|0.6239(99.11%)|0.6605(100.51%)|0.6320(99.98%)|0.6777(100.46%)|
11+
| **Optimized RTN** | q2_k_s | 0.5535(87.92%)| 0.6266(95.35%)|0.5901(93.35%)|0.6386(94.66%)|
12+
| **AutoRound+alg_ext** |q2_k_s|0.5740(91.18%)|0.6349(96.62%)|0.5962(94.31%)|0.6460(95.77%)|
13+
| **Optimized RTN** | q3_k_s | 0.6040(95.95%)|0.6382(97.12%)|0.6128(96.94%)|0.6598(97.82%)|
14+
| **AutoRound+alg_ext** |q3_k_s|0.6081(96.59%)|0.6503(98.97%)|0.6252(98.89%)|0.6622(98.17%)|
15+
| **Optimized RTN** | q3_k_m |0.6083(96.63%) |0.6418(97.68%)|0.6194(97.97%)||
16+
| **AutoRound+alg_ext** |q3_k_m|0.6127(97.33%)|0.6533(99.42%)|0.6197(98.02%)||
17+
| **Optimized RTN** | q4_k_s | 0.6228(98.94%)|0.6560(99.83%)|0.6303(99.70%)|0.6762(100.24%)|
18+
| **AutoRound+alg_ext** |q4_k_s|0.6239(99.11%)|0.6605(100.51%)|0.6320(99.98%)|0.6777(100.46%)|
19+
| **Optimized RTN** | q4_k_m |0.6252(99.32%) |0.6558(99.80%)|0.6296(99.59%)||
20+
| **AutoRound+alg_ext** |q4_k_m|0.6257(99.40%)|0.6575(100.06%)|0.6340(100.29%)||
21+
22+
**Time cost**
23+
|model |Optimized RTN |AutoRound+alg_ext|
24+
|:--------------------------|:-------------|:----------------|
25+
|Llama-3.1-8B |1m25s |29m43s |
26+
|Qwen2.5-7B-Instruct |1m20s |35m35s |
27+
|Qwen3-8b |1m29s |47m58s |
28+
|Qwen3-30B-A3B-Instruct-2507|25m12s |12h47m39s |

test/test_cpu/test_autoround.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ def test_alg_ext(self):
699699
ar.quantize()
700700

701701
def test_alg_ext_import(self):
702-
from auto_round.alg_ext import dq_quantize_block_ext, quantize_block_ext
702+
from auto_round.alg_ext import wrapper_autoround
703703

704704
def test_invalid_layer_config(self):
705705
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)