diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py index d04961ece87..8f39ef2e985 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/Pipeline-Models/qwen.py @@ -67,7 +67,6 @@ torch_dtype=torch.float16, attn_implementation="eager", transpose_value_cache=not args.disable_transpose_value_cache, - mixed_precision=True, trust_remote_code=True) else: model = AutoModelForCausalLM.load_low_bit( diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/README.md index 53f47df7946..1a3e277346f 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/README.md @@ -22,7 +22,7 @@ In the example [generate.py](./generate.py), we show a basic use case for a phi- ### 1. Install #### 1.1 Installation on Windows We suggest using conda to manage environment: -```bash +```cmd conda create -n llm python=3.10 libuv conda activate llm @@ -100,7 +100,7 @@ The examples below show how to run the **_optimized HuggingFace & FunASR model i - [Speech_Paraformer-Large](./speech_paraformer-large.py) ### 4.1 Run MiniCPM-Llama3-V-2_5 & MiniCPM-V-2_6 -```bash +```cmd # to run MiniCPM-Llama3-V-2_5 python minicpm-llama3-v2.5.py @@ -117,6 +117,12 @@ Arguments info: - `--max-prompt-len MAX_PROMPT_LEN`: Defines the maximum number of tokens that the input prompt can contain. It is default to be `512`. - `--disable-transpose-value-cache`: Disable the optimization of transposing value cache. +For [MiniCPM-V-2_6](./minicpm_v_2_6.py), you could also try to enable mixed precision optimization when encountering output problems: + +```cmd +python minicpm_v_2_6.py --mixed-precision +``` + #### Sample Output ##### [openbmb/MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) @@ -131,7 +137,7 @@ The image features a young child holding and showing off a white teddy bear wear ``` ### 4.2 Run Speech_Paraformer-Large -```bash +```cmd # to run Speech_Paraformer-Large python speech_paraformer-large.py ``` @@ -154,7 +160,7 @@ rtf_avg: 0.232: 100%|███████████████████ ``` ### 4.3 Run Bce-Embedding-Base-V1 -```bash +```cmd # to run Bce-Embedding-Base-V1 python bce-embedding.py ``` diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/minicpm_v_2_6.py b/python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/minicpm_v_2_6.py index 1a524a5b2dc..f25f3409f2b 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/minicpm_v_2_6.py +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal/minicpm_v_2_6.py @@ -41,6 +41,7 @@ parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--intra-pp", type=int, default=None) parser.add_argument("--inter-pp", type=int, default=None) + parser.add_argument("--mixed-precision", action='store_true') args = parser.parse_args() model_path = args.repo_id_or_model_path @@ -57,6 +58,7 @@ intra_pp=args.intra_pp, inter_pp=args.inter_pp, transpose_value_cache=not args.disable_transpose_value_cache, + mixed_precision=args.mixed_precision, ) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index eb684bce715..601bf7720f7 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -183,6 +183,10 @@ def from_pretrained(cls, *args, **kwargs): from intel_npu_acceleration_library.compiler import create_npu_kernels if optimize_model: + # TODO: enable mixed_precision when pipeline=True + if pipeline: + mixed_precision = False + invalidInputError( max_prompt_len < max_context_len, ( @@ -282,7 +286,8 @@ def optimize_npu_model(cls, *args, **kwargs): group_size=quantization_group_size, qtype=qtype, convert_model=convert_model, - save_directory=save_directory) + save_directory=save_directory, + mixed_precision=mixed_precision) model.save_low_bit = types.MethodType(save_low_bit, model) return model diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py index 2e98c1eb937..b9236fdee24 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py @@ -129,10 +129,14 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision, if quantization_group_size == 0: n_splits_linear = 1 if qtype == "sym_int8_rtn": - # do not split mlp down_proj for Qwen2-7B & sym_int8 + # do not split mlp down_proj for Qwen2-7B/MiniCPM-V-2_6 & sym_int8 n_splits_down_proj = 1 else: - n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1 + if (not mixed_precision) and model.config.intermediate_size == 18944: + # For Qwen2-7B and MiniCPM-V-2_6 + n_splits_down_proj = 16 + else: + n_splits_down_proj = 1 else: invalidInputError( model.config.hidden_size % quantization_group_size == 0 and @@ -170,10 +174,10 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision, # for Qwen2-7B-Insturct and MiniCPM-V 2.6, divide lm_head into 14 parts if model.config.hidden_size == 3584 and (model.config.vocab_size == 152064 or model.config.vocab_size == 151666) and not cpu_lm_head: - # Do not split lm_head and use sym_int8 instead when mixed_precison is True if quantization_group_size == 0: - # Do not split lm_head and use sym_int8 instead when mixed_precison is True - is_split = (not mixed_precision) and qtype == "sym_int4_rtn" + # TODO: may further adjust strategy, use sym_int8 for now + # is_split = (not mixed_precision) and qtype == "sym_int4_rtn" + is_split = False split_num = 14 if is_split else 1 new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num, bias=model.lm_head.bias, use_split=False) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index ccf6e242d90..da1213658d1 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -423,7 +423,7 @@ def feed_forward_sanm_decoder(self, x, w_1_bias, norm_weights, norm_bias): w_2 = self.linear(w_1_norm, 512, 2048, bias=False, wt_dtype=self.dtype) return w_2 - def mlp(self, hidden_states, seq_len=-1, mode="prefill"): + def mlp(self, hidden_states, seq_len=-1, mode="prefill", mixed_precision=False): mm1 = self.linear( hidden_states, self.intermediate_size, self.hidden_size, bias=False, wt_dtype=self.dtype, n_splits=self.n_splits_linear, @@ -438,8 +438,9 @@ def mlp(self, hidden_states, seq_len=-1, mode="prefill"): ) # type: ignore[attr-defined] mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] + wt_dtype = torch.int8 if mixed_precision else self.dtype hidden_states = self.linear( - mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype, + mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=wt_dtype, n_splits=self.n_splits_down_proj, scale_factor=(self.group_size == 0), is_prefill=(mode == "prefill") diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index 015efe10031..a77ac7690e0 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -97,7 +97,8 @@ def __init__( intermediate_size, n_splits_linear: int = 1, n_splits_down_proj: int = 1, - group_size: int = 0 + group_size: int = 0, + mixed_precision: bool = False, ): super().__init__(max_seq_len=max_seq_len, transpose_value=transpose_value, @@ -117,6 +118,7 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.transpose_value = transpose_value self.num_layers = num_layers + self.mixed_precision = mixed_precision cos = self.constant(self.cached_cos) self.cos = self.unsqueeze(cos, axis=0) @@ -279,7 +281,7 @@ def build_decoder( hidden_states = self.eltwise_add(residual, attn_output) residual = hidden_states hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight) - hidden_states = self.mlp(hidden_states, self.seq_len, self.mode) + hidden_states = self.mlp(hidden_states, self.seq_len, self.mode, self.mixed_precision) hidden_states = self.eltwise_add(residual, hidden_states) hidden_states = self.convert_to_fp16(hidden_states) @@ -311,6 +313,7 @@ def __init__( n_splits_linear: int = 1, n_splits_down_proj: int = 1, group_size: int = 0, + mixed_precision: bool = False, ): super().__init__() @@ -375,7 +378,8 @@ def __init__( dtype=np_dtype, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + mixed_precision=mixed_precision, ) self.backend_decoders.append(decoder) @@ -461,6 +465,7 @@ def __init__( n_splits_linear: int = 1, n_splits_down_proj: int = 1, group_size: int = 0, + mixed_precision: bool = False, ): super().__init__() self.op_parameters = parameters @@ -491,7 +496,8 @@ def __init__( dtype=np_dtype, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + mixed_precision=mixed_precision, ) self.layer_norm_0 = layer_norm_0 self.layer_norm_1 = layer_norm_1 @@ -571,6 +577,7 @@ def run_decode( rms_norm_eps = model.config.rms_norm_eps intermediate_size = model.config.intermediate_size group_size = getattr(model.config, "group_size", 0) + mixed_precision = getattr(model.config, "mixed_precision", False) layer_weights = [] input_layer_norm_weights = [] post_attn_layernorm_weights = [] @@ -630,7 +637,8 @@ def run_decode( do_print=False, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + mixed_precision=mixed_precision, ) dist.barrier() @@ -802,6 +810,7 @@ def run_prefill( rms_norm_eps = model.config.rms_norm_eps intermediate_size = model.config.intermediate_size group_size = getattr(model.config, "group_size", 0) + mixed_precision = getattr(model.config, "mixed_precision", False) deocderlayers = [] layer_weights = [] input_layer_norm_weights = [] @@ -850,7 +859,8 @@ def run_prefill( transpose_value=transpose_value_cache, n_splits_linear=n_splits_linear, n_splits_down_proj=n_splits_down_proj, - group_size=group_size + group_size=group_size, + mixed_precision=mixed_precision, ) layer_weights.extend(weights) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py index 50448bd684b..f1bb597e6ee 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/convert_pipeline.py @@ -194,8 +194,9 @@ def convert_llm(model: torch.nn.Module, transpose_value_cache: bool, group_size: int, qtype: str, - convert_model: bool=False, - save_directory: str=None): + convert_model: bool = False, + save_directory: str = None, + mixed_precision: bool = False): # whether to set layernorm weight as const layernorm_const = os.environ.get("IPEX_LLM_LAYERNORM_CONST", "1") == "1" if group_size == 0: @@ -204,7 +205,11 @@ def convert_llm(model: torch.nn.Module, # do not split mlp down_proj for Qwen2-7B & sym_int8 n_splits_down_proj = 1 else: - n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1 + if (not mixed_precision) and model.config.intermediate_size == 18944: + # For Qwen2-7B + n_splits_down_proj = 16 + else: + n_splits_down_proj = 1 else: n_splits_linear = model.config.hidden_size // group_size n_splits_down_proj = model.config.intermediate_size // group_size