From 1fd0b6cd0884cf0fedbebc07d551a14811997bb7 Mon Sep 17 00:00:00 2001 From: plusbang Date: Fri, 7 Feb 2025 16:22:49 +0800 Subject: [PATCH] fix cmt --- .../transformers/npu_pipeline_model/common.py | 161 +++++++++++++++++ .../transformers/npu_pipeline_model/llama.py | 164 +----------------- .../transformers/npu_pipeline_model/qwen.py | 66 +------ 3 files changed, 174 insertions(+), 217 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py index fbccd683d70..13dbb013a43 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py @@ -173,6 +173,105 @@ def __init__( self.compile() +class Llama32Embedding(NNFactory): + def __init__( + self, + vocab_size, + embedding_dim, + embedding_weight, + padding_idx, + inv_freq, + attention_scaling, + dtype, # fp16 + device: str = "NPU", + ): + super().__init__(False, device) + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.attention_scaling = attention_scaling + self.dtype = dtype + + # define input + weight = self.constant(embedding_weight) + input = self.parameter((1, 1), dtype=np.int32) + position_ids = self.parameter((1, 1), dtype=np.int64) + inv_freq = self.constant(inv_freq) + + # embed_tokens module + if padding_idx == -1: + padding_idx += vocab_size + + axis_node = self.constant(np.array([0], dtype=np.int64)) + if padding_idx is not None: + masked_embeddings = np.ones(weight.shape, dtype=np.float16) + masked_embeddings[padding_idx, :] = 0.0 # mask + + node_mask = self.constant(masked_embeddings) + node_masked_w = self.eltwise_mul(weight, node_mask) + res = self.gather(node_masked_w, input, axis_node, 0) + else: + res = self.gather(weight, input, axis_node, 0) + + # rotary_emb module + inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1)) + position_ids = self.reshape(position_ids, (1, 1, 1)) + freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq), + self.convert_to_fp32(position_ids)) + freqs = self.transpose(freqs, [0, 2, 1]) + emb = self.concat(freqs, freqs, axis=2) + cos = self.cos(emb) + sin = self.sin(emb) + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + # define outputs + res = self.convert_to_fp16(res) + cos = self.convert_to_fp32(cos) + sin = self.convert_to_fp32(sin) + + print("start compiling") + self.compile() + + +class Llama32PostEmbedding(NNFactory): + def __init__( + self, + inv_freq, + attention_scaling, + input_len: int = 1, + device: str = "NPU", + ): + super().__init__(False, device) + self.attention_scaling = attention_scaling + + # define input + position_ids = self.parameter((1, input_len), dtype=np.int64) + inv_freq = self.constant(inv_freq) + + # rotary_emb module + inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1)) + position_ids = self.reshape(position_ids, (1, 1, input_len)) + freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq), + self.convert_to_fp32(position_ids)) + freqs = self.transpose(freqs, [0, 2, 1]) + emb = self.concat(freqs, freqs, axis=2) + cos = self.cos(emb) + sin = self.sin(emb) + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + if input_len > 1: + cos = self.unsqueeze(cos, [1]) + sin = self.unsqueeze(sin, [1]) + + # define outputs + cos = self.convert_to_fp32(cos) + sin = self.convert_to_fp32(sin) + + print("start compiling") + self.compile() + + def obtain_weight_from_single_layer(attn_layer, mlp_layer): weights = [] if hasattr(attn_layer, "q_proj_dq_list"): @@ -216,3 +315,65 @@ def obtain_qkv_bias_from_single_layer(attn_layer): k_bias = attn_layer.k_proj.bias.to(torch.float16) v_bias = attn_layer.v_proj.bias.to(torch.float16) return q_bias, k_bias, v_bias + + +def obtain_embedding_from_model(model, convert_model, temp_dir, weight_dir, + max_prompt_len, keep_ir, compile_blob): + if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"): + # llama-2-7B & llama-3-8B + embedding_layer = model.model.embed_tokens + new_embedding = LLMEmbedding( + vocab_size=model.config.vocab_size, + embedding_dim=model.config.hidden_size, + embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), + padding_idx=model.config.pad_token_id, + dtype=np.float16, + ) + if convert_model: + bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") + embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) + first_blob_path = None + else: + first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", + temp_dir, keep_ir=keep_ir, + compile_blob=compile_blob) + os.remove(os.path.join(temp_dir, "embedding.bin")) + else: + # llama-3.2-3B & llama-3.2-1B + # for transformers >= 4.45.0 + embedding_layer = model.model.embed_tokens + new_embedding = Llama32Embedding( + vocab_size=model.config.vocab_size, + embedding_dim=model.config.hidden_size, + embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), + padding_idx=model.config.pad_token_id, + inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16), + attention_scaling=model.model.rotary_emb.attention_scaling, + dtype=np.float16, + ) + if convert_model: + bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") + embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) + first_blob_path = None + # save embedding post module + inv_freq = model.model.rotary_emb.inv_freq.to(torch.float16) + attention_scaling = model.model.rotary_emb.attention_scaling + embedding_post = Llama32PostEmbedding(inv_freq=inv_freq, + attention_scaling=attention_scaling, + input_len=1) + update_names_of_IR_and_export_blob(embedding_post, "embedding_post", + temp_dir, keep_ir=keep_ir, compile_blob=compile_blob) + embedding_post_prefill = Llama32PostEmbedding(inv_freq=inv_freq, + attention_scaling=attention_scaling, + input_len=max_prompt_len) + update_names_of_IR_and_export_blob(embedding_post_prefill, + "embedding_post_prefill", + temp_dir, keep_ir=keep_ir, compile_blob=compile_blob) + os.remove(os.path.join(temp_dir, "embedding_post.bin")) + os.remove(os.path.join(temp_dir, "embedding_post_prefill.bin")) + else: + first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", + temp_dir, keep_ir=keep_ir, + compile_blob=compile_blob) + os.remove(os.path.join(temp_dir, "embedding.bin")) + return first_blob_path diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py index dea8c0f32a4..714213796dd 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py @@ -18,108 +18,8 @@ import torch import numpy as np import os -from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead, \ - obtain_weight_from_single_layer -from intel_npu_acceleration_library.backend.factory import NNFactory - - -class Llama32Embedding(NNFactory): - def __init__( - self, - vocab_size, - embedding_dim, - embedding_weight, - padding_idx, - inv_freq, - attention_scaling, - dtype, # fp16 - device: str = "NPU", - ): - super().__init__(False, device) - self.vocab_size = vocab_size - self.embedding_dim = embedding_dim - self.padding_idx = padding_idx - self.attention_scaling = attention_scaling - self.dtype = dtype - - # define input - weight = self.constant(embedding_weight) - input = self.parameter((1, 1), dtype=np.int32) - position_ids = self.parameter((1, 1), dtype=np.int64) - inv_freq = self.constant(inv_freq) - - # embed_tokens module - if padding_idx == -1: - padding_idx += vocab_size - - axis_node = self.constant(np.array([0], dtype=np.int64)) - if padding_idx is not None: - masked_embeddings = np.ones(weight.shape, dtype=np.float16) - masked_embeddings[padding_idx, :] = 0.0 # mask - - node_mask = self.constant(masked_embeddings) - node_masked_w = self.eltwise_mul(weight, node_mask) - res = self.gather(node_masked_w, input, axis_node, 0) - else: - res = self.gather(weight, input, axis_node, 0) - - # rotary_emb module - inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1)) - position_ids = self.reshape(position_ids, (1, 1, 1)) - freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq), - self.convert_to_fp32(position_ids)) - freqs = self.transpose(freqs, [0, 2, 1]) - emb = self.concat(freqs, freqs, axis=2) - cos = self.cos(emb) - sin = self.sin(emb) - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - # define outputs - res = self.convert_to_fp16(res) - cos = self.convert_to_fp32(cos) - sin = self.convert_to_fp32(sin) - - print("start compiling") - self.compile() - - -class Llama32PostEmbedding(NNFactory): - def __init__( - self, - inv_freq, - attention_scaling, - input_len: int = 1, - device: str = "NPU", - ): - super().__init__(False, device) - self.attention_scaling = attention_scaling - - # define input - position_ids = self.parameter((1, input_len), dtype=np.int64) - inv_freq = self.constant(inv_freq) - - # rotary_emb module - inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1)) - position_ids = self.reshape(position_ids, (1, 1, input_len)) - freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq), - self.convert_to_fp32(position_ids)) - freqs = self.transpose(freqs, [0, 2, 1]) - emb = self.concat(freqs, freqs, axis=2) - cos = self.cos(emb) - sin = self.sin(emb) - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - if input_len > 1: - cos = self.unsqueeze(cos, [1]) - sin = self.unsqueeze(sin, [1]) - - # define outputs - cos = self.convert_to_fp32(cos) - sin = self.convert_to_fp32(sin) - - print("start compiling") - self.compile() +from .common import update_names_of_IR_and_export_blob, LowBitLLMLMHead, \ + obtain_weight_from_single_layer, obtain_embedding_from_model def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, @@ -197,62 +97,10 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir, bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin") weight.tofile(bin_file) - if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"): - # llama-2-7B & llama-3-8B - embedding_layer = model.model.embed_tokens - new_embedding = LLMEmbedding( - vocab_size=model.config.vocab_size, - embedding_dim=model.config.hidden_size, - embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), - padding_idx=model.config.pad_token_id, - dtype=np.float16, - ) - if convert_model: - bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") - embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) - first_blob_path = None - else: - first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", - temp_dir, keep_ir=keep_ir, - compile_blob=compile_blob) - os.remove(os.path.join(temp_dir, "embedding.bin")) - else: - # llama-3.2-3B & llama-3.2-1B - embedding_layer = model.model.embed_tokens - new_embedding = Llama32Embedding( - vocab_size=model.config.vocab_size, - embedding_dim=model.config.hidden_size, - embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), - padding_idx=model.config.pad_token_id, - inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16), - attention_scaling=model.model.rotary_emb.attention_scaling, - dtype=np.float16, - ) - if convert_model: - bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") - embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) - first_blob_path = None - # save embedding post module - inv_freq = model.model.rotary_emb.inv_freq.to(torch.float16) - attention_scaling = model.model.rotary_emb.attention_scaling - embedding_post = Llama32PostEmbedding(inv_freq=inv_freq, - attention_scaling=attention_scaling, - input_len=1) - update_names_of_IR_and_export_blob(embedding_post, "embedding_post", - temp_dir, keep_ir=keep_ir, compile_blob=compile_blob) - embedding_post_prefill = Llama32PostEmbedding(inv_freq=inv_freq, - attention_scaling=attention_scaling, - input_len=max_prompt_len) - update_names_of_IR_and_export_blob(embedding_post_prefill, - "embedding_post_prefill", - temp_dir, keep_ir=keep_ir, compile_blob=compile_blob) - os.remove(os.path.join(temp_dir, "embedding_post.bin")) - os.remove(os.path.join(temp_dir, "embedding_post_prefill.bin")) - else: - first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", - temp_dir, keep_ir=keep_ir, - compile_blob=compile_blob) - os.remove(os.path.join(temp_dir, "embedding.bin")) + first_blob_path = obtain_embedding_from_model(model, convert_model, + temp_dir, weight_dir, + max_prompt_len, + keep_ir, compile_blob) return first_blob_path, last_blob_path diff --git a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py index 94d63e5319f..076fc70bbc9 100644 --- a/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py +++ b/python/llm/src/ipex_llm/transformers/npu_pipeline_model/qwen.py @@ -18,8 +18,9 @@ import torch import numpy as np import os -from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead, \ - obtain_weight_from_single_layer, obtain_qkv_bias_from_single_layer +from .common import update_names_of_IR_and_export_blob, LowBitLLMLMHead, \ + obtain_weight_from_single_layer, obtain_qkv_bias_from_single_layer, \ + obtain_embedding_from_model from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead @@ -107,63 +108,10 @@ def convert_lm_head_and_embedding(model, temp_dir, weight_dir, bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin") weight.tofile(bin_file) - if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"): - embedding_layer = model.model.embed_tokens - new_embedding = LLMEmbedding( - vocab_size=model.config.vocab_size, - embedding_dim=model.config.hidden_size, - embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), - padding_idx=model.config.pad_token_id, - dtype=np.float16, - input_length=1, - ) - if convert_model: - bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") - embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) - first_blob_path = True - else: - first_blob_path = update_names_of_IR_and_export_blob(new_embedding, f"embedding", - temp_dir, keep_ir=keep_ir, - compile_blob=compile_blob) - os.remove(os.path.join(temp_dir, "embedding.bin")) - else: - # transformers >= 4.45.0 - from .llama import Llama32Embedding, Llama32PostEmbedding - embedding_layer = model.model.embed_tokens - new_embedding = Llama32Embedding( - vocab_size=model.config.vocab_size, - embedding_dim=model.config.hidden_size, - embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(), - padding_idx=model.config.pad_token_id, - inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16), - attention_scaling=model.model.rotary_emb.attention_scaling, - dtype=np.float16, - ) - if convert_model: - bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin") - embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file) - first_blob_path = None - # save embedding post module - inv_freq = model.model.rotary_emb.inv_freq.to(torch.float16) - attention_scaling = model.model.rotary_emb.attention_scaling - embedding_post = Llama32PostEmbedding(inv_freq=inv_freq, - attention_scaling=attention_scaling, - input_len=1) - update_names_of_IR_and_export_blob(embedding_post, "embedding_post", - temp_dir, keep_ir=keep_ir, compile_blob=compile_blob) - embedding_post_prefill = Llama32PostEmbedding(inv_freq=inv_freq, - attention_scaling=attention_scaling, - input_len=max_prompt_len) - update_names_of_IR_and_export_blob(embedding_post_prefill, - "embedding_post_prefill", - temp_dir, keep_ir=keep_ir, compile_blob=compile_blob) - os.remove(os.path.join(temp_dir, "embedding_post.bin")) - os.remove(os.path.join(temp_dir, "embedding_post_prefill.bin")) - else: - first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding", - temp_dir, keep_ir=keep_ir, - compile_blob=compile_blob) - os.remove(os.path.join(temp_dir, "embedding.bin")) + first_blob_path = obtain_embedding_from_model(model, convert_model, + temp_dir, weight_dir, + max_prompt_len, + keep_ir, compile_blob) return first_blob_path, last_blob_path