Skip to content

Commit

Permalink
fix cmt
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang committed Feb 7, 2025
1 parent 289356c commit 1fd0b6c
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 217 deletions.
161 changes: 161 additions & 0 deletions python/llm/src/ipex_llm/transformers/npu_pipeline_model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,105 @@ def __init__(
self.compile()


class Llama32Embedding(NNFactory):
def __init__(
self,
vocab_size,
embedding_dim,
embedding_weight,
padding_idx,
inv_freq,
attention_scaling,
dtype, # fp16
device: str = "NPU",
):
super().__init__(False, device)
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.attention_scaling = attention_scaling
self.dtype = dtype

# define input
weight = self.constant(embedding_weight)
input = self.parameter((1, 1), dtype=np.int32)
position_ids = self.parameter((1, 1), dtype=np.int64)
inv_freq = self.constant(inv_freq)

# embed_tokens module
if padding_idx == -1:
padding_idx += vocab_size

axis_node = self.constant(np.array([0], dtype=np.int64))
if padding_idx is not None:
masked_embeddings = np.ones(weight.shape, dtype=np.float16)
masked_embeddings[padding_idx, :] = 0.0 # mask

node_mask = self.constant(masked_embeddings)
node_masked_w = self.eltwise_mul(weight, node_mask)
res = self.gather(node_masked_w, input, axis_node, 0)
else:
res = self.gather(weight, input, axis_node, 0)

# rotary_emb module
inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1))
position_ids = self.reshape(position_ids, (1, 1, 1))
freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq),
self.convert_to_fp32(position_ids))
freqs = self.transpose(freqs, [0, 2, 1])
emb = self.concat(freqs, freqs, axis=2)
cos = self.cos(emb)
sin = self.sin(emb)
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling

# define outputs
res = self.convert_to_fp16(res)
cos = self.convert_to_fp32(cos)
sin = self.convert_to_fp32(sin)

print("start compiling")
self.compile()


class Llama32PostEmbedding(NNFactory):
def __init__(
self,
inv_freq,
attention_scaling,
input_len: int = 1,
device: str = "NPU",
):
super().__init__(False, device)
self.attention_scaling = attention_scaling

# define input
position_ids = self.parameter((1, input_len), dtype=np.int64)
inv_freq = self.constant(inv_freq)

# rotary_emb module
inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1))
position_ids = self.reshape(position_ids, (1, 1, input_len))
freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq),
self.convert_to_fp32(position_ids))
freqs = self.transpose(freqs, [0, 2, 1])
emb = self.concat(freqs, freqs, axis=2)
cos = self.cos(emb)
sin = self.sin(emb)
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
if input_len > 1:
cos = self.unsqueeze(cos, [1])
sin = self.unsqueeze(sin, [1])

# define outputs
cos = self.convert_to_fp32(cos)
sin = self.convert_to_fp32(sin)

print("start compiling")
self.compile()


def obtain_weight_from_single_layer(attn_layer, mlp_layer):
weights = []
if hasattr(attn_layer, "q_proj_dq_list"):
Expand Down Expand Up @@ -216,3 +315,65 @@ def obtain_qkv_bias_from_single_layer(attn_layer):
k_bias = attn_layer.k_proj.bias.to(torch.float16)
v_bias = attn_layer.v_proj.bias.to(torch.float16)
return q_bias, k_bias, v_bias


