diff --git a/vlmeval/config.py b/vlmeval/config.py index f7e7d35de..92d181d1a 100644 --- a/vlmeval/config.py +++ b/vlmeval/config.py @@ -45,6 +45,8 @@ "PLLaVA-34B": partial( PLLaVA, model_path="ermu2001/pllava-34b", dir_root=PLLaVA_ROOT ), + "InternVideo-2.5-Chat-8B": partial(InternVideo, model_path="OpenGVLab/InternVideo2_5_Chat_8B", max_new_tokens=4096), + "VideoLLaMA3-7B": partial(VideoLLaMA3, model_path="DAMO-NLP-SG/VideoLLaMA3-7B", max_new_tokens=4096), } ungrouped = { @@ -633,6 +635,8 @@ ), "MiniCPM-V-2_6": partial(MiniCPM_V_2_6, model_path="openbmb/MiniCPM-V-2_6"), "MiniCPM-o-2_6": partial(MiniCPM_o_2_6, model_path="openbmb/MiniCPM-o-2_6"), + "MiniCPM-V-4": partial(MiniCPM_V_4, model_path="openbmb/MiniCPM-V-4"), + "MiniCPM-V-4_5": partial(MiniCPM_V_4_5, model_path="openbmb/MiniCPM-V-4_5",max_new_tokens=8192), } xtuner_series = { @@ -692,6 +696,10 @@ "Thyme-7B": partial(Thyme, model_path="Kwai-Keye/Thyme-RL") } +keye_vl_series = { + "Keye-VL-1_5-8B": partial(KeyeVL, model_path="/fs-computility/llm/shared/mllm/hub/models--Kwai-Keye--Keye-VL-1_5-8B/snapshots/3921b3d6a81870b107ff76e54c320d8aab66a0da", use_vllm=False, max_new_tokens=4096) +} + llava_series = { "llava_v1.5_7b": partial(LLaVA, model_path="liuhaotian/llava-v1.5-7b"), "llava_v1.5_13b": partial(LLaVA, model_path="liuhaotian/llava-v1.5-13b"), @@ -759,6 +767,9 @@ "llava_video_qwen2_72b": partial( LLaVA_OneVision, model_path="lmms-lab/LLaVA-Video-72B-Qwen2" ), + "llava_onevision_1_5_8b": partial( + LLaVA_OneVision_1_5, model_path="lmms-lab/LLaVA-OneVision-1.5-8B-Instruct", max_new_tokens=4096 + ), } varco_vision_series = { @@ -1672,7 +1683,7 @@ aria_series, smolvlm_series, sail_series, valley_series, vita_series, ross_series, emu_series, ola_series, ursa_series, gemma_series, long_vita_series, ristretto_series, kimi_series, aguvis_series, hawkvl_series, - flash_vl, kimi_vllm_series, oryx_series, treevgr_series, varco_vision_series, qtunevl_series, xvl_series, thyme_series + flash_vl, kimi_vllm_series, oryx_series, treevgr_series, varco_vision_series, qtunevl_series, xvl_series, thyme_series, keye_vl_series ] for grp in model_groups: diff --git a/vlmeval/dataset/mlvu.py b/vlmeval/dataset/mlvu.py index bcad3e961..1868ab212 100644 --- a/vlmeval/dataset/mlvu.py +++ b/vlmeval/dataset/mlvu.py @@ -24,7 +24,7 @@ class MLVU(ConcatVideoDataset): def __init__(self, dataset='MLVU', nframe=0, fps=-1): self.DATASET_SETS[dataset] = ['MLVU_MCQ', 'MLVU_OpenEnded'] self.type_data_dict = { - 'M-Avg':['plotQA', 'needle', 'ego', 'count', 'anomaly_reco', 'topic_reasoning'], + 'M-Avg':['plotQA', 'needle', 'ego', 'count', 'anomaly_reco', 'topic_reasoning', 'order'], 'G-Avg':['sub_scene', 'summary'] } super().__init__(dataset=dataset, nframe=nframe, fps=fps) diff --git a/vlmeval/dataset/video_base.py b/vlmeval/dataset/video_base.py index a1854be48..0d6fb5b98 100644 --- a/vlmeval/dataset/video_base.py +++ b/vlmeval/dataset/video_base.py @@ -67,7 +67,7 @@ def frame_paths_fps(self, video, num_frames): return [osp.join(frame_root, self.frame_tmpl_fps.format(i, num_frames, self.fps)) for i in range(1, num_frames + 1)] - def save_video_frames(self, video): + def save_video_frames(self, video, max_frames=-1): import decord if self.fps > 0: vid_path = osp.join(self.data_root, video + '.mp4') @@ -80,12 +80,18 @@ def save_video_frames(self, video): # 计算需要提取的总帧数 required_frames = int(total_duration * self.fps) - - # 计算提取帧的间隔 - step_size = video_fps / self.fps - - # 计算提取帧的索引 - indices = [int(i * step_size) for i in range(required_frames)] + if max_frames > 0 and required_frames > max_frames: + print(f"video {video} requires {self.fps} fps sampling, \ + but all need sampled frames {required_frames} > max_frames {max_frames}, sample down to {max_frames} frames") + required_frames = max_frames + step_size = total_frames / (required_frames+1) + indices = [int(i * step_size) for i in range(1, required_frames + 1)] + else: + # 计算提取帧的间隔 + step_size = video_fps / self.fps + + # 计算提取帧的索引 + indices = [int(i * step_size) for i in range(required_frames)] # 提取帧并保存 frame_paths = self.frame_paths_fps(video, len(indices)) diff --git a/vlmeval/smp/file.py b/vlmeval/smp/file.py index 53925e018..adf7f8abf 100644 --- a/vlmeval/smp/file.py +++ b/vlmeval/smp/file.py @@ -169,7 +169,7 @@ def get_pred_file_format(): if pred_format == '': return 'xlsx' # default format else: - assert pred_format in ['tsv', 'xlsx', 'json'], f'Unsupported PRED_FORMAT {pred_format}' + assert pred_format in ['tsv', 'xlsx', 'csv'], f'Unsupported PRED_FORMAT {pred_format}' return pred_format diff --git a/vlmeval/vlm/__init__.py b/vlmeval/vlm/__init__.py index 8f4451259..d73299237 100644 --- a/vlmeval/vlm/__init__.py +++ b/vlmeval/vlm/__init__.py @@ -6,6 +6,7 @@ from .base import BaseModel from .hawk_vl import HawkVL from .thyme import Thyme +from .keye_vl import KeyeVL from .cogvlm import CogVlm, GLM4v, GLMThinking from .emu import Emu, Emu3_chat, Emu3_gen from .eagle_x import Eagle @@ -20,10 +21,11 @@ LLaVA_Next2, LLaVA_OneVision, LLaVA_OneVision_HF, + LLaVA_OneVision_1_5, ) from .vita import VITA, VITAQwen2 from .long_vita import LongVITA -from .minicpm_v import MiniCPM_V, MiniCPM_Llama3_V, MiniCPM_V_2_6, MiniCPM_o_2_6 +from .minicpm_v import MiniCPM_V, MiniCPM_Llama3_V, MiniCPM_V_2_6, MiniCPM_o_2_6, MiniCPM_V_4, MiniCPM_V_4_5 from .minigpt4 import MiniGPT4 from .mmalaya import MMAlaya, MMAlaya2 from .monkey import Monkey, MonkeyChat @@ -67,6 +69,8 @@ LLaMAVID, VideoChat2_HD, PLLaVA, + InternVideo, + VideoLLaMA3, ) from .vila import VILA, NVILA from .ovis import Ovis, Ovis1_6, Ovis1_6_Plus, Ovis2, OvisU1 diff --git a/vlmeval/vlm/base.py b/vlmeval/vlm/base.py index bb1ab95cf..2ce5a4566 100644 --- a/vlmeval/vlm/base.py +++ b/vlmeval/vlm/base.py @@ -90,7 +90,7 @@ def preproc_content(self, inputs): assert 'type' in item and 'value' in item mime, s = parse_file(item['value']) if mime is None: - assert item['type'] == 'text' + assert item['type'] == 'text', f'Invalid input type: {item}' else: assert mime.split('/')[0] == item['type'] item['value'] = s diff --git a/vlmeval/vlm/keye_vl.py b/vlmeval/vlm/keye_vl.py new file mode 100644 index 000000000..32751c72e --- /dev/null +++ b/vlmeval/vlm/keye_vl.py @@ -0,0 +1,182 @@ + + +from transformers import AutoProcessor + +from .base import BaseModel + +class KeyeVL(BaseModel): + + INSTALL_REQ = True + INTERLEAVE = True + VIDEO_LLM = True + + def __init__(self, model_path="Kwai-Keye/Keye-VL-1_5-8B", use_vllm=True, **kwargs): + + # check vllm and keye_vl_utils are installed + if use_vllm: + try: + from vllm import LLM, SamplingParams + from keye_vl_utils import process_vision_info + except Exception as e: + raise ImportError( + f"vllm and keye_vl_utils are not installed, please install them first, {e}" + "You can install them by running: " + "pip install keye-vl-utils==1.5.2 vllm>=0.10.2" + ) + else: + try: + from transformers import AutoModel, AutoTokenizer + from keye_vl_utils import process_vision_info + except Exception as e: + raise ImportError( + f"transformers and keye_vl_utils are not installed, please install them first, {e}" + "You can install them by running: " + "pip install keye-vl-utils==1.5.2 transformers>=4.56.1" + ) + + self.use_vllm = use_vllm + self.fps = 1 + self.max_frames = 64 # 1024 + self.kwargs = kwargs + # min_pixels = 32 * 28 * 28 + # max_pixels = 1280 * 28 * 28 + + self.model_path = model_path + if use_vllm: + try: + # Prefer eager mode to avoid torch.compile tracing of generators in custom model code + self.llm = LLM( + model=model_path, + limit_mm_per_prompt={"image": 10, "video": 10}, + trust_remote_code=True, + enforce_eager=True, + ) + except TypeError: + # Fallback for older vLLM versions without enforce_eager + self.llm = LLM( + model=model_path, + limit_mm_per_prompt={"image": 10, "video": 10}, + trust_remote_code=True, + tensor_parallel_size=1, + gpu_memory_utilization=0.8, + max_num_batched_tokens=32768, + max_model_len=32768, + ) + sampling_params = SamplingParams( + temperature=0.3, + max_tokens=4096, + ) + self.sampling_params = sampling_params + self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + else: + self.model = AutoModel.from_pretrained( + model_path, + torch_dtype="auto", + trust_remote_code=True, + # flash_attention_2 is recommended for better performance + attn_implementation="flash_attention_2", + ).eval() + self.model.to("cuda") + self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + def generate_inner_vllm(self, message, dataset=None): + print(f'{self.model_path} is a video-llm model using vllm, can not set fps or nframe, using default sampling method in keye_vl_utils') + content_list = [] + for msg in message: + if msg["type"] == "text": + content_list.append( + {"type": "text", "text": msg["value"]} + ) + elif msg["type"] == "image": + content_list.append( + {"type": "image", "image": msg["value"]} + ) + elif msg["type"] == "video": + content_list.append( + {"type": "video", "video": msg["value"]} + ) + else: + raise ValueError(f"Invalid message type: {msg['type']}, {msg}") + conversation = [ + {"role": "user", "content": content_list} + ] + prompt = self.processor.apply_chat_template( + conversation, tokenize=False, add_generation_prompt=True, + ) + from keye_vl_utils import process_vision_info + image_inputs, video_inputs, video_kwargs = process_vision_info( + conversation + ) + + mm_data = {} + if image_inputs is not None: + mm_data["image"] = image_inputs + if video_inputs is not None: + mm_data["video"] = video_inputs + + llm_inputs = { + "prompt": prompt, + "multi_modal_data": mm_data, + # FPS will be returned in video_kwargs + "mm_processor_kwargs": video_kwargs, + } + + outputs = self.llm.generate([llm_inputs], sampling_params=self.sampling_params) + generated_text = outputs[0].outputs[0].text + + return generated_text + + def generate_inner_transformers(self, message, dataset=None): + content_list = [] + for msg in message: + if msg["type"] == "text": + content_list.append( + {"type": "text", "text": msg["value"]} + ) + elif msg["type"] == "image": + content_list.append( + {"type": "image", "image": msg["value"]} + ) + elif msg["type"] == "video": + content_list.append( + {"type": "video", "video": msg["value"], "fps": self.fps, "max_frames": self.max_frames} + ) + else: + raise ValueError(f"Invalid message type: {msg['type']}, {msg}") + conversation = [ + {"role": "user", "content": content_list} + ] + # Preparation for inference + text = self.processor.apply_chat_template( + conversation, tokenize=False, add_generation_prompt=True + ) + from keye_vl_utils import process_vision_info + image_inputs, video_inputs, mm_processor_kwargs = process_vision_info(conversation) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + **mm_processor_kwargs + ) + inputs = inputs.to("cuda") + + # Inference: Generation of the output + generated_ids = self.model.generate(**inputs, **self.kwargs) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + output_text = self.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + return output_text + + + def generate_inner(self, message, dataset=None): + if self.use_vllm: + return self.generate_inner_vllm(message, dataset) + else: + return self.generate_inner_transformers(message, dataset) + + diff --git a/vlmeval/vlm/llava/__init__.py b/vlmeval/vlm/llava/__init__.py index 9ad9a644a..341dbc007 100644 --- a/vlmeval/vlm/llava/__init__.py +++ b/vlmeval/vlm/llava/__init__.py @@ -1,4 +1,4 @@ -from .llava import LLaVA, LLaVA_Next, LLaVA_Next2, LLaVA_OneVision, LLaVA_OneVision_HF +from .llava import LLaVA, LLaVA_Next, LLaVA_Next2, LLaVA_OneVision, LLaVA_OneVision_HF, LLaVA_OneVision_1_5 from .llava_xtuner import LLaVA_XTuner -__all__ = ['LLaVA', 'LLaVA_Next', 'LLaVA_XTuner', 'LLaVA_Next2', 'LLaVA_OneVision', 'LLaVA_OneVision_HF'] +__all__ = ['LLaVA', 'LLaVA_Next', 'LLaVA_XTuner', 'LLaVA_Next2', 'LLaVA_OneVision', 'LLaVA_OneVision_HF','LLaVA_OneVision_1_5'] diff --git a/vlmeval/vlm/llava/llava.py b/vlmeval/vlm/llava/llava.py index 889c01167..7ad4c2665 100644 --- a/vlmeval/vlm/llava/llava.py +++ b/vlmeval/vlm/llava/llava.py @@ -501,6 +501,7 @@ class LLaVA_OneVision(BaseModel): IMAGE_TOKEN_INDEX = -200 def __init__(self, model_path="lmms-lab/llava-onevision-qwen2-7b-si", **kwargs): + self.model_path = model_path assert model_path is not None try: from llava.model.builder import load_pretrained_model @@ -518,7 +519,7 @@ def __init__(self, model_path="lmms-lab/llava-onevision-qwen2-7b-si", **kwargs): raise err video_kwargs_default = dict( - overwrite=True, mm_spatial_pool_mode="average", force_sample=True + overwrite=True, mm_spatial_pool_mode="average", force_sample=False ) video_kwargs_default.update(kwargs) self.video_kwargs = video_kwargs_default @@ -549,6 +550,7 @@ def __init__(self, model_path="lmms-lab/llava-onevision-qwen2-7b-si", **kwargs): conv_mode = "qwen_1_5" if 'llava-video' in model_path.lower(): self.nframe = 64 + self.fps = 1 else: self.nframe = 16 if "72b" in model_path.lower(): @@ -622,7 +624,7 @@ def generate_inner_image(self, message, dataset=None): return text_outputs def generate_inner_video(self, message, dataset=None): - content, text_content, visual_content, videos = "", "", "", [] + content, text_content, visual_content, videos, images = "", "", "", [], [] for msg in message: if msg["type"] == "text": @@ -636,9 +638,15 @@ def generate_inner_video(self, message, dataset=None): "LLaVA-OneVision does not support multiple videos as input." ) - video_frames, frame_time, video_time = self.load_video( - videos[0], self.nframe, 1, self.force_sample - ) + if self.fps is not None and self.fps > 0: + print(f'{self.model_path} is a video-llm model, using fps {self.fps} to sample video, max frames is 64') + video_frames, frame_time, video_time = self.load_video( + videos[0], 64, self.fps, self.force_sample + ) # set the frame number to 64 + else: + video_frames, frame_time, video_time = self.load_video( + videos[0], self.nframe, 1, self.force_sample + ) time_instruciton = ( f"The video lasts for {video_time:.2f} seconds," @@ -715,12 +723,81 @@ def load_video(self, video_path, max_frames_num, fps=1, force_sample=False): frame_time = [i / vr.get_avg_fps() for i in frame_idx] frame_time = ",".join([f"{i:.2f}s" for i in frame_time]) spare_frames = vr.get_batch(frame_idx).asnumpy() - # import pdb;pdb.set_trace() return spare_frames, frame_time, video_time + def generate_inner_image_and_video(self, message, dataset=None): + content, images, videos = "", [], [] + for msg in message: + if msg["type"] == "text": + content += msg["value"] + elif msg["type"] == "image": + images.append(msg["value"]) + elif msg["type"] == "video": + videos.append(msg["value"]) + + if len(videos) > 1: + raise ValueError("LLaVA-OneVision does not support multiple videos as input.") + + if self.fps > 0: + print(f'{self.model_path} is a video-llm model, using fps {self.fps} to sample video, max frames is 64') + video, frame_time, video_time = self.load_video( + videos[0], 64, self.fps, self.force_sample + ) # set the frame number to 64 + else: + video, frame_time, video_time = self.load_video( + videos[0], self.nframe, 1, self.force_sample + ) + + frame_height, frame_width = video.shape[1:3] + images_array = [] + for image_path in images: + image = Image.open(image_path).convert("RGB") + image = image.resize((frame_width, frame_height)) + image_array = np.array(image) + images_array.append(image_array) + images_array = np.array(images_array) + + image_sizes = [(frame_width, frame_height) for frame in video] + [(frame_width, frame_height) for image in images_array] + + if video.dtype != images_array.dtype: + images_array = images_array.astype(video.dtype) + + print(f"Video shape: {video.shape}, Images array shape: {images_array.shape}") + if video.shape[1:] != images_array.shape[1:]: + raise ValueError(f"Shape mismatch: video shape {video.shape}, image shape {image_array.shape}") + + video = self.image_processor.preprocess(video, return_tensors="pt")["pixel_values"].half().cuda() + image = self.image_processor.preprocess(images_array, return_tensors="pt")["pixel_values"].half().cuda() + video_image = [video, image] + + conv_template = "qwen_1_5" # Make sure you use correct chat template for different models + time_instruciton = f"The video lasts for {video_time:.2f} seconds, and {len(video[0])} frames are uniformly sampled from it. These frames are located at {frame_time}.Please answer the following questions related to this video." + question = self.DEFAULT_IMAGE_TOKEN + f"\n{time_instruciton}\n" + self.DEFAULT_IMAGE_TOKEN + content + conv = copy.deepcopy(self.conv_templates[conv_template]) + conv.append_message(conv.roles[0], question) + conv.append_message(conv.roles[1], None) + prompt_question = conv.get_prompt() + input_ids = self.tokenizer_image_token(prompt_question, self.tokenizer, self.IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.model.device) + cont = self.model.generate( + input_ids, + images=video_image, + image_sizes=image_sizes, + modalities= ["video", "image"], + do_sample=False, + temperature=0, + max_new_tokens=4096, + ) + text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0].strip() + return text_outputs + + def generate_inner(self, message, dataset=None): if DATASET_MODALITY(dataset) == 'VIDEO' and 'megabench' not in dataset.lower(): - return self.generate_inner_video(message, dataset) + msg_type = [m['type'] for m in message] + if 'image' in msg_type: + return self.generate_inner_image_and_video(message, dataset) + else: + return self.generate_inner_video(message, dataset) else: return self.generate_inner_image(message, dataset) @@ -843,3 +920,100 @@ def generate_inner(self, message, dataset=None): return self.generate_inner_video(message, dataset) else: return self.generate_inner_image(message, dataset) + +class LLaVA_OneVision_1_5(BaseModel): + + def __init__(self, model_path = "lmms-lab/LLaVA-One-Vision-1.5-8B-Instruct", **kwargs): + from transformers import AutoTokenizer, AutoProcessor, AutoModelForCausalLM + from qwen_vl_utils import process_vision_info + + # default: Load the model on the available device(s) + self.model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True + ) + + # default processer + self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + self.kwargs = kwargs + self.fps = kwargs.pop('fps', 2) + self.nframe = kwargs.pop('nframe', 128) + self.FRAME_FACTOR = 2 + + def ensure_image_url(self, image: str) -> str: + prefixes = ['http://', 'https://', 'file://', 'data:image;'] + if any(image.startswith(prefix) for prefix in prefixes): + return image + if os.path.exists(image): + return 'file://' + image + raise ValueError(f'Invalid image: {image}') + + + def ensure_video_url(self, video: str) -> str: + prefixes = ['http://', 'https://', 'file://', 'data:video;'] + if any(video.startswith(prefix) for prefix in prefixes): + return video + if os.path.exists(video): + return 'file://' + video + raise ValueError(f'Invalid video: {video}') + + def generate_inner(self, message, dataset=None): + content_list = [] + for msg in message: + if msg["type"] == "text": + content_list.append({"type": "text", "text": msg["value"]}) + elif msg["type"] == "video": + item = { + 'type': 'video', + 'video': self.ensure_video_url(msg['value']) + } + item['min_pixels'] = 128 * 28 * 28 + item['max_pixels'] = 768 * 28 * 28 + item['total_pixels'] = 24576 * 28 * 28 + if self.fps is not None and self.fps > 0: + item['fps'] = self.fps + elif self.nframe is not None and self.nframe > 0: + import cv2 + video = cv2.VideoCapture(msg['value']) + frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + video.release() + if frame_count < self.nframe: + new_frame_count = frame_count // self.FRAME_FACTOR * self.FRAME_FACTOR + print(f"use {new_frame_count} for {msg['value']}") + item['nframes'] = new_frame_count + else: + item['nframes'] = self.nframe + content_list.append(item) + elif msg["type"] == "image": + content_list.append({"type": "image", "image": self.ensure_image_url(msg["value"])}) + else: + raise ValueError(f"Invalid message type: {msg['type']}, {msg}") + + messages = [ + {"role": "user", "content": content_list} + ] + + # Preparation for inference + text = self.processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + from qwen_vl_utils import process_vision_info + image_inputs, video_inputs = process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to("cuda") + + # Inference: Generation of the output + generated_ids = self.model.generate(**inputs, **self.kwargs) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + output_text = self.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + # print(len(output_text), output_text[0]) + return output_text[0] \ No newline at end of file diff --git a/vlmeval/vlm/minicpm_v.py b/vlmeval/vlm/minicpm_v.py index 0f9b52576..6d91da898 100644 --- a/vlmeval/vlm/minicpm_v.py +++ b/vlmeval/vlm/minicpm_v.py @@ -790,3 +790,483 @@ def generate_inner(self, message, dataset=None): res = self.extract_answer(res, dataset) return res + + +class MiniCPM_V_4(BaseModel): + INSTALL_REQ = False + INTERLEAVE = True + + def __init__(self, model_path='openbmb/MiniCPM-V-4', **kwargs): + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + assert model_path is not None + self.model_path = model_path + print(f'load from path {self.model_path}') + self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True) + self.model = self.model.to(dtype=torch.bfloat16) + self.model.eval().cuda() + self.kwargs = kwargs + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + torch.cuda.empty_cache() + + self.num_beams = 3 + self.max_new_tokens = 2048 + self.options_suffix_prompt = '''\nAnswer with the option's letter from the given choices directly.''' + self.wo_options_system_prompt = 'Carefully read the following question. Answer the question directly.' + self.detail_system_prompt = 'Answer this question in detail.' + self.vqa_prompt = 'Answer the question using a single word or phrase.' + self.multi_choice_cot_prompt = ('''Carefully read the following multichoice question, solve it step ''' + '''by step and finally pick the option associated with the correct ''' + '''answer in the format of "Answer: selected option\n\n''') + self.short_ans_cot_prompt = ('''Read the following question carefully, solve it step by step, and ''' + '''then output the final answer in the format of "Answer: single number ''' + '''or single word or phrase".\n\n''') + self.ocrbench_cot_prompt = 'Carefully observe the image and answer the OCR-related questions below. \n\n' + + def use_custom_prompt(self, dataset=None): + if dataset is None: + return False + if listinstr(['MCQ', 'VQA', 'Y/N'], DATASET_TYPE(dataset)) and DATASET_MODALITY(dataset) != 'VIDEO': + return True + return False + + def use_cot(self, dataset=None): + if dataset is None: + return False + if listinstr([ + 'MMMU', 'MathVista', 'MMStar', 'HallusionBench', 'OCRBench', + 'ChartQA', 'MathVision', 'MathVerse_MINI_Vision_Only' + ], dataset): + return True + elif listinstr([ + 'MMVet', 'MMBench', 'AI2D', 'RealWorldQA', 'POPE', 'ScienceQA', + 'TextVQA', 'DocVQA' + ], dataset): + return False + else: + return False + + def use_upsize(self, dataset=None): + if dataset is None: + return False + if listinstr([ + 'MathVista', 'MMVet', 'MMStar', 'AI2D', 'OCRBench', 'ChartQA', + 'TextVQA' + ], dataset): + return True + else: + return False + + def build_prompt(self, line, dataset=None): + if isinstance(line, int): + line = self.data.iloc[line] + + tgt_path = self.dump_image(line, dataset) + system_prompt, prompt = '', '' + + question = line['question'] + + if not self.use_cot(dataset): + if DATASET_TYPE(dataset) == 'MCQ': + options = { + cand: line[cand] + for cand in string.ascii_uppercase + if cand in line and not pd.isna(line[cand]) + } + options_prompt = 'Options:\n' + for key, item in options.items(): + options_prompt += f'{key}. {item}\n' + hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None + if hint is not None: + prompt += f'Hint: {hint}\n' + prompt += f'Question: {question}\n' + if len(options): + prompt += options_prompt + prompt += self.options_suffix_prompt + else: + system_prompt = self.wo_options_system_prompt + + if 'MMMU' in dataset: + if len(system_prompt) > 0: + prompt = system_prompt + '\n' + prompt + system_prompt = '' + elif dataset is not None and listinstr(['HallusionBench'], dataset): + question += ' Yes or No?' + prompt = question + elif dataset is not None and listinstr(['OCRBench'], dataset): + system_prompt = self.vqa_prompt + prompt = question + elif DATASET_TYPE(dataset) == 'VQA': + if listinstr(['LLaVABench'], dataset): + system_prompt = '' + elif listinstr(['MMVet'], dataset): + system_prompt = self.detail_system_prompt + else: + system_prompt = self.vqa_prompt + prompt = question + else: + prompt = question + else: + has_options = True + if DATASET_TYPE(dataset) == 'MCQ': + options = { + cand: line[cand] + for cand in string.ascii_uppercase + if cand in line and not pd.isna(line[cand]) + } + options_prompt = '' + for key, item in options.items(): + options_prompt += f'{key}. {item}\n' + hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None + if hint is not None: + prompt += f'Hint: {hint}\n' + prompt += f'{question}\n' + + if len(options): + prompt += options_prompt + else: + has_options = False + + if 'MMMU' in dataset: + if len(system_prompt) > 0: + prompt = system_prompt + '\n' + prompt + system_prompt = '' + else: + prompt = question + + if DATASET_TYPE(dataset) in ['MCQ', 'Y/N', 'VQA']: + if DATASET_TYPE(dataset) == 'MCQ': + if has_options: + prompt = self.multi_choice_cot_prompt + prompt + else: + prompt = self.short_ans_cot_prompt + prompt + elif DATASET_TYPE(dataset) == 'Y/N': + prompt = self.short_ans_cot_prompt + prompt + elif listinstr(['OCRBench'], dataset): + prompt = self.ocrbench_cot_prompt + prompt + else: + prompt = self.short_ans_cot_prompt + prompt + + msgs = [] + if system_prompt: + msgs.append(dict(type='text', value=system_prompt)) + if isinstance(tgt_path, list): + msgs.extend([dict(type='image', value=p) for p in tgt_path]) + else: + msgs = [dict(type='image', value=tgt_path)] + msgs.append(dict(type='text', value=prompt)) + + if dataset.startswith('MMMU_'): + from .. import MMMUDataset + msgs = MMMUDataset.split_MMMU(msgs) + + return msgs + + def extract_answer(self, res, dataset=None): + if dataset is None: + return res + if self.use_cot(dataset): + if DATASET_TYPE(dataset) == 'MCQ': + pattern = r'Answer:\s*([A-Ia-i])(?![A-Za-z])' + matches = re.findall(pattern, res, re.DOTALL) + if matches: + extracted_res = matches[-1].strip() + else: + extracted_res = res + return extracted_res + elif DATASET_TYPE(dataset) == 'VQA' and not listinstr(['OCRBench', 'MMVet'], dataset): + pattern = r'Answer:\s*(.*)\s*$' + match = re.search(pattern, res, re.DOTALL) + if match: + extracted_res = match.group(1) + else: + extracted_res = res + return extracted_res + elif DATASET_TYPE(dataset) == 'Y/N': + pattern = r'Answer:\s*(.*)\s*$' + match = re.search(pattern, res, re.DOTALL) + if match: + extracted_res = match.group(1) + else: + extracted_res = res + return extracted_res + return res + + def generate_inner(self, message, dataset=None): + if self.use_cot(dataset): + max_new_tokens = self.max_new_tokens + else: + max_new_tokens = 1024 + default_kwargs = dict( + max_new_tokens=max_new_tokens, + sampling=False, + num_beams=self.num_beams, + ) + default_kwargs.update(self.kwargs) + + content = [] + + for x in message: + if x['type'] == 'text': + content.append(x['value']) + elif x['type'] == 'image': + image = Image.open(x['value']).convert('RGB') + if not self.use_upsize(dataset): + content.append(image) + else: + img_width, img_height = image.width, image.height + if (img_width * img_height) >= (1344 * 1344): + content.append(image) + else: + ratio = math.sqrt((1344 * 1344) / (img_width * img_height)) + max_img_width = int(img_width * ratio) + new_img_width = random.randint(img_width, max_img_width) + new_img_height = int(new_img_width / img_width * img_height) + resized_image = image.resize((new_img_width, new_img_height)) + content.append(resized_image) + msgs = [{'role': 'user', 'content': content}] + + res = self.model.chat( + image=None, + msgs=msgs, + context=None, + tokenizer=self.tokenizer, + max_inp_length=8192, + **default_kwargs + ) + + if isinstance(res, tuple) and len(res) > 0: + res = res[0] + res = self.extract_answer(res, dataset) + + return res + + +class MiniCPM_V_4_5(MiniCPM_V_4): + INSTALL_REQ = False + INTERLEAVE = True + + def __init__(self, model_path='openbmb/MiniCPM-V-4_5', **kwargs): + super().__init__(model_path, **kwargs) + from transformers import AutoProcessor + self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True) + self._original_chat_template = self.tokenizer.chat_template + self._long_cot_chat_template = self._original_chat_template.replace( + "{{- '\\n' }}", "{{- '\\nI' }}") + + def use_long_cot(self, dataset=None): + if dataset is None: + return False + if listinstr([ + 'MMMU', 'MathVista', 'MMVet', 'MMBench', 'HallusionBench', + 'MMStar', 'MathVision', 'MathVerse_MINI', + 'MathVerse_MINI_Vision_Only', 'DynaMath', 'LogicVista' + ], dataset): + return True + else: + return False + + def use_cot(self, dataset=None): + if dataset is None: + return False + if listinstr([ + 'MMMU', 'MathVista', 'MMBench', 'HallusionBench', 'MMStar', + 'OCRBench', 'ChartQA', 'MathVision', 'MathVerse_MINI', + 'MathVerse_MINI_Vision_Only', 'DynaMath', 'LogicVista' + ], dataset): + return True + else: + return False + + def use_upsize(self, dataset=None): + if dataset is None: + return False + if self.use_long_cot(dataset): + return True + if listinstr(['AI2D', 'OCRBench', 'ChartQA', 'TextVQA'], dataset): + return True + else: + return False + + def build_prompt(self, line, dataset=None): + if self.use_long_cot(dataset): + self.tokenizer.chat_template = self._long_cot_chat_template + else: + self.tokenizer.chat_template = self._original_chat_template + + if isinstance(line, int): + line = self.data.iloc[line] + + tgt_path = self.dump_image(line, dataset) + system_prompt, prompt = '', '' + question = line['question'] + + if not self.use_cot(dataset): + if DATASET_TYPE(dataset) == 'MCQ': + options = { + cand: line[cand] + for cand in string.ascii_uppercase + if cand in line and not pd.isna(line[cand]) + } + options_prompt = 'Options:\n' + for key, item in options.items(): + options_prompt += f'{key}. {item}\n' + hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None + if hint is not None: + prompt += f'Hint: {hint}\n' + prompt += f'Question: {question}\n' + if len(options): + prompt += options_prompt + prompt += self.options_suffix_prompt + else: + system_prompt = self.wo_options_system_prompt + + if 'MMMU' in dataset: + if len(system_prompt) > 0: + prompt = system_prompt + '\n' + prompt + system_prompt = '' + elif dataset is not None and listinstr(['HallusionBench'], dataset): + question += ' Yes or No?' + prompt = question + elif dataset is not None and listinstr(['OCRBench'], dataset): + system_prompt = self.vqa_prompt + prompt = question + elif DATASET_TYPE(dataset) == 'VQA': + if listinstr(['LLaVABench'], dataset): + system_prompt = '' + elif listinstr(['MMVet'], dataset): + system_prompt = self.detail_system_prompt + else: + system_prompt = self.vqa_prompt + prompt = question + else: + prompt = question + else: + has_options = True + if DATASET_TYPE(dataset) == 'MCQ': + options = { + cand: line[cand] + for cand in string.ascii_uppercase + if cand in line and not pd.isna(line[cand]) + } + options_prompt = '' + for key, item in options.items(): + options_prompt += f'{key}. {item}\n' + hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None + if hint is not None: + prompt += f'Hint: {hint}\n' + prompt += f'{question}\n' + + if len(options): + prompt += options_prompt + else: + has_options = False + + if 'MMMU' in dataset: + if len(system_prompt) > 0: + prompt = system_prompt + '\n' + prompt + system_prompt = '' + else: + prompt = question + + if DATASET_TYPE(dataset) in ['MCQ', 'Y/N', 'VQA']: + if DATASET_TYPE(dataset) == 'MCQ': + if has_options: + prompt = self.multi_choice_cot_prompt + prompt + else: + prompt = self.short_ans_cot_prompt + prompt + elif DATASET_TYPE(dataset) == 'Y/N': + prompt = self.short_ans_cot_prompt + prompt + elif listinstr(['OCRBench'], dataset): + prompt = self.ocrbench_cot_prompt + prompt + else: + prompt = self.short_ans_cot_prompt + prompt + + msgs = [] + if system_prompt: + msgs.append(dict(type='text', value=system_prompt)) + if isinstance(tgt_path, list): + msgs.extend([dict(type='image', value=p) for p in tgt_path]) + else: + msgs = [dict(type='image', value=tgt_path)] + msgs.append(dict(type='text', value=prompt)) + + if dataset.startswith('MMMU_'): + from .. import MMMUDataset + msgs = MMMUDataset.split_MMMU(msgs) + + return msgs + + def generate_inner(self, message, dataset=None): + if self.use_long_cot(dataset): + default_kwargs = dict( + enable_thinking=True, + max_new_tokens=8192, + sampling=True, + temperature=0.7, + num_beams=1, + top_p=1.0, + top_k=0, + repetition_penalty=1.0, + no_repeat_ngram_size=0 + ) + elif self.use_cot(dataset): + default_kwargs = dict( + max_new_tokens=2048, + sampling=False, + num_beams=self.num_beams, + repetition_penalty=1.2 + ) + else: + default_kwargs = dict( + max_new_tokens=1024, + sampling=False, + num_beams=self.num_beams, + repetition_penalty=1.2 + ) + + default_kwargs.update(self.kwargs) + + content = [] + for x in message: + if x['type'] == 'text': + content.append(x['value']) + elif x['type'] == 'image': + image = Image.open(x['value']).convert('RGB') + if not self.use_upsize(dataset): + content.append(image) + else: + img_width, img_height = image.width, image.height + if (img_width * img_height) >= (1344 * 1344): + content.append(image) + else: + ratio = math.sqrt((1344 * 1344) / (img_width * img_height)) + max_img_width = int(img_width * ratio) + new_img_width = random.randint(img_width, max_img_width) + new_img_height = int(new_img_width / img_width * img_height) + resized_image = image.resize((new_img_width, new_img_height)) + content.append(resized_image) + msgs = [{'role': 'user', 'content': content}] + + self.processor.tokenizer = self.tokenizer + + res = self.model.chat( + image=None, + msgs=msgs, + context=None, + tokenizer=self.tokenizer, + processor=self.processor, + max_inp_length=8192, + max_slice_nums=1, + **default_kwargs + ) + + if isinstance(res, tuple) and len(res) > 0: + res = res[0] + + res = res.replace('\n', '\nI ') + res = self.extract_answer(res, dataset) + + return res diff --git a/vlmeval/vlm/video_llm/__init__.py b/vlmeval/vlm/video_llm/__init__.py index f8c566aeb..fef2a469e 100644 --- a/vlmeval/vlm/video_llm/__init__.py +++ b/vlmeval/vlm/video_llm/__init__.py @@ -4,5 +4,10 @@ from .video_chatgpt import VideoChatGPT from .llama_vid import LLaMAVID from .pllava import PLLaVA +from .internvideo import InternVideo +from .videollama3 import VideoLLaMA3 -__all__ = ['VideoLLaVA', 'VideoLLaVA_HF', 'Chatunivi', 'VideoChatGPT', 'LLaMAVID', 'VideoChat2_HD', 'PLLaVA'] +__all__ = [ + 'VideoLLaVA', 'VideoLLaVA_HF', 'Chatunivi', 'VideoChatGPT', 'LLaMAVID', 'VideoChat2_HD', 'PLLaVA', + 'InternVideo', 'VideoLLaMA3' +] diff --git a/vlmeval/vlm/video_llm/internvideo.py b/vlmeval/vlm/video_llm/internvideo.py new file mode 100644 index 000000000..d1bbd742b --- /dev/null +++ b/vlmeval/vlm/video_llm/internvideo.py @@ -0,0 +1,231 @@ +import numpy as np +import torch +import torchvision.transforms as T +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +import transformers +from transformers import AutoModel, AutoTokenizer + +import warnings +import copy as cp + +import sys +import os +import logging +from ..base import BaseModel +from ...smp import isimg, listinstr, version_cmp +from ...dataset import DATASET_TYPE +from PIL import Image + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + +def build_transform(input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose([T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD)]) + return transform + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + +def get_indices_by_fps(bound, src_fps, max_frame, target_fps, first_idx=0): + if bound: + start, end = bound[0], bound[1] + else: + start, end = 0.0, max_frame / src_fps + if target_fps is None or target_fps <= 0: + raise ValueError("target_fps must be a positive number") + times = np.arange(start, end, 1.0 / target_fps) + frame_indices = np.round(times * src_fps).astype(int) + frame_indices = np.clip(frame_indices, first_idx, max_frame) + frame_indices = np.unique(frame_indices) + return frame_indices + + +def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +def load_image(image, input_size=448, max_num=6): + transform = build_transform(input_size=input_size) + images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + +def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): + if bound: + start, end = bound[0], bound[1] + else: + start, end = -100000, 100000 + start_idx = max(first_idx, round(start * fps)) + end_idx = min(round(end * fps), max_frame) + seg_size = float(end_idx - start_idx) / num_segments + frame_indices = np.array([int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)]) + return frame_indices + +def get_num_frames_by_duration(duration): + local_num_frames = 4 + num_segments = int(duration // local_num_frames) + if num_segments == 0: + num_frames = local_num_frames + else: + num_frames = local_num_frames * num_segments + + num_frames = min(512, num_frames) + num_frames = max(128, num_frames) + + return num_frames + +def load_video( + video_path, bound=None, input_size=448, max_num=1, num_segments=32, target_fps=-1, get_frame_by_duration = False +): + from decord import VideoReader, cpu + vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) + max_frame = len(vr) - 1 + fps = float(vr.get_avg_fps()) + + pixel_values_list, num_patches_list = [], [] + transform = build_transform(input_size=input_size) + + if target_fps is not None and target_fps > 0: + frame_indices = get_indices_by_fps(bound, fps, max_frame, target_fps, first_idx=0) + else: + if get_frame_by_duration: + duration = max_frame / fps + num_segments = get_num_frames_by_duration(duration) + frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) + for frame_index in frame_indices: + img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB") + img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(tile) for tile in img] + pixel_values = torch.stack(pixel_values) + num_patches_list.append(pixel_values.shape[0]) + pixel_values_list.append(pixel_values) + pixel_values = torch.cat(pixel_values_list) + return pixel_values, num_patches_list + +def load_image(image_path, input_size=448, max_num=6): + image = Image.open(image_path).convert('RGB') + transform = build_transform(input_size=input_size) + images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values, pixel_values.shape[0] + + +class InternVideo(BaseModel): + INTERLEAVE = False + VIDEO_LLM = True + + def __init__(self, model_path='OpenGVLab/InternVideo2_5_Chat_8B', **kwargs): + self.model_path = model_path + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda().to(torch.bfloat16) + self.generation_config = dict( + do_sample=False, + temperature=0.0, + max_new_tokens=1024, + top_p=0.1, + num_beams=1 + ) + self.generation_config.update(kwargs) + + self.nframe = 128 + self.fps = 1 + + # recommend using transformers==4.40.0 + # full env recommendation in https://huggingface.co/OpenGVLab/InternVideo2_5_Chat_8B + self.transformers_version = '4.40.1' + assert version_cmp(transformers.__version__, self.transformers_version, 'eq') + + + def generate_inner(self, message, dataset=None): + text_content, videos, images = "", [], [] + + for msg in message: + if msg["type"] == "text": + text_content += msg["value"] + elif msg["type"] == "image": + # images.append(msg["value"]) + continue + else: + videos.append(msg["value"]) + + if len(videos) > 1: + raise ValueError( + "InternVideo does not support multiple videos as input." + ) + + with torch.no_grad(): + video_pixel_values, video_num_patches_list = load_video(videos[0], num_segments=self.nframe, target_fps=self.fps, max_num=1, get_frame_by_duration=False) + video_pixel_values = video_pixel_values.to(torch.bfloat16).to(self.model.device) + video_prefix = "".join([f"Frame{i+1}: \n" for i in range(len(video_num_patches_list))]) + + img_pixel_values_list, img_num_patches_list = [], [] + img_prefix = "" + for img_pth in images: + img_pixel_values, num_patches = load_image(img_pth, max_num=1) + img_pixel_values_list.append(img_pixel_values) + img_num_patches_list.append(num_patches) + img_prefix += f"\n" + + if len(img_pixel_values_list): + img_pixel_values = torch.cat(img_pixel_values_list, dim=0).to(torch.bfloat16).to(self.model.device) + pixel_values = torch.cat((img_pixel_values, video_pixel_values), dim=0) + else: + pixel_values = video_pixel_values + + # 顺序:图片在前、视频在后(与 question 和 pixel_values 一致) + num_patches_list_all = img_num_patches_list + video_num_patches_list + + # 对于pixel_values,每张图片重复四遍,使得后面可以merge + pixel_values = pixel_values.repeat_interleave(4, dim=0) + # num_patches_list_all = [x for x in num_patches_list_all for _ in range(4)] + num_patches_list_all = [x * 4 for x in num_patches_list_all] + question = img_prefix + video_prefix + text_content + # assert pixel_values.shape[0] % 4 == 0, "pixel_values.shape[0] must be divisible by 4" + output, chat_history = self.model.chat(self.tokenizer, pixel_values, question, self.generation_config, num_patches_list=num_patches_list_all, history=None, return_history=True) + + return output diff --git a/vlmeval/vlm/video_llm/videollama3.py b/vlmeval/vlm/video_llm/videollama3.py new file mode 100644 index 000000000..3cb0f8681 --- /dev/null +++ b/vlmeval/vlm/video_llm/videollama3.py @@ -0,0 +1,56 @@ +from ..base import BaseModel + +import torch +from transformers import AutoModelForCausalLM, AutoProcessor + + +class VideoLLaMA3(BaseModel): + INSTALL_REQ = False + INTERLEAVE = False + VIDEO_LLM = True + + def __init__(self, model_path="DAMO-NLP-SG/VideoLLaMA3-7B", **kwargs): + self.model_path = model_path + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + device_map={"": "cuda"}, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + ) + self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + self.kwargs = kwargs + self.fps = 1 + self.max_frames = 64 + + def generate_inner(self, message, dataset=None): + content_list = [] + for msg in message: + if msg["type"] == "text": + content_list.append({"type": "text", "text": msg["value"]}) + elif msg["type"] == "video": + content_list.append( + {"type": "video", "video": {"video_path": msg["value"], "fps": self.fps, "max_frames": self.max_frames}} + ) + elif msg["type"] == "image": + content_list.append({"type": "image", "image": {"image_path": msg["value"]}}) + else: + raise ValueError(f"Invalid message type: {msg['type']}, {msg}") + + conversation = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": content_list} + ] + + inputs = self.processor( + conversation=conversation, + add_system_prompt=True, + add_generation_prompt=True, + return_tensors="pt" + ) + inputs = {k: v.to(self.model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) + output_ids = self.model.generate(**inputs, **self.kwargs) + response = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0].strip() + return response \ No newline at end of file