Skip to content

Commit

Permalink
support chatglm4 in lookup (#11855)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyita authored Aug 21, 2024
1 parent 0236de3 commit cc27321
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions python/llm/src/ipex_llm/transformers/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,12 +493,19 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa
for k, v in past_key_values
]
elif self.config.model_type == "chatglm":
# for chatglm, cache shape is [sl, bs, nh, hn]
past_key_values = [
(k[:-(new_cache_size), :, :, :],
v[:-(new_cache_size), :, :, :])
for k, v in past_key_values
]
if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'):
past_key_values = [
(k[:, :, :-(new_cache_size), :],
v[:, :, :-(new_cache_size), :])
for k, v in past_key_values
]
else:
# for chatglm, cache shape is [sl, bs, nh, hn]
past_key_values = [
(k[:-(new_cache_size), :, :, :],
v[:-(new_cache_size), :, :, :])
for k, v in past_key_values
]
elif self.config.model_type in ["baichuan", "gptj"]:
past_key_values = [
(k[:, :, :-(new_cache_size), :],
Expand Down

0 comments on commit cc27321

Please sign in to comment.