Skip to content

Commit af395ae

Browse files
yinanheheyinan
andauthored
[Feat] Add support for evaluation of InternVideo2-Chat && Fix evaluation for mvbench (#280)
* [add] add internvideo2 support && change mvbench to video branch * [add] answer_prompt of internvideo2 * [add] change video type of internvideo2 * [fix] update template of mvbench * [reformat] * [fix] generate_until_multi_round * [Feat] videochat2 support --------- Co-authored-by: heyinan <[email protected]>
1 parent 7c2d91c commit af395ae

26 files changed

+933
-55
lines changed

lmms_eval/api/task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -937,7 +937,8 @@ def _download_from_youtube(path):
937937
if accelerator.is_main_process:
938938
force_download = dataset_kwargs.get("force_download", False)
939939
force_unzip = dataset_kwargs.get("force_unzip", False)
940-
cache_path = snapshot_download(repo_id=self.DATASET_PATH, repo_type="dataset", force_download=force_download, etag_timeout=60)
940+
revision = dataset_kwargs.get("revision", "main")
941+
cache_path = snapshot_download(repo_id=self.DATASET_PATH, revision=revision, repo_type="dataset", force_download=force_download, etag_timeout=60)
941942
zip_files = glob(os.path.join(cache_path, "**/*.zip"), recursive=True)
942943
tar_files = glob(os.path.join(cache_path, "**/*.tar*"), recursive=True)
943944

lmms_eval/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@
4444
"video_llava": "VideoLLaVA",
4545
"vila": "VILA",
4646
"xcomposer2_4KHD": "XComposer2_4KHD",
47+
"internvideo2": "InternVideo2",
4748
"xcomposer2d5": "XComposer2D5",
4849
"oryx": "Oryx",
50+
"videochat2": "VideoChat2",
4951
}
5052

5153

