Skip to content

Commit

Permalink
optimize phi3 again: use quantize kv if possible (#10953)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored May 7, 2024
1 parent aa2fa9f commit c801c37
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@
rotate_half, should_use_fuse_rope,
apply_rotary_pos_emb_cache_freq_xpu
)
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU, use_new_esimd_sdp_fp16
from ipex_llm.transformers.kv import DynamicNormalCache
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
from ipex_llm.transformers.models.utils import use_new_esimd_sdp_fp16, use_quantize_kv_cache
from ipex_llm.transformers.models.utils import use_sdp_fp8, restore_fp8_kv_cache
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache

from typing import Optional, Tuple, List
from transformers.models.phi.modeling_phi import repeat_kv
Expand Down Expand Up @@ -93,10 +95,18 @@ def attention_forward(
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)

if use_new_esimd_sdp_fp16(q_len, kv_seq_len, self.head_dim, query_states):
if (isinstance(past_key_value, DynamicFp8Cache) and
use_sdp_fp8(q_len, kv_seq_len, query_states)):
import linear_q4_0
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, attention_mask)
elif (isinstance(past_key_value, DynamicNormalCache) and
use_new_esimd_sdp_fp16(q_len, kv_seq_len, self.head_dim, query_states)):
import linear_q4_0
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
else:
if isinstance(past_key_value, DynamicFp8Cache):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
Expand Down Expand Up @@ -179,8 +189,12 @@ def model_forward(
):
# IPEX-LLM OPT: kv cache but no sdp (its head_dim 96, cannot use sdp)
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_quantize_kv = (use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids) and
self.config.hidden_size // self.config.num_attention_heads in [64, 128])
if use_cache:
if not isinstance(past_key_values, DynamicNormalCache):
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
return origin_model_forward(
self=self,
Expand Down

0 comments on commit c801c37

Please sign in to comment.