Skip to content

Commit 58b3d90

Browse files
authored
optimize vram for gguf and add momentum (#1031)
1 parent 81caded commit 58b3d90

File tree

7 files changed

+156
-34
lines changed

7 files changed

+156
-34
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ See our [paper](https://arxiv.org/pdf/2309.05516) for more details. For usage in
3030

3131

3232
## 🆕 What's New
33-
[2025/11] AutoRound now offers preliminary support for an **enhanced GGUF quantization algorithm** via `--enable_alg_ext`. For detailed accuracy benchmarks, please refer to the accompanying [documentation](./docs/gguf_alg_ext_acc.md).
33+
[2025/11] AutoRound now offers preliminary support for an enhanced GGUF quantization algorithm via `--enable_alg_ext`. For detailed accuracy benchmarks, please refer to the [documentation](./docs/gguf_alg_ext_acc.md).
3434

3535
[2025/10] AutoRound has been integrated into **SGLang**. You can now run models in the AutoRound format directly using the latest SGLang later than v0.5.4.
3636

@@ -46,8 +46,7 @@ refer to the documentation for accuracy [results](./docs/auto_scheme_acc.md) and
4646
for some accuracy results.
4747

4848
[2025/07] AutoRound now offers experimental support for **GGUF** format, and recommends using optimized RTN mode (--iters 0) for
49-
all bits other than 3 bits. **A more advanced algorithm** tailored for specific configurations may be available in
50-
v0.8.1.
49+
all bits other than 3 bits.
5150

5251
[2025/05] AutoRound has been integrated into **Transformers** and **vLLM**.
5352

@@ -192,7 +191,7 @@ ar.quantize_and_save(output_dir="./qmodel", format="auto_round")
192191
- **`layer_config` (dict)**: Configuration for weight quantization (default is `None`), mainly for mixed schemes.
193192

194193
##### Algorithm Settings
195-
- **`enable_alg_ext` (bool)**: Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`.
194+
- **`enable_alg_ext` (bool)**: [Experimental Feature] Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`.
196195
- **`disable_opt_rtn` (bool)**: Use pure RTN mode for specific schemes (e.g., GGUF and WOQ). Default is `False` (improved RTN enabled).
197196

198197
##### Tuning Process Parameters
@@ -208,6 +207,7 @@ ar.quantize_and_save(output_dir="./qmodel", format="auto_round")
208207
##### Device/Speed Configuration
209208
- **`enable_torch_compile` (bool)**: If no exception is raised, typically we recommend setting it to True for faster quantization with lower resource.
210209
- **`low_gpu_mem_usage` (bool)**: Whether to offload intermediate features to CPU at the cost of ~20% more tuning time (default is `False`).
210+
- **`low_cpu_mem_usage` (bool)**: [Experimental Feature]Whether to enable saving immediately to save ram usage (default is `False`).
211211
- **`device_map` (str|dict|int)**: The device to be used for tuning, e.g., `auto`, "cpu"`, `"cuda"`, `"0,1,2"` (default is `'0'`). When using "auto", it will try to use all available GPUs.
212212

213213
</details>

auto_round/__main__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,12 @@ def __init__(self, *args, **kwargs):
172172
type=float,
173173
help="Learning rate specifically for min-max tuning. " "If None, uses the same value as --lr. ",
174174
)
175+
tuning.add_argument(
176+
"--momentum",
177+
default=0,
178+
type=float,
179+
help="Momentum factor for the optimizer. Default is 0 (no momentum).",
180+
)
175181
tuning.add_argument(
176182
"--gradient_accumulate_steps",
177183
default=1,
@@ -591,6 +597,7 @@ def tune(args):
591597
extra_config=extra_config,
592598
layer_config=layer_config,
593599
model_dtype=args.model_dtype,
600+
momentum=args.momentum,
594601
)
595602

596603
model_name = args.model.rstrip("/")

auto_round/compressors/base.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def __init__(
193193
super_group_size, super_bits, scale_dtype ("fp16" etc.),
194194
nblocks, to_quant_block_names,
195195
enable_norm_bias_tuning, enable_quanted_input,
196-
disable_deterministic_algorithms, mllm, static_kv_dtype,enable_deterministic_algorithms
196+
disable_deterministic_algorithms, mllm, static_kv_dtype,enable_deterministic_algorithms,momentum
197197
Raises:
198198
ValueError: If invalid device is provided or tokenizer is missing for non-str model with iters > 0.
199199
RuntimeError: If model parameters are on meta device.
@@ -234,6 +234,7 @@ def __init__(
234234
enable_quanted_input: bool = kwargs.pop("enable_quanted_input", True)
235235
disable_deterministic_algorithms = kwargs.pop("disable_deterministic_algorithms", True)
236236
enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False)
237+
self.momentum = kwargs.pop("momentum", 0.0)
237238
static_kv_dtype = kwargs.pop("static_kv_dtype", None)
238239
model_dtype = kwargs.pop("model_dtype", None)
239240
device = kwargs.pop("device", None)
@@ -1567,11 +1568,12 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
15671568
# It is best to modify the model structure in the quantize function and check the format,
15681569
# because it may cause the gguf format to not be exported normally.
15691570
self.model = _handle_moe_model(self.model, formats=formats)
1570-
# Assign temporary names after replacing modules
1571-
for n, m in self.model.named_modules(): # TODO check if could removed
1571+
1572+
# Temporary names must be assigned after handle_moe_model;
1573+
# placing them earlier would cause them to be removed when the module is replaced.
1574+
for n, m in self.model.named_modules():
15721575
m.tmp_name = n
15731576

1574-
# TODO check scale_dtype
15751577
if not self.is_auto_scheme:
15761578
enable_gguf_official_mixed = True
15771579
else:
@@ -2661,12 +2663,24 @@ def _quantize_block(
26612663

26622664
lr = torch.tensor(self.lr)
26632665
minmax_lr = torch.tensor(self.minmax_lr)
2666+
is_adam = "adam" in self.__class__.__name__.lower()
2667+
2668+
extra_kwargs = {} if is_adam else {"momentum": self.momentum}
2669+
26642670
if self.enable_minmax_tuning:
2665-
optimizer = self.optimizer(
2666-
[{"params": round_params}, {"params": minmax_params, "lr": minmax_lr}], lr=lr, weight_decay=0
2667-
)
2671+
params = [
2672+
{"params": round_params},
2673+
{"params": minmax_params, "lr": minmax_lr},
2674+
]
26682675
else:
2669-
optimizer = self.optimizer(round_params, lr=lr, weight_decay=0)
2676+
params = round_params
2677+
2678+
optimizer = self.optimizer(
2679+
params,
2680+
lr=lr,
2681+
weight_decay=0,
2682+
**extra_kwargs,
2683+
)
26702684

26712685
if len(round_params) + len(minmax_params) <= 0:
26722686
dump_info = (

auto_round/data_type/gguf.py

Lines changed: 120 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from auto_round.export.export_to_gguf.packing import make_q3_quants, make_qx_quants
2222
from auto_round.logger import logger
2323
from auto_round.utils import get_reciprocal
24+
from auto_round.utils.device import clear_memory
2425

2526

2627
@register_dtype("int_sym_dq")
@@ -320,7 +321,7 @@ def _imatrix_handle_zero(imatrix: Union[torch.Tensor, float], weight: torch.Tens
320321

321322

322323
@torch.no_grad()
323-
def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None):
324+
def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None, split_num=1):
324325
super_bits = 4 if bits == 2 else 6
325326
super_group_size = 16 if bits == 2 else 8
326327
group_size = 16 if bits == 2 else 32
@@ -348,6 +349,7 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri
348349
nstep=params["nstep"],
349350
use_mad=params["use_mad"],
350351
weights=quant_weights,
352+
split_num=split_num,
351353
)
352354
scale = scale.to(scale_dtype)
353355
scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale)
@@ -428,16 +430,8 @@ def quant_tensor_gguf_asym_dq(
428430
Args:
429431
tensor (torch.Tensor): Input tensor to quantize.
430432
bits (int): Number of bits for quantization.
431-
group_size (int): Group size for per-group quantization.
432433
v (float): Perturbation added before rounding.
433-
min_scale (float): Minimum allowed scale value.
434-
max_scale (float): Maximum allowed scale value.
435434
scale_dtype (torch.dtype): Data type for quantized scale.
436-
tensor_min (torch.Tensor, optional): Minimum values for the tensor groups.
437-
tensor_max (torch.Tensor, optional): Maximum values for the tensor groups.
438-
q_scale_thresh (float): Threshold to clamp the quantized scale.
439-
super_group_size (int): Number of groups to bundle for secondary quantization.
440-
super_bits (int): Number of bits used in secondary quantization.
441435
imatrix (torch.Tensor, optional): Importance matrix for weighted quantization.
442436
443437
Returns:
@@ -446,10 +440,19 @@ def quant_tensor_gguf_asym_dq(
446440
orig_dtype = tensor.dtype
447441
maxq = 2**bits - 1
448442
group_size = 16 if bits == 2 else 32
443+
split_num = 1
444+
for dim in tensor.shape:
445+
if dim > 100_000:
446+
split_num = 16
447+
break
448+
449449
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
450+
450451
tensor = tensor.to(torch.float32)
451452
if scale is None:
452-
scale, wmin, d_scale, d_wmin = search_gguf_scale_min_asym(tensor, bits, scale_dtype, imatrix)
453+
scale, wmin, d_scale, d_wmin = search_gguf_scale_min_asym(
454+
tensor, bits, scale_dtype, imatrix, split_num=split_num
455+
)
453456

454457
inverse_scale = get_reciprocal(scale)
455458
int_w = torch.clamp(round_ste((tensor + wmin) * inverse_scale + v), 0, maxq)
@@ -458,7 +461,7 @@ def quant_tensor_gguf_asym_dq(
458461
return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin}
459462

460463

461-
def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None):
464+
def iterative_wls_quant_search_non_chunk(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None):
462465
"""Adapted from Llamacpp. Performs iterative weighted least squares quantization search.
463466
464467
Args:
@@ -526,6 +529,112 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
526529
return scale.to(torch.float32), -rmin.to(torch.float32)
527530

528531

532+
# TODO consolidate iterative_wls_quant_search_chunk and non-chunk
533+
def iterative_wls_quant_search_chunk(
534+
data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None, split_num=8
535+
):
536+
dtype = torch.float32
537+
data = data.to(dtype)
538+
maxq = 2**bits - 1
539+
minq = 0
540+
weights = 1.0 if weights is None else weights.to(dtype)
541+
542+
results_scale = []
543+
results_rmin = []
544+
chunk_size = (data.shape[0] + split_num - 1) // split_num
545+
for start in range(0, data.shape[0], chunk_size):
546+
end = min(start + chunk_size, data.shape[0])
547+
chunk = data[start:end]
548+
chunk_weights = weights if isinstance(weights, float) else weights[start:end]
549+
550+
rmin = torch.min(chunk, dim=1, keepdim=True)[0]
551+
rmax = torch.max(chunk, dim=1, keepdim=True)[0]
552+
sum_w = torch.sum(chunk_weights, dim=1, keepdim=True)
553+
sum_x = torch.sum(chunk_weights * chunk, dim=1, keepdim=True)
554+
scale = (rmax - rmin) / (maxq - minq)
555+
iscale = get_reciprocal(scale)
556+
quant_data = torch.clamp(torch.round(iscale * (chunk - rmin)), minq, maxq)
557+
diff = scale * quant_data + rmin - chunk
558+
best_mad = torch.sum(
559+
(chunk_weights * torch.abs(diff)) if use_mad else chunk_weights * torch.pow(diff, 2), dim=1, keepdim=True
560+
)
561+
562+
for is_ in range(nstep):
563+
factor = rrmin + rdelta * is_ + maxq - minq
564+
scale_new = (rmax - rmin) / factor
565+
iscale_new = get_reciprocal(scale_new)
566+
quant_data_new = torch.clamp(torch.round(iscale_new * (chunk - rmin)), minq, maxq)
567+
mul_weights_quant_data = chunk_weights * quant_data_new
568+
sum_l = torch.sum(mul_weights_quant_data, dim=-1, keepdim=True)
569+
sum_l2 = torch.sum(mul_weights_quant_data * quant_data_new, dim=-1, keepdim=True)
570+
sum_xl = torch.sum(mul_weights_quant_data * chunk, dim=-1, keepdim=True)
571+
D = sum_w * sum_l2 - torch.pow(sum_l, 2)
572+
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D
573+
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D
574+
this_min[this_min > 0] = 0
575+
this_scale[this_min > 0] = (sum_xl / sum_l2)[this_min > 0]
576+
reverse_this_scale = get_reciprocal(this_scale)
577+
quant_data = torch.clamp(torch.round(reverse_this_scale * (chunk - this_min)), minq, maxq)
578+
diff = this_scale * quant_data + this_min - chunk
579+
mad = torch.sum(
580+
(chunk_weights * torch.abs(diff)) if use_mad else chunk_weights * torch.pow(diff, 2),
581+
dim=-1,
582+
keepdim=True,
583+
)
584+
idx_to_replace = torch.where((mad < best_mad) & (D > 0))[0]
585+
best_mad[idx_to_replace] = mad[idx_to_replace]
586+
scale[idx_to_replace] = this_scale[idx_to_replace]
587+
rmin[idx_to_replace] = this_min[idx_to_replace]
588+
results_scale.append(scale.to(torch.float32))
589+
results_rmin.append(-rmin.to(torch.float32))
590+
if split_num > 1:
591+
clear_memory(device_list=[data.device])
592+
593+
return torch.cat(results_scale, dim=0), torch.cat(results_rmin, dim=0)
594+
595+
596+
def iterative_wls_quant_search(
597+
data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None, split_num=1
598+
):
599+
"""Adapted from Llamacpp. Performs iterative weighted least squares quantization search.
600+
601+
Args:
602+
data (torch.Tensor): Input tensor to quantize.
603+
bits (int): Number of quantization bits.
604+
rrmin (float): Initial range scaling factor.
605+
rdelta (float): Step size for range scaling.
606+
nstep (int): Number of search steps.
607+
use_mad (bool): Whether to use mean absolute deviation instead of squared error.
608+
weights (torch.Tensor): Weight matrix for each element.
609+
610+
Returns:
611+
Tuple: (Optimal scale tensor, optimal minimum value tensor)
612+
"""
613+
614+
# TODO this one should change to try catch later
615+
if split_num > 1:
616+
return iterative_wls_quant_search_chunk(
617+
data=data,
618+
bits=bits,
619+
rrmin=rrmin,
620+
rdelta=rdelta,
621+
nstep=nstep,
622+
use_mad=use_mad,
623+
weights=weights,
624+
split_num=split_num,
625+
)
626+
else:
627+
return iterative_wls_quant_search_non_chunk(
628+
data=data,
629+
bits=bits,
630+
rrmin=rrmin,
631+
rdelta=rdelta,
632+
nstep=nstep,
633+
use_mad=use_mad,
634+
weights=weights,
635+
)
636+
637+
529638
@torch.no_grad()
530639
def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype):
531640
from auto_round.export.export_to_gguf.config import K_SCALE_SIZE, QK_K
@@ -550,7 +659,6 @@ def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype):
550659
return scale
551660

552661

553-
#
554662
@register_dtype("rtn_int_sym_dq")
555663
def quant_tensor_gguf_sym_dq(
556664
tensor,
@@ -566,7 +674,6 @@ def quant_tensor_gguf_sym_dq(
566674
Args:
567675
tensor: Tensor containing the tensor to be quantized
568676
bits: Number of bits for quantization (e.g., 2, 3, 4, 8)
569-
group_size: Number of elements to share scale for quantization
570677
v: Rounding value perturbation
571678
min_scale: Minimum scale coefficient for tensor
572679
max_scale: Maximum scale coefficient for tensor

auto_round/data_type/int.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def quant_tensor_rtn_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5
7171
imatrix = 1.0
7272
else:
7373
imatrix = imatrix.reshape(1, -1)
74-
7574
imatrix = reshape_pad_tensor_by_group_size(imatrix, group_size, val=1e-5)[0].view(1, -1)
7675
imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1)
7776
imatrix = imatrix.reshape(tensor.shape)

auto_round/export/export_to_awq/utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,3 @@ def extra_repr(self) -> str:
316316
self.w_bit,
317317
self.group_size,
318318
)
319-
320-
321-
def clear_memory(weight=None):
322-
if weight is not None:
323-
del weight
324-
gc.collect()
325-
torch.cuda.empty_cache()

auto_round/export/export_to_gguf/packing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717

1818
from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, K_SCALE_SIZE, QK_K
19-
from auto_round.utils import get_reciprocal
19+
from auto_round.utils import clear_memory, get_reciprocal
2020

2121
GGML_QUANT_TYPE = {}
2222

@@ -66,6 +66,8 @@ def ggml_quant(
6666
wmin = wmin.to(device) if wmin is not None else wmin
6767
d_scale = d_scale.to(device) if d_scale is not None else d_scale
6868
d_wmin = d_wmin.to(device) if d_wmin is not None else d_wmin
69+
imatrix = imatrix.to(device) if imatrix is not None else imatrix
70+
clear_memory()
6971
new_data = quant_func(
7072
blocks, scale, zp=zp, wmin=wmin, d_scale=d_scale, d_wmin=d_wmin, imatrix=imatrix, original=original
7173
)

0 commit comments

Comments
 (0)