diff --git a/gradual_block_quant.py b/gradual_block_quant.py new file mode 100644 index 00000000000..f26d5d687d1 --- /dev/null +++ b/gradual_block_quant.py @@ -0,0 +1,336 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GBQ function. +""" +import gc +import nntplib +import os +import time +import numpy as np +import ipdb +import copy + +import paddle +from paddle import nn +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear, RowParallelLinear +from paddle.quantization import PTQ + +from paddleformers.transformers.qwen3_moe.modeling import Qwen3MoeDecoderLayer, Qwen3MoeModel + +from paddlenlp.utils.log import logger +from paddlenlp.quantization.quantization_linear import ( + ColumnParallelQuantizationLinear, + QuantizationLinear, + RowParallelQuantizationLinear, + ) + +import paddleslim +from paddleslim.quant.advanced import GPTQ +from paddleslim.common.wrapper_function import FuncWrapper + +from quant_utils import ( + init_params, + _clear_params, + prepare_qconfig, + get_scales, + save_scales, + save_moe_quant_w4a8_model, + get_ptq_params, + apply_gptq, + show_progress, + load_sharded_checkpoint +) +from custom_attention import QuantizedCustomAttentionLayer + +def get_mean_scale_for_moe(act_scales, num_experts=128): + new_act_scales={} + for k, value in act_scales.items(): + if '.mlp.experts' in k: + idx = k.split(".")[2] + gate_proj_scale = [] + up_proj_scale = [] + down_proj_scale = [] + for j in range(num_experts): + if act_scales["model.layers.{}.mlp.experts.{}.up_proj.activation_quanter".format(idx, j)] > 0.0: + gate_proj_scale.append(act_scales["model.layers.{}.mlp.experts.{}.gate_proj.activation_quanter".format(idx, j)]) + up_proj_scale.append(act_scales["model.layers.{}.mlp.experts.{}.up_proj.activation_quanter".format(idx, j)]) + down_proj_scale.append(act_scales["model.layers.{}.mlp.experts.{}.down_proj.activation_quanter".format(idx, j)]) + gate_mean=sum(gate_proj_scale)/len(gate_proj_scale) + up_mean=sum(up_proj_scale)/len(up_proj_scale) + down_mean=sum(down_proj_scale)/len(down_proj_scale) + for j in range(num_experts): + if act_scales["model.layers.{}.mlp.experts.{}.up_proj.activation_quanter".format(idx, j)] > 0.0: + new_act_scales["model.layers.{}.mlp.experts.{}.gate_proj.activation_quanter".format(idx, j)]=act_scales["model.layers.{}.mlp.experts.{}.gate_proj.activation_quanter".format(idx, j)] + new_act_scales["model.layers.{}.mlp.experts.{}.up_proj.activation_quanter".format(idx, j)]=act_scales["model.layers.{}.mlp.experts.{}.up_proj.activation_quanter".format(idx, j)] + new_act_scales["model.layers.{}.mlp.experts.{}.down_proj.activation_quanter".format(idx, j)]=act_scales["model.layers.{}.mlp.experts.{}.down_proj.activation_quanter".format(idx, j)] + else: + new_act_scales["model.layers.{}.mlp.experts.{}.gate_proj.activation_quanter".format(idx, j)]=gate_mean + new_act_scales["model.layers.{}.mlp.experts.{}.up_proj.activation_quanter".format(idx, j)]=up_mean + new_act_scales["model.layers.{}.mlp.experts.{}.down_proj.activation_quanter".format(idx, j)]=down_mean + else: + new_act_scales[k]=value + return new_act_scales + +@paddle.no_grad() +def apply_block_gptq(model, predictor, ptq_dials, tgt_dials, args): + """ + Gradual Block Quantization for IQ + Only once complete calibration process + PSS, AWQ, AutoClip, GPTQ and PTQ calibration process are all in here + ptq_dials : batch_source_texts + tgt_dials: batch_target_texts + """ + + logger.info("Starting block quantization...") + last_layer_outputs = [] + pp_id = args.pp_id + dp_degree = 1 + activation, weight, cachekv, q_config = prepare_qconfig(args) + act_scales = {} + weight_scales = {} + cachekv_scales = {} + model_to_quant = {} + best_quant_policies = {} + try: + hcg = fleet.get_hybrid_communicate_group() + rank = hcg.get_model_parallel_rank() + nranks = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + except: + rank = dist.get_rank() + nranks = dist.get_world_size() + dp_id = 0 + if args.lazy_load: + if nranks == 1: + state_dict = load_sharded_checkpoint(args.model_name_or_path, return_numpy=True) + else: + # For EB3.5, should load from xx.pdparams file now + model_path = os.path.join(args.model_name_or_path, f"model_state.tp0{rank}.pdparams") + state_dict = paddle.load(model_path, return_numpy=True) + ptq_state_dict = {} + + def get_block_out(sub_layer, layer_out, use_flash_attention, return_output=True): + + with paddle.amp.auto_cast(dtype="bfloat16"): + decode_out = sub_layer( + layer_out[0].cuda(), + attention_mask=layer_out[1].cuda() if not use_flash_attention else None, + position_embeddings=(layer_out[2][0].cuda(), layer_out[2][1].cuda()) + ) + if return_output: + return decode_out + + num_layers = model.config.num_hidden_layers + 2 + start = time.perf_counter() + block_index = 0 + + model.to("cpu") + paddle.device.cuda.empty_cache() + gc.collect() + time.sleep(10) + paddle.device.cuda.empty_cache() + + for sub_name, sub_layer in model.named_sublayers(): + logger.info(f'processing: {sub_name} - {sub_layer.full_name()} - {type(sub_layer)}') + if 'embed_tokens' in sub_name: + logger.info(f"Block {block_index}: {sub_name}") + sub_layer.to("gpu") + # get embedding output + logger.debug("Getting Embedding Output") + if args.lazy_load: + init_params(sub_layer, state_dict, sub_name, args.dtype) + logger.debug(f"{sub_name} init params done") + in_tokens = [] + for count, text in enumerate(ptq_dials): + # TODO 显存问题 + if count>499: + break + tokens = predictor._preprocess(text, tgt_dials[count]) + in_tokens.append(tokens) + logger.info(f"ALL samples: {len(in_tokens)}") + for idx in range(0, len(in_tokens)): + logger.info(f'embed_tokens infer step: {idx}') + input_map = in_tokens[idx] + if input_map is None: + print('input map is None') + continue + input_ids = input_map["input_ids"] + # print("input_ids:", input_ids.tolist()) + attention_mask = input_map["attention_mask"] if "attention_mask" in input_map else None + position_ids = input_map["position_ids"] if "position_ids" in input_map else None + if position_ids is None: + past_length = 0 + position_ids = paddle.arange( + past_length, paddle.shape(input_ids)[-1] + past_length, dtype=input_ids.dtype + ) + position_ids = position_ids.unsqueeze(0) + position_ids = paddle.expand_as(position_ids, input_ids) + + embedding_output = sub_layer(input_ids) + position_embeddings = model.model.rotary_emb(embedding_output, position_ids) + + batch_size, seq_length = input_ids.shape[:2] + attention_mask = ( + paddle.ones((batch_size, seq_length), dtype=paddle.bool) + if attention_mask is None + else attention_mask + ) + attention_mask = Qwen3MoeModel._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), 0, embedding_output.dtype + ) # [bs, 1, seq_len, seq_len] + if model.config.use_flash_attention: + attention_mask = None if ((paddle.triu(attention_mask) == attention_mask).all().item()) else attention_mask + + if args.offload_data: + layer_out = (embedding_output.cpu(), attention_mask.cpu(), (position_embeddings[0].cpu(), position_embeddings[1].cpu())) + else: + layer_out = (embedding_output, attention_mask, position_embeddings) + last_layer_outputs.append(layer_out) + del embedding_output, attention_mask, position_ids, position_embeddings + + del in_tokens + + show_progress(start, block_index, num_layers) + sub_layer.to("cpu") + paddle.device.cuda.empty_cache() + gc.collect() + + elif isinstance(sub_layer, Qwen3MoeDecoderLayer): + sub_layer.to("gpu") + logger.info(f"Block {block_index}: {sub_name}") + block_index += 1 + + layer_idx = int(sub_name.split('.')[-1]) + + if args.lazy_load: + layer_name = 'layers.' + sub_name.split('.')[-1] + init_params(sub_layer, state_dict, sub_name, args.dtype) + logger.info(f"{layer_name} init done") + + cur_layer_outputs = [] + + # Original layers outputs + for idx, layer_out in enumerate(last_layer_outputs): + logger.info(f'Decoder Layer infer step: {idx}') + decode_out = get_block_out(sub_layer, layer_out, model.config.use_flash_attention) + if args.offload_data: + cur_layer_outputs.append((decode_out.cpu(), layer_out[1].cpu(), (layer_out[2][0].cpu(), layer_out[2][1].cpu()))) + else: + cur_layer_outputs.append((decode_out, layer_out[1], layer_out[2])) + + # GPTQ for WINT4 + if args.gptq: + logger.debug("Step: GPTQ") + gptq = apply_gptq(sub_layer, predictor, args, ptq_dials, create_only=True) + for idx, layer_out in enumerate(last_layer_outputs): + # if idx >=128: + # break + get_block_out(sub_layer, layer_out, model.config.use_flash_attention, return_output=False) + logger.debug(f"gptq: {idx}") + gptq.fasterquantmoe() + del gptq + paddle.device.cuda.empty_cache() + gc.collect() + + # PTQ linears in current transformer block + logger.debug("Step: PTQ Preparation") + for cur_layer_name, linear_layer in sub_layer.named_sublayers(): + if type(linear_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]: + q_config.add_name_config([linear_layer.full_name()], activation=activation, weight=weight) + logger.debug(f"w4a8: {cur_layer_name} {linear_layer.full_name()}") + + if type(linear_layer) in [FuncWrapper]: + # set both act and weight for attention, actually act-k and act-v are quantized + q_config.add_name_config([linear_layer.full_name()], weight=cachekv[0], activation=cachekv[1],) + logger.debug(f"[Cache-KV Quant] {linear_layer.full_name()}") + ptq = PTQ(q_config) + sub_layer = ptq.quantize(sub_layer, inplace=True) + + # PTQ sampling + for cur_layer_name, cur_layer in sub_layer.named_sublayers(): + if type(cur_layer) in [ColumnParallelQuantizationLinear, QuantizationLinear, RowParallelQuantizationLinear]: + cur_layer.remove_dequantize_weight() + if args.quant_type in ["WINT4", "WINT8", "W4A16", "W8A16"]: + # only one forward needed for weight only + get_block_out(sub_layer, layer_out[0], model.config.use_flash_attention, return_output=False) + else: + ptq_step = 0 + for layer_out in last_layer_outputs: + ptq_step += 1 + logger.debug(f"ptq: {ptq_step}") + get_block_out(sub_layer, layer_out, model.config.use_flash_attention, return_output=False) + + act_scales, weight_scales, cachekv_scales = get_scales(model, act_scales, weight_scales, \ + cachekv_scales, dp_degree, nranks, rank, best_quant_policies) + sub_layer = ptq.convert(sub_layer, inplace=True) + + if args.lazy_load: + ptq_state_dict = get_ptq_params(sub_layer, ptq_state_dict, sub_name) + + for cur_layer_name, cur_layer in sub_layer.named_sublayers(): + if type(cur_layer) in [ColumnParallelQuantizationLinear, QuantizationLinear, RowParallelQuantizationLinear]: + cur_layer.remove_dequantize_weight() + del last_layer_outputs + last_layer_outputs = cur_layer_outputs + + del cur_layer_outputs + + if args.lazy_load: + _clear_params(sub_layer, state_dict, sub_name) + show_progress(start, block_index, num_layers) + sub_layer.to("cpu") + paddle.device.cuda.empty_cache() + gc.collect() + + act_scales=get_mean_scale_for_moe(act_scales) + save_scales(args, act_scales, weight_scales, cachekv_scales, mp_id=rank, dp_id=dp_id) + + + paddle.device.cuda.empty_cache() + gc.collect() + time.sleep(60*int(rank)) + + + if nranks == 1: + model_path = os.path.join(args.save_path, "model_state.pdparams") + else: + model_path = os.path.join(args.save_path, f"model_state.tp0{rank}.pdparams") + + if args.lazy_load: + # get uncleared params + for k, v in state_dict.items(): + ptq_state_dict[k] = v + # save model first, since init new params may cause gpu memory overflow + save_quant_model(ptq_state_dict, model_path, dp_id=dp_id) + for k, v in ptq_state_dict.items(): + ptq_state_dict[k] = paddle.to_tensor(v, dtype=args.dtype) + if 'scale' in k: + ptq_state_dict[k] = ptq_state_dict[k].cast('float32') + # cleared params are not initialized, need re-init + for k, v in model.state_dict().items(): + if not v._is_initialized(): + v.get_tensor()._share_data_with(ptq_state_dict[k].get_tensor()) + model.set_state_dict(ptq_state_dict) + else: + gc.collect() + # time.sleep(30*int(rank)) + paddle.device.cuda.empty_cache() + gc.collect() + save_moe_quant_w4a8_model(args,model.state_dict(), model_path, pp_id=pp_id, weight_scales=weight_scales) + logger.info(f"Save quant model to {args.save_path}") + # time.sleep(40*int(8-rank)) + logger.debug("-------------------gptq Done------------------") diff --git a/quant_utils.py b/quant_utils.py new file mode 100644 index 00000000000..795a98d8c75 --- /dev/null +++ b/quant_utils.py @@ -0,0 +1,1149 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PTQ functions. +""" +import json +import os +import random +import time +import numpy as np +from functools import partial +import paddle +import paddleslim +import paddle.distributed.fleet as fleet +from paddle.distributed.fleet.meta_parallel import ( + ColumnParallelLinear, + RowParallelLinear, +) +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from paddle.nn.quant import weight_quantize +import sys +import os + +from paddle.quantization import PTQ, QAT, QuantConfig +from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver +from paddlenlp.peft.lora import LoRALinear +from paddlenlp.peft.lora.lora_quant_layers import QuantedLoRALinear +from paddleslim.utils.log import logger +from paddlenlp.transformers.model_utils import _add_variant, load_state_dict +from paddleslim.quant.advanced import ( + AutoClip, + AWQSearch, + EMASampler, + LayerWiseQuantError, + MultiStepSampler, + PieceWiseSearch, + SmoothSearchV2, + Shift, + ReorderFFNWeight, + Smooth, + GPTQ, + moe_shared_scale, + TokenWiseClipping +) +from paddleslim.quant.advanced.utils import find_parent_layer_and_sub_name +from paddleslim.quant.layers import ( + QuantizedColumnParallelLinear, + QuantizedRowParallelLinear +) +from paddleslim.common.wrapper_function import FuncWrapper +from custom_attention import QuantizedCustomAttentionLayer +from abq import AdaptiveBaggingQuant +from paddleslim.quant.observers import ( + AbsMaxChannelWiseWeightObserver, + AbsmaxObserver, + AvgHeadwiseObserver, + GroupWiseWeightObserver, + TokenQuantileObserver, + KCacheChannelWiseObserver, + AsymCacheKVObserver +) +from paddleslim.quant.quanters import PACTQuanter +from paddleslim.quant.quanters.channel_wise_abs_max import ( + FakeQuanterChannelWiseAbsMaxObserver, +) +from paddleslim.quant.observers.abs_max_weight import AbsMaxChannelWiseWeightObserverLayer +from paddleslim.quant.observers.abs_max import AbsmaxObserverLayer +from paddleslim.quant.observers.token_quantile import TokenQuantileObserverLayer +from paddleslim.quant.observers.groupwise import GroupWiseWeightObserverLayer +from paddleslim.quant.observers.avg_headwise import AvgHeadwiseObserverLayer +from paddleslim.quant.observers.kcache_channelwise import KCacheChannelWiseObserverLayer +from paddleslim.quant.observers.asym_cachekv import AsymCacheKVObserverLayer +from paddleslim.quant.observers.abs_max_tokenwise import AbsmaxTokenwiseObserverLayer +import paddle.distributed as dist + +def load_sharded_checkpoint(folder, variant=None, return_numpy=False): + """ + + This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being + loaded in the model. + + Args: + folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint. + variant (`str`): The model variant. + + """ + # Load the index + pdparams_file = os.path.join(folder, _add_variant("model_state.pdparams", variant)) + lora_pdparams_file = os.path.join(folder, _add_variant("lora_model_state.pdparams", variant)) + safetensors_file = os.path.join(folder, _add_variant("model.safetensors", variant)) + if os.path.isfile(pdparams_file): + return paddle.load(pdparams_file, return_numpy=return_numpy) + if os.path.isfile(lora_pdparams_file): + return paddle.load(lora_pdparams_file, return_numpy=return_numpy) + if os.path.isfile(safetensors_file): + try: + from paddlenlp.utils.safetensors import fast_load_file as safe_load_file + except: + from safetensors.numpy import load_file as safe_load_file + + state_dict = safe_load_file(safetensors_file) + if not return_numpy: + for key in list(state_dict.keys()): + if isinstance(state_dict[key], np.ndarray): + state_dict[key] = paddle.Tensor(state_dict.pop(key), zero_copy=True) + return state_dict + + index_file = os.path.join(folder, _add_variant(PADDLE_WEIGHTS_INDEX_NAME, variant)) + safe_index_file = os.path.join(folder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)) + safe_master_file = os.path.join(folder, _add_variant(SAFE_MASTER_WEIGHTS_INDEX_NAME, variant)) + safe_peft_file = os.path.join(folder, _add_variant(SAFE_PEFT_WEIGHTS_INDEX_NAME, variant)) + + index_present = os.path.isfile(index_file) + safe_index_present = os.path.isfile(safe_index_file) + safe_master_present = os.path.isfile(safe_master_file) + safe_peft_present = os.path.isfile(safe_peft_file) + + load_safe = False + load_index = None + if safe_index_present: + load_safe = True # load safe due to preference + load_index = safe_index_file + elif safe_master_present: + load_safe = True + load_index = safe_master_file + elif index_present: + load_index = index_file + elif safe_peft_present: + load_safe = True + load_index = safe_peft_file + else: + raise ValueError(f"Could not find {index_file} or {safe_index_file} or {safe_peft_file}") + + if load_safe: + try: + from paddlenlp.utils.safetensors import fast_load_file as safe_load_file + except: + from safetensors.numpy import load_file as safe_load_file + + with open(load_index, "r", encoding="utf-8") as f: + index = json.load(f) + + shard_files = list(set(index["weight_map"].values())) + loader = safe_load_file if load_safe else partial(paddlenlp_load, map_location="np" if return_numpy else "cpu") + + ret = {} + for shard_file in tqdm(shard_files): + state_dict = loader(os.path.join(folder, shard_file)) + ret.update(state_dict) + + if not return_numpy: + for key in list(ret.keys()): + if isinstance(ret[key], np.ndarray): + ret[key] = paddle.Tensor(ret.pop(key), zero_copy=True) + + return ret + + +def show_progress(start, idx, steps): + """ + Show progress + """ + c = idx / steps * 100 + a = "*" * int(c) + b = "·" * (100 - int(c)) + dur = time.perf_counter() - start + logger.info("\r{:.2f}%[{}->{}] Cost time {:.2f}s".format(c, a, b, dur)) + time.sleep(0.1) + + +def get_ptq_params(model, ptq_state_dict, sub_name): + """ + Get ptq params from quant model + """ + for name, param in model.named_parameters(): + full_name = sub_name + "." + name + ptq_state_dict[full_name] = np.array(param.value().get_tensor()) + return ptq_state_dict + +@paddle.no_grad() +def _clear_params(model, state_dict=None, sub_name=None): + """ + Clear params + """ + for k, v in model.state_dict().items(): + # 清除参数的值 + v.value().get_tensor()._clear() + # if state_dict is not None: + # 拼接参数名 + # name = sub_name + "." + k + # if name in state_dict: + # 如果拼接后的参数名在state_dict中存在 + # if name in state_dict: + # 从state_dict中删除该参数 + # del state_dict[sub_name + "." + k] + + +def init_params(sub_layer, state_dict, sub_name, dtype): + """ + Init params and set state_dict + """ + new_dict = {} + for k, v in state_dict.items(): + if sub_name in k: + weight_name = k.replace(sub_name + ".", "") + # load from numpy, so we need to convert to bfloat16 firstly and then cast to other dtype + new_dict[weight_name] = paddle.to_tensor(v, dtype='bfloat16').cast(dtype).cuda() + for k, v in sub_layer.state_dict().items(): + if not v._is_initialized(): + v.get_tensor()._share_data_with(new_dict[k].get_tensor()) + sub_layer.set_state_dict(new_dict) + + +def prepare_qconfig(args): + """ + Prepare qconfig + """ + if 'C8' in args.quant_type: + quant_type = args.quant_type.replace('C8', '') + cachekv_quant = True + cachekv_quant_bits = 8 + elif 'C4' in args.quant_type: + quant_type = args.quant_type.replace('C4', '') + cachekv_quant = True + cachekv_quant_bits = 4 + elif 'C2' in args.quant_type: + quant_type = args.quant_type.replace('C2', '') + cachekv_quant = True + cachekv_quant_bits = 2 + else: + quant_type = args.quant_type.replace('C16', '') + cachekv_quant = False + + q_config = QuantConfig(activation=None, weight=None) + if quant_type == "W8A8": + activation = AbsmaxObserver(quant_bits=8) + weight = AbsMaxChannelWiseWeightObserver(quant_bits=8) + elif quant_type in ["WINT4", "W4A16"]: + activation = None + weight = GroupWiseWeightObserver(quant_bits=4, group_size=args.group_size) + elif quant_type in ["WINT8", "W8A16"]: + activation = None + weight = AbsMaxChannelWiseWeightObserver(quant_bits=8) + elif quant_type == "W4A8": + activation = AbsmaxObserver(quant_bits=8) + weight = AbsMaxChannelWiseWeightObserver(quant_bits=4) + else: + raise ValueError("quant_type should be in ['W8A8', 'WINT4', 'WINT8', 'W4A8', 'W4A16', 'W8A16']") + + q_config.add_qat_layer_mapping(ColumnParallelLinear, QuantizedColumnParallelLinear) + q_config.add_qat_layer_mapping(RowParallelLinear, QuantizedRowParallelLinear) + + cachekv = None + if cachekv_quant: + if cachekv_quant_bits == 8: + cachekv = [AvgHeadwiseObserver(quant_bits=cachekv_quant_bits, moving_avg=True, quant_axis=1,do_fp8_quant=True), + AvgHeadwiseObserver(quant_bits=cachekv_quant_bits, moving_avg=True, quant_axis=1,do_fp8_quant=True)] + # cachekv = [KCacheChannelWiseObserver(quant_bits=cachekv_quant_bits, symmetric=True), \ + # KCacheChannelWiseObserver(quant_bits=cachekv_quant_bits, symmetric=True)] + q_config.add_qat_layer_mapping(FuncWrapper, QuantizedCustomAttentionLayer) + elif cachekv_quant_bits == 4: + if args.abq: + cachekv = [AsymCacheKVObserver(quant_bits=cachekv_quant_bits, symmetric=False, quant_axis=[1, 3]), \ + AsymCacheKVObserver(quant_bits=cachekv_quant_bits, symmetric=False, quant_axis=[1, 3])] + else: + cachekv = [KCacheChannelWiseObserver(quant_bits=cachekv_quant_bits, symmetric=False), \ + KCacheChannelWiseObserver(quant_bits=cachekv_quant_bits, symmetric=False)] + q_config.add_qat_layer_mapping(FuncWrapper, QuantizedCustomAttentionLayer) + else: + raise ValueError('cachekv_quant_bits should be 8 or 4, 2bit is not supported for now.') + return activation, weight, cachekv, q_config + + +def get_scales(model, act_scales, weight_scales, cachekv_scales, + dp_degree=1, mp_degree=1, mp_id=0, best_quant_policies=None): + """ + get scales + """ + def gather_scale(cur_layer, dp_degree, mp_degree, mp_id): + scale = cur_layer.scales() + if dp_degree > 1: + scale_list = [] + paddle.distributed.all_gather(scale_list, scale) + gathered_scale = paddle.concat( + [ + paddle.reshape_( + scale_list[r * mp_degree + mp_id], + shape=[1] + scale_list[r * mp_degree + mp_id].shape) for r in range(dp_degree) + ], + axis=0).max(axis=0, keepdim=False) + paddle.assign(gathered_scale, cur_layer._scale) + return gathered_scale + else: + return scale + + def gather_min_max(cur_layer, max_values, min_values, dp_degree, mp_degree, mp_id, quant_bits): + bnt = (1 << (quant_bits - 1)) - 1 + qmin = -bnt - 1 + qmax = bnt + if dp_degree > 1: + max_list, min_list = [], [] + paddle.distributed.all_gather(max_list, max_values) + gathered_max = paddle.concat( + [ + paddle.reshape_( + max_list[r * mp_degree + mp_id], + shape=[1] + max_list[r * mp_degree + mp_id].shape) for r in range(dp_degree) + ], + axis=0).max(axis=0, keepdim=False) + paddle.distributed.all_gather(min_list, min_values) + gathered_min = paddle.concat( + [ + paddle.reshape_( + min_list[r * mp_degree + mp_id], + shape=[1] + min_list[r * mp_degree + mp_id].shape) for r in range(dp_degree) + ], + axis=0).min(axis=0, keepdim=False) + else: + gathered_max = max_values + gathered_min = min_values + gathered_scale = gathered_max - gathered_min + gathered_scale = paddle.to_tensor(gathered_scale / float(qmax - qmin), dtype="float32") + gathered_zp = qmin - paddle.round(gathered_min / gathered_scale) + gathered_zp = paddle.clip(gathered_zp, qmin, qmax) + cur_layer._scale = gathered_scale + cur_layer._zero_point = gathered_zp + return gathered_scale, gathered_zp + + for cur_name, cur_layer in model.named_sublayers(): + if 'layer.' in cur_name: + cur_name = cur_name.replace('layer.', '') + if type(cur_layer) in [AbsMaxChannelWiseWeightObserverLayer, GroupWiseWeightObserverLayer] \ + and "_observer" not in cur_name: + scale = gather_scale(cur_layer, dp_degree, mp_degree, mp_id) + weight_scales[cur_name] = scale.cast("float32").numpy().tolist() + if type(cur_layer) in [AbsmaxObserverLayer, TokenQuantileObserverLayer, AbsmaxTokenwiseObserverLayer] and "_observer" not in cur_name: + scale = gather_scale(cur_layer, dp_degree, mp_degree, mp_id) + if type(scale) in [int]: + act_scales[cur_name] = float(scale) + else: + act_scales[cur_name] = float(scale.cast("float32")) + logger.debug(f"{cur_name}, {act_scales[cur_name]}") + # 对量化层只能包一层oberserver,需要下述代码 + # if type(cur_layer) in [AbsmaxObserverLayer]: + # scale = gather_scale(cur_layer, dp_degree, mp_degree, mp_id) + # if type(scale) in [int]: + # act_scales[cur_name] = float(scale) + # else: + # act_scales[cur_name] = float(scale.cast("float32")) + # logger.debug(f"{cur_name}, {act_scales[cur_name]}") + if type(cur_layer) in [AvgHeadwiseObserverLayer, KCacheChannelWiseObserverLayer] \ + and "_observer" not in cur_name: + # hard code for inference + cur_name = cur_name.replace('attn_func.activation_quanter_v', 'cachev_matmul.activation_quanter') + cur_name = cur_name.replace('attn_func.activation_quanter_k', 'cachek_matmul.activation_quanter') + scale = gather_scale(cur_layer, dp_degree, mp_degree, mp_id) + cachekv_scales[cur_name] = scale.cast("float32").numpy().tolist() + logger.debug(f"{cur_name}, {cachekv_scales[cur_name][0]}") + # save zeropints in scale file if its not 0 or list of 0s + cachekv_scales[cur_name + '.zero_point'] = cur_layer.zero_points().cast("float32").numpy().tolist() + logger.debug(f"{cur_name + '.zero_point'}, {cachekv_scales[cur_name + '.zero_point']}") + if type(cur_layer) == AsymCacheKVObserverLayer and "_observer" not in cur_name: + cur_name = cur_name.replace('attn_func.activation_quanter_v', 'cachev_matmul.activation_quanter') + cur_name = cur_name.replace('attn_func.activation_quanter_k', 'cachek_matmul.activation_quanter') + layerid = int(cur_name.split('.')[3]) + kv_flag = cur_name.split('.')[5][5] + "_int4" + kv_max_name = cur_name.split('.')[5][5] + "_max" + kv_min_name = cur_name.split('.')[5][5] + "_min" + best_kv_max = best_quant_policies[layerid].get(kv_max_name) + best_kv_min = best_quant_policies[layerid].get(kv_min_name) + + scales, zps = gather_min_max(cur_layer, best_kv_max, best_kv_min, dp_degree, + mp_degree, mp_id, cur_layer._quant_bits) + kv_losses = best_quant_policies[layerid].get('kv_loss') + cachekv_scales[cur_name] = scales.cast("float32").numpy().tolist() + cachekv_scales[cur_name + '.zero_point'] = zps.cast("float32").numpy().tolist() + logger.debug(f"quant_bits: {cur_layer._quant_bits}, kv_losses: {kv_losses}, \ + scales: {cur_layer.scales()}, zps: {cur_layer.zero_points()}") + return act_scales, weight_scales, cachekv_scales + + +def save_scales(args, act_scales, weight_scales, cachekv_scales, mp_id=0, dp_id=0): + """ + save scales + """ + if dp_id == 0: + if act_scales: + with open(f"{args.save_path}/act_scales_{mp_id}.json", "w") as outfile: + json.dump(act_scales, outfile) + logger.debug("save act scales") + if weight_scales: + with open(f"{args.save_path}/weight_scales_{mp_id}.json", "w") as outfile: + json.dump(weight_scales, outfile) + logger.debug("save weight scales") + if cachekv_scales: + with open(f"{args.save_path}/cachekv_scales_{mp_id}.json", "w") as outfile: + json.dump(cachekv_scales, outfile) + logger.debug("save cachekv scales") + + +def save_quant_model(state_dict, save_path, dp_id=0): + """ + Save quant model + """ + if dp_id == 0: + new_scale_dict = {} + for k, v in state_dict.items(): + # hard code for inference + # if 'weight_quanter._dequanter._scales' in k: + # continue + if 'layer.' in k: + new_scale_dict[k.replace('layer.', '')] = v + else: + new_scale_dict[k] = v + # for k, v in state_dict.items(): + # if 'ernie.layers.' in k: + # new_k=k.split('ernie.layers.')[1].split('.')[0] + # orin_index=int(new_k) + # new_k=str(int(new_k)+27) + # new_k='ernie.layers.'+new_k+k.split('ernie.layers.'+str(orin_index))[1] + # new_scale_dict[new_k] = v + # else: + # new_scale_dict[k] = v + # del state_dict + # import gc + # gc.collect() + for k,v in new_scale_dict.items(): + if 'lm_head' in k: + continue + if '.experts' in k: + continue + else: + if '_proj.weight' in k and len(v.shape) == 2: + paddle.assign(v.cast(paddle.int8), v) + paddle.save(new_scale_dict, save_path) + logger.info(f"Save model to {save_path}") + +def merge_and_valid_shared_weights(tensor_list): + assert len(tensor_list) > 0, "smooth weights or shift biases must not be empty" + ret_tensor = tensor_list[0] + for i in range(1, len(tensor_list)): + cur_tensor = tensor_list[i] + compare = paddle.where(ret_tensor==cur_tensor, 0, 1) + assert paddle.sum(compare) == 0, "smooth or shift is not shared" + return ret_tensor + +def save_moe_quant_w4a8_model(args,state_dict, save_path, pp_id=0, weight_scales=None, share_smooth=True): + """ + Save quant model + """ + num_experts = 64 + paddle.set_default_dtype("bfloat16") + paddle.set_device("cpu") + ffn_hidden_size = 28672 + hidden_size = 8192 + convert_scale_dict={} + + for k, v in state_dict.items(): + if 'quanter._scales' in k or "quanter._zero_point" in k: + continue + else: + if 'layer.' in k: continue + + elif 'model.layers.' in k: + if '.mlp.experts' in k: + if ".weight" in k and ("up_proj" in k or "gate_proj" in k or "down_proj" in k ): + v = v.cast(paddle.int8) + convert_scale_dict[k] = v + logger.debug('casting {} to int8'.format(k)) + else: + convert_scale_dict[k] = v + else: + convert_scale_dict[k] = v + + paddle.save(convert_scale_dict, save_path) + logger.info(f"Save model to {save_path}") + +def save_moe_quant_model(state_dict, save_path, pp_id=0): + """ + Save quant model + """ + num_experts = 48 + paddle.set_default_dtype("bfloat16") + paddle.set_device("cpu") + ffn_hidden_size = 36864 + num_layers = 27 + num_attention_heads = 96 + num_key_value_heads = 8 + hidden_size = 12288 + mp_size = 8 + num_experts = 48 + export_model_type = 'WINT8' + int8_moe_method = "weight-only-int4" + convert_scale_dict = {} + hcg = fleet.get_hybrid_communicate_group() + rank = hcg.get_model_parallel_rank() + ffn_hidden_size = ffn_hidden_size // mp_size + for k, v in state_dict.items(): + if k.endswith("self_attn.qkv_proj.weight"): + idx = k.split(".")[2] + idx_for_save = str(int(idx) + pp_id*27) + print("idx", idx) + up_gate_proj_weight = [] + up_gate_proj_bias = [] + down_proj_weight = [] + down_proj_bias = [] + for j in range(num_experts): + up_gate_proj_weight.append(paddle.to_tensor(state_dict["ernie.layers.{}.mlp.experts.{}.up_gate_proj.weight".format(idx, j)], dtype=paddle.get_default_dtype())) + up_gate_proj_bias.append(paddle.to_tensor(state_dict["ernie.layers.{}.mlp.experts.{}.up_gate_proj.bias".format(idx, j)], dtype=paddle.get_default_dtype())) + down_proj_weight.append(paddle.to_tensor(state_dict["ernie.layers.{}.mlp.experts.{}.down_proj.weight".format(idx, j)], dtype=paddle.get_default_dtype())) + down_proj_bias.append(paddle.to_tensor(state_dict["ernie.layers.{}.mlp.experts.{}.down_proj.bias".format(idx, j)], dtype=paddle.get_default_dtype())) + # hard code for inference + # if 'weight_quanter._dequanter._scales' in k: + # continue + ffn1_weight_tensor = paddle.to_tensor(paddle.concat(up_gate_proj_weight, axis=0), dtype=paddle.get_default_dtype()).reshape([num_experts, hidden_size, -1]) + ffn1_weight_tensor_list = [] + ffn1_weight_scale_tensor_list = [] + for i in range(num_experts): + ffn1_weight_tensor_i, ffn1_weight_scale_tensor_i = weight_quantize( + ffn1_weight_tensor[i], algo="weight_only_int4", arch=80 + ) + ffn1_weight_tensor_list.append(ffn1_weight_tensor_i.reshape([hidden_size, ffn_hidden_size // mp_size])) + ffn1_weight_scale_tensor_list.append(ffn1_weight_scale_tensor_i) + ffn1_weight_tensor = paddle.concat(ffn1_weight_tensor_list, axis=0) + ffn1_weight_scale_tensor_list = paddle.concat(ffn1_weight_scale_tensor_list, axis=0) + convert_scale_dict["ffn1_weights_scales_{}".format(idx_for_save)] = ffn1_weight_scale_tensor_list.cast(paddle.get_default_dtype()).reshape([num_experts, -1]) + convert_scale_dict["ffn1_weights_{}".format(idx_for_save)] = ffn1_weight_tensor.reshape([num_experts, hidden_size, -1]) + convert_scale_dict["ffn1_biases_{}".format(idx_for_save)] = paddle.to_tensor(paddle.concat(up_gate_proj_bias, axis=0), dtype=paddle.get_default_dtype()).reshape([num_experts, -1]) + ffn2_weight_tensor = paddle.to_tensor(paddle.concat(down_proj_weight, axis=0), dtype=paddle.get_default_dtype()).reshape([num_experts, -1, hidden_size]) + ffn2_baisss = paddle.concat(down_proj_bias, axis=0) + print("rank", rank) + if rank > 0: + ffn2_baisss.zero_() + print(f'removing bias for rank:{rank}') + else: + print(f'keeping bias for rank:{rank}') + + ffn2_weight_tensor = paddle.to_tensor(paddle.concat(down_proj_weight, axis=0), dtype=paddle.get_default_dtype()).reshape([num_experts, -1, hidden_size]) + ffn2_weight_tensor_list = [] + ffn2_weight_scale_tensor_list = [] + for i in range(num_experts): + ffn2_weight_tensor_i, ffn2_weight_scale_tensor_i = weight_quantize( + ffn2_weight_tensor[i], algo="weight_only_int4", arch=80 + ) + ffn2_weight_tensor_list.append(ffn2_weight_tensor_i.reshape([ffn_hidden_size // mp_size, hidden_size // 2])) + ffn2_weight_scale_tensor_list.append(ffn2_weight_scale_tensor_i) + ffn2_weight_tensor = paddle.concat(ffn2_weight_tensor_list, axis=0) + ffn2_weight_scale_tensor_list = paddle.concat(ffn2_weight_scale_tensor_list, axis=0) + convert_scale_dict["ffn2_weights_scales_{}".format(idx_for_save)] = ffn2_weight_scale_tensor_list.cast(paddle.get_default_dtype()).reshape([num_experts, -1]) + convert_scale_dict["ffn2_weights_{}".format(idx_for_save)] = ffn2_weight_tensor.reshape([num_experts, ffn_hidden_size // mp_size, -1]) + + if k.endswith("self_attn.qkv_proj.weight"): + idx = k.split(".")[2] + for k, v in state_dict.items(): + if '.experts' in k: + continue + elif 'weight_quanter._dequanter._scales' in k: + continue + else: + if pp_id>0: + if 'lm_head' in k: + continue + elif 'layer.' in k: + if 'ernie.layers.' in k: + new_k=k.split('ernie.layers.')[1].split('.')[0] + orin_index=int(new_k) + new_k=str(int(new_k)+27*pp_id) + new_k='ernie.layers.'+new_k+k.split('ernie.layers.'+str(orin_index))[1] + convert_scale_dict[new_k.replace('layer.', '')] = v + else: + convert_scale_dict[k.replace('layer.', '')] = v + else: + if 'ernie.layers.' in k: + new_k=k.split('ernie.layers.')[1].split('.')[0] + orin_index=int(new_k) + new_k=str(int(new_k)+27*pp_id) + new_k='ernie.layers.'+new_k+k.split('ernie.layers.'+str(orin_index))[1] + convert_scale_dict[new_k] = v + else: + convert_scale_dict[k] = v + else: + if 'layer.' in k: + if 'ernie.layers.' in k: + new_k=k.split('ernie.layers.')[1].split('.')[0] + orin_index=int(new_k) + new_k=str(int(new_k)+27*pp_id) + new_k='ernie.layers.'+new_k+k.split('ernie.layers.'+str(orin_index))[1] + convert_scale_dict[new_k.replace('layer.', '')] = v + else: + convert_scale_dict[k.replace('layer.', '')] = v + else: + if 'ernie.layers.' in k: + new_k=k.split('ernie.layers.')[1].split('.')[0] + orin_index=int(new_k) + new_k=str(int(new_k)+27*pp_id) + new_k='ernie.layers.'+new_k+k.split('ernie.layers.'+str(orin_index))[1] + convert_scale_dict[new_k] = v + else: + convert_scale_dict[k] = v + + paddle.save(convert_scale_dict, save_path) + logger.info(f"Save model to {save_path}") + +def calibration(predictor, ptq_dials, args, max_step=None, smooth=None): + """ + Calibration + """ + max_step = len(ptq_dials) if max_step is None else max_step + with paddle.no_grad(): + for idx in range(0, len(ptq_dials), args.batch_size): + batch_dials = ptq_dials[idx : idx + args.batch_size] + out = predictor.predict_ptq(batch_dials) + if smooth is not None and out != {"result": 'No result for invalid input'}: + smooth.step += 1 + if idx % 10 == 0: + logger.debug(f"Sample Step: {idx}") + if idx >= max_step: + break + +def apply_shift(model, predictor, args, ptq_model_config, ptq_dials, create_only=False): + """ + Shift + """ + shift_sampler = EMASampler() + shift = Shift( + model=model, + model_config=ptq_model_config, + sample_function=shift_sampler, + shift_all_linears=True, + ) + if create_only: + return shift + calibration(predictor, ptq_dials, args) + shift.update_weight() + del shift, shift_sampler + +def apply_smooth(model, predictor, args, ptq_model_config, ptq_dials, + max_step=None, no_search=False, create_only=False): + """ + Smooth + """ + logger.debug("------------------Start Smooth-------------------") + smooth_sampler = MultiStepSampler() + if args.smooth_method == "smoothquant": + if args.smooth_search_v2: + search_func = SmoothSearchV2( + weight_bits_length=8, + act_bits_length=8, + search_min=0.1, + search_step=100, + weight_quant_method='abs_max_channel_wise', + act_quant_method="abs_max", + dp_degree=args.data_parallel_degree, + ) + else: + search_func = PieceWiseSearch( + k_piece=args.k_piece, + bits_length=8, + search_piece=args.search_piece, + search_alpha_min=0.1, + search_alpha_max=0.9, + search_scale_min=1.0, + search_scale_max=10.0, + use_clip=args.use_clip, + weight_quant_method="abs_max_channel_wise", + act_quant_method="abs_max", + dp_degree=args.data_parallel_degree, + ) + elif args.smooth_method == "awq": + search_func = AWQSearch( + n_grid=20, + bits_length=4, + weight_quant_method="abs_max_channel_wise", + ) + smooth = Smooth( + model, + ptq_model_config, + alpha=0.5, + smooth_all_linears=True, + sample_function=smooth_sampler, + search_function=search_func if not no_search else None, + start_sample_step=args.start_sample_step, + smooth_method=args.smooth_method, + ) + if create_only: + return smooth + + calibration(predictor, ptq_dials, args, max_step=max_step, smooth=smooth) + smooth.update_weight() + del smooth, smooth_sampler, search_func + + if args.load_smooth_model and not args.load_quant_model: + logger.info(f"Load model checkpoint from {args.load_smooth_path}") + model_dict = load_sharded_checkpoint(args.load_smooth_path, return_numpy=True) + model.set_dict(model_dict) + + if args.save_smooth_model: + try: + hcg = fleet.get_hybrid_communicate_group() + rank = hcg.get_model_parallel_rank() + nranks = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + except: + rank = dist.get_rank() + nranks = dist.get_world_size() + dp_id = 0 + if nranks == 1: + model_path = os.path.join(args.save_smooth_path, "model_state.pdparams") + else: + model_path = os.path.join(args.save_path, f"model_state.tp0{rank}.pdparams") + save_quant_model(model.state_dict(), model_path, dp_id=dp_id) + +def apply_token_wise_clipping(model, predictor, args, ptq_model_config, ptq_dials, max_step=None): + """ + token wise clipping + """ + logger.debug("------------------Start Token_wise_clipping-------------------") + token_wise_clipping = TokenWiseClipping( + model, + ptq_model_config, + ) + max_step = len(ptq_dials) if max_step is None else max_step + + fp_input, fp_output = [], [] + + with paddle.no_grad(): + for idx in range(0, len(ptq_dials), args.batch_size): + batch_dials = ptq_dials[idx : idx + args.batch_size] + input_map = predictor.preprocess(batch_dials, extra_infos=None, fast_ptq_sampling=True) + if input_map is None: + continue + fp_input.append(input_map) + output = model(**input_map) + fp_output.append(output[0]) + if idx % 10 == 0: + logger.debug(f"Token Wise Clipping Sample Step: {idx}") + if idx >= max_step: + break + token_wise_clipping.token_wise_clipping(fp_input, fp_output) + +def apply_autoclip(model, predictor, args, ptq_dials, create_only=False): + """ + AutoClip + """ + logger.debug("-------------------Start AutoClip------------------") + smooth_sampler = MultiStepSampler() + auto_clip = AutoClip(model, weight_bits=4, sample_function=smooth_sampler, n_grid=20, max_shrink=0.5) + if create_only: + return auto_clip + calibration(predictor, ptq_dials, args, max_step=len(ptq_dials) - args.start_sample_step) + auto_clip.auto_clip() + +def apply_gptq(model, predictor, args, ptq_dials, create_only=False): + """ + GPTQ + """ + gptq = GPTQ(model, + quant_bits=4, + weight_quant_method='abs_max_channel_wise', + blocksize=128, + percdamp=.2, + groupsize=args.group_size, + actorder=False, + ) + if create_only: + return gptq + calibration(predictor, ptq_dials, args) + gptq.fasterquant() + +def apply_moe_shared_scale(model, predictor, args, ptq_dials, create_only=False): + """ + GPTQ + """ + moe_sharedscale = moe_shared_scale(model, + quant_bits=8, + quant_method='abs_max', + ) + if create_only: + return moe_sharedscale + calibration(predictor, ptq_dials, args) + moe_sharedscale.search_best_scale() + +def apply_analysis(model, predictor, args, ptq_dials): + """ + Calcualte quant error for each layer + Return a list [skip_layer_name] + """ + logger.debug("-------------------Start Analysis------------------") + analysis_loss_dict = {} + skip_list_analysis = [] + for cur_name, cur_layer in model.named_sublayers(): + if type(cur_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]: + parent_layer, sub_name = find_parent_layer_and_sub_name(model, cur_name) + cur_quant_layer = LayerWiseQuantError(cur_layer) + setattr(parent_layer, sub_name, cur_quant_layer) + + for idx in range(0, len(ptq_dials), args.batch_size): + batch_dials = ptq_dials[idx : idx + args.batch_size] + predictor.predict_ptq(batch_dials) + if idx % 10 == 0: + logger.debug(f"Sample Error Step: {idx}") + + for cur_name, analysis_layer in model.named_sublayers(): + if type(analysis_layer) == LayerWiseQuantError: + loss = paddle.to_tensor(analysis_layer.losses, dtype="float32").mean() + parent_layer, sub_name = find_parent_layer_and_sub_name(model, cur_name) + setattr(parent_layer, sub_name, analysis_layer.layer) + analysis_loss_dict[analysis_layer.layer.full_name()] = float(loss) + del analysis_layer + + ranklist = sorted(analysis_loss_dict, key=analysis_loss_dict.get, reverse=True) + + for i, name in enumerate(ranklist): + logger.debug(f"layer name: {name}, loss: {analysis_loss_dict[name]}") + if analysis_loss_dict[name] > 5: + skip_list_analysis.append(name) + logger.debug(f"skip length: {len(skip_list_analysis)}, skip list: {skip_list_analysis}") + return skip_list_analysis + + +def load_quant_model(model, args, ptq_dials, skip_list_analysis): + """ + Load quantized model and its scales + """ + activation, weight, cachekv, q_config = prepare_qconfig(args) + for cur_name, cur_layer in model.named_sublayers(): + if "out_linear" in cur_name: + continue + if cur_layer.full_name() in skip_list_analysis: + logger.debug(f"skip: {cur_name}, {cur_layer.full_name()}") + continue + if type(cur_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]: + q_config.add_name_config([cur_layer.full_name()], activation=activation, weight=weight) + if type(cur_layer) in [FuncWrapper]: + # set both act and weight for attention, actually act-k and act-v are quantized + q_config.add_name_config([cur_layer.full_name()], weight=cachekv[0], activation=cachekv[1]) + + ptq = PTQ(q_config) + model = ptq.quantize(model, inplace=True) + + logger.info("Load quant model...") + rank = dist.get_rank() + nranks = dist.get_world_size() + if activation is not None: + with open(f"{args.load_quant_path}/act_scales_{rank}.json") as outfile: + act_scales = json.load(outfile) + else: + act_scales = {} + # if 'C8' in args.quant_type or 'C4' in args.quant_type: + # with open(f"{args.load_quant_path}/cachekv_scales_{rank}.json") as outfile: + # cachekv_scales = json.load(outfile) + # else: + cachekv_scales = {} + with open(f"{args.load_quant_path}/weight_scales_{rank}.json") as outfile: + weight_scales = json.load(outfile) + + for cur_name, cur_layer in model.named_sublayers(): + if 'layer.' in cur_name: + cur_name = cur_name.replace('layer.', '') + if hasattr(cur_layer, 'scales'): + if type(cur_layer) in [AbsMaxChannelWiseWeightObserverLayer, GroupWiseWeightObserverLayer]: + cur_layer._scale = paddle.to_tensor(weight_scales[cur_name], dtype=args.dtype) + if type(cur_layer) in [AbsmaxObserverLayer, TokenQuantileObserverLayer]: + assert activation is not None, "AbsmaxObserverLayer must set observer" + cur_layer._scale = paddle.to_tensor(act_scales[cur_name], dtype=args.dtype) + if type(cur_layer) in [AvgHeadwiseObserverLayer, KCacheChannelWiseObserverLayer]: + cur_name = cur_name.replace('attn_func.activation_quanter_v', 'cachev_matmul.activation_quanter') + cur_name = cur_name.replace('attn_func.activation_quanter_k', 'cachek_matmul.activation_quanter') + cur_layer._scale = paddle.to_tensor(cachekv_scales[cur_name], dtype=args.dtype) + if cur_name + '.zero_point' in cachekv_scales: + cur_layer._zero_point = paddle.to_tensor(cachekv_scales[cur_name + '.zero_point'], dtype=args.dtype) + model = ptq.convert(model, inplace=True) + if nranks == 1: + model_path = os.path.join(args.load_quant_path, "model_state.pdparams") + model_dict = load_sharded_checkpoint(args.load_quant_path, return_numpy=True) + else: + model_path = os.path.join(args.load_quant_path, f"model_state.tp0{rank}.pdparams") + model_dict = paddle.load(model_path, return_numpy=True) + logger.info(f"Load model checkpoint from {model_path}") + cur_model_dict = model.state_dict().keys() + for key in cur_model_dict: + if 'layer.' in key and 'scale' not in key: + saved_key = key.replace('layer.', '') + if saved_key in model_dict: + model_dict[key] = np.array(model_dict[saved_key], dtype=np.uint16) + # hard code here + if 'scale' in key or 'scales' in key: + scales_key = '.'.join(key.split('.')[:-2]) + new_key = key + if 'layer.' in scales_key: + scales_key = scales_key.replace('layer.', '') + new_key = key.replace('layer.', '') + if 'activation_quanter' in key: + new_key = new_key.replace('cachev_matmul.activation_quanter', 'attn_func.activation_quanter_v') + new_key = new_key.replace('cachek_matmul.activation_quanter', 'attn_func.activation_quanter_k') + model_dict[key] = np.array(model_dict[new_key], dtype=np.float32) + if 'weight_quanter' in key: + model_dict[key] = np.array(weight_scales[scales_key], dtype=np.float32) + else: + model_dict[key] = np.array(model_dict[key], dtype=np.float32) + if 'zero_point' in key: + model_dict[key] = np.array(model.state_dict()[key], dtype=np.float32) + model.set_dict(model_dict) + + +def apply_ptq(model, predictor, args, ptq_dials, skip_list_analysis): + """ + PTQ calibration process and save quantized model + """ + # logger.info("-------------------GPTQ start------------------") + + # gptq = GPTQ(model, + # quant_bits=4, + # weight_quant_method='abs_max_channel_wise', + # blocksize=128, + # percdamp=.1, + # actorder=True + # ) + + # calibration(predictor, ptq_dials, args,max_step=5) + # gptq.fasterquant() + # logger.info("-------------------GPTQ Done------------------") + dp_degree = args.data_parallel_degree + try: + hcg = fleet.get_hybrid_communicate_group() + rank = hcg.get_model_parallel_rank() + nranks = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + except: + rank = dist.get_rank() + nranks = dist.get_world_size() + dp_id = 0 + logger.info("-------------------Start PTQ------------------") + activation, weight, cachekv, q_config = prepare_qconfig(args) + weight_4bit = AbsMaxChannelWiseWeightObserver(quant_bits=4) + for cur_name, cur_layer in model.named_sublayers(): + if "out_linear" in cur_name: + logger.debug(f"skip {cur_name} {cur_layer.full_name()}") + continue + if cur_layer.full_name() in skip_list_analysis: + logger.debug(f"skip {cur_name} {cur_layer.full_name()}") + continue + # if "experts" in cur_name: + # q_config.add_name_config([cur_layer.full_name()], activation=activation, weight=weight_4bit) + # logger.debug(f"weight_w_4bit: {cur_name} {cur_layer.full_name()}") + # continue + if type(cur_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]: + # if ('linear2' in cur_layer.full_name() or 'linear2' in cur_name) and args.token_clip: + # logger.debug(f'token clip layer: {cur_layer.full_name()}, {cur_name}') + # activation1 = TokenQuantileObserver(quant_bits=8, percentile=1.0) + # q_config.add_name_config([cur_layer.full_name()], activation=activation1, weight=weight) + # if "experts" in cur_name: + # q_config.add_name_config([cur_layer.full_name()], activation=activation, weight=weight_4bit) + # logger.debug(f"weight_w_4bit_using_gptq: {cur_name} {cur_layer.full_name()}") + # else: + # q_config.add_name_config([cur_layer.full_name()], activation=activation, weight=weight) + q_config.add_name_config([cur_layer.full_name()], activation=None, weight=weight) + if type(cur_layer) in [FuncWrapper]: + # set both act and weight for attention, actually act-k and act-v are quantized + q_config.add_name_config([cur_layer.full_name()], weight=cachekv[0], activation=cachekv[1]) + + ptq = PTQ(q_config) + model = ptq.quantize(model, inplace=True) + args.token_clip=False + + if args.token_clip: + apply_token_wise_clipping(model, predictor, args, None, ptq_dials, max_step=16) + + for idx in range(0, len(ptq_dials), args.batch_size): + batch_dials = ptq_dials[idx : idx + args.batch_size] + predictor.predict_ptq(batch_dials) + if idx % 10 == 0: + logger.info(f"Sample PTQ Step: {idx}") + if idx >= 128: + break + + best_quant_policies = None + if args.abq: + best_quant_policies = apply_abq(model, predictor, args, ptq_dials, max_step=args.abq_step) + + if not os.path.exists(args.save_path): + os.makedirs(args.save_path) + + # act_scales, weight_scales, cachekv_scales = get_scales( + # model, {}, {}, {}, dp_degree, nranks, rank, best_quant_policies) + # save_scales(args, act_scales, weight_scales, cachekv_scales, mp_id=rank, dp_id=dp_id) + + model = ptq.convert(model, inplace=True) + + act_scales, weight_scales, cachekv_scales = get_scales( + model, {}, {}, {}, dp_degree, nranks, rank, best_quant_policies) + save_scales(args, act_scales, weight_scales, cachekv_scales, mp_id=rank, dp_id=dp_id) + + if nranks == 1: + model_path = os.path.join(args.save_path, "model_state.pdparams") + else: + model_path = os.path.join(args.save_path, f"model_state.tp0{rank}.pdparams") + state_dict = model.state_dict() + save_quant_model(state_dict, model_path, dp_id=dp_id) + logger.info(f"Save quant model to {args.save_path}") + logger.info("-------------------PTQ Done------------------") + +def apply_abq(model, predictor, args, ptq_dials, max_step=16): + """ + ABQ process, search best quantization policy + """ + max_step = min(max_step, int(len(ptq_dials) // args.batch_size)) + abq = AdaptiveBaggingQuant(args, model, max_step) + for cur_name, cur_layer in model.named_sublayers(): + if type(cur_layer) == QuantizedCustomAttentionLayer: + cur_layer.quant_info = abq.quant_info + cur_layer.enable_fake_quant = True + with paddle.no_grad(): + for idx in range(0, len(ptq_dials), args.batch_size): + batch_dials = ptq_dials[idx : idx + args.batch_size] + input_map = predictor.preprocess(batch_dials, extra_infos=None, fast_ptq_sampling=True) + if input_map is None: + continue + output = model(**input_map) + if idx % 10 == 0: + logger.debug(f"ABQ Sample Step: {idx}") + if idx >= max_step: + break + abq.search() + logger.info("========cachekv search done========") + return abq.best_quant_policies + + +def apply_layerwise_quant(model, predictor, args, ptq_dials, skip_list_analysis, layer_num=4): + """ + Quant model layer by layer + For each layer, complete calibration process will be repeated + """ + logger.debug("-------------------Start Layerwise PTQ------------------") + dp_degree = args.data_parallel_degree + try: + hcg = fleet.get_hybrid_communicate_group() + rank = hcg.get_model_parallel_rank() + nranks = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + except: + rank = dist.get_rank() + nranks = dist.get_world_size() + dp_id = 0 + all_layers = [] + for _, cur_layer in model.named_sublayers(): + if type(cur_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]: + all_layers.append(cur_layer.full_name()) + if type(cur_layer) in [FuncWrapper]: + all_layers.append(cur_layer.full_name()) + activation, weight, cachekv, q_config = prepare_qconfig(args) + + if not os.path.exists(args.save_path): + os.makedirs(args.save_path) + act_scales = {} + weight_scales = {} + cachekv_scales = {} + for i in range(0, len(all_layers), layer_num): + cur_layer_name = all_layers[i : i + layer_num] + cachekv_name = [] + for n in cur_layer_name: + if 'cache_kv' in n: + cachekv_name.append(n) + for n in cachekv_name: + cur_layer_name.remove(n) + for skip_layer in skip_list_analysis: + if skip_layer in cur_layer_name: + cur_layer_name.remove(skip_layer) + + logger.debug(f"Quantizing step {i} / {len(all_layers)}") + logger.debug(f"{cur_layer_name} {cachekv_name}") + if not cur_layer_name: + continue + + q_config.add_name_config(cur_layer_name, activation=activation, weight=weight) + # set both act and weight for attention, actually act-k and act-v are quantized + q_config.add_name_config(cachekv_name, weight=cachekv[0], activation=cachekv[1]) + ptq = PTQ(q_config) + model = ptq.quantize(model, inplace=True) + for idx in range(0, len(ptq_dials), args.batch_size): + batch_dials = ptq_dials[idx : idx + args.batch_size] + predictor.predict_ptq(batch_dials) + if idx % 10 == 0: + logger.info(f"Sample PTQ Step: {idx}") + + act_scales, weight_scales, cachekv_scales = get_scales(model, act_scales, weight_scales, \ + cachekv_scales, dp_degree, nranks, rank) + model = ptq.convert(model, inplace=True) + + + + save_scales(args, act_scales, weight_scales, cachekv_scales, rank, dp_id) + + rank = dist.get_rank() + nranks = dist.get_world_size() + if nranks == 1: + model_path = os.path.join(args.save_path, "model_state.pdparams") + else: + model_path = os.path.join(args.save_path, f"model_state.tp0{rank}.pdparams") + save_quant_model(model.state_dict(), model_path, dp_id=dp_id) + logger.info(f"Save quant model to {args.save_path}") + logger.debug("-------------------Layerwise PTQ Done------------------") + + +def create_qat_model(model, args, dtype): + """ + Create QAT model + """ + q_config = QuantConfig(activation=None, weight=None) + q_config.add_qat_layer_mapping(LoRALinear, QuantedLoRALinear) + q_config.add_qat_layer_mapping(ColumnParallelLinear, QuantizedColumnParallelLinear) + q_config.add_qat_layer_mapping(RowParallelLinear, QuantizedRowParallelLinear) + if args is None or args.quant_type == "W8A8": + activation = PACTQuanter(quanter=FakeQuanterWithAbsMaxObserver(), init_value=20, dtype=dtype) + weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=8, dtype=dtype) + elif args.quant_type in ["WINT4", "W4A16"]: + activation = None + weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=4, dtype=dtype) + elif args.quant_type in ["WINT8", "W8A16"]: + activation = None + weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=8, dtype=dtype) + elif args.quant_type == "W4A8": + activation = PACTQuanter(quanter=FakeQuanterWithAbsMaxObserver(), init_value=20, dtype=dtype) + weight = FakeQuanterChannelWiseAbsMaxObserver(bit_length=4, dtype=dtype) + else: + raise ValueError("quant_type should be one of ['W8A8', 'WINT4', 'WINT8', 'W4A8', 'W4A16', 'W8A16']") + for cur_name, cur_layer in model.named_sublayers(): + if "out_linear" in cur_name: + continue + if type(cur_layer) in [ColumnParallelLinear, RowParallelLinear, paddle.nn.Linear]: + q_config.add_name_config([cur_layer.full_name()], activation=activation, weight=weight) + + qat = QAT(q_config) + model = qat.quantize(model, inplace=True) + return model + \ No newline at end of file diff --git a/qwen_quant.py b/qwen_quant.py new file mode 100644 index 00000000000..433658dc339 --- /dev/null +++ b/qwen_quant.py @@ -0,0 +1,645 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import copy +import json +import os +import sys +import time +from abc import abstractmethod +from contextlib import contextmanager +from dataclasses import dataclass, field +from threading import Thread +from typing import List + +import numpy as np +import paddle +import paddle.incubate.multiprocessing as mp +from paddle.base.framework import in_cinn_mode, in_pir_executor_mode +from paddle.distributed import fleet + +from paddlenlp.generation import GenerationConfig, TextIteratorStreamer +from paddlenlp.trainer import PdArgumentParser + +from paddleformers.transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + PretrainedConfig, + PretrainedModel, + PretrainedTokenizer, +) +from paddlenlp.trl import llm_utils +from paddlenlp.utils.import_utils import ( + auto_dynamic_graph_pybind, + is_paddlenlp_ops_available, +) +from paddlenlp.utils.log import logger +import gc + +from gradual_block_quant import apply_block_gptq +from quant_utils import load_quant_model + + +@dataclass +class PredictorArgument: + model_name_or_path: str = field(default=None, metadata={"help": "The directory of model."}) + model_prefix: str = field(default="model", metadata={"help": "the prefix name of static model"}) + save_path: str = field(default="./gbq_model", metadata={"help": "the path to save model"}) + src_length: int = field(default=None, metadata={"help": "The max length of source text."}) + min_length: int = field(default=1, metadata={"help": "the min length for decoding."}) + max_length: int = field(default=1024, metadata={"help": "the max length for decoding."}) + top_k: int = field(default=0, metadata={"help": "top_k parameter for generation"}) + top_p: float = field(default=0.7, metadata={"help": "top_p parameter for generation"}) + temperature: float = field(default=0.95, metadata={"help": "temperature parameter for generation"}) + repetition_penalty: float = field(default=1.0, metadata={"help": "repetition penalty parameter for generation"}) + device: str = field(default="gpu", metadata={"help": "Device"}) + dtype: str = field(default=None, metadata={"help": "Model dtype"}) + lora_path: str = field(default=None, metadata={"help": "The directory of LoRA parameters. Default to None"}) + export_precache: bool = field(default=False, metadata={"help": "whether use prefix weight to do infer"}) + prefix_path: str = field( + default=None, metadata={"help": "The directory of Prefix Tuning parameters. Default to None"} + ) + decode_strategy: str = field( + default="sampling", + metadata={ + "help": "the decoding strategy of generation, which should be one of ['sampling', 'greedy_search', 'beam_search']. Default to sampling" + }, + ) + use_flash_attention: bool = field( + default=False, + metadata={"help": "Whether to use flash attention"}, + ) + + mode: str = field( + default="dynamic", metadata={"help": "the type of predictor, it should be one of [dynamic, static]"} + ) + inference_model: bool = field(default=False, metadata={"help": "whether use InferenceModel to do generation"}) + quant_type: str = field( + default="", + metadata={ + "help": "Quantization type. Supported values: a8w8, a8w8c8, a8w8_fp8, a8w8c8_fp8, weight_only_int4, weight_only_int8" + }, + ) + avx_model: bool = field( + default=False, metadata={"help": "whether use AvxModel to do generation when using cpu inference"} + ) + avx_type: str = field( + default=None, + metadata={ + "help": "avx compute type. Supported values: fp16, bf16,fp16_int8\ + fp16: first_token and next_token run in fp16\ + fp16_int8 : first_token run in fp16, next token run in int8" + }, + ) + avx_cachekv_type: str = field( + default="fp16", + metadata={"help": "avx cachekv type. Supported values: fp16,int8"}, + ) + batch_size: int = field(default=1, metadata={"help": "The batch size of data."}) + benchmark: bool = field( + default=False, + metadata={ + "help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. " + }, + ) + use_fake_parameter: bool = field(default=False, metadata={"help": "use fake parameter, for ptq scales now."}) + block_attn: bool = field(default=False, metadata={"help": "whether use block attention"}) + block_size: int = field(default=64, metadata={"help": "the block size for cache_kvs."}) + cachekv_int8_type: str = field( + default=None, + metadata={ + "help": "If cachekv_int8_type set as `dynamic`, cache kv would be quantized to int8 dynamically. If cachekv_int8_type set as `static`, cache kv would be quantized to int8 Statically." + }, + ) + + append_attn: bool = field(default=False, metadata={"help": "whether use append attention"}) + + chat_template: str = field( + default=None, + metadata={ + "help": "the path of `chat_template.json` file to handle multi-rounds conversation. " + "If is None(do not set --chat_template argument), it will use the default `chat_template.json`;" + "If is equal with `model_name_or_path`, it will use the default loading; " + "If is directory, it will find the `chat_template.json` under the directory; If is file, it will load it." + "If is none string, it will not use chat_template.json." + }, + ) + + total_max_length: int = field( + default=4096, metadata={"help": "Super parameter. Maximum sequence length(encoder+decoder)."} + ) + speculate_method: str = field( + default=None, + metadata={ + "help": "speculate method, it should be one of ['None', 'inference_with_reference', 'eagle', 'mtp']" + }, + ) + speculate_max_draft_token_num: int = field( + default=1, + metadata={"help": "the max length of draft tokens for speculate method."}, + ) + speculate_max_ngram_size: int = field(default=1, metadata={"help": "the max ngram size of speculate method."}) + speculate_verify_window: int = field( + default=2, metadata={"help": "the max length of verify window for speculate method."} + ) + speculate_max_candidate_len: int = field(default=5, metadata={"help": "the max length of candidate tokens."}) + draft_model_name_or_path: str = field(default=None, metadata={"help": "The directory of eagle or draft model"}) + draft_model_quant_type: str = field( + default="", + metadata={"help": "Draft model quantization type. Reserved for future"}, + ) + return_full_hidden_states: bool = field(default=False, metadata={"help": "whether return full hidden_states"}) + + mla_use_matrix_absorption: bool = field(default=False, metadata={"help": "implement mla with matrix-absorption."}) + weightonly_group_size: int = field(default=-1, metadata={"help": "the max length of candidate tokens."}) + weight_block_size: List[int] = field( + default_factory=lambda: [128, 128], + metadata={"help": "Quantitative granularity of weights. Supported values: [128 128]"}, + ) + moe_quant_type: str = field( + default="", + metadata={"help": "Quantization type of moe. Supported values: weight_only_int4, weight_only_int8"}, + ) + output_via_mq: bool = field( + default=True, + metadata={"help": "Controls whether the message queue is enabled for output"}, + ) + dynamic_insert: bool = field(default=False, metadata={"help": "whether use dynamic insert"}) + total_request_num: int = field(default=None, metadata={"help": "The total number of request data"}) + lazy_load: bool = field( + default=False, + metadata={"help": "Whether to use lazy load"}, + ) + gptq: bool = field( + default=False, + metadata={"help": "Whether to use gptq"}, + ) + iq: bool = field( + default=False, + metadata={"help": "Whether to use iq"}, + ) + ptq_samples: int = field( + default=-1, + metadata={"help": ""}, + ) + offload_data: bool = field( + default=False, + metadata={"help": "Whether to offload data"}, + ) + debug: bool = field( + default=0, + metadata={"help": "Whether to offload data"}, + ) + use_hessian: bool = field( + default=0, + metadata={"help": "Whether to offload data"}, + ) + use_tq: bool = field( + default=0, + metadata={"help": "Whether to offload data"}, + ) + use_wint4: bool = field( + default=0, + metadata={"help": "Whether to offload data"}, + ) + wint4_all: bool = field( + default=0, + metadata={"help": "Whether to offload data"}, + ) + pp_id: bool = field( + default=0, + metadata={"help": "Whether to offload data"}, + ) + group_size: int = field( + default=128, + metadata={"help": "Whether to offload data"}, + ) + load_quant_path: str = field( + default=None, + metadata={"help": "Whether to offload data"}, + ) + + def __post_init__(self): + if self.speculate_method is not None: + self.append_attn = True + if self.append_attn: + self.block_attn = True + if self.block_attn: + self.inference_model = True + assert self.max_length < self.total_max_length, "max_length should smaller than total_max_length." + if self.src_length is None: + self.src_length = self.total_max_length - self.max_length + # update config parameter for inference predictor + if self.decode_strategy == "greedy_search": + self.top_p = 0.0 + self.temperature = 1.0 + if self.total_request_num is None: + self.total_request_num = self.batch_size + + +@dataclass +class ModelArgument: + model_type: str = field( + default=None, + metadata={"help": "the type of the model, which can be one of ['gpt-3', 'ernie-3.5-se', 'llama-img2txt']"}, + ) + data_file: str = field(default=None, metadata={"help": "data file directory"}) + output_file: str = field(default="output.json", metadata={"help": "predict result file directory"}) + + +def batchfy_text(texts, batch_size): + batch_texts = [] + batch_start = 0 + while batch_start < len(texts): + batch_texts += [texts[batch_start : min(batch_start + batch_size, len(texts))]] + batch_start += batch_size + return batch_texts + + +class BasePredictor: + def __init__( + self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, model: PretrainedModel = None + ): + if model is not None and hasattr(model, "config"): + self.model_config = model.config + else: + self.model_config = AutoConfig.from_pretrained(config.model_name_or_path) + + self.config: PredictorArgument = config + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path, padding_side="left") + + self.tokenizer = tokenizer + + self.return_tensors = "pd" + self.tensor_parallel_rank, self.tensor_parallel_degree = llm_utils.init_dist_env() + self.model_config.tensor_parallel_rank, self.model_config.tensor_parallel_degree = ( + self.tensor_parallel_rank, + self.tensor_parallel_degree, + ) + + try: + self.generation_config = GenerationConfig.from_pretrained(config.model_name_or_path) + except: + logger.warning( + "Can't find generation config, so it will not use generation_config field in the model config" + ) + self.generation_config = None + + def _preprocess(self, source, tgt=None): + # if self.tokenizer.chat_template is not None: + # # for str -> List[str] eg. "hello" + # # for List[str] -> List[str] eg. ["hello", "hello new"] + # # for List[List[str]] -> List[List[List[str]]] eg. 历史对话形式,一轮 + # # [ [ "Hello, how are you?", "I'm doing great. How can I help you today?"], + # # ["I'd like to show off how chat templating works!"], ] + # # for List[Dict] -> List[List[Dict]] [{'role': 'user', 'content': 'hello'}, {'role': 'assistant', 'content': 'nice'}] + # # -> [[{'role': 'user', 'content': 'hello'}, {'role': 'assistant', 'content': 'nice'}]] + # if not isinstance(source, list) or not isinstance(source[0], str): + # source = [source] + # source = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in source] + # if tgt is not None: + # source = [source[0] + tgt[0]] + + tokenized_source = self.tokenizer( + source, + max_length=self.config.src_length, + truncation=True, + return_attention_mask=True, + return_tensors=self.return_tensors, + padding=True, + # when use chat_template, it should not add special tokens + # chatglm2 prefix-tokens can not be tokenized into ids + add_special_tokens=self.tokenizer.chat_template is None, + ) + return tokenized_source + + @abstractmethod + def _infer(self, inputs): + raise NotImplementedError + + def _postprocess(self, predictions, return_tokens=False): + decoded_predictions = self.tokenizer.batch_decode( + predictions, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + if return_tokens: + return decoded_predictions, predictions + else: + return decoded_predictions + + def predict(self, input_texts: str | list[str], return_tokens=False): + tokenized_source = self._preprocess(input_texts) + # Synchronize the HPU device for the static graph predictor + # Ensure that configuration data read from the CPU is updated to the HPU device + paddle.device.synchronize() + predictions = self._infer(tokenized_source) + decoded_predictions = self._postprocess(predictions, return_tokens=return_tokens) + return decoded_predictions + + +class DygraphPredictor(BasePredictor): + def __init__( + self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, model: PretrainedModel = None, **kwargs + ): + super().__init__(config, tokenizer, model) + self.model = model + if config.dtype is not None: + dtype = config.dtype + else: + raise ValueError("Please specific the model dtype.") + + if self.model is None: + self.model = AutoModelForCausalLM.from_pretrained( + config.model_name_or_path, + use_flash_attention=config.use_flash_attention, + dtype=dtype, + convert_from_hf=True, + tensor_parallel_degree=self.tensor_parallel_degree, + tensor_parallel_rank=self.tensor_parallel_rank, + ) + self.model.eval() + + @paddle.no_grad() + def _infer(self, inputs: dict[str, paddle.Tensor]): + result = self.model.generate( + **inputs, + # max_new_tokens=self.config.max_length, + # bos_token_id=self.tokenizer.bos_token_id, + # eos_token_id=llm_utils.get_eos_token_id(self.tokenizer, self.generation_config), + # pad_token_id=self.tokenizer.pad_token_id, + # decode_strategy=self.config.decode_strategy, + # temperature=self.config.temperature, + # top_k=self.config.top_k, + # top_p=self.config.top_p, + # repetition_penalty=self.config.repetition_penalty, + max_new_tokens=1024, + temperature=0.1, + top_p=0.7, + repetition_penalty=1, + ) + result = result[0] + return result + + +class AutoPredictor: + def __init__(self, *args, **kwargs): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path).`" + ) + + @classmethod + def create_predictor( + cls, + predictor_args: PredictorArgument, + config: PretrainedConfig, + model_args: ModelArgument, + tokenizer: PretrainedTokenizer = None, + model: PretrainedModel = None, + **kwargs, + ): + """ + Create a predictor + + Args: + predictor_args (PredictorArgument): The predictor arguments. + config (PretrainedConfig): The model configuration. + model_args (ModelArgument): The model arguments. + tokenizer (PretrainedTokenizer): The tokenizer. + **kwargs: Additional keyword arguments. + Returns: + Predictor: The predictor. + """ + cache_kvs_shape = None # used for not block_attn/append_attn + cache_k_shapes = None # used for block_attn/append_attn + cache_v_shapes = None # used for block_attn/append_attn + + # static or dynamic + execute_mode = "Dygraph" if predictor_args.mode == "dynamic" else "StaticGraph" + + # infer/ no infer + inference_mode = "" + + predictor_class_name = execute_mode + inference_mode + "Predictor" + + import_class = sys.modules[__name__] + + # import class + predictor_class = getattr(import_class, predictor_class_name) + + # instance + predictor = predictor_class( + predictor_args, + tokenizer=tokenizer, + model=model, + cache_k_shapes=cache_k_shapes, + cache_v_shapes=cache_v_shapes, + cache_kvs_shape=cache_kvs_shape, + model_args=model_args, + **kwargs, + ) + return predictor + + +def create_predictor( + predictor_args: PredictorArgument, + model_args: ModelArgument, + **kwargs, +): + paddle.set_device(predictor_args.device) + paddle.set_default_dtype(predictor_args.dtype) + + from paddlenlp.utils.env import USE_FAST_TOKENIZER + + tokenizer = AutoTokenizer.from_pretrained( + predictor_args.model_name_or_path + ) + + # init chat_template for tokenizer + llm_utils.init_chat_template(tokenizer, predictor_args.model_name_or_path, predictor_args.chat_template) + + # # TODO(wj-Mcat): fix llama tokenzier pad_token bug + + config = AutoConfig.from_pretrained(predictor_args.model_name_or_path) + + tensor_parallel_rank, tensor_parallel_degree = llm_utils.init_dist_env() + + model = None + + # model loading + if False: # predictor_args.inference_model: #---# + pass #wangna + # model = AutoInferenceModelForCausalLM.from_pretrained( + # predictor_args.model_name_or_path, + # config=config, + # predictor_args=predictor_args, + # model_args=model_args, + # dtype=predictor_args.dtype, + # tensor_parallel_degree=tensor_parallel_degree, + # tensor_parallel_rank=tensor_parallel_rank, + # ) + else: + if predictor_args.mode == "dynamic": + # model import (gpt-3,ernie) or AutoModel + if model_args.model_type == "gpt-3": + sys.path.append("./gpt-3") + from modeling import GPTForCausalLM + + model = GPTForCausalLM.from_pretrained( + predictor_args.model_name_or_path, + dtype=predictor_args.dtype, + tensor_parallel_degree=tensor_parallel_degree, + tensor_parallel_rank=tensor_parallel_rank, + tensor_parallel_output=False, + ) + elif model_args.model_type == "ernie-3.5-se": + sys.path.append("./ernie-3.5-se") + from modeling import Ernie35ForCausalLM + + tensor_parallel_degree = paddle.distributed.get_world_size() + tensor_parallel_rank = paddle.distributed.get_rank() + model = Ernie35ForCausalLM.from_pretrained( + predictor_args.model_name_or_path, + dtype=predictor_args.dtype, + tensor_parallel_degree=tensor_parallel_degree, + tensor_parallel_rank=tensor_parallel_rank, + tensor_parallel_output=False, + ) + else: + with paddle.LazyGuard(): + with paddle.no_grad(): + model = AutoModelForCausalLM.from_pretrained( + predictor_args.model_name_or_path, + dtype=predictor_args.dtype, + convert_from_hf=True, + use_flash_attention=predictor_args.use_flash_attention, + tensor_parallel_degree=tensor_parallel_degree, + tensor_parallel_rank=tensor_parallel_rank, + tensor_parallel_output=False, + ) + predictor = AutoPredictor.create_predictor(predictor_args, config, model_args, tokenizer, model=model, **kwargs) + + return predictor + + +def predict(): + parser = PdArgumentParser((PredictorArgument, ModelArgument)) + predictor_args, model_args = parser.parse_args_into_dataclasses() + + llm_utils.set_triton_cache(predictor_args.model_name_or_path, predictor_args.mode) + try: + from paddle.utils import try_import + + try_import("paddlenlp_ops") + except ImportError: + logger.warning("paddlenlp_ops does not exist, please install paddlenlp_ops.") + return + tensor_parallel_degree = paddle.distributed.get_world_size() + if tensor_parallel_degree > 1: + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": tensor_parallel_degree, + "pp_degree": 1, + "sharding_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + + predictor = create_predictor(predictor_args, model_args) + + source_texts = [] + target_texts = [] + if model_args.data_file: + with open(model_args.data_file, "r", encoding="utf-8") as f: + for line in f: + example = json.loads(line) + # src tgt + if isinstance(example["instruction"], str) or predictor.tokenizer.chat_template is None: + if isinstance(example["instruction"], str): + source_texts.append(example["instruction"]) + target_texts.append(example["output"]) + else: + # load multi-rounds dataset + source_texts.append(example["instruction"][0]) + target_texts.append(example["output"][0]) + else: + source_texts.append(list(zip(example["instruction"], example["output"]))) + target_texts.append("") + + else: + source_texts = [ + "济南燃气结清费用需要带什么资料", "中通快递从辽宁到新疆要多久", "度小满逾期如何协商停催", "顺丰快递是昼夜不停的运吗?" + ] + target_texts = ["", "", "", ""] + + batch_source_texts = batchfy_text(source_texts, predictor_args.batch_size) + batch_target_texts = batchfy_text(target_texts, predictor_args.batch_size) + + apply_block_gptq(predictor.model, predictor, batch_source_texts, batch_target_texts, predictor_args) + # exit() + + if predictor_args.load_quant_path: + load_quant_model(predictor.model, predictor_args, None, []) + + if predictor_args.benchmark: + benchmark(predictor) + + + +def benchmark(predictor, predictor_args, model_args): + # Just construct a simple benchmark input. We pad input to the src_length. + test_texts = "你是百度AI,请**参考公开资料提供的信息,回答用户问题**,做到**时效性高,专业权威,客观无偏见**。\n\n### 公开资料说明\n1. 如果不同的公开资料出现矛盾且都符合正常逻辑,务必参考权威性更高的公开资料。如果根据权威性无法区分,请给用户提供多种说法。\n2. 如果用户问题对时间信息比较敏感,结合当前时间和公开资料的发布时间选择合适公开资料。\n\n### 引用说明\n1. 将相关公开资料索引用方括号包裹,置于相关内容后,例如:\"这是相关内容。[1][2]\"。\n\n### 回答要求\n1. 优先满足用户的主要需求,并且从用户问题和公开资料中挖掘用户可能的潜在需求进行满足。\n2. 优先参考公开资料中提供的信息,如果公开资料确实没有用户需求的相关信息,请你说明公开资料没有提及相关内容,并基于自身知识回答。\n3. 如果用户包含负向情绪,如焦虑/不安/困惑/气愤/孤独无助等,请你用更有人情味的风格回答。\n4. 如果用户问题涉及网站访问、平台查询、资源获取、工具使用等需求,并且公开资料中提供了对应的准确链接时,请以\"[网址名称](URL)\"的格式给出。\n\n### 背景信息\n当前时间:2025年09月12日星期五\n当前所在地:山东省济南市\n用户画像:性别: 女;年龄: 中年;手机型号: vivo XFold3;\n用户检索历史: [济南济华燃气王官庄服务站电话: 2025-09-12 16:30:28; 王官庄济华燃气营业厅: 2025-09-12 16:30:21; 燃气结清费用去哪: 2025-09-12 16:28:09; 2024年济南高中录取分数线: 2025-09-12 13:29:08; 2024年高中录取分数线: 2025-09-12 13:28:54; 2024年高中寒假放假时间表: 2025-09-12 13:28:49; 济南385分能上高中吗: 2025-09-12 13:27:40; 385分能上高中吗: 2025-09-12 13:27:26; ABS材质对人体有害吗: 2025-09-12 11:30:55; abs是什么材质?: 2025-09-12 11:30:16; 公租房小区儿童娱乐区域属于配套建设吗: 2025-09-12 09:32:17; 公租房小区没有娱乐设施吗为什么: 2025-09-12 09:30:32; 2026年取消公租房最新通知: 2025-09-12 09:30:07; 公租房小区没有娱乐设施吗: 2025-09-12 09:29:53; 公租房小区没有娱乐设施吗: 2025-09-12 09:28:02; 公租房小区没有娱乐设施合法吗: 2025-09-12 09:26:40; 小区娱乐设施谁安装: 2025-09-12 09:17:26; 监控用漏电保护器还是空气开关: 2025-09-12 08:48:27; 漏电保护器和空气开关有什么区别: 2025-09-12 08:45:56; 秋天吃什么食物最好: 2025-09-12 08:25:17; 红花如意丸的功效与作用: 2025-09-11 10:57:15; 雪莲果是凉性的还是热性的: 2025-09-11 09:09:30; 狗狗不吃饭但精神很好: 2025-09-10 20:26:16; 狗狗不吃饭是怎么回事: 2025-09-10 20:25:04; 狗用脚挠痒痒怎么回事: 2025-09-10 20:24:23; 百香果的功效和作用: 2025-09-10 13:02:05; 劳务合同属于什么合同类型: 2025-09-09 16:28:58; 劳务合同属于行政合同吗: 2025-09-09 16:28:35; 自来水地埋管是什么材料做的: 2025-09-09 14:38:29; 2025年1月灭火器更换新规定: 2025-09-09 09:27:37; 新灭火器第一次换粉是什么时候: 2025-09-09 09:26:43; 口臭是胃火还是肝火: 2025-09-08 19:42:29; ]\n\n### 公开资料\n[1] 标题: 奥德集团有限公司费县分公司办事服务 \n参考特征:内容权威性非常高, 作者权威性非常高, 时效性较高, 发布于2025-05-27\n正文: (一)报装资料、申请受理: 1.开发商及村委集体用户安装:单体楼的建设工程规划许可证复印件、单体楼建筑施工图、小区整体平面图电子版一份。 2.零散居民户安装:即前期社区整体安装燃气时未安装的用户,出示房产证或者购房合同原件。 3.非居民用户:工商用户及小微用户等非居民用户燃气报装资料需提供用气地址的产权资料、营业执照、有效身份证明。\n\n\n[2] 标题: 燃气缴费、维修及相关服务办理程序、线上线下办理渠道、时限、网点设置、服务标准、服务承诺和便民措施 \n参考特征:内容权威性较高, 作者权威性非常高, 时效性较高, 发布于2025-02-19\n正文: 1、用户充值流程: (1)营业厅充值流程: 用户携带燃气卡、本到营业网点柜台→递交燃气卡、本→用户付款→营业员核对信息→系统充值→打印收款收据→递还燃气卡、本、收据→充值完成。 (2)政务大厅充值流程: 用户携带身份证、燃气卡、本到政务大厅→自助叫号机叫号→柜台钱等候叫号→递交燃气卡、本→用户付款→营业员核对信息→系统充值→打印收款收据→递还燃气卡、本、收据→充值完成。 (3)线上充值流程: 微信关注“奥德悦生活”微信公众号→首次登陆绑定用户编号→选择购气量→在线支付→到就近自助写卡机(实时更新,可在公众号内查询)写卡。\n\n\n[3] 标题: 燃气民用销户办事指南 \n参考特征:内容权威性非常高, 作者权威性非常高, 时效性较高, 发布于2025-06-06\n正文: 用户携带户主身份证原件及复印件、天然气用户卡、银行卡复印件(退还预存燃气费),并提交销户申请。\n\n\n[4] 标题: 【办事服务】2025年山东长乐集团民生燃气有限公司用气申请、过户、销户等项目办事服务指南\n参考特征:内容权威性较高, 作者权威性非常高, 时效性较高, 发布于2025-06-03\n正文: 一、申请 单位或是小区统一安装由单位、小区物业办公室或开发公司向燃气公司提出安装申请,填写申请单,并提供小区平面图;散户安装(1)持有效身份证件及现金(如开发商代收,需带天然气配套设施收费票据或证明)(2)签订居民燃气供用气合同。\n\n### 用户问题\n济南燃气结清费用需要带什么资料", + + benchmark_texts = [ + test_texts + ] + + batch_benchmark_texts = batchfy_text(benchmark_texts, 1) + print("***********Start Benchmark**********") + + warmup_time = 5 + test_time = 20 + + print("***********Start Warmup**********") + for _ in range(warmup_time): + for bs, batch_source_text in enumerate(batch_benchmark_texts): + predictor.predict(batch_source_text) + + print("***********Start Speed Test**********") + start = time.perf_counter() + output_tokens = 0 + for _ in range(test_time): + for bs, batch_source_text in enumerate(batch_benchmark_texts): + results = predictor.predict(batch_source_text, return_tokens=True) + if predictor.tensor_parallel_rank == 0: + output_tokens += sum([len(tokens) for tokens in results[-1]]) + end = time.perf_counter() + if predictor.tensor_parallel_rank == 0: + print("Avg Elapse time is: ", (end - start) / test_time) + print("Output tokens is: ", output_tokens) + print( + "Input length is: {}, Output length is: {}, bs is: {}, IPS: {:.3f} tokens/s, QPS: {:.3f} requests/s. ".format( + 16384, + 1024, + 1, + (output_tokens / (end - start)), + (1 * test_time / (end - start)), + ) + ) + + +if __name__ == "__main__": + predict() diff --git a/run.sh b/run.sh new file mode 100644 index 00000000000..f0cad76f533 --- /dev/null +++ b/run.sh @@ -0,0 +1,39 @@ +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT +unset PADDLE_TRAINERS_NUM + +export DISTRIBUTED_TRAINER_ENDPOINTS=`hostname -i` +export PYTHONPATH="/root/paddlejob/workspace/env_run/output/wangna11/PaddleFormers/third_party/PaddleNLP/":$PYTHONPATH +export PYTHONPATH="/root/paddlejob/workspace/env_run/output/wangna11/PaddleFormers/third_party/PaddleSlim/":$PYTHONPATH +export LD_LIBRARY_PATH=/root/paddlejob/workspace/env_run/output/wangna11/miniconda3/envs/qwen/lib/python3.10/site-packages/nvidia/cudnn/lib/:$LD_LIBRARY_PATH +export LOAD_STATE_DICT_THREAD_NUM=128 +export DISABLE_FASTER_SET_STATE_DICT=1 #---# + +export CUDA_VISIBLE_DEVICES=4,5,6,7 + +model_name_or_path=Qwen/qwen30b_a3b_model_1119/ +data_path=/root/paddlejob/workspace/env_run/output/wangna11/wwb/sft_1119.jsonl + +save_name=qwen30b_a3b_model_1_tmp +log_dir=/root/paddlejob/workspace/env_run/output/wangna11/PaddleFormers/log_${save_name} +save_path=/root/paddlejob/workspace/env_run/output/wangna11/PaddleFormers/output/${save_name} + +rm -rf ${log_dir} +rm -rf ${save_path} +mkdir -p ${save_path} + +python -u -m paddle.distributed.launch \ + --log_dir ${log_dir} \ + qwen_quant.py \ + --model_name_or_path ${model_name_or_path} \ + --dtype bfloat16 \ + --mode dynamic \ + --total_max_length 8192 \ + --data_file ${data_path} \ + --save_path ${save_path} \ + --quant_type W4A8C8 \ + --gptq True \ + --output_file ./output.txt \