def obtain_embedding_from_model(model, convert_model, temp_dir, weight_dir,
max_prompt_len, keep_ir, compile_blob):
if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"):
# llama-2-7B & llama-3-8B
embedding_layer = model.model.embed_tokens
new_embedding = LLMEmbedding(
vocab_size=model.config.vocab_size,
embedding_dim=model.config.hidden_size,
embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
padding_idx=model.config.pad_token_id,
dtype=np.float16,
)
if convert_model:
bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
first_blob_path = None
else:
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
temp_dir, keep_ir=keep_ir,
compile_blob=compile_blob)
os.remove(os.path.join(temp_dir, "embedding.bin"))
else:
# llama-3.2-3B & llama-3.2-1B
# for transformers >= 4.45.0
embedding_layer = model.model.embed_tokens
new_embedding = Llama32Embedding(
vocab_size=model.config.vocab_size,
embedding_dim=model.config.hidden_size,
embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
padding_idx=model.config.pad_token_id,
inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16),
attention_scaling=model.model.rotary_emb.attention_scaling,
dtype=np.float16,
)
if convert_model:
bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
first_blob_path = None
# save embedding post module
inv_freq = model.model.rotary_emb.inv_freq.to(torch.float16)
attention_scaling = model.model.rotary_emb.attention_scaling
embedding_post = Llama32PostEmbedding(inv_freq=inv_freq,
attention_scaling=attention_scaling,
input_len=1)
update_names_of_IR_and_export_blob(embedding_post, "embedding_post",
temp_dir, keep_ir=keep_ir, compile_blob=compile_blob)
embedding_post_prefill = Llama32PostEmbedding(inv_freq=inv_freq,
attention_scaling=attention_scaling,
input_len=max_prompt_len)
update_names_of_IR_and_export_blob(embedding_post_prefill,
"embedding_post_prefill",
temp_dir, keep_ir=keep_ir, compile_blob=compile_blob)
os.remove(os.path.join(temp_dir, "embedding_post.bin"))
os.remove(os.path.join(temp_dir, "embedding_post_prefill.bin"))
else:
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
temp_dir, keep_ir=keep_ir,
compile_blob=compile_blob)
os.remove(os.path.join(temp_dir, "embedding.bin"))
return first_blob_path
164 changes: 6 additions & 158 deletions python/llm/src/ipex_llm/transformers/npu_pipeline_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,108 +18,8 @@
import torch
import numpy as np
import os
from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead, \
obtain_weight_from_single_layer
from intel_npu_acceleration_library.backend.factory import NNFactory


class Llama32Embedding(NNFactory):
def __init__(
self,
vocab_size,
embedding_dim,
embedding_weight,
padding_idx,
inv_freq,
attention_scaling,
dtype, # fp16
device: str = "NPU",
):
super().__init__(False, device)
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.attention_scaling = attention_scaling
self.dtype = dtype

# define input
weight = self.constant(embedding_weight)
input = self.parameter((1, 1), dtype=np.int32)
position_ids = self.parameter((1, 1), dtype=np.int64)
inv_freq = self.constant(inv_freq)

# embed_tokens module
if padding_idx == -1:
padding_idx += vocab_size

axis_node = self.constant(np.array([0], dtype=np.int64))
if padding_idx is not None:
masked_embeddings = np.ones(weight.shape, dtype=np.float16)
masked_embeddings[padding_idx, :] = 0.0 # mask

node_mask = self.constant(masked_embeddings)
node_masked_w = self.eltwise_mul(weight, node_mask)
res = self.gather(node_masked_w, input, axis_node, 0)
else:
res = self.gather(weight, input, axis_node, 0)

# rotary_emb module
inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1))
position_ids = self.reshape(position_ids, (1, 1, 1))
freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq),
self.convert_to_fp32(position_ids))
freqs = self.transpose(freqs, [0, 2, 1])
emb = self.concat(freqs, freqs, axis=2)
cos = self.cos(emb)
sin = self.sin(emb)
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling

# define outputs
res = self.convert_to_fp16(res)
cos = self.convert_to_fp32(cos)
sin = self.convert_to_fp32(sin)

print("start compiling")
self.compile()


class Llama32PostEmbedding(NNFactory):
def __init__(
self,
inv_freq,
attention_scaling,
input_len: int = 1,
device: str = "NPU",
):
super().__init__(False, device)
self.attention_scaling = attention_scaling

