From 3047c88087b79f07470e238ced141f1f88422d6b Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 4 Jul 2024 09:37:35 +0800 Subject: [PATCH] add qwen2 npu support --- .../transformers/npu_models/convert.py | 27 +- .../ipex_llm/transformers/npu_models/qwen2.py | 305 ++++++++++++++++++ 2 files changed, 327 insertions(+), 5 deletions(-) create mode 100644 python/llm/src/ipex_llm/transformers/npu_models/qwen2.py diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 535f5ddce45..40efff4740a 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -71,15 +71,32 @@ def convert_forward(m, target_m, new_forward): def optimize_llm(model: torch.nn.Module): if model.config.model_type == "llama": from ipex_llm.transformers.npu_models.llama import merge_qkv - model.apply(merge_qkv) from ipex_llm.transformers.npu_models.llama import merge_mlp + model.apply(merge_qkv) model.apply(merge_mlp) + from ipex_llm.transformers.npu_models.llama import llama_model_forward - from transformers.models.llama.modeling_llama import LlamaModel - convert_forward(model, LlamaModel, llama_model_forward) from ipex_llm.transformers.npu_models.llama import llama_attention_forward - from transformers.models.llama.modeling_llama import LlamaAttention - convert_forward(model, LlamaAttention, llama_attention_forward) from ipex_llm.transformers.npu_models.llama import llama_mlp_forward + from transformers.models.llama.modeling_llama import LlamaModel + from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import LlamaMLP + convert_forward(model, LlamaModel, llama_model_forward) + convert_forward(model, LlamaAttention, llama_attention_forward) convert_forward(model, LlamaMLP, llama_mlp_forward) + + elif model.config.model_type == "qwen2": + from ipex_llm.transformers.npu_models.qwen2 import merge_qkv + from ipex_llm.transformers.npu_models.qwen2 import merge_mlp + model.apply(merge_qkv) + model.apply(merge_mlp) + + from ipex_llm.transformers.npu_models.qwen2 import qwen2_model_forward + from ipex_llm.transformers.npu_models.qwen2 import qwen2_attention_forward + from ipex_llm.transformers.npu_models.qwen2 import qwen2_mlp_forward + from transformers.models.qwen2.modeling_qwen2 import Qwen2Model + from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention + from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP + convert_forward(model, Qwen2Model, qwen2_model_forward) + convert_forward(model, Qwen2Attention, qwen2_attention_forward) + convert_forward(model, Qwen2MLP, qwen2_mlp_forward) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2.py new file mode 100644 index 00000000000..ef811bab4a1 --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2.py @@ -0,0 +1,305 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/qwen2/modeling_qwen2.py +# which is licensed under Apache License 2.0: +# +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import math +from typing import Optional, Tuple, Union, List + +import torch + +from ipex_llm.transformers.npu_models.common import merge_linear +from ipex_llm.transformers.kv import DynamicNormalCache +from ipex_llm.utils.common import invalidInputError + +from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP +from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv +from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa +from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.cache_utils import Cache + + +def merge_qkv(module: torch.nn.Module): + if isinstance(module, Qwen2Attention): + qkv_proj = merge_linear([ + module.q_proj, + module.k_proj, + module.v_proj + ]) + module.qkv_proj = qkv_proj + del module.q_proj, module.k_proj, module.v_proj + + +def merge_mlp(module: torch.nn.Module): + if isinstance(module, Qwen2MLP): + gate_up_proj = merge_linear([ + module.gate_proj, + module.up_proj, + ]) + module.gate_up_proj = gate_up_proj + del module.gate_proj, module.up_proj + + +def qwen2_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +): + output_attentions = output_attentions if output_attentions is not None else \ + self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else + self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + invalidInputError(False, + "You cannot specify both decoder_input_ids and " + "decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + invalidInputError(False, + "You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + past_key_values_length = 0 + + # ipex-llm changes start + if use_cache and not isinstance(past_key_values, DynamicNormalCache): + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + # ipex-llm changes end + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, + dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + flash_attn_2 = self._attn_implementation == "flash_attention_2" + if attention_mask is not None and flash_attn_2 and use_cache: + + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + invalidInputError( + False, + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Qwen2." + " Make sure to call `tokenizer.padding_side = 'left'` before tokenizing " + "the input. " + ) + + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and + 0 in attention_mask) else None + elif self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # ipex-llm changes start + next_cache = next_decoder_cache + # ipex-llm changes end + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, + all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +def qwen2_attention_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) + qkv = qkv.transpose(1, 2) + query_states, key_states, value_states = qkv.split([self.num_heads, + self.num_key_value_heads, + self.num_key_value_heads], dim=1) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, None) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = None + if query_states.size(2) == key_states.size(2): + # first token + from intel_npu_acceleration_library.functional import scaled_dot_product_attention + attn_output = scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=q_len > 1 and bsz == 1, + ) + else: + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value + + +def qwen2_mlp_forward(self, x): + gate_up_proj = self.gate_up_proj(x) + gate_proj, up_proj = gate_up_proj.chunk(2, dim=-1) + down_proj = self.down_proj(self.act_fn(gate_proj) * up_proj) + return down_proj