From 46514b23c4b2f0105ab4cdfdcec6f8f337db2903 Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Sun, 1 Dec 2024 19:31:26 -0800 Subject: [PATCH] Support imatrix-guided quantization for NPU CW (#12468) * init commit * remove print * add interface * fix * fix * fix style --- .../ipex_llm/ggml/model/llama/llama_cpp.py | 29 +++++++++++++++++ .../ipex_llm/transformers/low_bit_linear.py | 13 ++++++-- .../src/ipex_llm/transformers/npu_model.py | 22 +++++++++---- .../transformers/npu_models/convert.py | 32 +++++++++++++++---- python/llm/src/ipex_llm/transformers/utils.py | 29 +++++++++++++---- 5 files changed, 104 insertions(+), 21 deletions(-) diff --git a/python/llm/src/ipex_llm/ggml/model/llama/llama_cpp.py b/python/llm/src/ipex_llm/ggml/model/llama/llama_cpp.py index 8308308045a..e98c2622ab2 100644 --- a/python/llm/src/ipex_llm/ggml/model/llama/llama_cpp.py +++ b/python/llm/src/ipex_llm/ggml/model/llama/llama_cpp.py @@ -1018,6 +1018,35 @@ def ggml_quantize_tensor_rtn( _lib.ggml_quantize_tensor_rtn.restype = ctypes.c_size_t +def ggml_quantize_tensor_rtn_with_weights( + src, # type: ctypes.Array[ctypes.c_float] # type: ignore + dst: ctypes.c_void_p, + scale_ptr, # type: ctypes.Array[ctypes.c_float] # type: ignore + qtype: ctypes.c_int, + n: ctypes.c_size_t, + k: ctypes.c_int, + hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore + scale_search: ctypes.c_bool, + weights, # type: ctypes.Array[ctypes.c_float] # type: ignore +) -> int: + return _lib.ggml_quantize_tensor_rtn_with_weights(src, dst, scale_ptr, qtype, n, k, + hist, scale_search, weights) + + +_lib.ggml_quantize_tensor_rtn_with_weights.argtypes = [ + ctypes.POINTER(ctypes.c_float), + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_float), + ctypes.c_int, + ctypes.c_size_t, + ctypes.c_int, + ctypes.POINTER(ctypes.c_int64), + ctypes.c_bool, + ctypes.POINTER(ctypes.c_float), +] +_lib.ggml_quantize_tensor_rtn_with_weights.restype = ctypes.c_size_t + + def ggml_type_size(qtype: ctypes.c_int) -> int: return _lib.ggml_type_size(qtype) diff --git a/python/llm/src/ipex_llm/transformers/low_bit_linear.py b/python/llm/src/ipex_llm/transformers/low_bit_linear.py index fb59160aa12..82fbdf6f506 100644 --- a/python/llm/src/ipex_llm/transformers/low_bit_linear.py +++ b/python/llm/src/ipex_llm/transformers/low_bit_linear.py @@ -246,8 +246,17 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]: if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]: scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float)) - ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n, - k, hist, enable_scale_search) + if imatrix is None: + ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n, + k, hist, enable_scale_search) + else: + imatrix = imatrix.data.data_ptr() + imatrix = ctypes.cast(imatrix, ctypes.POINTER(ctypes.c_float)) + ggml.ggml_quantize_tensor_rtn_with_weights(src, dst, scale_ptr, + qtype, n, + k, hist, + enable_scale_search, + imatrix) return dst_tensor, scale.type(torch.float16) else: ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search) diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 5cc15129125..2eb54e515b4 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -26,7 +26,7 @@ from transformers.configuration_utils import PretrainedConfig from ipex_llm.utils.common.log4Error import invalidInputError -from ipex_llm.transformers.utils import logger +from ipex_llm.transformers.utils import logger, load_imatrix_data from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post @@ -137,6 +137,12 @@ def from_pretrained(cls, *args, **kwargs): convert_model = kwargs.pop('convert_model', False) save_directory = kwargs.pop('save_directory', None) fuse_layers = kwargs.pop('fuse_layers', None) + imatrix_file = kwargs.pop('imatrix_file', None) + + if imatrix_file is not None: + imatrix_data = load_imatrix_data(imatrix_file) + else: + imatrix_data = None invalidInputError( quantization_group_size in [0, 32, 64, 128], @@ -205,7 +211,8 @@ def from_pretrained(cls, *args, **kwargs): "transpose_value_cache": transpose_value_cache, "convert_model": convert_model, "save_directory": save_directory, - "fuse_layers": fuse_layers + "fuse_layers": fuse_layers, + "imatrix_data": imatrix_data } model = cls.optimize_npu_model(*args, **optimize_kwargs) else: @@ -213,7 +220,8 @@ def from_pretrained(cls, *args, **kwargs): optimize_llm(model) with torch.no_grad(): cls.load_convert(qtype, model, "cpu", modules_to_not_convert, - quantization_group_size, *args, **kwargs) + quantization_group_size, imatrix_data=imatrix_data, + *args, **kwargs) if hasattr(model, "llm"): create_npu_kernels(model.llm) else: @@ -246,6 +254,7 @@ def optimize_npu_model(cls, *args, **kwargs): convert_model = kwargs.pop('convert_model', False) save_directory = kwargs.pop('save_directory', None) fuse_layers = kwargs.pop('fuse_layers', None) + imatrix_data = kwargs.pop('imatrix_data', None) if hasattr(model, "llm"): llm = model.llm @@ -258,7 +267,8 @@ def optimize_npu_model(cls, *args, **kwargs): optimize_llm_pre(model, qtype, mixed_precision, quantization_group_size=quantization_group_size) cls.load_convert(qtype, model, "cpu", modules_to_not_convert, - quantization_group_size, *args, **kwargs) + quantization_group_size, imatrix_data, + *args, **kwargs) create_npu_kernels(llm) model = model.eval() logger.info(f"Finish to convert model") @@ -305,12 +315,12 @@ def optimize_npu_model(cls, *args, **kwargs): @classmethod def load_convert(cls, q_k, optimize_model, device, modules_to_not_convert, - group_size=0, *arg, **kwarg): + group_size=0, imatrix_data=None, *arg, **kwarg): from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear replace_with_QuantizedLinear(optimize_model, q_k, device=device, modules_to_not_convert=modules_to_not_convert, - group_size=group_size) + group_size=group_size, imatrix=imatrix_data) @classmethod def load_convert_cpu(cls, q_k, optimize_model, device, modules_to_not_convert, diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 461ec731cdf..9cae68ae16f 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -19,11 +19,11 @@ import torch import importlib from ipex_llm.transformers.npu_models.linear import QuantizedLinear -import tempfile import time from typing import Callable, List, Optional from transformers import GenerationConfig, \ LogitsProcessorList, StoppingCriteriaList +from ipex_llm.transformers.utils import module_name_process def module_optimization(func) -> torch.nn.Module: @@ -39,7 +39,7 @@ def module_optimization(func) -> torch.nn.Module: """ def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert, - group_size=0, *args, **kwargs): + group_size=0, imatrix=None, full_name="", *args, **kwargs): """Recursively apply the optimization function. Args: @@ -49,23 +49,40 @@ def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert, """ for name, layer in model.named_children(): + if full_name == "": + cur_full_name = name + else: + cur_full_name = full_name + "." + name + cur_imatrix = None + if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"): + new_module_name, _, cur_module_name, dq_idx = module_name_process(cur_full_name) + if imatrix is not None and new_module_name in imatrix: + cur_imatrix = imatrix[new_module_name] + if cur_imatrix.shape[0] != layer.weight.shape[1]: + ws = layer.weight.shape[1] + cur_imatrix = cur_imatrix[ws * dq_idx: ws * (dq_idx + 1)] if name not in modules_to_not_convert: new_layer = func(layer, qtype, device, modules_to_not_convert, - group_size=group_size, *args, **kwargs) + group_size=group_size, imatrix=cur_imatrix, + *args, **kwargs) if new_layer: model.add_module(name, new_layer) wrapper(new_layer, qtype, device, modules_to_not_convert, - group_size=group_size, *args, **kwargs) + group_size=group_size, imatrix=imatrix, + full_name=cur_full_name, + *args, **kwargs) else: wrapper(layer, qtype, device, modules_to_not_convert, - group_size=group_size, *args, **kwargs) + group_size=group_size, imatrix=imatrix, + full_name=cur_full_name, + *args, **kwargs) return wrapper @module_optimization def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, - group_size): + group_size, imatrix): from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype from ipex_llm.ggml.quantize import ggml_tensor_qtype iqtype = ggml_tensor_qtype[qtype] @@ -79,7 +96,8 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0" qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32), iqtype, device=device, - enable_scale_search=enable_scale_search) + enable_scale_search=enable_scale_search, + imatrix=imatrix) return QuantizedLinear(qweights, scale, layer.bias, group_size=group_size) diff --git a/python/llm/src/ipex_llm/transformers/utils.py b/python/llm/src/ipex_llm/transformers/utils.py index 1e2228172d6..2ec0dcf456f 100644 --- a/python/llm/src/ipex_llm/transformers/utils.py +++ b/python/llm/src/ipex_llm/transformers/utils.py @@ -247,6 +247,10 @@ def module_name_process(full_module_name): else: super_module_name = None exp_id = None + new_module_name = None + layer = None + cur_module = None + dq_idx = None if super_module_name == 'block_sparse_moe': # handle mixtral moe here moe_mapping = {"w1": "gate", "w2": "down", "w3": "up"} @@ -265,11 +269,24 @@ def module_name_process(full_module_name): layer = module_name_list[2] cur_module = module_name_list[-1][:-5] new_module_name = '_'.join([layer, cur_module]) + elif len(module_name_list) == 6 and 'dq' in module_name_list[-1]: + # for NPU dq_list linear + layer = module_name_list[2] + cur_module = module_name_list[-1] + try: + dq_idx = int(cur_module[-2:]) + except: + dq_idx = int(cur_module[-1:]) + if cur_module[0] in 'qkvo': + cur_module = cur_module[0] + elif cur_module[:2] == "up": + cur_module = cur_module[:2] + elif cur_module[:4] == "gate" or cur_module[:4] == "down": + cur_module = cur_module[:4] + new_module_name = '_'.join([layer, cur_module]) elif len(module_name_list) == 1: new_module_name = module_name_list[0] - layer = None - cur_module = None - return new_module_name, layer, cur_module + return new_module_name, layer, cur_module, dq_idx def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_config=None): @@ -283,7 +300,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi if qtype in [ggml_tensor_qtype["gguf_iq2_xxs"], ggml_tensor_qtype["gguf_iq2_xs"], ggml_tensor_qtype["gguf_iq1_s"]]: # For quantization which needs importance matrix - new_module_name, layer, cur_module = module_name_process(full_module_name) + new_module_name, layer, cur_module, _ = module_name_process(full_module_name) # custom mixed quantization strategy if model_type == "mixtral": if cur_module == 'v': @@ -312,7 +329,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi if new_module_name == 'lm_head': cur_qtype = ggml_tensor_qtype['sym_int8'] elif qtype == ggml_tensor_qtype["q2_k"]: - new_module_name, layer, cur_module = module_name_process(full_module_name) + new_module_name, layer, cur_module, _ = module_name_process(full_module_name) if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]): # TODO: q2_k need others k-quants type here cur_qtype = ggml_tensor_qtype['q2_k'] @@ -325,7 +342,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi cur_qtype = ggml_tensor_qtype['sym_int8'] elif qtype > 100: # gguf mixed precision - new_module_name, layer, cur_module = module_name_process(full_module_name) + new_module_name, layer, cur_module, _ = module_name_process(full_module_name) num_hidden_layers = getattr(model_config, "num_hidden_layers", None) if qtype in [gguf_mixed_qtype["gguf_q4k_s"], gguf_mixed_qtype["gguf_q4k_m"]] and \ new_module_name == 'lm_head':