# define input
position_ids = self.parameter((1, input_len), dtype=np.int64)
inv_freq = self.constant(inv_freq)

# rotary_emb module
inv_freq = self.reshape(inv_freq, (1, inv_freq.shape[0], 1))
position_ids = self.reshape(position_ids, (1, 1, input_len))
freqs = self.eltwise_mul(self.convert_to_fp32(inv_freq),
self.convert_to_fp32(position_ids))
freqs = self.transpose(freqs, [0, 2, 1])
emb = self.concat(freqs, freqs, axis=2)
cos = self.cos(emb)
sin = self.sin(emb)
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
if input_len > 1:
cos = self.unsqueeze(cos, [1])
sin = self.unsqueeze(sin, [1])

# define outputs
cos = self.convert_to_fp32(cos)
sin = self.convert_to_fp32(sin)

print("start compiling")
self.compile()
from .common import update_names_of_IR_and_export_blob, LowBitLLMLMHead, \
obtain_weight_from_single_layer, obtain_embedding_from_model


def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
Expand Down Expand Up @@ -197,62 +97,10 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
weight.tofile(bin_file)

if hasattr(model.model.layers[0].self_attn.rotary_emb, "cos_cached"):
# llama-2-7B & llama-3-8B
embedding_layer = model.model.embed_tokens
new_embedding = LLMEmbedding(
vocab_size=model.config.vocab_size,
embedding_dim=model.config.hidden_size,
embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
padding_idx=model.config.pad_token_id,
dtype=np.float16,
)
if convert_model:
bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
first_blob_path = None
else:
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
temp_dir, keep_ir=keep_ir,
compile_blob=compile_blob)
os.remove(os.path.join(temp_dir, "embedding.bin"))
else:
# llama-3.2-3B & llama-3.2-1B
embedding_layer = model.model.embed_tokens
new_embedding = Llama32Embedding(
vocab_size=model.config.vocab_size,
embedding_dim=model.config.hidden_size,
embedding_weight=embedding_layer.weight.to(torch.float16).detach().numpy(),
padding_idx=model.config.pad_token_id,
inv_freq=model.model.rotary_emb.inv_freq.to(torch.float16),
attention_scaling=model.model.rotary_emb.attention_scaling,
dtype=np.float16,
)
if convert_model:
bin_file = os.path.join(weight_dir, f"model_embedding_input_0.bin")
embedding_layer.weight.to(torch.float16).detach().numpy().tofile(bin_file)
first_blob_path = None
# save embedding post module
inv_freq = model.model.rotary_emb.inv_freq.to(torch.float16)
attention_scaling = model.model.rotary_emb.attention_scaling
embedding_post = Llama32PostEmbedding(inv_freq=inv_freq,
attention_scaling=attention_scaling,
input_len=1)
update_names_of_IR_and_export_blob(embedding_post, "embedding_post",
temp_dir, keep_ir=keep_ir, compile_blob=compile_blob)
embedding_post_prefill = Llama32PostEmbedding(inv_freq=inv_freq,
attention_scaling=attention_scaling,
input_len=max_prompt_len)
update_names_of_IR_and_export_blob(embedding_post_prefill,
"embedding_post_prefill",
temp_dir, keep_ir=keep_ir, compile_blob=compile_blob)
os.remove(os.path.join(temp_dir, "embedding_post.bin"))
os.remove(os.path.join(temp_dir, "embedding_post_prefill.bin"))
else:
first_blob_path = update_names_of_IR_and_export_blob(new_embedding, "embedding",
temp_dir, keep_ir=keep_ir,
compile_blob=compile_blob)
os.remove(os.path.join(temp_dir, "embedding.bin"))
first_blob_path = obtain_embedding_from_model(model, convert_model,
temp_dir, weight_dir,
max_prompt_len,
keep_ir, compile_blob)

return first_blob_path, last_blob_path

Expand Down
Loading

0 comments on commit 1fd0b6c

Please sign in to comment.