lmms_eval/models/internvideo2.py

Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
import logging
2+
import os
3+
from typing import List, Tuple
4+
5+
import decord
6+
import numpy as np
7+
import torch
8+
import torchvision.transforms as T
9+
from accelerate import Accelerator, DistributedType
10+
from decord import VideoReader, cpu
11+
12+
decord.bridge.set_bridge("torch")
13+
import torch.nn.functional as F
14+
from PIL import Image
15+
from tqdm import tqdm
16+
from transformers import AutoModel, AutoTokenizer
17+
18+
from lmms_eval.api.instance import Instance
19+
from lmms_eval.api.model import lmms
20+
from lmms_eval.api.registry import register_model
21+
22+
eval_logger = logging.getLogger("eval_logger")
23+
24+
25+
from datetime import timedelta
26+
27+
from accelerate.state import AcceleratorState
28+
from accelerate.utils import InitProcessGroupKwargs
29+
30+
DEFAULT_GEN_KWARGS = dict(
31+
num_beams=1,
32+
max_new_tokens=1024,
33+
do_sample=False,
34+
)
35+
36+
# def get_index(num_frames, num_segments):
37+
# seg_size = float(num_frames - 1) / num_segments
38+
# start = int(seg_size / 2)
39+
# offsets = np.array([
40+
# start + int(np.round(seg_size * idx)) for idx in range(num_segments)
41+
# ])
42+
# return offsets
43+
44+
45+
def get_index(max_frame, num_segments, fps, first_idx=0, bound=None):
46+
if bound:
47+
start, end = bound[0], bound[1]
48+
if start is None:
49+
start, end = -100000, 100000
50+
else:
51+
start, end = -100000, 100000
52+
start_idx = max(first_idx, round(start * fps))
53+
end_idx = min(round(end * fps), max_frame)
54+
seg_size = float(end_idx - start_idx) / num_segments
55+
frame_indices = np.array([int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)])
56+
return frame_indices
57+
58+
59+
def load_image(image_path, resolution=224, hd_num=6):
60+
image = Image.open(image_path).convert("RGB")
61+
image_tensor = T.PILToTensor()(image).unsqueeze(0)
62+
image_tensor = HD_transform_no_padding(image_tensor.float(), image_size=resolution, hd_num=hd_num)
63+
T_, C, H, W = image_tensor.shape
64+
65+
mean = (0.485, 0.456, 0.406)
66+
std = (0.229, 0.224, 0.225)
67+
68+
transform = T.Compose([T.Lambda(lambda x: x.float().div(255.0)), T.Normalize(mean, std)])
69+
image_tensor = transform(image_tensor).cuda()
70+
71+
sub_img = image_tensor.reshape(1, T_, 3, H // resolution, resolution, W // resolution, resolution).permute(0, 3, 5, 1, 2, 4, 6).reshape(-1, T_, 3, resolution, resolution).contiguous()
72+
73+
glb_img = F.interpolate(image_tensor.float(), size=(resolution, resolution), mode="bicubic", align_corners=False).to(sub_img.dtype).unsqueeze(0)
74+
75+
image_tensor = torch.cat([sub_img, glb_img]) # .unsqueeze(0)
76+
return image_tensor
77+
78+
79+
def load_video(video_path, num_segments=16, return_msg=False, resolution=224, hd_num=6, padding=False):
80+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
81+
num_frames = len(vr) - 1
82+
83+
frame_indices = get_index(max_frame=num_frames, num_segments=num_segments, fps=float(vr.get_avg_fps()), first_idx=0, bound=None)
84+
mean = (0.485, 0.456, 0.406)
85+
std = (0.229, 0.224, 0.225)
86+
87+
transform = T.Compose([T.Lambda(lambda x: x.float().div(255.0)), T.Normalize(mean, std)])
88+
89+
frames = vr.get_batch(frame_indices)
90+
frames = frames.permute(0, 3, 1, 2)
91+
92+
if padding:
93+
frames = HD_transform_padding(frames.float(), image_size=resolution, hd_num=hd_num)
94+
else:
95+
frames = HD_transform_no_padding(frames.float(), image_size=resolution, hd_num=hd_num)
96+
97+
frames = transform(frames)
98+
T_, C, H, W = frames.shape
99+
100+
sub_img = frames.reshape(1, T_, 3, H // resolution, resolution, W // resolution, resolution).permute(0, 3, 5, 1, 2, 4, 6).reshape(-1, T_, 3, resolution, resolution).contiguous()
101+
102+
glb_img = F.interpolate(frames.float(), size=(resolution, resolution), mode="bicubic", align_corners=False).to(sub_img.dtype).unsqueeze(0)
103+
104+
frames = torch.cat([sub_img, glb_img]).unsqueeze(0)
105+
106+
if return_msg:
107+
fps = float(vr.get_avg_fps())
108+
sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices])
109+
# " " should be added in the start and end
110+
msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds."
111+
return frames, msg
112+
else:
113+
return frames
114+
115+
116+
def HD_transform_padding(frames, image_size=224, hd_num=6):
117+
def _padding_224(frames):
118+
_, _, H, W = frames.shape
119+
tar = int(np.ceil(H / 224) * 224)
120+
top_padding = (tar - H) // 2
121+
bottom_padding = tar - H - top_padding
122+
left_padding = 0
123+
right_padding = 0
124+
125+
padded_frames = F.pad(frames, pad=[left_padding, right_padding, top_padding, bottom_padding], mode="constant", value=255)
126+
return padded_frames
127+
128+
_, _, H, W = frames.shape
129+
trans = False
130+
if W < H:
131+
frames = frames.flip(-2, -1)
132+
trans = True
133+
width, height = H, W
134+
else:
135+
width, height = W, H
136+
137+
ratio = width / height
138+
scale = 1
139+
while scale * np.ceil(scale / ratio) <= hd_num:
140+
scale += 1
141+
scale -= 1
142+
new_w = int(scale * image_size)
143+
new_h = int(new_w / ratio)
144+
145+
resized_frames = F.interpolate(frames, size=(new_h, new_w), mode="bicubic", align_corners=False)
146+
padded_frames = _padding_224(resized_frames)
147+
148+
if trans:
149+
padded_frames = padded_frames.flip(-2, -1)
150+
151+
return padded_frames
152+
153+
154+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
155+
best_ratio_diff = float("inf")
156+
best_ratio = (1, 1)
157+
area = width * height
158+
for ratio in target_ratios:
159+
target_aspect_ratio = ratio[0] / ratio[1]
160+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
161+
if ratio_diff < best_ratio_diff:
162+
best_ratio_diff = ratio_diff
163+
best_ratio = ratio
164+
elif ratio_diff == best_ratio_diff:
165+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
166+
best_ratio = ratio
167+
return best_ratio
168+
169+
170+
def HD_transform_no_padding(frames, image_size=224, hd_num=6, fix_ratio=(2, 1)):
171+
min_num = 1
172+
max_num = hd_num
173+
_, _, orig_height, orig_width = frames.shape
174+
aspect_ratio = orig_width / orig_height
175+
176+
# calculate the existing video aspect ratio
177+
target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num)
178+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
179+
180+
# find the closest aspect ratio to the target
181+
if fix_ratio:
182+
target_aspect_ratio = fix_ratio
183+
else:
184+
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
185+
186+
# calculate the target width and height
187+
target_width = image_size * target_aspect_ratio[0]
188+
target_height = image_size * target_aspect_ratio[1]
189+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
190+
191+
# resize the frames
192+
resized_frame = F.interpolate(frames, size=(target_height, target_width), mode="bicubic", align_corners=False)
193+
return resized_frame
194+
195+
196+
@register_model("InternVideo2")
197+
class InternVideo2(lmms):
198+
def __init__(
199+
self,
200+
pretrained: str = "OpenGVLab/InternVideo2_chat_8B_HD",
201+
modality: str = "video",
202+
device: str = "cuda:0",
203+
device_map: str = "cuda:0",
204+
batch_size: str = "1",
205+
num_segments: str = "8",
206+
hd_num: str = "6",
207+
**kwargs,
208+
):
209+
super().__init__()
210+
self.path = pretrained
211+
self.instruction = "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons.\n"
212+
213+
self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True, use_fast=False)
214+
self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, trust_remote_code=True).eval().cuda()
215+
batch_size = int(batch_size)
216+
self.num_segments = int(num_segments)
217+
self.hd_num = int(hd_num)
218+
assert batch_size == 1, f"Batch size should be 1 for InternVideo2, but got {batch_size}."
219+
self.batch_size_per_gpu = batch_size
220+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
221+
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
222+
self.accelerator = accelerator
223+
if accelerator.num_processes > 1:
224+
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
225+
self.device_map = f"cuda:{accelerator.local_process_index}"
226+
elif accelerator.num_processes == 1 and device_map == "auto":
227+
self._device = torch.device(device)
228+
self.device_map = device_map
229+
else:
230+
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
231+
self.device_map = f"cuda:{accelerator.local_process_index}"
232+
233+
if accelerator.num_processes > 1:
234+
assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
235+
# If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
236+
# Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
237+
# I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
238+
if accelerator.distributed_type == DistributedType.DEEPSPEED:
239+
kwargs = {
240+
"train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
241+
"train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
242+
}
243+
AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
244+
eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")
245+
246+
if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
247+
self._model = accelerator.prepare(self.model)
248+
else:
249+
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
250+
self.accelerator = accelerator
251+
if self.accelerator.is_local_main_process:
252+
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
253+
self._rank = self.accelerator.local_process_index
254+
self._world_size = self.accelerator.num_processes
255+
elif accelerator.num_processes == 1 and device_map == "auto":
256+
eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")
257+
self._rank = 0
258+
self._word_size = 1
259+
else:
260+
eval_logger.info(f"Using single device: {self._device}")
261+
self.model.to(self._device)
262+
self._rank = 0
263+
self._world_size = 1
264+
265+
self.modality = modality
266+
267+
@property
268+
def config(self):
269+
# return the associated transformers.AutoConfig for the given pretrained model.
270+
return self._config
271+
272+
@property
273+
def tokenizer(self):
274+
return self._tokenizer
275+
276+
@property
277+
def model(self):
278+
# returns the model, unwrapping it if using Accelerate
279+
if hasattr(self, "accelerator"):
280+
return self.accelerator.unwrap_model(self._model)
281+
else:
282+
return self._model
283+
284+
@property
285+
def batch_size(self):
286+
return self.batch_size_per_gpu
287+
288+
@property
289+
def device(self):
290+
return self._device
291+
292+
@property
293+
def rank(self):
294+
return self._rank
295+
296+
@property
297+
def world_size(self):
298+
return self._world_size
299+
300+
def flatten(self, input):
301+
new_list = []
302+
for i in input:
303+
for j in i:
304+
new_list.append(j)
305+
return new_list
306+
307+
def generate_until(self, requests) -> List[str]:
308+
res = []
309+
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
310+
311+
for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]:
312+
if "until" in gen_kwargs:
313+
gen_kwargs.pop("until")
314+
for k, v in DEFAULT_GEN_KWARGS.items():
315+
if k not in gen_kwargs:
316+
gen_kwargs[k] = v
317+
318+
pop_keys = []
319+
for k, v in gen_kwargs.items():
320+
if k not in DEFAULT_GEN_KWARGS:
321+
pop_keys.append(k)
322+
323+
for k in pop_keys:
324+
gen_kwargs.pop(k)
325+
326+
visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
327+
visuals = self.flatten(visuals)
328+
if self.modality == "image":
329+
image_path = visuals[0]
330+
pixel_values = load_image(image_path, resolution=224, hd_num=self.hd_num)
331+
pixel_values = pixel_values.to(torch.bfloat16).cuda()
332+
question = contexts
333+
response, history = self.model.chat(self.tokenizer, msg="", user_prompt=question, media_type="image", media_tensor=pixel_values, instruction=None, chat_history=[], return_history=True, **gen_kwargs)
334+
elif self.modality == "video":
335+
assert len(visuals) == 1, f"Only one video is supported, but got {len(visuals)} videos. [META-INFO]{visuals}"
336+
video_path = visuals[0]
337+
if "mvbench" in task:
338+
answer_prompt = "Best Option:("
339+
else:
340+
answer_prompt = None
341+
pixel_values = load_video(video_path, num_segments=self.num_segments, return_msg=False, resolution=224, hd_num=self.hd_num)
342+
pixel_values = pixel_values.to(torch.bfloat16).cuda()
343+
question = self.instruction + contexts
344+
response, history = self.model.chat(
345+
self.tokenizer,
346+
msg="",
347+
user_prompt=question,
348+
media_type="video",
349+
media_tensor=pixel_values,
350+
instruction=self.instruction,
351+
chat_history=[],
352+
return_history=True,
353+
generation_config=gen_kwargs,
354+
answer_prompt=answer_prompt,
355+
debug_conv=False,
356+
)
357+
res.append(response)
358+
pbar.update(1)
359+
pbar.close()
360+
return res
361+
362+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
363+
assert False, "Not implemented yet."
364+
365+
def generate_until_multi_round(self, requests) -> List[str]:
366+
raise NotImplementedError("TODO: Implement multi-round generation for InternVideo2")

0 commit comments

Comments
 (0)