Skip to content

Commit

Permalink
Support imatrix-guided quantization for NPU CW (intel-analytics#12468)
Browse files Browse the repository at this point in the history
* init commit

* remove print

* add interface

* fix

* fix

* fix style
  • Loading branch information
rnwang04 authored and przemekmatusiak committed Dec 10, 2024
1 parent 6771e9e commit 46514b2
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 21 deletions.
29 changes: 29 additions & 0 deletions python/llm/src/ipex_llm/ggml/model/llama/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,35 @@ def ggml_quantize_tensor_rtn(
_lib.ggml_quantize_tensor_rtn.restype = ctypes.c_size_t


def ggml_quantize_tensor_rtn_with_weights(
src, # type: ctypes.Array[ctypes.c_float] # type: ignore
dst: ctypes.c_void_p,
scale_ptr, # type: ctypes.Array[ctypes.c_float] # type: ignore
qtype: ctypes.c_int,
n: ctypes.c_size_t,
k: ctypes.c_int,
hist, # type: ctypes.Array[ctypes.c_int64] # type: ignore
scale_search: ctypes.c_bool,
weights, # type: ctypes.Array[ctypes.c_float] # type: ignore
) -> int:
return _lib.ggml_quantize_tensor_rtn_with_weights(src, dst, scale_ptr, qtype, n, k,
hist, scale_search, weights)


_lib.ggml_quantize_tensor_rtn_with_weights.argtypes = [
ctypes.POINTER(ctypes.c_float),
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_float),
ctypes.c_int,
ctypes.c_size_t,
ctypes.c_int,
ctypes.POINTER(ctypes.c_int64),
ctypes.c_bool,
ctypes.POINTER(ctypes.c_float),
]
_lib.ggml_quantize_tensor_rtn_with_weights.restype = ctypes.c_size_t


def ggml_type_size(qtype: ctypes.c_int) -> int:
return _lib.ggml_type_size(qtype)

Expand Down
13 changes: 11 additions & 2 deletions python/llm/src/ipex_llm/transformers/low_bit_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,17 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int,
if qtype not in [IQ2_XXS, IQ2_XS, Q2_K, IQ1_S, Q4_K, Q6_K, Q5_K, FP6_K]:
if qtype in [SYM_INT8_RTN, SYM_INT4_RTN]:
scale_ptr = ctypes.cast(scale.data.data_ptr(), ctypes.POINTER(ctypes.c_float))
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
k, hist, enable_scale_search)
if imatrix is None:
ggml.ggml_quantize_tensor_rtn(src, dst, scale_ptr, qtype, n,
k, hist, enable_scale_search)
else:
imatrix = imatrix.data.data_ptr()
imatrix = ctypes.cast(imatrix, ctypes.POINTER(ctypes.c_float))
ggml.ggml_quantize_tensor_rtn_with_weights(src, dst, scale_ptr,
qtype, n,
k, hist,
enable_scale_search,
imatrix)
return dst_tensor, scale.type(torch.float16)
else:
ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist, enable_scale_search)
Expand Down
22 changes: 16 additions & 6 deletions python/llm/src/ipex_llm/transformers/npu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from transformers.configuration_utils import PretrainedConfig

from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.transformers.utils import logger
from ipex_llm.transformers.utils import logger, load_imatrix_data
from ipex_llm.transformers.npu_models.convert import optimize_llm, optimize_llm_post


Expand Down Expand Up @@ -137,6 +137,12 @@ def from_pretrained(cls, *args, **kwargs):
convert_model = kwargs.pop('convert_model', False)
save_directory = kwargs.pop('save_directory', None)
fuse_layers = kwargs.pop('fuse_layers', None)
imatrix_file = kwargs.pop('imatrix_file', None)

if imatrix_file is not None:
imatrix_data = load_imatrix_data(imatrix_file)
else:
imatrix_data = None

