1
1
import base64
2
- from io import BytesIO
3
2
import os
4
- from typing import List , Optional , Tuple , Union
5
3
import uuid
6
4
import warnings
5
+ from io import BytesIO
6
+ from typing import List , Optional , Tuple , Union
7
7
8
8
warnings .simplefilter ("ignore" , category = DeprecationWarning )
9
9
warnings .filterwarnings ("ignore" )
10
10
11
+ import torch
11
12
from accelerate import Accelerator , DistributedType
12
13
from loguru import logger as eval_logger
13
14
from PIL import Image
14
- import torch
15
15
from tqdm import tqdm
16
- from transformers import AutoTokenizer , AutoModelForCausalLM , AutoProcessor
16
+ from transformers import AutoModelForCausalLM , AutoProcessor , AutoTokenizer
17
17
18
18
from lmms_eval import utils
19
19
from lmms_eval .api .instance import Instance
@@ -208,16 +208,10 @@ def _collate(x):
208
208
message .append ({"role" : "user" , "content" : [{"type" : "text" , "text" : context }]})
209
209
else :
210
210
message .append ({"role" : "user" , "content" : [{"type" : "text" , "text" : context }]})
211
-
211
+
212
212
messages .append (message )
213
213
214
- input_data = self .processor .apply_chat_template (
215
- messages ,
216
- add_generation_prompt = True ,
217
- tokenize = True ,
218
- return_tensors = 'pt' ,
219
- return_dict = True
220
- ).to (self .device ).to (self .dtype )
214
+ input_data = self .processor .apply_chat_template (messages , add_generation_prompt = True , tokenize = True , return_tensors = "pt" , return_dict = True ).to (self .device ).to (self .dtype )
221
215
222
216
# preconfigure gen_kwargs with defaults
223
217
if "max_new_tokens" not in gen_kwargs :
@@ -228,21 +222,18 @@ def _collate(x):
228
222
gen_kwargs ["top_p" ] = None
229
223
if "num_beams" not in gen_kwargs :
230
224
gen_kwargs ["num_beams" ] = 1
231
- if 'do_sample' not in gen_kwargs :
232
- gen_kwargs ['do_sample' ] = False
233
- if 'repetition_penalty' not in gen_kwargs :
234
- gen_kwargs ['repetition_penalty' ] = 1.0
235
- if 'length_penalty' not in gen_kwargs :
236
- gen_kwargs ['length_penalty' ] = 0.9
237
- if 'min_length' not in gen_kwargs :
238
- gen_kwargs ['min_length' ] = 4
239
- if 'eos_token_id' not in gen_kwargs :
240
- gen_kwargs ['eos_token_id' ] = self .tokenizer .eos_token_id
241
-
242
- prompt = self .processor .apply_chat_template (
243
- messages ,
244
- add_generation_prompt = True
245
- )
225
+ if "do_sample" not in gen_kwargs :
226
+ gen_kwargs ["do_sample" ] = False
227
+ if "repetition_penalty" not in gen_kwargs :
228
+ gen_kwargs ["repetition_penalty" ] = 1.0
229
+ if "length_penalty" not in gen_kwargs :
230
+ gen_kwargs ["length_penalty" ] = 0.9
231
+ if "min_length" not in gen_kwargs :
232
+ gen_kwargs ["min_length" ] = 4
233
+ if "eos_token_id" not in gen_kwargs :
234
+ gen_kwargs ["eos_token_id" ] = self .tokenizer .eos_token_id
235
+
236
+ prompt = self .processor .apply_chat_template (messages , add_generation_prompt = True )
246
237
cont = self .model .generate (
247
238
input_data ,
248
239
** gen_kwargs ,
0 commit comments