File tree Expand file tree Collapse file tree 1 file changed +25
-2
lines changed
python/llm/src/ipex_llm/transformers/models Expand file tree Collapse file tree 1 file changed +25
-2
lines changed Original file line number Diff line number Diff line change 15
15
#
16
16
17
17
18
+ import torch
19
+ from transformers .generation .logits_process import RepetitionPenaltyLogitsProcessor
20
+
21
+
22
+ # todo
23
+ def patched_repetition_penalty_call (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ):
24
+ score = torch .gather (scores , 1 , input_ids )
25
+
26
+ # if score < 0 then repetition penalty has to be
27
+ # multiplied to reduce the token probabilities
28
+ score = torch .where (score < 0 , score * self .penalty , score / self .penalty )
29
+
30
+ # ipex llm changes start: call scatter on CPU
31
+ device = scores .device
32
+ scores = scores .to ('cpu' )
33
+ input_ids = input_ids .to ('cpu' )
34
+ score = score .to ('cpu' )
35
+ scores .scatter_ (1 , input_ids , score )
36
+ scores = scores .to (device )
37
+ # ipex llm changes end
38
+
39
+ return scores
40
+
41
+
18
42
def minicpmv_generate_wrapper (origin_generate ):
19
43
def generate (
20
44
self ,
@@ -30,8 +54,7 @@ def generate(
30
54
decode_text = False ,
31
55
** kwargs
32
56
):
33
- if kwargs .get ("repetition_penalty" , None ) is not None :
34
- kwargs ["repetition_penalty" ] = 1
57
+ RepetitionPenaltyLogitsProcessor .__call__ = patched_repetition_penalty_call
35
58
return origin_generate (
36
59
self = self ,
37
60
input_ids = input_ids ,
You can’t perform that action at this time.
0 commit comments