invalidInputError(
quantization_group_size in [0, 32, 64, 128],
Expand Down Expand Up @@ -205,15 +211,17 @@ def from_pretrained(cls, *args, **kwargs):
"transpose_value_cache": transpose_value_cache,
"convert_model": convert_model,
"save_directory": save_directory,
"fuse_layers": fuse_layers
"fuse_layers": fuse_layers,
"imatrix_data": imatrix_data
}
model = cls.optimize_npu_model(*args, **optimize_kwargs)
else:
from ipex_llm.transformers.npu_models.convert import optimize_llm
optimize_llm(model)
with torch.no_grad():
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
quantization_group_size, *args, **kwargs)
quantization_group_size, imatrix_data=imatrix_data,
*args, **kwargs)
if hasattr(model, "llm"):
create_npu_kernels(model.llm)
else:
Expand Down Expand Up @@ -246,6 +254,7 @@ def optimize_npu_model(cls, *args, **kwargs):
convert_model = kwargs.pop('convert_model', False)
save_directory = kwargs.pop('save_directory', None)
fuse_layers = kwargs.pop('fuse_layers', None)
imatrix_data = kwargs.pop('imatrix_data', None)

if hasattr(model, "llm"):
llm = model.llm
Expand All @@ -258,7 +267,8 @@ def optimize_npu_model(cls, *args, **kwargs):
optimize_llm_pre(model, qtype, mixed_precision,
quantization_group_size=quantization_group_size)
cls.load_convert(qtype, model, "cpu", modules_to_not_convert,
quantization_group_size, *args, **kwargs)
quantization_group_size, imatrix_data,
*args, **kwargs)
create_npu_kernels(llm)
model = model.eval()
logger.info(f"Finish to convert model")
Expand Down Expand Up @@ -305,12 +315,12 @@ def optimize_npu_model(cls, *args, **kwargs):

@classmethod
def load_convert(cls, q_k, optimize_model, device, modules_to_not_convert,
group_size=0, *arg, **kwarg):
group_size=0, imatrix_data=None, *arg, **kwarg):
from ipex_llm.transformers.npu_models.convert import replace_with_QuantizedLinear

replace_with_QuantizedLinear(optimize_model, q_k, device=device,
modules_to_not_convert=modules_to_not_convert,
group_size=group_size)
group_size=group_size, imatrix=imatrix_data)

