Skip to content

Commit 4b9c57c

Browse files
authored
Support compress kv with lookahead (#11752)
* support compress kv with lookahead * enough kv miss param
1 parent 93455aa commit 4b9c57c

File tree

7 files changed

+32
-12
lines changed

7 files changed

+32
-12
lines changed

python/llm/src/ipex_llm/transformers/models/chatglm2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,9 @@ def chatglm2_attention_forward(
287287
else:
288288
from transformers.configuration_utils import PretrainedConfig
289289
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
290-
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1)
290+
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
291+
self.layer_number - 1,
292+
q_len)
291293
key_states, value_states = past_key_value.update(
292294
key_states, value_states, self.layer_number - 1,
293295
query_states, attention_mask, n_head // n_kv_head,

python/llm/src/ipex_llm/transformers/models/chatglm4.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,9 @@ def chatglm4_attention_forward(
213213
else:
214214
from transformers.configuration_utils import PretrainedConfig
215215
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
216-
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1)
216+
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
217+
self.layer_number - 1,
218+
q_len)
217219
key_states, value_states = past_key_value.update(
218220
key_states, value_states, self.layer_number - 1,
219221
query_states, attention_mask, n_head // n_kv_head,

python/llm/src/ipex_llm/transformers/models/minicpm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def minicpm_attention_forward_original(
127127
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
128128

129129
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
130-
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
130+
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
131+
seq_len=q_len)
131132
no_tp = not self.config.pretraining_tp > 1
132133
decoding_fast_path = use_decoding_fast_path(self.q_proj,
133134
use_fuse_rope,
@@ -408,7 +409,8 @@ def minicpm_attention_forward_quantized(
408409
bsz, q_len, _ = hidden_states.size()
409410
device = hidden_states.device
410411
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
411-
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
412+
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
413+
seq_len=q_len)
412414
no_tp = not self.config.pretraining_tp > 1
413415
decoding_fast_path = use_decoding_fast_path(self.q_proj,
414416
use_fuse_rope,
@@ -821,7 +823,8 @@ def minicpm_attention_forward_original_4_39(
821823
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
822824

823825
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
824-
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
826+
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
827+
seq_len=q_len)
825828
no_tp = not self.config.pretraining_tp > 1
826829
decoding_fast_path = use_decoding_fast_path(self.q_proj,
827830
use_fuse_rope,

python/llm/src/ipex_llm/transformers/models/mistral.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,8 @@ def mistral_attention_forward_4_36_quantized(
699699
original_dtype = hidden_states.dtype
700700

701701
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
702-
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
702+
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
703+
seq_len=q_len)
703704
decoding_fast_path = use_decoding_fast_path(self.q_proj,
704705
use_fuse_rope,
705706
enough_kv_room,
@@ -916,7 +917,9 @@ def mistral_attention_forward_4_36_original(
916917
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
917918

918919
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
919-
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
920+
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
921+
self.layer_idx,
922+
q_len)
920923
decoding_fast_path = use_decoding_fast_path(self.q_proj,
921924
use_fuse_rope,
922925
enough_kv_room,
@@ -1172,7 +1175,8 @@ def mistral_attention_forward_4_39_original(
11721175
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
11731176

11741177
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
1175-
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
1178+
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
1179+
q_len)
11761180
decoding_fast_path = use_decoding_fast_path(self.q_proj,
11771181
use_fuse_rope,
11781182
enough_kv_room,

python/llm/src/ipex_llm/transformers/models/phi3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def attention_forward(
135135
if past_key_value is not None:
136136
# [CompressKV]
137137
if use_compresskv:
138-
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
138+
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
139+
self.layer_idx,
140+
q_len)
139141
key_states, value_states = past_key_value.update(
140142
key_states, value_states, self.layer_idx,
141143
query_states, attention_mask, self.num_key_value_groups,

python/llm/src/ipex_llm/transformers/models/qwen2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,8 @@ def qwen2_attention_forward(
440440
if past_key_value is not None:
441441
# [CompressKV]
442442
if use_compresskv:
443-
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
443+
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
444+
q_len)
444445
key_states, value_states = past_key_value.update(
445446
key_states, value_states, self.layer_idx,
446447
query_states, attention_mask, self.num_key_value_groups,
@@ -471,6 +472,8 @@ def qwen2_attention_forward(
471472
is_causal=True).to(hidden_states.dtype)
472473
elif use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
473474
import xe_addons
475+
if use_compresskv:
476+
attention_mask = None
474477
if isinstance(past_key_value, DynamicFp8Cache):
475478
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
476479
attention_mask)

python/llm/src/ipex_llm/transformers/speculative.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,12 +460,16 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l
460460

461461
def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False):
462462
if version.parse(trans_version) >= version.parse("4.36.0"):
463-
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
464-
if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache)):
463+
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache,\
464+
DynamicCompressCache
465+
if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache,
466+
DynamicCompressCache)):
465467
if hasattr(past_key_values, "_seen_tokens"):
466468
past_key_values._seen_tokens -= new_cache_size
467469
else:
468470
past_key_values.seen_tokens -= new_cache_size
471+
if isinstance(past_key_values, DynamicCompressCache):
472+
past_key_values.real_kv_len -= new_cache_size
469473

470474
for i, k in enumerate(past_key_values.key_cache):
471475
past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :]

0 commit comments

Comments
 (0)