diff --git a/ds_config_zero2.json b/ds_config_zero2.json new file mode 100644 index 0000000..4be2c0b --- /dev/null +++ b/ds_config_zero2.json @@ -0,0 +1,52 @@ +{ + "fp16": { + "enabled": "auto", + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "bf16": { + "enabled": "auto" + }, + "optimizer": { + "type": "AdamW", + "params": { + "lr": "auto", + "betas": "auto", + "eps": "auto", + "weight_decay": "auto" + } + }, + + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": "auto", + "warmup_max_lr": "auto", + "warmup_num_steps": "auto" + } + }, + + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "none", + "pin_memory": true + }, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 100, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/finetune_ds.sh b/finetune_ds.sh new file mode 100644 index 0000000..f2e9fd4 --- /dev/null +++ b/finetune_ds.sh @@ -0,0 +1,48 @@ +#!/bin/bash +export CUDA_DEVICE_MAX_CONNECTIONS=1 +DIR=`pwd` + +GPUS_PER_NODE=8 +NNODES=1 +NODE_RANK=0 +MASTER_ADDR=localhost +MASTER_PORT=6001 + +SAVE_PATH=/path/to/experiments/MiniCPM-V-FFT +BASE_MODEL=/path/to/pretrained_model/MiniCPM-V +TRAIN_DATASET=/path/to/train.json +VAL_DATASET=/path/to/test.json + + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +torchrun $DISTRIBUTED_ARGS finetune_minicpmv.py \ + --model_name_or_path $BASE_MODEL \ + --data_path $TRAIN_DATASET \ + --bf16 True \ + --fix_vit True \ + --output_dir $SAVE_PATH \ + --num_train_epochs 2 \ + --per_device_train_batch_size 12 \ + --per_device_eval_batch_size 8 \ + --gradient_accumulation_steps 4 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 100 \ + --save_total_limit 2 \ + --learning_rate 1e-5 \ + --weight_decay 0.1 \ + --adam_beta2 0.95 \ + --warmup_ratio 0.01 \ + --lr_scheduler_type "cosine" \ + --logging_steps 2 \ + --report_to "tensorboard" \ + --model_max_length 512 \ + --gradient_checkpointing True \ + --deepspeed ds_config_zero2.json diff --git a/finetune_minicpmv.py b/finetune_minicpmv.py new file mode 100644 index 0000000..2a826d2 --- /dev/null +++ b/finetune_minicpmv.py @@ -0,0 +1,477 @@ +# This code is based on the revised code from fastchat based on tatsu-lab/stanford_alpaca. + + +from dataclasses import dataclass, field +import json +import math +import logging +import os +from typing import Dict, Optional, List +import torch +from torch.utils.data import Dataset +from deepspeed import zero +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +import transformers +from transformers import Trainer, GPTQConfig, deepspeed +from transformers.trainer_pt_utils import LabelSmoother +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from accelerate.utils import DistributedType +from io import BytesIO +from PIL import Image +import requests +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, NewType +import copy +from collections.abc import Mapping +import numpy as np +from minicpmv.model.modeling_minicpmv import MiniCPMV +from minicpmv.model.configuration_minicpm import MiniCPMVConfig + +InputDataClass = NewType("InputDataClass", Any) + +def _read_from_path( + img_path: Union[str, 'PIL.Image.Image']) -> 'PIL.Image.Image': + if isinstance(img_path, str): + img_path = img_path.strip() + if img_path.startswith('http'): + content = requests.get(img_path).content + image = Image.open(BytesIO(content)) + else: + image = Image.open(img_path) + else: + image = img_path + if image.mode in {'L', 'RGBA'}: + image = image.convert('RGB') + return image + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="Qwen/Qwen-7B") + + +@dataclass +class DataArguments: + data_path: str = field( + default=None, metadata={"help": "Path to the training data."} + ) + eval_data_path: str = field( + default=None, metadata={"help": "Path to the evaluation data."} + ) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=512, + metadata={ + "help": "Maximum sequence length. Sequences will be left padded (and possibly truncated)." + }, + ) + use_lora: bool = False + fix_vit: bool = True + + +@dataclass +class LoraArguments: + lora_r: int = 64 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_target_modules: List[str] = field( + default_factory=lambda: ["c_attn", "attn.c_proj", "w1", "w2"] ##["in_proj","out_proj","c_fc"] + ) + lora_weight_path: str = "" + lora_bias: str = "none" + q_lora: bool = False + + +def maybe_zero_3(param): + if hasattr(param, "ds_id"): + assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE + with zero.GatheredParameters([param]): + param = param.data.detach().cpu().clone() + else: + param = param.detach().cpu().clone() + return param + + +# Borrowed from peft.utils.get_peft_model_state_dict +def get_peft_state_maybe_zero_3(named_params, bias): + if bias == "none": + to_return = {k: t for k, t in named_params if "lora_" in k} + elif bias == "all": + to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = {} + maybe_lora_bias = {} + lora_bias_names = set() + for k, t in named_params: + if "lora_" in k: + to_return[k] = t + bias_name = k.split("lora_")[0] + "bias" + lora_bias_names.add(bias_name) + elif "bias" in k: + maybe_lora_bias[k] = t + for k, t in maybe_lora_bias: + if bias_name in lora_bias_names: + to_return[bias_name] = t + else: + raise NotImplementedError + to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} + return to_return + +local_rank = None + +def rank0_print(*args): + if local_rank == 0: + print(*args) + + +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str, bias="none"): + """Collects the state dict and dump to disk.""" + # check if zero3 mode enabled + if deepspeed.is_deepspeed_zero3_enabled(): + state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() + else: + if trainer.args.use_lora: + state_dict = get_peft_state_maybe_zero_3( + trainer.model.named_parameters(), bias + ) + else: + state_dict = trainer.model.state_dict() + if trainer.args.should_save and trainer.args.local_rank == 0: + trainer._save(output_dir, state_dict=state_dict) + + +def expand_question_into_multimodal(question_text, image_token_len, im_st_token, im_ed_token, im_patch_token): + if '' in question_text[0]['content']: + question_text[0]['content'] = question_text[0]['content'].replace( + '', im_st_token + im_patch_token * image_token_len + im_ed_token) + else: + question_text[0]['content'] = im_st_token + im_patch_token * \ + image_token_len + im_ed_token + '\n' + question_text[0]['content'] + return question_text + + +def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"): + items = [] + if isinstance(orig_items[0][key], list): + # batch example + assert isinstance(orig_items[0][key][0], torch.Tensor) + for it in orig_items: + for tr in it[key]: + items.append({key: tr}) + else: + # single example + assert isinstance(orig_items[0][key], torch.Tensor) + items = orig_items + + batch_size = len(items) + shape = items[0][key].shape + dim = len(shape) + assert dim <= 3 + if max_length is None: + max_length = 0 + max_length = max(max_length, max(item[key].shape[-1] for item in items)) + min_length = min(item[key].shape[-1] for item in items) + dtype = items[0][key].dtype + + if dim == 1: + return torch.cat([item[key] for item in items], dim=0) + elif dim == 2: + if max_length == min_length: + return torch.cat([item[key] for item in items], dim=0) + tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value + else: + tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value + + for i, item in enumerate(items): + if dim == 2: + if padding_side == "left": + tensor[i, -len(item[key][0]):] = item[key][0].clone() + else: + tensor[i, : len(item[key][0])] = item[key][0].clone() + elif dim == 3: + if padding_side == "left": + tensor[i, -len(item[key][0]):, :] = item[key][0].clone() + else: + tensor[i, : len(item[key][0]), :] = item[key][0].clone() + + return tensor + + +def preprocess( + example, + tokenizer: transformers.PreTrainedTokenizer, + max_len:int, + image_token_len: int = 64, +) -> Dict: + """ + NOTE: support only one image + example:{"image": "path or url", "conversations":[...]} + """ + pad_token_id = 0 + image_path = example['image'] + # image = _read_from_path(image_path) + + prompt = '' + for i, msg in enumerate(example['conversations'][:-1]): + role = msg['role'] + content = msg['content'] + assert role in ['user', 'assistant'] + if i == 0: + assert role == 'user', 'The role of first msg should be user' + content = tokenizer.im_start + tokenizer.unk_token * image_token_len + tokenizer.im_end + '\n' + content + prompt += '<用户>' if role=='user' else '' + prompt += content + + assert example['conversations'][-1]['role'] == 'assistant' + prompt += '' + + exact_input_len = len(tokenizer.encode(prompt)) + prompt += example['conversations'][-1]['content'] + + # full_text = prompt + tokenizer.eos + if tokenizer.add_bos_token: + input_ids = tokenizer.encode(prompt) + else: + exact_input_len += 1 + input_ids = [tokenizer.bos_id] + tokenizer.encode(prompt) + input_ids = input_ids + [tokenizer.eos_id] + + + cur_example_len = len(input_ids) + # print(f"input_ids: {input_ids}, cur_example_len: {cur_example_len}") + + attention_mask = torch.ones(max_len) + if cur_example_len <= max_len: + # do left padding + padded_len = max_len - len(input_ids) + attention_mask[:padded_len] = 0 + padded_input_ids = [pad_token_id] * padded_len + input_ids + targets = copy.deepcopy(padded_input_ids) + targets[:exact_input_len] = [-100] * exact_input_len + else: + logging.warning( + f"Found example that exceed the max input length of model, ignore the whole label loss: {example}" + ) + padded_input_ids = input_ids[:max_len] + targets = [-100] * max_len + + image_start_tokens = padded_input_ids.index(tokenizer.im_start_id) + # 跳过 im_start + image_start_tokens += 1 + image_end_tokens = padded_input_ids.index(tokenizer.im_end_id) + image_bound = torch.LongTensor([[image_start_tokens, image_end_tokens]]) + + padded_input_ids = torch.LongTensor(padded_input_ids) + targets = torch.LongTensor(targets) + + return dict( + input_ids=padded_input_ids, + label_ids=targets, + attention_mask=attention_mask, + image_bound=image_bound, + image_paths=image_path + ) + + +class LazySupervisedDataset(Dataset): + """Dataset for supervised fine-tuning.""" + + def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, max_len: int, image_token_len: int=64): + super(LazySupervisedDataset, self).__init__() + self.tokenizer = tokenizer + self.max_len = max_len + self.image_token_len = image_token_len + + rank0_print("Formatting inputs...Skip in lazy mode") + self.tokenizer = tokenizer + self.raw_data = raw_data + self.cached_data_dict = {} + + def __len__(self): + return len(self.raw_data) + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + if i in self.cached_data_dict: + return self.cached_data_dict[i] + + ret = preprocess(self.raw_data[i], self.tokenizer, self.max_len, self.image_token_len) + self.cached_data_dict[i] = ret + + return ret + + +def make_supervised_data_module( + tokenizer: transformers.PreTrainedTokenizer, data_args, max_len +) -> Dict: + """Make dataset and collator for supervised fine-tuning.""" + + rank0_print("Loading data...") + train_json = json.load(open(data_args.data_path, "r")) + train_dataset = LazySupervisedDataset(train_json, tokenizer=tokenizer, max_len=max_len) + + if data_args.eval_data_path: + eval_json = json.load(open(data_args.eval_data_path, "r")) + eval_dataset = LazySupervisedDataset(eval_json, tokenizer=tokenizer, max_len=max_len) + else: + eval_dataset = None + + global local_rank + if local_rank == 0: + print(train_json[0]) + print(train_dataset[0]) + + return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) + + +def minicpmv_data_collator(features: List[InputDataClass]) -> Dict[str, Any]: + import torch + + if not isinstance(features[0], Mapping): + features = [vars(f) for f in features] + first = features[0] + batch = {} + + # Special handling for labels. + # Ensure that tensor is created with the correct type + # (it should be automatically the case, but let's make sure of it.) + if "label" in first and first["label"] is not None: + label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"] + dtype = torch.long if isinstance(label, int) else torch.float + batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype) + elif "label_ids" in first and first["label_ids"] is not None: + if isinstance(first["label_ids"], torch.Tensor): + batch["labels"] = torch.stack([f["label_ids"] for f in features]) + else: + dtype = torch.long if isinstance(first["label_ids"][0], int) else torch.float + batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype) + + # Handling of all other possible keys. + # Again, we will use the first element to figure out which key/values are not None for this model. + for k, v in first.items(): + if k not in ("label", "label_ids") and v is not None and not isinstance(v, str): + if isinstance(v, torch.Tensor): + batch[k] = torch.stack([f[k] for f in features]) + elif isinstance(v, np.ndarray): + batch[k] = torch.tensor(np.stack([f[k] for f in features])) + else: + batch[k] = torch.tensor([f[k] for f in features]) + elif isinstance(v, str): + batch[k] = [f[k] for f in features] + return batch + + +def train(): + global local_rank + + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments, LoraArguments) + ) + ( + model_args, + data_args, + training_args, + lora_args, + ) = parser.parse_args_into_dataclasses() + + if getattr(training_args, 'deepspeed', None) and getattr(lora_args, 'q_lora', False): + training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED + + compute_dtype = ( + torch.float16 + if training_args.fp16 + else (torch.bfloat16 if training_args.bf16 else torch.float32) + ) + + local_rank = training_args.local_rank + + device_map = None + world_size = int(os.environ.get("WORLD_SIZE", 1)) + ddp = world_size != 1 + if lora_args.q_lora: + device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None + if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): + logging.warning( + "FSDP or ZeRO3 are not incompatible with QLoRA." + ) + + config = MiniCPMVConfig.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + trust_remote_code=True, + ) + config.use_cache = False + + # Load model and tokenizer + model = MiniCPMV.from_pretrained( + model_args.model_name_or_path, + config=config, + cache_dir=training_args.cache_dir, + device_map=device_map, + trust_remote_code=True, + quantization_config=GPTQConfig( + bits=4, disable_exllama=True + ) + if training_args.use_lora and lora_args.q_lora + else None, + ) + + if not training_args.use_lora: + if training_args.fix_vit and hasattr(model,'transformer') and hasattr(model.transformer,'visual'): + model.transformer.visual.requires_grad_(False) + if hasattr(model.transformer.visual,'attn_pool'): + model.transformer.visual.attn_pool.requires_grad_(True) + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.model_max_length, + padding_side="left", + use_fast=False, + trust_remote_code=True, + ) + + if training_args.use_lora: + if lora_args.q_lora or "chat" in model_args.model_name_or_path.lower(): + modules_to_save = None + else: + modules_to_save = ["wte", "lm_head"] + lora_config = LoraConfig( + r=lora_args.lora_r, + lora_alpha=lora_args.lora_alpha, + target_modules=lora_args.lora_target_modules, + lora_dropout=lora_args.lora_dropout, + bias=lora_args.lora_bias, + task_type="CAUSAL_LM", + modules_to_save=modules_to_save # This argument serves for adding new tokens. + ) + if lora_args.q_lora: + model = prepare_model_for_kbit_training( + model, use_gradient_checkpointing=training_args.gradient_checkpointing + ) + + model = get_peft_model(model, lora_config) + + if training_args.gradient_checkpointing: + model.enable_input_require_grads() + + # Load data + data_module = make_supervised_data_module( + tokenizer=tokenizer, data_args=data_args, max_len=training_args.model_max_length + ) + + # Start trainner + trainer = Trainer( + model=model, args=training_args, data_collator=minicpmv_data_collator, **data_module + ) + + trainer.train() + trainer.save_state() + + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias) + + +if __name__ == "__main__": + train() diff --git a/minicpmv/model/configuration_minicpm.py b/minicpmv/model/configuration_minicpm.py new file mode 100644 index 0000000..564f5cf --- /dev/null +++ b/minicpmv/model/configuration_minicpm.py @@ -0,0 +1,216 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MiniCPM model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +MINICPM_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class MiniCPMConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MiniCPMModel`]. It is used to instantiate an MiniCPM + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the MiniCPM-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the MiniCPM model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MiniCPMModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. MiniCPM 1 supports up to 2048 tokens, + MiniCPM 2 up to 4096, CodeMiniCPM up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalMiniCPM/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import MiniCPMModel, MiniCPMConfig + + >>> # Initializing a MiniCPM minicpm-7b style configuration + >>> configuration = MiniCPMConfig() + + >>> # Initializing a model from the minicpm-7b style configuration + >>> model = MiniCPMModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "minicpm" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + scale_emb=1, + dim_model_base=1, + scale_depth=1, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.scale_emb = scale_emb + self.dim_model_base = dim_model_base + self.scale_depth = scale_depth + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + + +class MiniCPMVConfig(MiniCPMConfig): + model_type = "minicpmv" + keys_to_ignore_at_inference = ["past_key_values"] + def __init__( + self, + vision_encoder="vit_so400m_patch14_siglip_384.webli", + query_num=64, + image_size=448, + drop_vision_last_layer=True, + **kwargs + ): + self.vision_encoder = vision_encoder + self.query_num = query_num + self.image_size=image_size + self.drop_vision_last_layer = drop_vision_last_layer + super().__init__(**kwargs) diff --git a/minicpmv/model/modeling_minicpm.py b/minicpmv/model/modeling_minicpm.py new file mode 100644 index 0000000..5a61af8 --- /dev/null +++ b/minicpmv/model/modeling_minicpm.py @@ -0,0 +1,1454 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch MiniCPM model.""" +import math +import warnings +from typing import List, Optional, Tuple, Union, Dict + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from .configuration_minicpm import MiniCPMConfig +import re + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +except: + pass + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MiniCPMConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + warnings.warn( + "Calling `transformers.models.minicpm.modeling_minicpm._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask" + ) + return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + warnings.warn( + "Calling `transformers.models.minicpm.modeling_minicpm._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.minicpm.modeling_minicpm.AttentionMaskConverter._make_causal_mask" + ) + return AttentionMaskConverter._make_causal_mask( + input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length + ) + +# @torch.jit.script # type: ignore +def rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float): + old_dtype = hidden.dtype + variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True) + hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype) + return hidden * weight + + +class MiniCPMRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MiniCPMRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + return rms_layernorm(hidden_states, self.weight, self.variance_epsilon) + + +ALL_LAYERNORM_LAYERS.append(MiniCPMRMSNorm) + + +class MiniCPMRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + # seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32 + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class MiniCPMLinearScalingRotaryEmbedding(MiniCPMRotaryEmbedding): + """MiniCPMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class MiniCPMDynamicNTKScalingRotaryEmbedding(MiniCPMRotaryEmbedding): + """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + # cos = cos[position_ids].unsqueeze(unsqueeze_dim) + # sin = sin[position_ids].unsqueeze(unsqueeze_dim) + # q_embed = (q * cos) + (rotate_half(q) * sin) + # k_embed = (k * cos) + (rotate_half(k) * sin) + orig_dtype = k.dtype + cos = cos[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(unsqueeze_dim) # [bs, 1, seq_len, dim] + q_fp32 = q.to(dtype=torch.float32, device=q.device) + k_fp32 = k.to(dtype=torch.float32, device=k.device) + q_embed = (q_fp32 * cos) + (rotate_half(q_fp32) * sin) + k_embed = (k_fp32 * cos) + (rotate_half(k_fp32) * sin) + return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype) + +class MiniCPMMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + + +class MiniCPMAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MiniCPMConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = MiniCPMRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = MiniCPMLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = MiniCPMDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MiniCPMFlashAttention2(MiniCPMAttention): + """ + MiniCPM flash attention module. This module inherits from `MiniCPMAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # MiniCPMFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (MiniCPMRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class MiniCPMSdpaAttention(MiniCPMAttention): + """ + MiniCPM attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MiniCPMAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MiniCPMAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MiniCPMModel is using MiniCPMSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MINICPM_ATTENTION_CLASSES = { + "eager": MiniCPMAttention, + "flash_attention_2": MiniCPMFlashAttention2, + "sdpa": MiniCPMSdpaAttention, +} + + +class MiniCPMDecoderLayer(nn.Module): + def __init__(self, config: MiniCPMConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MINICPM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = MiniCPMMLP(config) + self.input_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.scale_depth = config.scale_depth + self.num_hidden_layers = config.num_hidden_layers + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers)) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers)) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +MINICPM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MiniCPMConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare MiniCPM Model outputting raw hidden-states without any specific head on top.", + MINICPM_START_DOCSTRING, +) +class MiniCPMPreTrainedModel(PreTrainedModel): + config_class = MiniCPMConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MiniCPMDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MINICPM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare MiniCPM Model outputting raw hidden-states without any specific head on top.", + MINICPM_START_DOCSTRING, +) +class MiniCPMModel(MiniCPMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MiniCPMDecoderLayer`] + + Args: + config: MiniCPMConfig + """ + + def __init__(self, config: MiniCPMConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MiniCPMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._use_sdpa = config._attn_implementation == "sdpa" + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + self.norm = MiniCPMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.config.scale_emb + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class MiniCPMForCausalLM(MiniCPMPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MiniCPMModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MiniCPMForCausalLM + + >>> model = MiniCPMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states / (self.config.hidden_size / self.config.dim_model_base)) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + @torch.inference_mode() + def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", + max_length: int = 4096, num_beams=1, do_sample=True, top_p=0.8, temperature=0.3, logits_processor=None, + **kwargs): + if history is None: + history = [] + if logits_processor: + gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + else: + gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + + history.append({"role": role, "content": query}) + history_str = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=False) + inputs = tokenizer(history_str, return_tensors='pt').to(self.device) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] + response = tokenizer.decode(outputs) + pattern = re.compile(r".*?(?=|<用户>)", re.DOTALL) + matches = pattern.findall(response) + if len(matches) > 0: + response = matches[0] + history.append({"role": "assistant", "content": response}) + return response, history + + +@add_start_docstrings( + """ + The MiniCPM Model transformer with a sequence classification head on top (linear layer). + + [`MiniCPMForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MINICPM_START_DOCSTRING, +) +class MiniCPMForSequenceClassification(MiniCPMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MiniCPMModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MINICPM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/minicpmv/model/modeling_minicpmv.py b/minicpmv/model/modeling_minicpmv.py new file mode 100644 index 0000000..673a8a5 --- /dev/null +++ b/minicpmv/model/modeling_minicpmv.py @@ -0,0 +1,422 @@ +import math +from typing import List, Optional +import json +import timm +import torch +import torchvision +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from torchvision import transforms +from transformers import LlamaTokenizer +from .configuration_minicpm import MiniCPMVConfig +from .modeling_minicpm import MiniCPMPreTrainedModel, MiniCPMForCausalLM +from .resampler import Resampler +from PIL import Image +import requests + + + +class MiniCPMVPreTrainedModel(MiniCPMPreTrainedModel): + config_class = MiniCPMVConfig + + +class MiniCPMV(MiniCPMVPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.llm = MiniCPMForCausalLM(config) + self.vpm = self.init_vision_module() + self.vision_dim = self.vpm.embed_dim + self.embed_dim = self.llm.config.hidden_size + self.resampler = self.init_resampler(self.embed_dim ,self.vision_dim) + self.transform = self.init_transform() + + + def init_vision_module(self): + model = timm.create_model( + self.config.vision_encoder, + pretrained=False, + num_classes=0, + dynamic_img_size=True, + dynamic_img_pad=True + ) + + if isinstance(model, timm.models.VisionTransformer): + if model.attn_pool is not None: + model.attn_pool = torch.nn.Identity() + + if self.config.drop_vision_last_layer: + model.blocks = model.blocks[:-1] + + return model + + def init_resampler(self, embed_dim, vision_dim): + return Resampler( + grid_size=int(math.sqrt(self.config.query_num)), + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + ) + + def init_transform(self): + return transforms.Compose([ + transforms.Resize( + (self.config.image_size, self.config.image_size), + interpolation=torchvision.transforms.InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD) + ]) + + def get_vision_embedding(self, pixel_values): + res = [] + dtype = self.vpm.pos_embed.data.dtype + for pixel_value in pixel_values: + vision_embedding = self.vpm.forward_features(pixel_value.unsqueeze(0).type(dtype)) + if hasattr(self.vpm, 'num_prefix_tokens') and self.vpm.num_prefix_tokens > 0: + vision_embedding = vision_embedding[:, self.vpm.num_prefix_tokens:] + res.append(self.resampler(vision_embedding)) + return torch.vstack(res) + + def get_vllm_embedding(self, data): + if 'vision_hidden_states' not in data: + pixel_values_list = data['pixel_values'] + vision_hidden_states = [] + for pixel_values in pixel_values_list: + if len(pixel_values) > 0: + vision_hidden_states.append(self.get_vision_embedding(pixel_values)) + elif self.training: + dtype = self.vpm.pos_embed.data.dtype + device = self.vpm.pos_embed.data.device + dummy_image = torch.zeros( + (1, 3, 224, 224), + device=device, dtype=dtype + ) + vision_hidden_states.append(self.get_vision_embedding(dummy_image)) + else: + vision_hidden_states.append([]) + + else: + vision_hidden_states = data['vision_hidden_states'] + + vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb + vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance( + i, torch.Tensor) else i for i in vision_hidden_states] + + bs = len(data['input_ids']) + for i in range(bs): + cur_vs_hs = vision_hidden_states[i] + if len(cur_vs_hs) > 0: + cur_vllm_emb = vllm_embedding[i] + cur_image_bound = data['image_bound'][i] + if len(cur_image_bound) > 0: + image_indices = torch.stack( + [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound] + ).to(vllm_embedding.device) + + cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), + cur_vs_hs.view(-1, cur_vs_hs.shape[-1])) + elif self.training: + cur_vllm_emb += cur_vs_hs[0].mean() * 0 + + return vllm_embedding, vision_hidden_states + + def forward(self, + input_ids: torch.LongTensor = None, + image_paths: Optional[List] = None, + image_bound: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None): + pixel_values_list = [] + for image_path in image_paths: + if image_path.startswith("http://") or image_path.startswith("https://"): + image = Image.open(requests.get(image_path, stream=True).raw) + else: + image = Image.open(image_path) + image = image.convert("RGB") + pixel_values = self.transform(image).to(self.device) + # print('single pixel_values',pixel_values, pixel_values.size()) + pixel_values_list.append(pixel_values) + + vision_hidden_states = [] + for pixel_values in pixel_values_list: + # (3, 448, 448) + if len(pixel_values) > 0: + vision_hidden_states.append(self.get_vision_embedding(pixel_values.unsqueeze(0))) + elif self.training: + dtype = self.vpm.pos_embed.data.dtype + device = self.vpm.pos_embed.data.device + dummy_image = torch.zeros( + (1, 3, 224, 224), + device=device, dtype=dtype + ) + vision_hidden_states.append(self.get_vision_embedding(dummy_image)) + else: + vision_hidden_states.append([]) + + vllm_embedding = self.llm.model.embed_tokens(input_ids) * self.llm.config.scale_emb + vision_hidden_states = [i.type(vllm_embedding.dtype) if isinstance( + i, torch.Tensor) else i for i in vision_hidden_states] + + bs = len(input_ids) + for i in range(bs): + cur_vs_hs = vision_hidden_states[i] + if len(cur_vs_hs) > 0: + cur_vllm_emb = vllm_embedding[i] + cur_image_bound = image_bound[i] + if len(cur_image_bound) > 0: + image_indices = torch.stack( + [torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound] + ).to(vllm_embedding.device) + + cur_vllm_emb.scatter_(0, image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]), + cur_vs_hs.view(-1, cur_vs_hs.shape[-1])) + elif self.training: + cur_vllm_emb += cur_vs_hs[0].mean() * 0 + + return self.llm( + input_ids=None, + inputs_embeds=vllm_embedding, + labels=labels, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict + ) + + + def _convert_to_tensors(self, tokenizer, input_str, max_inp_length: Optional[int] = None): + if tokenizer.add_bos_token: + input_ids = tokenizer.encode(input_str) + else: + input_ids = [tokenizer.bos_id] + tokenizer.encode(input_str) + if max_inp_length is not None: + input_ids = input_ids[: max_inp_length] + input_ids = torch.tensor(input_ids, dtype=torch.int32) + + image_start_tokens = torch.where(input_ids == tokenizer.im_start_id)[0] + # 跳过 im_start + image_start_tokens += 1 + image_end_tokens = torch.where(input_ids == tokenizer.im_end_id)[0] + valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) + image_bound = torch.hstack( + [image_start_tokens[: valid_image_nums].unsqueeze(-1), + image_end_tokens[:valid_image_nums].unsqueeze(-1)] + ) + + model_input = {} + model_input["input_ids"] = input_ids.unsqueeze(0).to(self.device) + model_input["image_bound"] = image_bound + + return model_input + + + def _process_list(self, tokenizer, data_list: List[str], max_inp_length: Optional[int] = None): + pad_keys = ['input_ids'] + input_tensors = [] + for data in data_list: + input_tensors.append(self._convert_to_tensors(tokenizer, data, max_inp_length)) + padded = {} + for key in pad_keys: + padded[key] = pad(input_tensors, key, padding_side="left").to(self.device) + padded['image_bound'] = [i['image_bound'] for i in input_tensors] + return padded + + def _decode(self, inputs_embeds, tokenizer, **kwargs): + output = self.llm.generate( + inputs_embeds=inputs_embeds, + pad_token_id=0, + eos_token_id=tokenizer.eos_token_id, + **kwargs + ) + return self._decode_text(output, tokenizer) + + def _decode_text(self, result_ids, tokenizer): + result_text = [] + for result in result_ids: + result = result[result != 0] + if result[0] == tokenizer.bos_id: + result = result[1:] + if result[-1] == tokenizer.eos_id: + result = result[:-1] + result_text.append(tokenizer.decode(result).strip()) + return result_text + + def generate( + self, + data_list=None, + img_list=None, + tokenizer=None, + max_inp_length: Optional[int] = None, + vision_hidden_states=None, + return_vision_hidden_states=False, + **kwargs + ): + + assert data_list is not None + bs = len(data_list) + if img_list == None: + img_list = [[] for i in range(bs)] + assert bs == len(img_list) + + model_inputs = self._process_list(tokenizer, data_list, max_inp_length) + + if vision_hidden_states is None: + pixel_values = [] + for i in range(bs): + img_inps = [] + for img in img_list[i]: + img_inps.append(self.transform(img)) + if img_inps: + pixel_values.append(torch.stack(img_inps).to(self.device)) + else: + pixel_values.append([]) + model_inputs['pixel_values'] = pixel_values + else: + model_inputs['vision_hidden_states'] = vision_hidden_states + + with torch.inference_mode(): + model_inputs['inputs_embeds'], vision_hidden_states = self.get_vllm_embedding(model_inputs) + + result = self._decode(model_inputs['inputs_embeds'], tokenizer, **kwargs) + + if return_vision_hidden_states: + return result, vision_hidden_states + + return result + + + def chat(self, image, msgs, context, tokenizer, vision_hidden_states=None, max_new_tokens=2048, sampling=False, **kwargs): + if isinstance(msgs, str): + msgs = json.loads(msgs) + # msgs to prompt + prompt = '' + for i, msg in enumerate(msgs): + role = msg['role'] + content = msg['content'] + assert role in ['user', 'assistant'] + if i == 0: + assert role == 'user', 'The role of first msg should be user' + content = tokenizer.im_start + tokenizer.unk_token * self.config.query_num + tokenizer.im_end + '\n' + content + prompt += '<用户>' if role=='user' else '' + prompt += content + prompt += '' + final_input = prompt + + if sampling: + generation_config = { + 'top_p': 0.8, + 'top_k': 100, + 'temperature':0.6, + 'do_sample': True + } + else: + generation_config = { + 'num_beams': 3, + 'repetition_penalty': 1.2, + } + + generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()) + + with torch.inference_mode(): + res, vision_hidden_states = self.generate( + data_list=[final_input], + max_inp_length=2048, + img_list=[[image]], + tokenizer=tokenizer, + max_new_tokens=max_new_tokens, + vision_hidden_states=vision_hidden_states, + return_vision_hidden_states=True, + **generation_config + ) + answer = res[0] + context = msgs + context.append({'role':'assistant', 'content': answer}) + + return answer, context, generation_config + + +class LlamaTokenizerWrapper(LlamaTokenizer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.im_start = "" + self.im_end = "" + self.ref_start = "" + self.ref_end = "" + self.box_start = "" + self.box_end = "" + self.quad_start = "" + self.quad_end = "" + + @property + def eos_id(self): + return self.sp_model.eos_id() + + @property + def bos_id(self): + return self.sp_model.bos_id() + + @property + def unk_id(self): + return self.sp_model.unk_id() + + @property + def im_start_id(self): + return self._convert_token_to_id(self.im_start) + + @property + def im_end_id(self): + return self._convert_token_to_id(self.im_end) + + +def pad(orig_items, key, max_length=None, padding_value=0, padding_side="left"): + items = [] + if isinstance(orig_items[0][key], list): + assert isinstance(orig_items[0][key][0], torch.Tensor) + for it in orig_items: + for tr in it[key]: + items.append({key: tr}) + else: + assert isinstance(orig_items[0][key], torch.Tensor) + items = orig_items + + batch_size = len(items) + shape = items[0][key].shape + dim = len(shape) + assert dim <= 3 + if max_length is None: + max_length = 0 + max_length = max(max_length, max(item[key].shape[-1] for item in items)) + min_length = min(item[key].shape[-1] for item in items) + dtype = items[0][key].dtype + + if dim == 1: + return torch.cat([item[key] for item in items], dim=0) + elif dim == 2: + if max_length == min_length: + return torch.cat([item[key] for item in items], dim=0) + tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value + else: + tensor = torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + padding_value + + for i, item in enumerate(items): + if dim == 2: + if padding_side == "left": + tensor[i, -len(item[key][0]):] = item[key][0].clone() + else: + tensor[i, : len(item[key][0])] = item[key][0].clone() + elif dim == 3: + if padding_side == "left": + tensor[i, -len(item[key][0]):, :] = item[key][0].clone() + else: + tensor[i, : len(item[key][0]), :] = item[key][0].clone() + + return tensor diff --git a/minicpmv/model/resampler.py b/minicpmv/model/resampler.py new file mode 100644 index 0000000..dde0034 --- /dev/null +++ b/minicpmv/model/resampler.py @@ -0,0 +1,164 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict +import math +import requests +from io import BytesIO +from functools import partial +from PIL import Image +from typing import Callable, Optional, Sequence, Tuple, List, Union +import numpy as np + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.init import trunc_normal_ +from torchvision import transforms +from torchvision.transforms import InterpolationMode + +def get_abs_pos(abs_pos, tgt_size): + # abs_pos: L, C + # tgt_size: M + # return: M, C + src_size = int(math.sqrt(abs_pos.size(0))) + tgt_size = int(math.sqrt(tgt_size)) + dtype = abs_pos.dtype + + if src_size != tgt_size: + return F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size, tgt_size), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) + else: + return abs_pos + + +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class Resampler(nn.Module): + """ + A 2D perceiver-resampler network with one cross attention layers by + (grid_size**2) learnable queries and 2d sincos pos_emb + Outputs: + A tensor with the shape of (grid_size**2, embed_dim) + """ + + def __init__( + self, + grid_size, + embed_dim, + num_heads, + kv_dim=None, + norm_layer=partial(nn.LayerNorm, eps=1e-6) + ): + super().__init__() + self.num_queries = grid_size ** 2 + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.pos_embed = nn.Parameter( + torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float() + ).requires_grad_(False) + + self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) + trunc_normal_(self.query, std=.02) + + if kv_dim is not None and kv_dim != embed_dim: + self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) + else: + self.kv_proj = nn.Identity() + + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.ln_q = norm_layer(embed_dim) + self.ln_kv = norm_layer(embed_dim) + + self.ln_post = norm_layer(embed_dim) + self.proj = nn.Parameter((embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x, attn_mask=None): + + pos_embed = get_abs_pos(self.pos_embed, x.size(1)) + + x = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn( + self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask)[0] + x = out.permute(1, 0, 2) + + x = self.ln_post(x) + x = x @ self.proj + return x + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) \ No newline at end of file