@classmethod
def load_convert_cpu(cls, q_k, optimize_model, device, modules_to_not_convert,
Expand Down
32 changes: 25 additions & 7 deletions python/llm/src/ipex_llm/transformers/npu_models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import torch
import importlib
from ipex_llm.transformers.npu_models.linear import QuantizedLinear
import tempfile
import time
from typing import Callable, List, Optional
from transformers import GenerationConfig, \
LogitsProcessorList, StoppingCriteriaList
from ipex_llm.transformers.utils import module_name_process


def module_optimization(func) -> torch.nn.Module:
Expand All @@ -39,7 +39,7 @@ def module_optimization(func) -> torch.nn.Module:
"""

def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert,
group_size=0, *args, **kwargs):
group_size=0, imatrix=None, full_name="", *args, **kwargs):
"""Recursively apply the optimization function.
Args:
Expand All @@ -49,23 +49,40 @@ def wrapper(model: torch.nn.Module, qtype, device, modules_to_not_convert,
"""
for name, layer in model.named_children():
if full_name == "":
cur_full_name = name
else:
cur_full_name = full_name + "." + name
cur_imatrix = None
if isinstance(layer, torch.nn.Linear) and not hasattr(layer, "qtype"):
new_module_name, _, cur_module_name, dq_idx = module_name_process(cur_full_name)
if imatrix is not None and new_module_name in imatrix:
cur_imatrix = imatrix[new_module_name]
if cur_imatrix.shape[0] != layer.weight.shape[1]:
ws = layer.weight.shape[1]
cur_imatrix = cur_imatrix[ws * dq_idx: ws * (dq_idx + 1)]
if name not in modules_to_not_convert:
new_layer = func(layer, qtype, device, modules_to_not_convert,
group_size=group_size, *args, **kwargs)
group_size=group_size, imatrix=cur_imatrix,
*args, **kwargs)
if new_layer:
model.add_module(name, new_layer)
wrapper(new_layer, qtype, device, modules_to_not_convert,
group_size=group_size, *args, **kwargs)
group_size=group_size, imatrix=imatrix,
full_name=cur_full_name,
*args, **kwargs)
else:
wrapper(layer, qtype, device, modules_to_not_convert,
group_size=group_size, *args, **kwargs)
group_size=group_size, imatrix=imatrix,
full_name=cur_full_name,
*args, **kwargs)

return wrapper


@module_optimization
def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
group_size):
group_size, imatrix):
from ipex_llm.transformers.low_bit_linear import ggml_convert_qtype
from ipex_llm.ggml.quantize import ggml_tensor_qtype
iqtype = ggml_tensor_qtype[qtype]
Expand All @@ -79,7 +96,8 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert,
enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0"
qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32),
iqtype, device=device,
enable_scale_search=enable_scale_search)
enable_scale_search=enable_scale_search,
imatrix=imatrix)
return QuantizedLinear(qweights, scale, layer.bias,
group_size=group_size)

Expand Down
29 changes: 23 additions & 6 deletions python/llm/src/ipex_llm/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ def module_name_process(full_module_name):
else:
super_module_name = None
exp_id = None
new_module_name = None
layer = None
cur_module = None
dq_idx = None
if super_module_name == 'block_sparse_moe':
# handle mixtral moe here
moe_mapping = {"w1": "gate", "w2": "down", "w3": "up"}
Expand All @@ -265,11 +269,24 @@ def module_name_process(full_module_name):
layer = module_name_list[2]
cur_module = module_name_list[-1][:-5]
new_module_name = '_'.join([layer, cur_module])
elif len(module_name_list) == 6 and 'dq' in module_name_list[-1]:
# for NPU dq_list linear
layer = module_name_list[2]
cur_module = module_name_list[-1]
try:
dq_idx = int(cur_module[-2:])
except:
dq_idx = int(cur_module[-1:])
if cur_module[0] in 'qkvo':
cur_module = cur_module[0]
elif cur_module[:2] == "up":
cur_module = cur_module[:2]
elif cur_module[:4] == "gate" or cur_module[:4] == "down":
cur_module = cur_module[:4]
new_module_name = '_'.join([layer, cur_module])
elif len(module_name_list) == 1:
new_module_name = module_name_list[0]
layer = None
cur_module = None
return new_module_name, layer, cur_module
return new_module_name, layer, cur_module, dq_idx


def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_config=None):
Expand All @@ -283,7 +300,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
if qtype in [ggml_tensor_qtype["gguf_iq2_xxs"], ggml_tensor_qtype["gguf_iq2_xs"],
ggml_tensor_qtype["gguf_iq1_s"]]:
# For quantization which needs importance matrix
new_module_name, layer, cur_module = module_name_process(full_module_name)
new_module_name, layer, cur_module, _ = module_name_process(full_module_name)
# custom mixed quantization strategy
if model_type == "mixtral":
if cur_module == 'v':
Expand Down Expand Up @@ -312,7 +329,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
if new_module_name == 'lm_head':
cur_qtype = ggml_tensor_qtype['sym_int8']
elif qtype == ggml_tensor_qtype["q2_k"]:
new_module_name, layer, cur_module = module_name_process(full_module_name)
new_module_name, layer, cur_module, _ = module_name_process(full_module_name)
if cur_module == 'v' or (cur_module == 'down' and int(layer) in [0, 1, 10, 11]):
# TODO: q2_k need others k-quants type here
cur_qtype = ggml_tensor_qtype['q2_k']
Expand All @@ -325,7 +342,7 @@ def get_cur_qtype_and_imatrix(qtype, full_module_name, imatrix_data, model_confi
cur_qtype = ggml_tensor_qtype['sym_int8']
elif qtype > 100:
# gguf mixed precision
new_module_name, layer, cur_module = module_name_process(full_module_name)
new_module_name, layer, cur_module, _ = module_name_process(full_module_name)
num_hidden_layers = getattr(model_config, "num_hidden_layers", None)
if qtype in [gguf_mixed_qtype["gguf_q4k_s"], gguf_mixed_qtype["gguf_q4k_m"]] and \
new_module_name == 'lm_head':
Expand Down

0 comments on commit 46514b2

Please sign in to comment.