Skip to content

Commit 4abf677

Browse files
committed
[Fix] Regular Linting
- Added missing comma in AVAILABLE_MODELS for consistency. - Reordered import statements in vora.py for better readability. - Simplified input_data generation by condensing method calls into a single line. - Ensured default generation parameters are set correctly in the VoRA class.
1 parent 8b171bb commit 4abf677

File tree

3 files changed

+25
-34
lines changed

3 files changed

+25
-34
lines changed

lmms_eval/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
"videochat_flash": "VideoChat_Flash",
7171
"whisper": "Whisper",
7272
"whisper_vllm": "WhisperVllm",
73-
"vora": "VoRA"
73+
"vora": "VoRA",
7474
}
7575

7676

lmms_eval/models/openai_compatible.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,12 @@ def generate_until(self, requests) -> List[str]:
189189
payload["max_output_tokens"] = gen_kwargs["max_new_tokens"]
190190
payload["temperature"] = gen_kwargs["temperature"]
191191

192-
if 'o1' in self.model_version or 'o3' in self.model_version:
193-
del payload['max_output_tokens']
194-
del payload['temperature']
195-
payload['reasoning_effort'] = 'medium'
196-
payload['response_format'] = {'type': 'text'}
197-
192+
if "o1" in self.model_version or "o3" in self.model_version:
193+
del payload["max_output_tokens"]
194+
del payload["temperature"]
195+
payload["reasoning_effort"] = "medium"
196+
payload["response_format"] = {"type": "text"}
197+
198198
for attempt in range(self.max_retries):
199199
try:
200200
response = self.client.chat.completions.create(**payload)

lmms_eval/models/vora.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import base64
2-
from io import BytesIO
32
import os
4-
from typing import List, Optional, Tuple, Union
53
import uuid
64
import warnings
5+
from io import BytesIO
6+
from typing import List, Optional, Tuple, Union
77

88
warnings.simplefilter("ignore", category=DeprecationWarning)
99
warnings.filterwarnings("ignore")
1010

11+
import torch
1112
from accelerate import Accelerator, DistributedType
1213
from loguru import logger as eval_logger
1314
from PIL import Image
14-
import torch
1515
from tqdm import tqdm
16-
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
16+
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
1717

1818
from lmms_eval import utils
1919
from lmms_eval.api.instance import Instance
@@ -208,16 +208,10 @@ def _collate(x):
208208
message.append({"role": "user", "content": [{"type": "text", "text": context}]})
209209
else:
210210
message.append({"role": "user", "content": [{"type": "text", "text": context}]})
211-
211+
212212
messages.append(message)
213213

214-
input_data = self.processor.apply_chat_template(
215-
messages,
216-
add_generation_prompt=True,
217-
tokenize=True,
218-
return_tensors='pt',
219-
return_dict=True
220-
).to(self.device).to(self.dtype)
214+
input_data = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True).to(self.device).to(self.dtype)
221215

222216
# preconfigure gen_kwargs with defaults
223217
if "max_new_tokens" not in gen_kwargs:
@@ -228,21 +222,18 @@ def _collate(x):
228222
gen_kwargs["top_p"] = None
229223
if "num_beams" not in gen_kwargs:
230224
gen_kwargs["num_beams"] = 1
231-
if 'do_sample' not in gen_kwargs:
232-
gen_kwargs['do_sample'] = False
233-
if 'repetition_penalty' not in gen_kwargs:
234-
gen_kwargs['repetition_penalty'] = 1.0
235-
if 'length_penalty' not in gen_kwargs:
236-
gen_kwargs['length_penalty'] = 0.9
237-
if 'min_length' not in gen_kwargs:
238-
gen_kwargs['min_length'] = 4
239-
if 'eos_token_id' not in gen_kwargs:
240-
gen_kwargs['eos_token_id'] = self.tokenizer.eos_token_id
241-
242-
prompt = self.processor.apply_chat_template(
243-
messages,
244-
add_generation_prompt=True
245-
)
225+
if "do_sample" not in gen_kwargs:
226+
gen_kwargs["do_sample"] = False
227+
if "repetition_penalty" not in gen_kwargs:
228+
gen_kwargs["repetition_penalty"] = 1.0
229+
if "length_penalty" not in gen_kwargs:
230+
gen_kwargs["length_penalty"] = 0.9
231+
if "min_length" not in gen_kwargs:
232+
gen_kwargs["min_length"] = 4
233+
if "eos_token_id" not in gen_kwargs:
234+
gen_kwargs["eos_token_id"] = self.tokenizer.eos_token_id
235+
236+
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
246237
cont = self.model.generate(
247238
input_data,
248239
**gen_kwargs,

0 commit comments

Comments
 (0)