|
| 1 | +import logging |
| 2 | +import os |
| 3 | +from typing import List, Tuple |
| 4 | + |
| 5 | +import decord |
| 6 | +import numpy as np |
| 7 | +import torch |
| 8 | +import torchvision.transforms as T |
| 9 | +from accelerate import Accelerator, DistributedType |
| 10 | +from decord import VideoReader, cpu |
| 11 | + |
| 12 | +decord.bridge.set_bridge("torch") |
| 13 | +import torch.nn.functional as F |
| 14 | +from PIL import Image |
| 15 | +from tqdm import tqdm |
| 16 | +from transformers import AutoModel, AutoTokenizer |
| 17 | + |
| 18 | +from lmms_eval.api.instance import Instance |
| 19 | +from lmms_eval.api.model import lmms |
| 20 | +from lmms_eval.api.registry import register_model |
| 21 | + |
| 22 | +eval_logger = logging.getLogger("eval_logger") |
| 23 | + |
| 24 | + |
| 25 | +from datetime import timedelta |
| 26 | + |
| 27 | +from accelerate.state import AcceleratorState |
| 28 | +from accelerate.utils import InitProcessGroupKwargs |
| 29 | + |
| 30 | +DEFAULT_GEN_KWARGS = dict( |
| 31 | + num_beams=1, |
| 32 | + max_new_tokens=1024, |
| 33 | + do_sample=False, |
| 34 | +) |
| 35 | + |
| 36 | +# def get_index(num_frames, num_segments): |
| 37 | +# seg_size = float(num_frames - 1) / num_segments |
| 38 | +# start = int(seg_size / 2) |
| 39 | +# offsets = np.array([ |
| 40 | +# start + int(np.round(seg_size * idx)) for idx in range(num_segments) |
| 41 | +# ]) |
| 42 | +# return offsets |
| 43 | + |
| 44 | + |
| 45 | +def get_index(max_frame, num_segments, fps, first_idx=0, bound=None): |
| 46 | + if bound: |
| 47 | + start, end = bound[0], bound[1] |
| 48 | + if start is None: |
| 49 | + start, end = -100000, 100000 |
| 50 | + else: |
| 51 | + start, end = -100000, 100000 |
| 52 | + start_idx = max(first_idx, round(start * fps)) |
| 53 | + end_idx = min(round(end * fps), max_frame) |
| 54 | + seg_size = float(end_idx - start_idx) / num_segments |
| 55 | + frame_indices = np.array([int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)]) |
| 56 | + return frame_indices |
| 57 | + |
| 58 | + |
| 59 | +def load_image(image_path, resolution=224, hd_num=6): |
| 60 | + image = Image.open(image_path).convert("RGB") |
| 61 | + image_tensor = T.PILToTensor()(image).unsqueeze(0) |
| 62 | + image_tensor = HD_transform_no_padding(image_tensor.float(), image_size=resolution, hd_num=hd_num) |
| 63 | + T_, C, H, W = image_tensor.shape |
| 64 | + |
| 65 | + mean = (0.485, 0.456, 0.406) |
| 66 | + std = (0.229, 0.224, 0.225) |
| 67 | + |
| 68 | + transform = T.Compose([T.Lambda(lambda x: x.float().div(255.0)), T.Normalize(mean, std)]) |
| 69 | + image_tensor = transform(image_tensor).cuda() |
| 70 | + |
| 71 | + sub_img = image_tensor.reshape(1, T_, 3, H // resolution, resolution, W // resolution, resolution).permute(0, 3, 5, 1, 2, 4, 6).reshape(-1, T_, 3, resolution, resolution).contiguous() |
| 72 | + |
| 73 | + glb_img = F.interpolate(image_tensor.float(), size=(resolution, resolution), mode="bicubic", align_corners=False).to(sub_img.dtype).unsqueeze(0) |
| 74 | + |
| 75 | + image_tensor = torch.cat([sub_img, glb_img]) # .unsqueeze(0) |
| 76 | + return image_tensor |
| 77 | + |
| 78 | + |
| 79 | +def load_video(video_path, num_segments=16, return_msg=False, resolution=224, hd_num=6, padding=False): |
| 80 | + vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) |
| 81 | + num_frames = len(vr) - 1 |
| 82 | + |
| 83 | + frame_indices = get_index(max_frame=num_frames, num_segments=num_segments, fps=float(vr.get_avg_fps()), first_idx=0, bound=None) |
| 84 | + mean = (0.485, 0.456, 0.406) |
| 85 | + std = (0.229, 0.224, 0.225) |
| 86 | + |
| 87 | + transform = T.Compose([T.Lambda(lambda x: x.float().div(255.0)), T.Normalize(mean, std)]) |
| 88 | + |
| 89 | + frames = vr.get_batch(frame_indices) |
| 90 | + frames = frames.permute(0, 3, 1, 2) |
| 91 | + |
| 92 | + if padding: |
| 93 | + frames = HD_transform_padding(frames.float(), image_size=resolution, hd_num=hd_num) |
| 94 | + else: |
| 95 | + frames = HD_transform_no_padding(frames.float(), image_size=resolution, hd_num=hd_num) |
| 96 | + |
| 97 | + frames = transform(frames) |
| 98 | + T_, C, H, W = frames.shape |
| 99 | + |
| 100 | + sub_img = frames.reshape(1, T_, 3, H // resolution, resolution, W // resolution, resolution).permute(0, 3, 5, 1, 2, 4, 6).reshape(-1, T_, 3, resolution, resolution).contiguous() |
| 101 | + |
| 102 | + glb_img = F.interpolate(frames.float(), size=(resolution, resolution), mode="bicubic", align_corners=False).to(sub_img.dtype).unsqueeze(0) |
| 103 | + |
| 104 | + frames = torch.cat([sub_img, glb_img]).unsqueeze(0) |
| 105 | + |
| 106 | + if return_msg: |
| 107 | + fps = float(vr.get_avg_fps()) |
| 108 | + sec = ", ".join([str(round(f / fps, 1)) for f in frame_indices]) |
| 109 | + # " " should be added in the start and end |
| 110 | + msg = f"The video contains {len(frame_indices)} frames sampled at {sec} seconds." |
| 111 | + return frames, msg |
| 112 | + else: |
| 113 | + return frames |
| 114 | + |
| 115 | + |
| 116 | +def HD_transform_padding(frames, image_size=224, hd_num=6): |
| 117 | + def _padding_224(frames): |
| 118 | + _, _, H, W = frames.shape |
| 119 | + tar = int(np.ceil(H / 224) * 224) |
| 120 | + top_padding = (tar - H) // 2 |
| 121 | + bottom_padding = tar - H - top_padding |
| 122 | + left_padding = 0 |
| 123 | + right_padding = 0 |
| 124 | + |
| 125 | + padded_frames = F.pad(frames, pad=[left_padding, right_padding, top_padding, bottom_padding], mode="constant", value=255) |
| 126 | + return padded_frames |
| 127 | + |
| 128 | + _, _, H, W = frames.shape |
| 129 | + trans = False |
| 130 | + if W < H: |
| 131 | + frames = frames.flip(-2, -1) |
| 132 | + trans = True |
| 133 | + width, height = H, W |
| 134 | + else: |
| 135 | + width, height = W, H |
| 136 | + |
| 137 | + ratio = width / height |
| 138 | + scale = 1 |
| 139 | + while scale * np.ceil(scale / ratio) <= hd_num: |
| 140 | + scale += 1 |
| 141 | + scale -= 1 |
| 142 | + new_w = int(scale * image_size) |
| 143 | + new_h = int(new_w / ratio) |
| 144 | + |
| 145 | + resized_frames = F.interpolate(frames, size=(new_h, new_w), mode="bicubic", align_corners=False) |
| 146 | + padded_frames = _padding_224(resized_frames) |
| 147 | + |
| 148 | + if trans: |
| 149 | + padded_frames = padded_frames.flip(-2, -1) |
| 150 | + |
| 151 | + return padded_frames |
| 152 | + |
| 153 | + |
| 154 | +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): |
| 155 | + best_ratio_diff = float("inf") |
| 156 | + best_ratio = (1, 1) |
| 157 | + area = width * height |
| 158 | + for ratio in target_ratios: |
| 159 | + target_aspect_ratio = ratio[0] / ratio[1] |
| 160 | + ratio_diff = abs(aspect_ratio - target_aspect_ratio) |
| 161 | + if ratio_diff < best_ratio_diff: |
| 162 | + best_ratio_diff = ratio_diff |
| 163 | + best_ratio = ratio |
| 164 | + elif ratio_diff == best_ratio_diff: |
| 165 | + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: |
| 166 | + best_ratio = ratio |
| 167 | + return best_ratio |
| 168 | + |
| 169 | + |
| 170 | +def HD_transform_no_padding(frames, image_size=224, hd_num=6, fix_ratio=(2, 1)): |
| 171 | + min_num = 1 |
| 172 | + max_num = hd_num |
| 173 | + _, _, orig_height, orig_width = frames.shape |
| 174 | + aspect_ratio = orig_width / orig_height |
| 175 | + |
| 176 | + # calculate the existing video aspect ratio |
| 177 | + target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) |
| 178 | + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) |
| 179 | + |
| 180 | + # find the closest aspect ratio to the target |
| 181 | + if fix_ratio: |
| 182 | + target_aspect_ratio = fix_ratio |
| 183 | + else: |
| 184 | + target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size) |
| 185 | + |
| 186 | + # calculate the target width and height |
| 187 | + target_width = image_size * target_aspect_ratio[0] |
| 188 | + target_height = image_size * target_aspect_ratio[1] |
| 189 | + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] |
| 190 | + |
| 191 | + # resize the frames |
| 192 | + resized_frame = F.interpolate(frames, size=(target_height, target_width), mode="bicubic", align_corners=False) |
| 193 | + return resized_frame |
| 194 | + |
| 195 | + |
| 196 | +@register_model("InternVideo2") |
| 197 | +class InternVideo2(lmms): |
| 198 | + def __init__( |
| 199 | + self, |
| 200 | + pretrained: str = "OpenGVLab/InternVideo2_chat_8B_HD", |
| 201 | + modality: str = "video", |
| 202 | + device: str = "cuda:0", |
| 203 | + device_map: str = "cuda:0", |
| 204 | + batch_size: str = "1", |
| 205 | + num_segments: str = "8", |
| 206 | + hd_num: str = "6", |
| 207 | + **kwargs, |
| 208 | + ): |
| 209 | + super().__init__() |
| 210 | + self.path = pretrained |
| 211 | + self.instruction = "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons.\n" |
| 212 | + |
| 213 | + self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True, use_fast=False) |
| 214 | + self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, trust_remote_code=True).eval().cuda() |
| 215 | + batch_size = int(batch_size) |
| 216 | + self.num_segments = int(num_segments) |
| 217 | + self.hd_num = int(hd_num) |
| 218 | + assert batch_size == 1, f"Batch size should be 1 for InternVideo2, but got {batch_size}." |
| 219 | + self.batch_size_per_gpu = batch_size |
| 220 | + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) |
| 221 | + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) |
| 222 | + self.accelerator = accelerator |
| 223 | + if accelerator.num_processes > 1: |
| 224 | + self._device = torch.device(f"cuda:{accelerator.local_process_index}") |
| 225 | + self.device_map = f"cuda:{accelerator.local_process_index}" |
| 226 | + elif accelerator.num_processes == 1 and device_map == "auto": |
| 227 | + self._device = torch.device(device) |
| 228 | + self.device_map = device_map |
| 229 | + else: |
| 230 | + self._device = torch.device(f"cuda:{accelerator.local_process_index}") |
| 231 | + self.device_map = f"cuda:{accelerator.local_process_index}" |
| 232 | + |
| 233 | + if accelerator.num_processes > 1: |
| 234 | + assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported." |
| 235 | + # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model |
| 236 | + # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works |
| 237 | + # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work. |
| 238 | + if accelerator.distributed_type == DistributedType.DEEPSPEED: |
| 239 | + kwargs = { |
| 240 | + "train_micro_batch_size_per_gpu": self.batch_size_per_gpu, |
| 241 | + "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes, |
| 242 | + } |
| 243 | + AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs) |
| 244 | + eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0") |
| 245 | + |
| 246 | + if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED: |
| 247 | + self._model = accelerator.prepare(self.model) |
| 248 | + else: |
| 249 | + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) |
| 250 | + self.accelerator = accelerator |
| 251 | + if self.accelerator.is_local_main_process: |
| 252 | + eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism") |
| 253 | + self._rank = self.accelerator.local_process_index |
| 254 | + self._world_size = self.accelerator.num_processes |
| 255 | + elif accelerator.num_processes == 1 and device_map == "auto": |
| 256 | + eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism") |
| 257 | + self._rank = 0 |
| 258 | + self._word_size = 1 |
| 259 | + else: |
| 260 | + eval_logger.info(f"Using single device: {self._device}") |
| 261 | + self.model.to(self._device) |
| 262 | + self._rank = 0 |
| 263 | + self._world_size = 1 |
| 264 | + |
| 265 | + self.modality = modality |
| 266 | + |
| 267 | + @property |
| 268 | + def config(self): |
| 269 | + # return the associated transformers.AutoConfig for the given pretrained model. |
| 270 | + return self._config |
| 271 | + |
| 272 | + @property |
| 273 | + def tokenizer(self): |
| 274 | + return self._tokenizer |
| 275 | + |
| 276 | + @property |
| 277 | + def model(self): |
| 278 | + # returns the model, unwrapping it if using Accelerate |
| 279 | + if hasattr(self, "accelerator"): |
| 280 | + return self.accelerator.unwrap_model(self._model) |
| 281 | + else: |
| 282 | + return self._model |
| 283 | + |
| 284 | + @property |
| 285 | + def batch_size(self): |
| 286 | + return self.batch_size_per_gpu |
| 287 | + |
| 288 | + @property |
| 289 | + def device(self): |
| 290 | + return self._device |
| 291 | + |
| 292 | + @property |
| 293 | + def rank(self): |
| 294 | + return self._rank |
| 295 | + |
| 296 | + @property |
| 297 | + def world_size(self): |
| 298 | + return self._world_size |
| 299 | + |
| 300 | + def flatten(self, input): |
| 301 | + new_list = [] |
| 302 | + for i in input: |
| 303 | + for j in i: |
| 304 | + new_list.append(j) |
| 305 | + return new_list |
| 306 | + |
| 307 | + def generate_until(self, requests) -> List[str]: |
| 308 | + res = [] |
| 309 | + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") |
| 310 | + |
| 311 | + for contexts, gen_kwargs, doc_to_visual, doc_id, task, split in [reg.args for reg in requests]: |
| 312 | + if "until" in gen_kwargs: |
| 313 | + gen_kwargs.pop("until") |
| 314 | + for k, v in DEFAULT_GEN_KWARGS.items(): |
| 315 | + if k not in gen_kwargs: |
| 316 | + gen_kwargs[k] = v |
| 317 | + |
| 318 | + pop_keys = [] |
| 319 | + for k, v in gen_kwargs.items(): |
| 320 | + if k not in DEFAULT_GEN_KWARGS: |
| 321 | + pop_keys.append(k) |
| 322 | + |
| 323 | + for k in pop_keys: |
| 324 | + gen_kwargs.pop(k) |
| 325 | + |
| 326 | + visuals = [doc_to_visual(self.task_dict[task][split][doc_id])] |
| 327 | + visuals = self.flatten(visuals) |
| 328 | + if self.modality == "image": |
| 329 | + image_path = visuals[0] |
| 330 | + pixel_values = load_image(image_path, resolution=224, hd_num=self.hd_num) |
| 331 | + pixel_values = pixel_values.to(torch.bfloat16).cuda() |
| 332 | + question = contexts |
| 333 | + response, history = self.model.chat(self.tokenizer, msg="", user_prompt=question, media_type="image", media_tensor=pixel_values, instruction=None, chat_history=[], return_history=True, **gen_kwargs) |
| 334 | + elif self.modality == "video": |
| 335 | + assert len(visuals) == 1, f"Only one video is supported, but got {len(visuals)} videos. [META-INFO]{visuals}" |
| 336 | + video_path = visuals[0] |
| 337 | + if "mvbench" in task: |
| 338 | + answer_prompt = "Best Option:(" |
| 339 | + else: |
| 340 | + answer_prompt = None |
| 341 | + pixel_values = load_video(video_path, num_segments=self.num_segments, return_msg=False, resolution=224, hd_num=self.hd_num) |
| 342 | + pixel_values = pixel_values.to(torch.bfloat16).cuda() |
| 343 | + question = self.instruction + contexts |
| 344 | + response, history = self.model.chat( |
| 345 | + self.tokenizer, |
| 346 | + msg="", |
| 347 | + user_prompt=question, |
| 348 | + media_type="video", |
| 349 | + media_tensor=pixel_values, |
| 350 | + instruction=self.instruction, |
| 351 | + chat_history=[], |
| 352 | + return_history=True, |
| 353 | + generation_config=gen_kwargs, |
| 354 | + answer_prompt=answer_prompt, |
| 355 | + debug_conv=False, |
| 356 | + ) |
| 357 | + res.append(response) |
| 358 | + pbar.update(1) |
| 359 | + pbar.close() |
| 360 | + return res |
| 361 | + |
| 362 | + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: |
| 363 | + assert False, "Not implemented yet." |
| 364 | + |
| 365 | + def generate_until_multi_round(self, requests) -> List[str]: |
| 366 | + raise NotImplementedError("TODO: Implement multi-round generation for InternVideo2") |
0 commit comments