Skip to content

Commit 93455aa

Browse files
authored
fix minicpm V 2.6 repeat output (#11753)
1 parent 7e917d6 commit 93455aa

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

python/llm/src/ipex_llm/transformers/models/minicpmv.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,30 @@
1515
#
1616

1717

18+
import torch
19+
from transformers.generation.logits_process import RepetitionPenaltyLogitsProcessor
20+
21+
22+
# todo
23+
def patched_repetition_penalty_call(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
24+
score = torch.gather(scores, 1, input_ids)
25+
26+
# if score < 0 then repetition penalty has to be
27+
# multiplied to reduce the token probabilities
28+
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
29+
30+
# ipex llm changes start: call scatter on CPU
31+
device = scores.device
32+
scores = scores.to('cpu')
33+
input_ids = input_ids.to('cpu')
34+
score = score.to('cpu')
35+
scores.scatter_(1, input_ids, score)
36+
scores = scores.to(device)
37+
# ipex llm changes end
38+
39+
return scores
40+
41+
1842
def minicpmv_generate_wrapper(origin_generate):
1943
def generate(
2044
self,
@@ -30,8 +54,7 @@ def generate(
3054
decode_text=False,
3155
**kwargs
3256
):
33-
if kwargs.get("repetition_penalty", None) is not None:
34-
kwargs["repetition_penalty"] = 1
57+
RepetitionPenaltyLogitsProcessor.__call__ = patched_repetition_penalty_call
3558
return origin_generate(
3659
self=self,
3760
input_ids=input_ids,

0 commit comments

Comments
 (0)