From 2aa3522a5fe7aa2dd18561122c40fc8840e3b2f5 Mon Sep 17 00:00:00 2001 From: Tuomas Rintamaki Date: Wed, 11 Dec 2024 05:42:18 -0800 Subject: [PATCH] ADLR/megatron-lm!2432 - Small VLM example --- examples/multimodal/config.py | 50 +++----- examples/multimodal/evaluate_ai2d.py | 22 ++-- examples/multimodal/evaluate_chartqa.py | 13 +- examples/multimodal/evaluate_coco.py | 18 ++- examples/multimodal/evaluate_mathvista.py | 12 +- examples/multimodal/evaluate_mmmu.py | 4 + examples/multimodal/evaluate_ocrbench.py | 12 +- examples/multimodal/evaluate_textvqa.py | 25 ++-- examples/multimodal/evaluate_vqav2.py | 16 ++- examples/multimodal/evaluation_datasets.py | 84 +++++++++++-- examples/multimodal/model.py | 14 +++ examples/multimodal/multimodal_args.py | 6 +- .../run_text_generation_qwen25_7b_siglip.sh | 111 ++++++++++++++++++ examples/multimodal/run_text_generation.py | 26 ++-- .../tokenizer/multimodal_tokenizer.py | 2 +- 15 files changed, 324 insertions(+), 91 deletions(-) create mode 100755 examples/multimodal/nvlm/run_text_generation_qwen25_7b_siglip.sh diff --git a/examples/multimodal/config.py b/examples/multimodal/config.py index 343fcd5896..ee404604b6 100644 --- a/examples/multimodal/config.py +++ b/examples/multimodal/config.py @@ -7,34 +7,20 @@ def get_language_model_config(config): - if config.language_model_type == "2b": + if config.language_model_type == "llama3_8b": + config.activation_func = torch.nn.functional.silu config.add_bias_linear = False config.bias_activation_fusion = False config.gated_linear_unit = True - config.apply_query_key_layer_scaling = True - config.layernorm_zero_centered_gamma = True - config.bias_dropout_fusion = False - config.rotary_percent = 0.5 - config.apply_rope_fusion = False - config.attention_softmax_in_fp32 = True - elif config.language_model_type == "8b": - config.add_bias_linear = False - config.bias_activation_fusion = False - config.gated_linear_unit = False - config.apply_query_key_layer_scaling = True - config.layernorm_zero_centered_gamma = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) config.bias_dropout_fusion = False - config.rotary_percent = 0.5 - config.attention_dropout = 0.0 config.apply_rope_fusion = False - config.activation_func = squared_relu - config.ffn_hidden_size = 16384 - config.masked_softmax_fusion = True config.attention_softmax_in_fp32 = True - config.num_query_groups = 32 - config.kv_channels = 128 - config.rotary_interleaved = False - elif config.language_model_type == "llama3_8b": + config.ffn_hidden_size = 14336 + elif config.language_model_type == "mistral_7b": config.activation_func = torch.nn.functional.silu config.add_bias_linear = False config.bias_activation_fusion = False @@ -47,7 +33,7 @@ def get_language_model_config(config): config.apply_rope_fusion = False config.attention_softmax_in_fp32 = True config.ffn_hidden_size = 14336 - elif config.language_model_type == "mistral_7b": + elif config.language_model_type == "yi-34b": config.activation_func = torch.nn.functional.silu config.add_bias_linear = False config.bias_activation_fusion = False @@ -59,10 +45,11 @@ def get_language_model_config(config): config.bias_dropout_fusion = False config.apply_rope_fusion = False config.attention_softmax_in_fp32 = True - config.ffn_hidden_size = 14336 - elif config.language_model_type == "yi-34b": + config.ffn_hidden_size = 20480 + elif config.language_model_type == "qwen2.5_7B": config.activation_func = torch.nn.functional.silu config.add_bias_linear = False + config.add_qkv_bias = True config.bias_activation_fusion = False config.gated_linear_unit = True config.apply_query_key_layer_scaling = False @@ -72,7 +59,7 @@ def get_language_model_config(config): config.bias_dropout_fusion = False config.apply_rope_fusion = False config.attention_softmax_in_fp32 = True - config.ffn_hidden_size = 20480 + config.ffn_hidden_size = 18944 elif config.language_model_type == "qwen2.0_72B": config.activation_func = torch.nn.functional.silu config.add_bias_linear = False @@ -168,13 +155,7 @@ def get_vision_projection_config(config, hidden_size): config.bias_activation_fusion = False config.add_bias_linear = False config.hidden_size = hidden_size # Used as the vision projection output size, i.e., the input to the language model. - if config.language_model_type == "2b": - config.ffn_hidden_size = 5440 - config.activation_func = torch.nn.functional.gelu - if config.language_model_type == "8b": - config.ffn_hidden_size = 16384 - config.activation_func = squared_relu - elif config.language_model_type == "llama3_8b": + if config.language_model_type == "llama3_8b": config.ffn_hidden_size = 14336 config.activation_func = torch.nn.functional.gelu elif config.language_model_type == "mistral_7b": @@ -185,6 +166,9 @@ def get_vision_projection_config(config, hidden_size): config.ffn_hidden_size = 20480 config.normalization = "LayerNorm" config.activation_func = torch.nn.functional.gelu + elif config.language_model_type == "qwen2.5_7B": + config.ffn_hidden_size = 3584 + config.activation_func = torch.nn.functional.gelu elif config.language_model_type == "qwen2.0_72B": config.ffn_hidden_size = 29568 config.normalization = "LayerNorm" diff --git a/examples/multimodal/evaluate_ai2d.py b/examples/multimodal/evaluate_ai2d.py index 2d5db67b67..39b866ae4a 100644 --- a/examples/multimodal/evaluate_ai2d.py +++ b/examples/multimodal/evaluate_ai2d.py @@ -9,19 +9,25 @@ def merge_input_files(input_path): """Merge input files to a format compatible with the evaluator.""" input_file_paths, output_file_path = get_input_output_paths(input_path, task="AI2D") - results = [] + results = dict() for input_file_path in input_file_paths: with open(input_file_path, "r") as input_file: for line in input_file: res = json.loads(line) - results.append( - { - "question_id": res["sample_id"], - "answer": res["answer"], - "gt_answer": res["gt_answer"], - } - ) + sample_id = res["sample_id"] + + # Ignore possible duplicates. + if sample_id in results: + continue + + results[sample_id] = { + "question_id": sample_id, + "answer": res["answer"], + "gt_answer": res["gt_answer"], + } + + results = list(results.values()) with open(output_file_path, "w") as output_file: json.dump(results, output_file) diff --git a/examples/multimodal/evaluate_chartqa.py b/examples/multimodal/evaluate_chartqa.py index e9238069d4..53d4944f46 100644 --- a/examples/multimodal/evaluate_chartqa.py +++ b/examples/multimodal/evaluate_chartqa.py @@ -9,15 +9,22 @@ def merge_input_files(input_path): """Merge input files to a format compatible with the evaluator.""" input_file_paths, output_file_path = get_input_output_paths(input_path, task="ChartQA") - results = [] + results = dict() for input_file_path in input_file_paths: with open(input_file_path, "r") as input_file: for line in input_file: res = json.loads(line) - res["question_id"] = res["sample_id"] + sample_id = res["sample_id"] - results.append(res) + # Ignore possible duplicates. + if sample_id in results: + continue + + res["question_id"] = sample_id + results[sample_id] = res + + results = list(results.values()) with open(output_file_path, "w") as output_file: json.dump(results, output_file) diff --git a/examples/multimodal/evaluate_coco.py b/examples/multimodal/evaluate_coco.py index a717090c92..8eeb367e8f 100644 --- a/examples/multimodal/evaluate_coco.py +++ b/examples/multimodal/evaluate_coco.py @@ -11,20 +11,28 @@ def convert_to_coco_format(input_path): """Convert input files to COCO compatible format.""" input_file_paths, output_file_path = get_input_output_paths(input_path, task="captioning") - captions = [] + results = dict() for input_file_path in input_file_paths: with open(input_file_path, "r") as input_file: for line in input_file: res = json.loads(line) + sample_id = res["sample_id"] - question_id = res['sample_id'] - caption = res['caption'].rstrip('.').lower() + # Ignore possible duplicates. + if sample_id in results: + continue - captions.append({"image_id": question_id, "caption": caption}) + caption = res["caption"].rstrip(".").lower() + results[sample_id] = { + "image_id": sample_id, + "caption": caption, + } + + results = list(results.values()) with open(output_file_path, "w") as output_file: - json.dump(captions, output_file, indent=4) + json.dump(results, output_file, indent=4) return output_file_path diff --git a/examples/multimodal/evaluate_mathvista.py b/examples/multimodal/evaluate_mathvista.py index 3474c5f25e..a55f312f21 100644 --- a/examples/multimodal/evaluate_mathvista.py +++ b/examples/multimodal/evaluate_mathvista.py @@ -11,13 +11,21 @@ def merge_input_files(input_path): """Merge input files to a format compatible with the evaluator.""" input_file_paths, output_file_path = get_input_output_paths(input_path, task="MathVista") - results = [] + results = dict() for input_file_path in input_file_paths: with open(input_file_path, "r") as input_file: for line in input_file: res = json.loads(line) - results.append(res) + sample_id = res["sample_id"] + + # Remove possible duplicates. + if sample_id in results: + continue + + results[sample_id] = res + + results = list(results.values()) with open(output_file_path, "w") as output_file: json.dump(results, output_file) diff --git a/examples/multimodal/evaluate_mmmu.py b/examples/multimodal/evaluate_mmmu.py index 66118fa905..22c3921f25 100644 --- a/examples/multimodal/evaluate_mmmu.py +++ b/examples/multimodal/evaluate_mmmu.py @@ -48,6 +48,10 @@ def convert_to_mmmu_format(input_path): ) # MMMU eval script expects just a sample_id to prediction mapping. + # Skip possible duplicates. + if sample_id in output: + continue + output[sample_id] = prediction with open(output_file_path, "w") as output_file: diff --git a/examples/multimodal/evaluate_ocrbench.py b/examples/multimodal/evaluate_ocrbench.py index bc2b901065..b37473a67d 100644 --- a/examples/multimodal/evaluate_ocrbench.py +++ b/examples/multimodal/evaluate_ocrbench.py @@ -8,13 +8,21 @@ def merge_input_files(input_path): """Merge input files to a format compatible with the evaluator.""" input_file_paths, output_file_path = get_input_output_paths(input_path, task="OCRBench") - results = [] + results = dict() for input_file_path in input_file_paths: with open(input_file_path, "r") as input_file: for line in input_file: res = json.loads(line) - results.append(res) + sample_id = res["sample_id"] + + # Remove possible duplicates. + if sample_id in results: + continue + + results[sample_id] = res + + results = list(results.values()) with open(output_file_path, "w") as output_file: json.dump(results, output_file) diff --git a/examples/multimodal/evaluate_textvqa.py b/examples/multimodal/evaluate_textvqa.py index c9bba7134b..af782bdf03 100644 --- a/examples/multimodal/evaluate_textvqa.py +++ b/examples/multimodal/evaluate_textvqa.py @@ -9,22 +9,25 @@ def merge_input_files(input_path): """Merge input files to a format compatible with the evaluator.""" input_file_paths, output_file_path = get_input_output_paths(input_path, task="TextVQA") - results = [] + results = dict() for input_file_path in input_file_paths: with open(input_file_path, "r") as input_file: for line in input_file: res = json.loads(line) - results.append( - { - "question_id": res["sample_id"], - "answer": res["answer"], - "gt_answer": res["gt_answer"], - } - ) - - # Make order deterministic. - # results = sorted(results, key=lambda d: d["question_id"]) + sample_id = res["sample_id"] + + # Remove possible duplicates. + if sample_id in results: + continue + + results[sample_id] = { + "question_id": sample_id, + "answer": res["answer"], + "gt_answer": res["gt_answer"], + } + + results = list(results.values()) with open(output_file_path, "w") as output_file: json.dump(results, output_file) diff --git a/examples/multimodal/evaluate_vqav2.py b/examples/multimodal/evaluate_vqav2.py index 0b1b9209be..7807d80723 100644 --- a/examples/multimodal/evaluate_vqav2.py +++ b/examples/multimodal/evaluate_vqav2.py @@ -9,15 +9,22 @@ def merge_input_files(input_path): """Merge input files to a format compatible with the evaluator.""" input_file_paths, output_file_path = get_input_output_paths(input_path, task="VQAv2") - results = [] + results = dict() for input_file_path in input_file_paths: with open(input_file_path, "r") as input_file: for line in input_file: res = json.loads(line) - res["question_id"] = res["sample_id"] + sample_id = res["sample_id"] - results.append(res) + # Skip possible duplicates. + if sample_id in results: + continue + + res["question_id"] = sample_id + results[sample_id] = res + + results = list(results.values()) with open(output_file_path, "w") as output_file: json.dump(results, output_file) @@ -57,6 +64,9 @@ def compute_vqa_accuracy(result_file, task): assert len(gt) == 1, "expected exactly one groundtruth answer." gt = gt[0] + pred = pred.rstrip("%") + gt = gt.rstrip("%") + if is_number(pred) and is_number(gt): pred = float(pred) gt = float(gt) diff --git a/examples/multimodal/evaluation_datasets.py b/examples/multimodal/evaluation_datasets.py index 97f9ba926f..50a50d5687 100644 --- a/examples/multimodal/evaluation_datasets.py +++ b/examples/multimodal/evaluation_datasets.py @@ -188,7 +188,7 @@ def __init__( use_tiling, max_num_tiles, use_thumbnail, - single_image, + prompt_style, vision_model_type, ): import datasets @@ -246,7 +246,7 @@ def __init__( self._use_tiling = use_tiling self._max_num_tiles = max_num_tiles self._use_thumbnail = use_thumbnail - self._single_image = single_image + self._prompt_style = prompt_style self._vision_model_type = vision_model_type def __len__(self): @@ -258,7 +258,7 @@ def __getitem__(self, idx): sample = self._dataset[idx] # Use the single image approach from the MMMU repo. - if self._single_image: + if self._prompt_style == "single_image": sample = process_single_sample(sample) sample = construct_prompt(sample, self._config) @@ -274,7 +274,69 @@ def __getitem__(self, idx): vision_model_type=self._vision_model_type, ) sample_num_tiles = [len(sample_imgs)] - else: + + prompt = sample["final_input_prompt"] + for i in range(8): + prompt = prompt.replace(f"", "") + sample["final_input_prompt"] = f"\n{prompt}" + elif self._prompt_style == "vlmevalkit": + sample = construct_prompt(sample, self._config) + + if sample["question_type"] == "multiple-choice": + question = sample["question"] + + options = "" + for k, v in sample["index2ans"].items(): + options += f"{k}. {v}\n" + + final_prompt = f"{question}\n" + if "hint" in sample: + final_prompt += f"Hint: {sample['hint']}\n" + + if "task_instructions" in sample: + final_prompt += f"Task instructions: {sample['task_instructions']}\n" + + final_prompt += options + final_prompt += "Answer with the option's letter from the given choices directly." + + sample["final_input_prompt"] = final_prompt.rstrip() + else: + question = sample["question"] + final_prompt = f"{question}\n" + final_prompt += "Answer the question directly." + sample["final_input_prompt"] = final_prompt.rstrip() + + sample_imgs = [] + sample_num_tiles = [] + + img_indices = sorted(list(set(re.findall(r"" + + img = sample[img_key] + assert img is not None, f"{img_str} is in prompt but not in sample images" + + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + adjusted_max_num_tiles, + self._use_thumbnail, + augment=False, + vision_model_type=self._vision_model_type, + ) # List of tiles. + + sample_imgs.extend(imgs) + sample_num_tiles.append(len(imgs)) + + sample["final_input_prompt"] = " ".join([f'' for i in range(len(img_indices))]) + "\n" + sample["final_input_prompt"] + elif self._prompt_style == "multi_image": sample = construct_prompt(sample, self._config) sample_imgs = [] @@ -315,6 +377,8 @@ def __getitem__(self, idx): assert ( f"" not in sample["final_input_prompt"] ), "prompt contains unhandled image tags" + else: + raise ValueError(f"unknown prompt style {self._prompt_style}") # MMMU specific metadata. metadata = {"question_type": sample["question_type"]} @@ -323,10 +387,6 @@ def __getitem__(self, idx): metadata["all_choices"] = sample["all_choices"] prompt = sample['final_input_prompt'] - if self._single_image: - for i in range(8): - prompt = prompt.replace(f"", "") - prompt = f"\n{prompt}" tile_count = torch.tensor(sample_num_tiles, dtype=torch.int) @@ -780,8 +840,10 @@ def get_evaluation_dataset( vision_model_type, ) elif task == 'MMMU': - # Note: single_image=True uses only one image like in the MMMU repo example. - # single_image=False uses all images in the sample. + # Note: + # - prompt_style="single_image" uses only one image like in the MMMU repo example. + # - prompt_style="multi_image" uses multiple input images. + # - prompt_style="vlmevalkit" is similar to https://github.com/open-compass/VLMEvalKit/blob/5d3cebcf18ef4bfbadc3bd3ef80bdc7aad2c6557/vlmeval/vlm/internvl_chat.py#L499 dataset = MMMUDataset( input_image_path, num_samples_per_partition, @@ -792,7 +854,7 @@ def get_evaluation_dataset( use_tiling, max_num_tiles, use_thumbnail, - single_image=True, + prompt_style="single_image", vision_model_type=vision_model_type, ) elif task == "VideoMME": diff --git a/examples/multimodal/model.py b/examples/multimodal/model.py index 6db834e97a..a28a428325 100644 --- a/examples/multimodal/model.py +++ b/examples/multimodal/model.py @@ -136,6 +136,20 @@ def model_provider( else: vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules + # Toggle --recompute* for the vision and language model separately. + if args.recompute_vision: + if vision_config.recompute_method is not None and vision_config.recompute_granularity is not None: + vision_config.recompute_num_layers = vision_config.num_layers + else: + vision_config.recompute_granularity = None + vision_config.recompute_method = None + vision_config.recompute_num_layers = None + + vision_projection_config.recompute_granularity = None + vision_projection_config.recompute_method = None + vision_projection_config.recompute_num_layers = None + + tokenizer = get_tokenizer() image_token_index = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) diff --git a/examples/multimodal/multimodal_args.py b/examples/multimodal/multimodal_args.py index 4b2be450af..eb56118e71 100644 --- a/examples/multimodal/multimodal_args.py +++ b/examples/multimodal/multimodal_args.py @@ -49,7 +49,7 @@ def add_multimodal_extra_args(parser): group.add_argument( "--tokenizer-prompt-format", type=str, - choices=["mistral", "llama3", "chatml", "nvlm-yi-34b", "qwen2p0"], + choices=["mistral", "llama3", "chatml", "nvlm-yi-34b", "qwen2p0", "qwen2p5"], required=True, help="Prompt format to use with the tokenizer.", ) @@ -71,5 +71,9 @@ def add_multimodal_extra_args(parser): group.add_argument( "--packing-seq-length", type=int, default=0, help="Packing sequence length. Must be > 0 if using packing." ) + group.add_argument( + "--recompute-vision", action="store_true", default=False, help="Enable activation checkpointing in the vision model" + ) + return parser diff --git a/examples/multimodal/nvlm/run_text_generation_qwen25_7b_siglip.sh b/examples/multimodal/nvlm/run_text_generation_qwen25_7b_siglip.sh new file mode 100755 index 0000000000..3b6221996c --- /dev/null +++ b/examples/multimodal/nvlm/run_text_generation_qwen25_7b_siglip.sh @@ -0,0 +1,111 @@ +#!/bin/bash + +export NCCL_IB_SL=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NVTE_APPLY_QK_LAYER_SCALING=0 +export TOKENIZERS_PARALLELISM="false" + +INPUT_IMAGE_PATH="placeholder" +GROUNDTRUTH_PATH="placeholder" + +while [[ $# -gt 0 ]]; do + case $1 in + -i|--input-image-path) + INPUT_IMAGE_PATH="$2" + shift + shift + ;; + -o|--output-path) + OUTPUT_PATH="$2" + shift + shift + ;; + -m|--model-path) + MODEL_PATH="$2" + shift + shift + ;; + -t|--task) + TASK="$2" + shift + shift + ;; + -g|--gt-path) + GROUNDTRUTH_PATH="$2" + shift + shift + ;; + -*|--*) + echo "Invalid option $1" + exit 1 + ;; + esac +done + +# Please modify these as needed. +NUM_PARTITIONS=0 +START=0 +END=0 + + +SEQ_LEN=256 +DECODER_SEQ_LEN=8192 +EXTRA_ARGS=" --pixel-shuffle --use-tiling --max-num-tiles 12 --use-thumbnail" + +for PARTITION_ID in $( eval echo {$START..$END} ) +do + torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ + --attention-softmax-in-fp32 \ + --transformer-impl transformer_engine \ + --use-te \ + --use-checkpoint-args \ + --normalization RMSNorm \ + --norm-epsilon 1e-06 \ + --language-model-type=qwen2.5_7B \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 1 \ + --group-query-attention \ + --num-query-groups 4 \ + --num-layers 28 \ + --hidden-size 3584 \ + --ffn-hidden-size 18944 \ + --add-qkv-bias \ + --num-attention-heads 28 \ + --max-position-embeddings 32768 \ + --no-masked-softmax-fusion \ + --load ${MODEL_PATH} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model Qwen/Qwen2.5-7B-Instruct \ + --tokenizer-prompt-format qwen2p5 \ + --bf16 \ + --micro-batch-size 1 \ + --seq-length ${SEQ_LEN} \ + --decoder-seq-length ${DECODER_SEQ_LEN} \ + --out-seq-length 128 \ + --temperature 1.0 \ + --img-h 448 \ + --img-w 448 \ + --patch-dim 14 \ + --seed 153 \ + --top_k 1 \ + --no-load-rng \ + --no-load-optim \ + --input-image-path ${INPUT_IMAGE_PATH} \ + --num-partitions ${NUM_PARTITIONS} \ + --partition-id ${PARTITION_ID} \ + --output-path ${OUTPUT_PATH} \ + --gt-path ${GROUNDTRUTH_PATH} \ + --task ${TASK} \ + ${EXTRA_ARGS} \ + --special-tokens "" "" "" \ + --vision-model-type siglip \ + --ckpt-format torch +done diff --git a/examples/multimodal/run_text_generation.py b/examples/multimodal/run_text_generation.py index f4bb5025ff..5b8622c643 100644 --- a/examples/multimodal/run_text_generation.py +++ b/examples/multimodal/run_text_generation.py @@ -19,6 +19,8 @@ from multimodal_args import add_multimodal_extra_args from megatron.core import parallel_state +from megatron.core.enums import ModelType +from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings from megatron.inference.text_generation.api import generate_and_post_process from megatron.inference.text_generation.forward_step import ForwardStep @@ -36,7 +38,7 @@ def add_text_generation_args(parser): group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.') group.add_argument("--top_k", type=int, default=0, help='Top k sampling.') group.add_argument( - "--out-seq-length", type=int, default=1024, help='Length of the output generated text.' + "--out-seq-length", type=int, default=128, help='Length of the output generated text.' ) group.add_argument("--output-path", type=str, help='Output file path') group.add_argument('--input-image-path', type=str, help="Input image directory") @@ -206,8 +208,8 @@ def generate_samples(model, config: EvaluationConfig, print_output): if config.task == "VideoMME": output["questions"][0][output_name] = generated else: - output[output_name] = generated output["prompt"] = prompt + output[output_name] = generated if config.task == "captioning": output["ground_truth"] = answers @@ -354,7 +356,7 @@ def _forward(self, tokens, position_ids, attention_mask): ) def __call__(self, tokens, position_ids, attention_mask): - num_image_tokens = (tokens == self.model.image_token_index).sum().item() + num_image_tokens = (tokens == self.model.module.image_token_index).sum().item() num_tokens = tokens.size(1) recv_buffer_seq_length = None if num_image_tokens > 0: @@ -406,7 +408,7 @@ def get_conversation(task, question): {"role": "system", "content": "Answer the questions."}, { "role": "user", - "content": "\nProvide a one-sentence caption for provided image.", + "content": f"{IMAGE_TOKEN}\nProvide a one-sentence caption for provided image.", }, ] elif task in ("TextVQA", "VQAv2", "ChartQA"): @@ -414,13 +416,13 @@ def get_conversation(task, question): {"role": "system", "content": "Answer the questions."}, { "role": "user", - "content": f"\n{question}\nAnswer the question using a single word or phrase.", + "content": f"{IMAGE_TOKEN}\n{question}\nAnswer the question using a single word or phrase.", }, ] elif task in ("OCRBench", "MathVista", "AI2D"): conversation = [ {"role": "system", "content": "Answer the questions."}, - {"role": "user", "content": f"\n{question}"}, + {"role": "user", "content": f"{IMAGE_TOKEN}\n{question}"}, ] elif task == "MMMU": conversation = [ @@ -441,7 +443,7 @@ def get_conversation(task, question): conversation = [ {"role": "system", "content": "Answer the questions."}, - {"role": "user", "content": f"\n{question}"}, + {"role": "user", "content": f"{IMAGE_TOKEN}\n{question}"}, ] return conversation @@ -464,11 +466,13 @@ def get_prompt_and_generated(prompt_and_generation, prompt_format): prompt = splitted[0] generated = splitted[1] generated = generated.split("<|im_end|>")[0] - elif prompt_format in ("nvlm-yi-34b", "qwen2p0"): + elif prompt_format in ("nvlm-yi-34b", "qwen2p0", "qwen2p5"): splitted = prompt_and_generation.split("<|im_start|>assistant\n") prompt = splitted[0] generated = splitted[1] generated = generated.split("<|im_end|>")[0] + else: + raise ValueError(f"Prompt format {prompt_format} is not supported.") # Remove possible garbage. generated = generated.strip() @@ -489,11 +493,11 @@ def main(): args = get_args() - def wrapped_model_provider(pre_process, post_process): - return model_provider(pre_process, post_process, parallel_output=False) + def wrapped_model_provider(pre_process, post_process, add_encoder, add_decoder): + return model_provider(pre_process, post_process, add_encoder, add_decoder, parallel_output=False) # Set up model and load checkpoint. - model = get_model(wrapped_model_provider, wrap_with_ddp=False) + model = get_model(wrapped_model_provider, model_type=ModelType.encoder_and_decoder, wrap_with_ddp=False) if args.load is not None: _ = load_checkpoint(model, None, None) diff --git a/megatron/training/tokenizer/multimodal_tokenizer.py b/megatron/training/tokenizer/multimodal_tokenizer.py index c5ea95c069..605f36f52a 100644 --- a/megatron/training/tokenizer/multimodal_tokenizer.py +++ b/megatron/training/tokenizer/multimodal_tokenizer.py @@ -121,7 +121,7 @@ def __init__( has_bos=False, has_system_role=True, ) - elif prompt_format == "qwen2p0": + elif prompt_format in ("qwen2p0", "qwen2p5"): # "<|im_start|>assistant\n" is the prefix for assistant messages self._prompt_config = PromptConfig( assistant_prefix_len=3,