diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 740bb4b0719c61..6afdbdd3f69fe8 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -818,6 +818,8 @@ title: MatCha - local: model_doc/mgp-str title: MGP-STR + - local: model_doc/mplugdocowl + title: mPLUGDocOwl - local: model_doc/nougat title: Nougat - local: model_doc/oneformer diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 92cbdd44d7c0ea..a303ecf2537329 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -214,6 +214,7 @@ Flax), PyTorch, and/or TensorFlow. | [MobileNetV2](model_doc/mobilenet_v2) | ✅ | ❌ | ❌ | | [MobileViT](model_doc/mobilevit) | ✅ | ✅ | ❌ | | [MobileViTV2](model_doc/mobilevitv2) | ✅ | ❌ | ❌ | +| [mPLUGDocOwl](model_doc/mplugdocowl) | ✅ | ❌ | ❌ | | [MPNet](model_doc/mpnet) | ✅ | ✅ | ❌ | | [MPT](model_doc/mpt) | ✅ | ❌ | ❌ | | [MRA](model_doc/mra) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/mplugdocowl.md b/docs/source/en/model_doc/mplugdocowl.md new file mode 100644 index 00000000000000..e1b8d4e453d616 --- /dev/null +++ b/docs/source/en/model_doc/mplugdocowl.md @@ -0,0 +1,75 @@ + + +# mPLUG-DocOwl1.5 + +## Overview + +The mPLUG-DocOwl1.5 model was proposed in [mPLUG-DocOwl 1.5: Unified Structure Learning for OCR-free Document Understanding](https://arxiv.org/pdf/2403.12895) by Anwen Hu, Haiyang Xu, Jiabo Ye, Ming Yan +Liang Zhang, Bo Zhang, Chen Li, Ji Zhang, Qin Jin, Fei Huang, Jingren Zhou. + +MPLUG-DocOwl1.5 is a multimodal model designed for text-rich images. It features the H-Reducer vision-to-text module, which preserves spatial relationships and efficiently processes high-resolution document images by merging visual features horizontally. + +The model employs Unified Structure Learning with structure-aware parsing tasks and multi-grained text localization tasks, teaching it to parse text using line feeds, spaces, and extended Markdown syntax, which enhances the model's ability to correlate text with specific positions in the image. + +DocOwl 1.5 undergoes a two-stage training process: Unified Structure Learning followed by Multi-task Tuning among Downstream Tasks. The high-quality DocReason25K dataset boosts reasoning abilities, allowing DocOwl 1.5-Chat to balance concise answers and detailed explanations. + +The abstract from the paper is the following: + +*Structure information is critical for understanding the semantics of text-rich images, such as documents, tables, and charts. Existing Multimodal Large Language Models (MLLMs) for Visual Document Understanding are equipped with text recognition ability but lack general structure understanding abilities for text-rich document images. In this work, we emphasize the importance of structure information in Visual Document Understanding and propose the Unified Structure Learning to boost the performance of MLLMs. Our Unified Structure Learning comprises structure-aware parsing tasks and multi-grained text localization tasks across 5 domains: document, webpage, table, chart, and natural image. To better encode structure information, we design a simple and effective vision-to-text module H-Reducer, which can not only maintain the layout information but also reduce the length of visual features by merging horizontal adjacent patches through convolution, enabling the LLM to understand high-resolution images more efficiently. Furthermore, by constructing structure-aware text sequences and multi-grained pairs of texts and bounding boxes for publicly available text-rich images, we build a comprehensive training set DocStruct4M to support structure learning. Finally, we construct a small but high-quality reasoning tuning dataset DocReason25K to trigger the detailed explanation ability in the document domain. Our model DocOwl 1.5 achieves state-of-the-art performance on 10 visual document understanding benchmarks, improving the SOTA performance of MLLMs with a 7B LLM by more than 10 points in 5/10 benchmarks.* + +Tips: + +DocOowl-Chat: For more accurate and stable generation, set do_sample=False. Performs better on most of the samples compared to the DocOwl-Omni checkpoint. +DocOwl-Omni: For optimal performance, use do_sample=True and top_p=0.7 as recommended in the original code. + +This model was contributed by [danaaubakirova](https://huggingface.co/danaaubakirova). +The original code can be found [here](https://github.com/X-PLUG/mPLUG-DocOwl/tree/main/DocOwl1.5). + + +## MPLUGDocOwlConfig + +[[autodoc]] MPLUGDocOwlConfig + +## MPLUGDocOwlImageProcessor +[[autodoc]] MPLUGDocOwlImageProcessor + +## MPLUGDocOwlProcessor +[[autodoc]] MPLUGDocOwlProcessor + +## MPLUGDocOwlHReducer +[[autodoc]] MPLUGDocOwlHReducer + +## MPLUGDocOwlForCausalLM +[[autodoc]] MPLUGDocOwlForCausalLM + - forward + +## MPLUGDocOwlLanguageModel +[[autodoc]] MPLUGDocOwlLanguageModel + +## MPLUGDocOwlPreTrainedLanguageModel +[[autodoc]] MPLUGDocOwlPreTrainedLanguageModel + +## MPLUGDocOwlVisionModel +[[autodoc]] MPLUGDocOwlVisionModel + +## MPLUGDocOwlVisionTransformer +[[autodoc]] MPLUGDocOwlVisionTransformer + +## MPLUGDocOwlForConditionalGeneration + +[[autodoc]] MPLUGDocOwlForConditionalGeneration + - forward \ No newline at end of file diff --git a/examples_multi_col_60204.png b/examples_multi_col_60204.png new file mode 100644 index 00000000000000..7541f52e0b732d Binary files /dev/null and b/examples_multi_col_60204.png differ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4c953bab6be4b0..e8bdd1e219daac 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -576,6 +576,10 @@ "models.mobilenet_v2": ["MobileNetV2Config"], "models.mobilevit": ["MobileViTConfig"], "models.mobilevitv2": ["MobileViTV2Config"], + "models.mplugdocowl": [ + "MPLUGDocOwlConfig", + "MPLUGDocOwlProcessor", + ], "models.mpnet": [ "MPNetConfig", "MPNetTokenizer", @@ -1170,6 +1174,7 @@ _import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"]) _import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"]) _import_structure["models.mobilevit"].extend(["MobileViTFeatureExtractor", "MobileViTImageProcessor"]) + _import_structure["models.mplugdocowl"].extend(["MPLUGDocOwlImageProcessor"]) _import_structure["models.nougat"].append("NougatImageProcessor") _import_structure["models.oneformer"].extend(["OneFormerImageProcessor"]) _import_structure["models.owlv2"].append("Owlv2ImageProcessor") @@ -2667,6 +2672,19 @@ "MobileViTV2PreTrainedModel", ] ) + _import_structure["models.mplugdocowl"].extend( + [ + "MPLUGDocOwlAttention", + "MPLUGDocOwlForCausalLM", + "MPLUGDocOwlForConditionalGeneration", + "MPLUGDocOwlHReducer", + "MPLUGDocOwlLanguageModel", + "MPLUGDocOwlPreTrainedLanguageModel", + "MPLUGDocOwlPreTrainedModel", + "MPLUGDocOwlVisionModel", + "MPLUGDocOwlVisionTransformer", + ] + ) _import_structure["models.mpnet"].extend( [ "MPNetForMaskedLM", @@ -5266,6 +5284,10 @@ from .models.mobilevitv2 import ( MobileViTV2Config, ) + from .models.mplugdocowl import ( + MPLUGDocOwlConfig, + MPLUGDocOwlProcessor, + ) from .models.mpnet import ( MPNetConfig, MPNetTokenizer, @@ -5895,6 +5917,7 @@ MobileNetV2ImageProcessor, ) from .models.mobilevit import MobileViTFeatureExtractor, MobileViTImageProcessor + from .models.mplugdocowl import MPLUGDocOwlImageProcessor from .models.nougat import NougatImageProcessor from .models.oneformer import OneFormerImageProcessor from .models.owlv2 import Owlv2ImageProcessor @@ -7122,6 +7145,17 @@ MobileViTV2Model, MobileViTV2PreTrainedModel, ) + from .models.mplugdocowl import ( + MPLUGDocOwlAttention, + MPLUGDocOwlForCausalLM, + MPLUGDocOwlForConditionalGeneration, + MPLUGDocOwlHReducer, + MPLUGDocOwlLanguageModel, + MPLUGDocOwlPreTrainedLanguageModel, + MPLUGDocOwlPreTrainedModel, + MPLUGDocOwlVisionModel, + MPLUGDocOwlVisionTransformer, + ) from .models.mpnet import ( MPNetForMaskedLM, MPNetForMultipleChoice, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index cc1e41b3fc4076..7f515f21a925fb 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -152,6 +152,7 @@ mobilenet_v2, mobilevit, mobilevitv2, + mplugdocowl, mpnet, mpt, mra, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 512c1eaaf5e01a..44ad7feb023be3 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -169,6 +169,7 @@ ("mobilenet_v2", "MobileNetV2Config"), ("mobilevit", "MobileViTConfig"), ("mobilevitv2", "MobileViTV2Config"), + ("mplugdocowl", "MPLUGDocOwlConfig"), ("mpnet", "MPNetConfig"), ("mpt", "MptConfig"), ("mra", "MraConfig"), @@ -461,6 +462,7 @@ ("mobilenet_v2", "MobileNetV2"), ("mobilevit", "MobileViT"), ("mobilevitv2", "MobileViTV2"), + ("mplugdocowl", "mPLUGDocOwl"), ("mpnet", "MPNet"), ("mpt", "MPT"), ("mra", "MRA"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 8bfc61b9bea349..22b752a4b1e6cc 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -106,6 +106,7 @@ ("mobilenet_v2", ("MobileNetV2ImageProcessor",)), ("mobilevit", ("MobileViTImageProcessor",)), ("mobilevitv2", ("MobileViTImageProcessor",)), + ("mplugdocowl", ("MPLUGDocOwlImageProcessor",)), ("nat", ("ViTImageProcessor", "ViTImageProcessorFast")), ("nougat", ("NougatImageProcessor",)), ("oneformer", ("OneFormerImageProcessor",)), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d096abf4342614..b6e64cf46e3601 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -312,6 +312,7 @@ ("mega", "MegaForMaskedLM"), ("megatron-bert", "MegatronBertForPreTraining"), ("mobilebert", "MobileBertForPreTraining"), + ("mplugdocowl", "MPLUGDocOwlForConditionalGeneration"), ("mpnet", "MPNetForMaskedLM"), ("mpt", "MptForCausalLM"), ("mra", "MraForMaskedLM"), @@ -711,6 +712,7 @@ ("llava", "LlavaForConditionalGeneration"), ("llava-next-video", "LlavaNextVideoForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"), + ("mplugdocowl", "MPLUGDocOwlForConditionalGeneration"), ("paligemma", "PaliGemmaForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), ("video_llava", "VideoLlavaForConditionalGeneration"), diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 1ab136a1e74ca7..13a635b111e913 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -76,6 +76,7 @@ ("markuplm", "MarkupLMProcessor"), ("mctct", "MCTCTProcessor"), ("mgp-str", "MgpstrProcessor"), + ("mplugdocowl", "MPLUGDocOwlProcessor"), ("oneformer", "OneFormerProcessor"), ("owlv2", "Owlv2Processor"), ("owlvit", "OwlViTProcessor"), diff --git a/src/transformers/models/mplugdocowl/__init__.py b/src/transformers/models/mplugdocowl/__init__.py new file mode 100644 index 00000000000000..045002f9da1496 --- /dev/null +++ b/src/transformers/models/mplugdocowl/__init__.py @@ -0,0 +1,107 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available + + +_import_structure = { + "configuration_mplugdocowl": ["MPLUGDocOwlConfig"], + "modeling_mplugdocowl": [ + "MPLUGDocOwlAttention", + "MPLUGDocOwlForCausalLM", + "MPLUGDocOwlForConditionalGeneration", + "MPLUGDocOwlHReducer", + "MPLUGDocOwlLanguageModel", + "MPLUGDocOwlPreTrainedLanguageModel", + "MPLUGDocOwlPreTrainedModel", + "MPLUGDocOwlVisionModel", + "MPLUGDocOwlVisionTransformer", + ], + "processing_mplugdocowl": ["MPLUGDocOwlProcessor"], +} + +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["image_processing_mplugdocowl"] = ["MPLUGDocOwlImageProcessor"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_mplugdocowl"] = [ + "MPLUGDocOwlAttention", + "MPLUGDocOwlForCausalLM", + "MPLUGDocOwlForConditionalGeneration", + "MPLUGDocOwlHReducer", + "MPLUGDocOwlLanguageModel", + "MPLUGDocOwlPreTrainedLanguageModel", + "MPLUGDocOwlPreTrainedModel", + "MPLUGDocOwlVisionModel", + "MPLUGDocOwlVisionTransformer", + ] + + +if TYPE_CHECKING: + from .configuration_mplugdocowl import MPLUGDocOwlConfig + from .modeling_mplugdocowl import ( + MPLUGDocOwlAttention, + MPLUGDocOwlForCausalLM, + MPLUGDocOwlForConditionalGeneration, + MPLUGDocOwlHReducer, + MPLUGDocOwlLanguageModel, + MPLUGDocOwlPreTrainedLanguageModel, + MPLUGDocOwlPreTrainedModel, + MPLUGDocOwlVisionModel, + MPLUGDocOwlVisionTransformer, + ) + from .processing_mplugdocowl import MPLUGDocOwlProcessor + + try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .image_processing_mplugdocowl import MPLUGDocOwlImageProcessor + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_mplugdocowl import ( + MPLUGDocOwlAttention, + MPLUGDocOwlForCausalLM, + MPLUGDocOwlForConditionalGeneration, + MPLUGDocOwlHReducer, + MPLUGDocOwlLanguageModel, + MPLUGDocOwlPreTrainedLanguageModel, + MPLUGDocOwlPreTrainedModel, + MPLUGDocOwlVisionModel, + MPLUGDocOwlVisionTransformer, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) diff --git a/src/transformers/models/mplugdocowl/configuration_mplugdocowl.py b/src/transformers/models/mplugdocowl/configuration_mplugdocowl.py new file mode 100644 index 00000000000000..91d6f8a2a28660 --- /dev/null +++ b/src/transformers/models/mplugdocowl/configuration_mplugdocowl.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MPLUGDocOwl model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ..auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class MPLUGDocOwlConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MPLUGDocOwlForConditionalGeneration`]. It is used to instantiate an + MPLUGDocOwl model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the MPLUGDocOwl-Chat. + + e.g. [mplugdocowl-hf/mplugdocowl-Chat](https://huggingface.co/mplugdocowl-hf/mplugdocowl-Chat) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): + The config object or dictionary of the vision backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + The config object or dictionary of the text backbone. + hreducer_hidden_size (`int`, *optional*, defaults to 1024): The hidden size for the hreducer. + hreducer_layer_norm (`float`, *optional*, defaults to 1e-06): The layer normalization parameter for the hreducer. + hreducer_conv_shape (`str`, *optional*, defaults to `"1x4"`): The kernel size for the convolutional layer in the hreducer. + ignore_index (`int`, *optional*, defaults to -100): + The ignore index for the loss function. + image_token_index (`int`, *optional*, defaults to 32000): + The image token index to encode the image prompt. + + Example: + + ```python + >>> from transformers import MPLUGDocOwlForConditionalGeneration, MPLUGDocOwlConfig, CLIPVisionConfig, LlamaConfig + + >>> # Initializing a CLIP-vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Llama config + >>> text_config = LlamaConfig() + + >>> # Initializing a MPLUGDocOwl mplugdocowl-1.5-Chat style configuration + >>> configuration = MPLUGDocOwlConfig(vision_config, text_config) + + >>> # Initializing a model from the mplugdocowl-1.5-Chat style configuration + >>> model = MPLUGDocOwlForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mplugdocowl" + is_composition = False + + def __init__( + self, + vision_config=None, + text_config=None, + hreducer_hidden_size=1024, + hreducer_layer_norm=1e-6, + hreducer_conv_shape="1x4", + ignore_index=-100, + image_token_index=32000, + **kwargs, + ): + self.ignore_index = ignore_index + self.image_token_index = image_token_index + + if isinstance(vision_config, dict): + vision_config["model_type"] = ( + vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" + ) + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=448, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + layer_norm_eps=1e-6, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + hidden_act="quick_gelu", + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + self._vocab_size = self.text_config.vocab_size + self.hreducer_hidden_size = hreducer_hidden_size + self.hreducer_layer_norm = hreducer_layer_norm + self.hreducer_conv_shape = hreducer_conv_shape + super().__init__(**kwargs) diff --git a/src/transformers/models/mplugdocowl/convert_mplugdocowl_weights_to_hf.py b/src/transformers/models/mplugdocowl/convert_mplugdocowl_weights_to_hf.py new file mode 100644 index 00000000000000..1a94ce4ceaba77 --- /dev/null +++ b/src/transformers/models/mplugdocowl/convert_mplugdocowl_weights_to_hf.py @@ -0,0 +1,159 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import re + +import torch +from huggingface_hub import hf_hub_download + +from transformers import ( + AddedToken, + AutoConfig, + AutoTokenizer, + MPLUGDocOwlConfig, + MPLUGDocOwlForConditionalGeneration, + MPLUGDocOwlProcessor, +) +from transformers.models.mplugdocowl.image_processing_mplugdocowl import MPLUGDocOwlImageProcessor + + +KEYS_TO_MODIFY_MAPPING = { + r"model\.vision_model\.embeddings\.position_embedding": r"vision_tower.vision_model.embeddings.position_embedding", + r"model\.vision_model\.encoder\.layers\.(\d+)\.input_layernorm": r"vision_tower.vision_model.encoder.layers.\1.layer_norm1", + r"model\.vision_model\.encoder\.layers\.(\d+)\.post_attention_layernorm": r"vision_tower.vision_model.encoder.layers.\1.layer_norm2", + r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn.dense": r"vision_tower.vision_model.encoder.layers.\1.self_attn.out_proj", + r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn.query_key_value": r"vision_tower.vision_model.encoder.layers.\1.self_attn.q_v_k_proj", + r"model\.vision_model\.embeddings\.pre_layernorm": r"vision_tower.vision_model.embeddings.pre_layernorm", + r"model\.vision_model\.embeddings\.patch_embed": r"vision_tower.vision_model.embeddings.patch_embedding", + r"model\.vision_model\.embeddings\.cls_token": r"vision_tower.vision_model.embeddings.class_embedding", + r"model\.vision_model\.": r"vision_tower.vision_model.", + r"model\.layers\.": r"language_model.model.layers.", + r"model\.mm_projector": r"multi_modal_projector", + r"lm_head": r"language_model.lm_head", + r"model\.norm\.": r"language_model.model.norm.", + r"model\.embed_tokens": r"language_model.model.embed_tokens", + r"model\.vision2text": r"multi_modal_projector", + r"ln_q": r"layer_norm", +} + + +def convert_state_dict_to_hf(state_dict): + new_state_dict = {} + for key, value in state_dict.items(): + if key.endswith(".inv_freq"): + continue + original_key = key + for pattern, replacement in KEYS_TO_MODIFY_MAPPING.items(): + if re.search(pattern, key): + key = re.sub(pattern, replacement, key) + + new_state_dict[key] = value + print(f"Converted {original_key} to {key}") + return new_state_dict + + +def convert_mplugdocowl_llama_to_hf( + text_model_id, vision_model_id, output_hub_path, old_state_dict_id, pretrained=True +): + if not pretrained: + torch.set_default_dtype(torch.float16) + text_config = AutoConfig.from_pretrained(text_model_id) + + tokenizer = AutoTokenizer.from_pretrained(text_model_id, use_fast=False) + tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) + tokenizer.add_special_tokens({"pad_token": ""}) + + image_processor = MPLUGDocOwlImageProcessor() + processor = MPLUGDocOwlProcessor(tokenizer=tokenizer, image_processor=image_processor) + config = MPLUGDocOwlConfig(text_config=text_config) + config.pad_token_id = 32001 + + with torch.device("cuda:0"): + model = MPLUGDocOwlForConditionalGeneration(config).eval() + + # Pad to 64 for performance reasons + pad_shape = 64 + + state_dict_path = hf_hub_download(old_state_dict_id, "pytorch_model.bin") + + state_dict = torch.load(state_dict_path, map_location="cpu") + + state_dict = convert_state_dict_to_hf(state_dict) + + state_dict["multi_modal_projector.reducer_before.0.weight"] = state_dict[ + "multi_modal_projector.reducer_before.0.weight" + ].contiguous() + state_dict["multi_modal_projector.reducer.weight"] = state_dict[ + "multi_modal_projector.reducer.weight" + ].contiguous() + + model.load_state_dict(state_dict, strict=True, assign=True) + + pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data + mu = torch.mean(pre_expansion_embeddings, dim=0).float() + n = pre_expansion_embeddings.size()[0] + sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n + dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma) + + # We add an image token so we resize the model + model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape) + model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack( + tuple( + (dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[32000:].shape[0])) + ), + dim=0, + ) + model.language_model.lm_head.weight.data[32000:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))), + dim=0, + ) + model.to(torch.float16) + model.save_pretrained("tmp/hf_models/mplugdocowl1.5-Chat-hf/") + processor.save_pretrained("tmp/hf_models/mplugdocowl1.5-Chat-hf") + else: + model = MPLUGDocOwlForConditionalGeneration.from_pretrained("tmp/hf_models/mplugdocowl1.5-Chat-hf") + model.to(torch.float16) + processor = MPLUGDocOwlProcessor.from_pretrained("tmp/hf_models/mplugdocowl1.5-Chat-hf") + model.push_to_hub(output_hub_path) + processor.push_to_hub(output_hub_path) + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--text_model_id", + help="Hub location of the text model", + ) + parser.add_argument( + "--vision_model_id", + help="Hub location of the vision model", + ) + parser.add_argument( + "--output_hub_path", + help="Location on the hub of the converted model", + ) + parser.add_argument( + "--old_state_dict_id", + help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`", + ) + args = parser.parse_args() + convert_mplugdocowl_llama_to_hf( + args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/mplugdocowl/image_processing_mplugdocowl.py b/src/transformers/models/mplugdocowl/image_processing_mplugdocowl.py new file mode 100644 index 00000000000000..60061215384790 --- /dev/null +++ b/src/transformers/models/mplugdocowl/image_processing_mplugdocowl.py @@ -0,0 +1,752 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for MPLUGDocOwl.""" + +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + convert_to_rgb, + get_resize_output_image_size, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + OPENAI_CLIP_MEAN, + OPENAI_CLIP_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_kwargs, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_vision_available, logging + + +logger = logging.get_logger(__name__) + + +if is_vision_available(): + import PIL + +GRID_DICT = { + "grid_1": [(1, 1)], + "grid_4": [(1, 1), (1, 2), (2, 1), (1, 3), (3, 1), (2, 2), (1, 4), (4, 1)], + "grid_9": [ + (1, 1), + (1, 2), + (2, 1), + (1, 3), + (3, 1), + (2, 2), + (1, 4), + (4, 1), + (1, 5), + (5, 1), + (1, 6), + (6, 1), + (2, 3), + (3, 2), + (1, 7), + (7, 1), + (4, 2), + (2, 4), + (1, 8), + (8, 1), + (3, 3), + (1, 9), + (9, 1), + ], + "grid_3x3": [(3, 3)], + "grid_20": [ + (1, 1), + (1, 2), + (2, 1), + (1, 3), + (3, 1), + (1, 4), + (2, 2), + (4, 1), + (1, 5), + (5, 1), + (1, 6), + (2, 3), + (3, 2), + (6, 1), + (1, 7), + (7, 1), + (1, 8), + (2, 4), + (4, 2), + (8, 1), + (1, 9), + (3, 3), + (9, 1), + (1, 10), + (2, 5), + (5, 2), + (10, 1), + (1, 11), + (11, 1), + (2, 6), + (3, 4), + (4, 3), + (6, 2), + (2, 7), + (7, 2), + (3, 5), + (5, 3), + (2, 8), + (4, 4), + (8, 2), + (2, 9), + (3, 6), + (6, 3), + (9, 2), + (2, 10), + (4, 5), + (5, 4), + (10, 2), + ], +} + + +def box_area(boxes): + r""" + Compute the area of each bounding box in a given set of bounding boxes. + + Args: + boxes (np.ndarray): An array of shape (N, 4) containing N bounding boxes, + boxes (`np.ndarray`): An array of shape (N, 4) containing N bounding boxes, + each represented by the coordinates [x_min, y_min, x_max, y_max]. + + Returns: + `np.ndarray`: An array of shape (N,) containing the area of each bounding box. + """ + return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + + +def box_iou(boxes1, area1, boxes2, eps=1e-5): + r""" + Compute the Intersection over Union (IoU) between two sets of bounding boxes. + + Args: + boxes1 (np.ndarray): An array of shape (N, 4) containing N bounding boxes. + area1 (np.ndarray): An array of shape (N,) containing the area of each bounding box in boxes1. + boxes2 (np.ndarray): An array of shape (M, 4) containing M bounding boxes. + eps (float, optional): A small value to avoid division by zero. Defaults to 1e-5. + boxes1 (`np.ndarray`): An array of shape (N, 4) containing N bounding boxes. + area1 (`np.ndarray`): An array of shape (N,) containing the area of each bounding box in boxes1. + boxes2 (`np.ndarray`): An array of shape (M, 4) containing M bounding boxes. + eps (`float`, *optional*): A small value to avoid division by zero. Defaults to 1e-5. + + Returns: + `tuple`: A tuple containing: + - `np.ndarray`: An array of shape (N, M) containing the IoU between each pair of boxes from boxes1 and boxes2. + - `np.ndarray`: An array of shape (N, M) containing the union areas of each pair of boxes. + """ + area2 = box_area(boxes2) + + top_left = np.maximum(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + bottom_right = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + width_height = np.clip(bottom_right - top_left, a_min=0, a_max=None) # [N,M,2] + intersection = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - intersection + + iou = intersection / (union + eps) + + return iou, union + + +def anchor_rank(anchors, anchors_areas, input_image_size, eps=1e-5): + r""" + Rank anchors based on their IoU and shape-adaptive IoU with respect to an input image size. + + Args: + anchors (np.ndarray): An array of shape (N, 4) containing N anchors. + anchors_areas (np.ndarray): An array of shape (N,) containing the area of each anchor. + input_image_size (tuple): A tuple (height, width) representing the size of the input image. + eps (float, optional): A small value to avoid division by zero. Defaults to 1e-5. + anchors (`np.ndarray`): An array of shape (N, 4) containing N anchors. + anchors_areas (`np.ndarray`): An array of shape (N,) containing the area of each anchor. + input_image_size (`tuple`): A tuple (height, width) representing the size of the input image. + eps (`float`, *optional*, defaults to 1e-05): A small value to avoid division by zero. Defaults to 1e-5. + + Returns: + `int`: The index of the selected anchor with the highest rank. + + """ + input_image_bbox = np.array([[0, 0, input_image_size[1], input_image_size[0]]]) + + boxes1 = anchors + boxes2 = input_image_bbox + boxes3 = anchors.copy() + boxes3[:, 3] = input_image_size[0] / input_image_size[1] * anchors[:, 2] # for resolution-independent iou + + area1 = anchors_areas + + iou, _ = box_iou(boxes1, area1, boxes2) + iou = iou.squeeze(1) + + shape_iou, _ = box_iou(boxes1, area1, boxes3) + shape_iou = np.diag(shape_iou) # Get diagonal for self-comparison + + index = np.argmax(shape_iou * 100 + iou) + + return index + + +def anchor_resize( + image: ImageInput, + anchors: str = "grid_9", + size: Dict[str, int] = None, + grid_dict: Dict[str, List[Tuple[int, int]]] = GRID_DICT, + resample=PILImageResampling.BICUBIC, +): + r""" + Resize an image based on selected anchor and its associated size. + + Args: + image (`ImageInput`): The input image to be resized. + anchors (`str`, *optional*, defaults to "grid_9"): The key for selecting anchor sizes from the grid_dict. Defaults to "grid_9". + size (`Dict[str, int]`, *optional*): A dictionary containing the target size for resizing. Defaults to None. + grid_dict (`Dict[str, List[Tuple[int, int]]]`, *optional*): A dictionary containing the anchor grid configurations. Defaults to GRID_DICT. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): The resampling method to use. Defaults to PILImageResampling.BICUBIC. + Returns: + tuple: A tuple containing: + - List[np.ndarray]: A list containing the resized image. + - int: The index of the selected anchor. + - `List[np.ndarray]`: A list containing the resized image. + - `int`: The index of the selected anchor. + """ + # Convert anchors to xyxy format + anchors = [tuple(_) for _ in grid_dict[anchors]] + size = size["width"] + anchors = np.array([[0, 0, anchor[1] * size, anchor[0] * size] for anchor in anchors]) + anchor_areas = box_area(anchors) + + # Resize image based on selected anchor + selected_anchor = anchor_rank(anchors, anchor_areas, (image.size[1], image.size[0])) + target_size = anchors[selected_anchor][2:].astype(int) # target width, height + resized_img = image.resize((target_size[0], target_size[1]), resample=resample) + resized_img = np.array(resized_img) + return (resized_img, selected_anchor) + + +def shape_adaptive_cropping( + image_patches: ImageInput, + size: Dict[str, int] = None, + anchors: str = "grid_9", + grid_dict: Dict[str, List[Tuple[int, int]]] = GRID_DICT, + selected_anchor: int = None, +): + r""" + Performs shape-adaptive cropping on image patches based on selected anchor size. + + This function is designed to handle images with various aspect ratios and resolutions by cropping + the image into multiple sub-images using a shape-adaptive grid. The goal is to preserve the resolution + and aspect ratio as much as possible to prevent text blur and distortion, which is critical for tasks + requiring visually-situated language understanding. + + Args: + image_patches (ImageInput): The input image patches to be cropped. + size (Dict[str, int], optional): A dictionary containing the target size for cropping. The size + is expected to have a key "width". Defaults to None. + anchors (str, optional): The key for selecting anchor sizes from the grid_dict. Defaults to "grid_9". + grid_dict (Dict[str, List[Tuple[int, int]]], optional): A dictionary containing the anchor grid + configurations. Defaults to GRID_DICT. + add_global_img (bool, optional): Whether to add the global image to the list of cropped patches. + Defaults to True. + selected_anchor (int, optional): The index of the selected anchor for cropping. If None, the + function will select an anchor based on the shape-adaptive + criteria. Defaults to None. + + Returns: + tuple: A tuple containing: + - List[np.ndarray]: A list of cropped image patches. + - np.ndarray: An array containing the positions of the patches. + - int: The number of patches. + - int: The maximum anchor size. + + Notes: + The function first converts the input anchors to a format suitable for cropping. It then reshapes + the image patches according to the selected anchor size. The resulting sub-images maintain the + resolution and aspect ratio of the original image as much as possible. + Find more details in the paper https://arxiv.org/pdf/2310.05126. + + Example: + Consider: + nh (int): Number of rows in the grid. + nw (int): Number of columns in the grid. + Hv (int): Height of the visual encoder input. + Wv (int): Width of the visual encoder input. + Nc (int): Maximum number of cells (sub-images) in the grid. + + The grid configurations and their selection are based on two main criteria: + 1. Resolution coherence (Srr): This measures the IoU between the input image resolution and the grid resolution. + Srr(I, g) = IoU((H, W), (nh * Hv, nw * Wv)) + 2. Shape similarity (Sra): This measures the IoU between the input image aspect ratio and the grid aspect ratio. + Sra(I, g) = IoU((H, W), (nh, nw)) + + The matched grid is selected by maximizing the matching score: + g* = argmax (Sra(I, g) + Srr(I, g)) + + After selecting the appropriate grid, the input image is resized to (nh * Hv, nw * Wv) and cropped into nh * nw local images. + Additionally, to maintain the global structure information of the image, the input image is resized to (Hv, Wv) as a global image. + + """ + anchors = [tuple(_) for _ in grid_dict[anchors]] + size = size["width"] + + anchor_max = max(max(_) for _ in anchors) + + image_patches = image_patches.transpose(2, 0, 1) + + anchor_size = anchors[selected_anchor] + + num_h, num_w = anchor_size + + image_input = image_patches.reshape(3, num_h, size, num_w, size) + + image_input = image_input.transpose(1, 3, 2, 4, 0) + image_input = image_input.reshape((-1, size, size, 3)) + image_patches_list = [image_input[i] for i in range(image_input.shape[0])] + anchor = anchors[selected_anchor] # w,h + patch_position = np.concatenate( + [ + np.repeat(np.arange(anchor[0])[:, np.newaxis], anchor[1], axis=1)[:, :, np.newaxis], + np.repeat(np.arange(anchor[1])[np.newaxis, :], anchor[0], axis=0)[:, :, np.newaxis], + ], + axis=2, + ) + + patch_position = patch_position.reshape(-1, 2) + patch_position = np.vstack((np.ones((1, 2), dtype=np.int64) * anchor_max, patch_position)) + return image_patches_list, patch_position, patch_position.shape[0], anchor_max + + +def add_global_image(images, patch_images): + """ + This function takes a list of global images and a list of lists containing patch images, + and combines them such that each image is followed by its corresponding patch images. + + :param images: List of global images + :param patch_images: List of lists of patch images corresponding to each image + :return: A new list with images followed by their corresponding patch images + """ + # Create a new list to store the combined elements + combined_images = [] + + # Combine elements + for image, patches in zip(images, patch_images): + combined_images.append(image) + combined_images.extend(patches) + + return combined_images + + +class MPLUGDocOwlImageProcessor(BaseImageProcessor): + r""" + Constructs a MPLUGDocOwlImageProcessor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by + `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`): + Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `False`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the + `preprocess` method. + crop_size (`Dict[str, int]` *optional*, defaults to `False`): + Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess` + method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method. + image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*, defaults to `True`): + Whether to convert the image to RGB. + do_shape_adaptive_cropping (`bool`, *optional*, defaults to `True`): Whether to do a shape adaptive cropping of the input image. Should be only called if the do_anchor_resize is called. + do_anchor_resize (`bool`, *optional*, defaults to `True`): Whether to resize the image based on the specified anchor. Should be called before do_shape_adaptive_cropping. + do_add_global_image (`bool`, *optional*, defaults to `True`): Whether to add the global image to the image input. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = False, + crop_size: Dict[str, int] = False, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = True, + do_shape_adaptive_cropping: bool = True, + do_anchor_resize: bool = True, + do_add_global_image: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 448, "width": 448} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 448, "width": 448} + crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN + self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD + self.do_convert_rgb = do_convert_rgb + self.do_shape_adaptive_cropping = do_shape_adaptive_cropping + self.do_anchor_resize = do_anchor_resize + self.do_add_global_image = do_add_global_image + self._valid_processor_keys = [ + "images", + "do_resize", + "size", + "resample", + "do_center_crop", + "crop_size", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_convert_rgb", + "return_tensors", + "data_format", + "input_data_format", + "do_shape_adaptive_cropping", + "do_anchor_resize", + "do_add_global_image", + ] + + def anchor_resize( + self, image: ImageInput, size: Dict[str, int] = None, resample: PILImageResampling = PILImageResampling.BICUBIC + ): + r""" + Resizes an image using the specified anchor point and resampling method. + + Args: + image (ImageInput): The image to be resized. + size (Dict[str, int], optional): A dictionary specifying the desired width and height. Default is None. + resample (PILImageResampling, optional): The resampling method to use. Default is PILImageResampling.BICUBIC. + + Returns: + Image: The resized image. + """ + return anchor_resize(image=image, size=size, resample=resample) + + def adaptive_crop( + self, + image_patches: ImageInput, + size: Dict[str, int] = None, + selected_anchor: int = None, + ): + r""" + Performs adaptive cropping on image patches based on a selected anchor point. + + Args: + image_patches (ImageInput): The image patches to be cropped. + size (Dict[str, int], optional): A dictionary specifying the desired width and height. Default is None. + selected_anchor (int, optional): The index of the selected anchor point. Default is None. + + Returns: + Image: The cropped image patches. + """ + return shape_adaptive_cropping(image_patches=image_patches, size=size, selected_anchor=selected_anchor) + + def add_global_image( + self, + images: List, + patch_images: List, + ): + r""" + Adds global image data to a list of patch images. + + Args: + images (List): The list of images to which global image data will be added. + patch_images (List): The list of patch images to be combined with the global image data. + + Returns: + List: The combined list of images with global image data. + """ + return add_global_image(images=images, patch_images=patch_images) + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BICUBIC, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + r""" + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + default_to_square = True + if "shortest_edge" in size: + size = size["shortest_edge"] + default_to_square = False + elif "height" in size and "width" in size: + size = (size["height"], size["width"]) + else: + raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") + + output_size = get_resize_output_image_size( + image, + size=size, + default_to_square=default_to_square, + input_data_format=input_data_format, + ) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BICUBIC, + do_center_crop: bool = False, + crop_size: int = False, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: bool = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_shape_adaptive_cropping: bool = True, + do_anchor_resize: bool = True, + do_add_global_image: bool = True, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + sizeexi (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with + the longest edge resized to keep the input aspect ratio. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`): + Whether to center crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the center crop. Only has an effect if `do_center_crop` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to + `True`. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, param_name="size", default_to_square=True) + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + do_shape_adaptive_cropping = ( + do_shape_adaptive_cropping if do_shape_adaptive_cropping is not None else self.do_shape_adaptive_cropping + ) + do_anchor_resize = do_anchor_resize if do_anchor_resize is not None else self.do_anchor_resize + do_add_global_image = do_add_global_image if do_add_global_image is not None else self.do_add_global_image + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + # 1. Keep global image to be able to work with it later + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + patch_images = images.copy() + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + if do_center_crop: + images = [ + self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images + ] + + if do_resize: + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_anchor_resize: + output = [self.anchor_resize(image, size) for image in patch_images] + + if do_shape_adaptive_cropping: + output = [ + self.adaptive_crop(image_patches=image, size=size, selected_anchor=selected_anchor) + for (image, selected_anchor) in output + ] + patch_images, patch_positions, num_patches, anchor_max = zip(*output) + + if do_add_global_image: + images = self.add_global_image(images, patch_images) + else: + images = [patch for sublist in patch_images for patch in sublist] + patch_positions = [pos[1:] for pos in patch_positions] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = { + "pixel_values": images, + "patch_positions": patch_positions, + "num_patches": num_patches, + "anchor_max": anchor_max, + } + return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/src/transformers/models/mplugdocowl/modeling_mplugdocowl.py b/src/transformers/models/mplugdocowl/modeling_mplugdocowl.py new file mode 100644 index 00000000000000..b2ed3a80ce92c3 --- /dev/null +++ b/src/transformers/models/mplugdocowl/modeling_mplugdocowl.py @@ -0,0 +1,1943 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MPLUGDocOwl model.""" + +import math +from dataclasses import dataclass +from functools import partial +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ... import PreTrainedModel +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, +) +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_mplugdocowl import MPLUGDocOwlConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MPLUGDocOwlConfig" + + +# contrastive loss function, adapted from +# https://sachinruk.github.io/blog/2021-03-07-clip.html +def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: + return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) + + +def clip_loss(similarity: torch.Tensor) -> torch.Tensor: + caption_loss = contrastive_loss(similarity) + image_loss = contrastive_loss(similarity.t()) + return (caption_loss + image_loss) / 2.0 + + +@dataclass +class MPLUGDocOwlCausalLMOutputWithPast(ModelOutput): + """ + Base class for MPLUGDocOwl causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, + sequence_length, hidden_size)`. + + image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +MPLUGDOCOWL_START_VISION_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MPLUGDocOwlConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare MPLUGDocOwl Model outputting raw hidden-states without any specific head on top.", + MPLUGDOCOWL_START_VISION_DOCSTRING, +) +class MPLUGDocOwlPreTrainedModel(PreTrainedModel): + config_class = MPLUGDocOwlConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MPLUGDocOwlEncoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = False + + @property + def _supports_sdpa(self): + """ + Retrieve language_model's attribute to check whether the model supports + SDPA or not. + """ + return self.language_model._supports_sdpa + + +MPLUGDOCOWL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`MPLUGDocOwlImageProcessor.__call__`] for details ([]`MPLUGDocOwlProcessor`] uses + [`MPLUGDocOwlImageProcessor`] for processing images). + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +class MPLUGDocOwlVisionEmbeddings(nn.Module): + def __init__(self, config: MPLUGDocOwlConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.embed_dim)) + + self.pre_layernorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) + + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(patch_embeds.dtype) + + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding[:, : embeddings.size(1)].to(patch_embeds.dtype) + embeddings = self.pre_layernorm(embeddings) + + return embeddings + + +class MPLUGDocOwlAttention(MPLUGDocOwlPreTrainedModel): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__(config) + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = nn.Dropout(config.attention_dropout) + + self.q_v_k_proj = nn.Linear(self.embed_dim, 3 * self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + bsz, seq_len, embed_dim = hidden_states.size() + + mixed_qkv = self.q_v_k_proj(hidden_states) + + mixed_qkv = mixed_qkv.reshape(bsz, seq_len, self.num_heads, 3, embed_dim // self.num_heads).permute( + 3, 0, 2, 1, 4 + ) # [3, b, np, sq, hn] + query_states, key_states, value_states = ( + mixed_qkv[0], + mixed_qkv[1], + mixed_qkv[2], + ) + # get query proj + attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + + attention_scores = attention_scores * self.scale + + # Normalize the attention scores to probabilities. + attention_probs = torch.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3) + + new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,) + + context_layer = context_layer.reshape(new_context_layer_shape) + + output = self.out_proj(context_layer) + + outputs = (output, attention_probs) if output_attentions else (output, None) + + return outputs + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MPLUGDocOwlVision +class MPLUGDocOwlVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MPLUGDocOwlEncoderLayer(nn.Module): + def __init__(self, config: MPLUGDocOwlConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = MPLUGDocOwlAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = MPLUGDocOwlVisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(config.encoder_attention_heads,)`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + head_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = hidden_states + residual + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +MPLUGDocOwl_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`MPLUGDocOwlImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class MPLUGDocOwlVisionEncoder(MPLUGDocOwlPreTrainedModel): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + ['MPLUGDocOwlEncoderLayer']. + + Args: + config: MPLUGDocOwlConfig + """ + + def __init__(self, config: MPLUGDocOwlConfig): + super().__init__(config) + self.config = config + self.layers = nn.ModuleList([MPLUGDocOwlEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Causal mask for the text model. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class MPLUGDocOwlVisionTransformer(MPLUGDocOwlPreTrainedModel): + def __init__(self, config: MPLUGDocOwlConfig): + super().__init__(config) + self.config = config + self.embed_dim = config.hidden_size + + self.embeddings = MPLUGDocOwlVisionEmbeddings(config) + + self.encoder = MPLUGDocOwlVisionEncoder(config) + self.post_layernorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.post_init() + + @add_start_docstrings_to_model_forward(MPLUGDocOwl_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=MPLUGDocOwlConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + + Returns: + Union[Tuple, BaseModelOutputWithPooling]: A `BaseModelOutputWithPooling` or a tuple of (last_hidden_state, pooled_output, hidden_states, attentions), where: + - last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)): Sequence of hidden states at the output of the last layer of the model. + - pooler_output (torch.FloatTensor of shape (batch_size, hidden_size)): The last hidden state after applying the post-layer normalization. + - hidden_states (Optional[Tuple[torch.FloatTensor]]): Tuple of torch.FloatTensor (one for the output of each layer) of shape (batch_size, sequence_length, hidden_size). + - attentions (Optional[Tuple[torch.FloatTensor]]): Tuple of torch.FloatTensor (one for each attention head) of shape (batch_size, num_heads, sequence_length, sequence_length). + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + pooled_output = last_hidden_state[:, 0, :] + pooled_output = self.post_layernorm(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The vision model from MPLUGDocOwl without any head or projection on top.""", + MPLUGDOCOWL_START_VISION_DOCSTRING, +) +class MPLUGDocOwlVisionModel(PreTrainedModel): + config_class = MPLUGDocOwlConfig + main_input_name = "pixel_values" + _no_split_modules = ["MPLUGDocOwlEncoderLayer"] + _supports_sdpa = False + + def __init__(self, config: MPLUGDocOwlConfig): + super().__init__(config) + self.vision_model = MPLUGDocOwlVisionTransformer(config) + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings # .patch_embedding + + @add_start_docstrings_to_model_forward(MPLUGDocOwl_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=MPLUGDocOwlConfig) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + `BaseModelOutputWithPooling` or `tuple`: + If `return_dict` is `True`, a `BaseModelOutputWithPooling` is returned, containing: + - **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + - **pooler_output** (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed + by a linear layer and a Tanh activation function. The linear layer weights are trained from the next + sentence prediction (classification) objective during pretraining. This output is usually not a good + summary of the semantic content of the input, you're often better with averaging or pooling the + sequence of hidden-states for the whole input sequence. + - **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of each layer + the output of the embedding layer). + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + - **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, CLIPVisionModel + + >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32") + >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled CLS states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, inputs_embeds.device, past_key_values_length=past_key_values_length + ).to(inputs_embeds.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MPLUGDocOwl +class MPLUGDocOwlRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MPLUGDocOwlRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +ALL_LAYERNORM_LAYERS.append(MPLUGDocOwlRMSNorm) + + +class MPLUGDocOwlRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class MPLUGDocOwlLinearScalingRotaryEmbedding(MPLUGDocOwlRotaryEmbedding): + """MPLUGDocOwlRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +class MPLUGDocOwlDynamicNTKScalingRotaryEmbedding(MPLUGDocOwlRotaryEmbedding): + """MPLUGDocOwlRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MPLUGDocOwlMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.pretraining_tp = config.pretraining_tp + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.pretraining_tp > 1: + slice = self.intermediate_size // self.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +class MultiwayNetwork(nn.Module): + r""" + A multi-path network that applies different modules to different parts of the input tensor based on provided indices. + This approach is particularly useful for handling multi-modal data by projecting visual and language features into a shared semantic space while preserving their distinctive properties. + Formally it is refered to as Modality Adaptive Module (MAM). More details are in the paper: https://arxiv.org/pdf/2311.04257. + + Args: + module_provider (Callable): A callable that returns an instance of the module to be applied to the inputs. + num_multiway (int, optional): The number of different modules to use. Defaults to 2. + + Methods: + forward(hidden_states, multiway_indices): + Applies the corresponding module to each part of the hidden states as indicated by multiway_indices. + + Args: + hidden_states (torch.Tensor): The input tensor of shape (batch_size, seq_length, hidden_size). + multiway_indices (torch.Tensor): A tensor of indices indicating which module to apply to each part of hidden_states. + + Returns: + torch.Tensor: The output tensor after applying the selected modules. + + Example: + Given a vision-language sequence \(X \in \mathbb{R}^{(L_V + L_T) \times d}\) and modality indicators \(M \in \{0, 1\}^{(L_V + L_T) \times d}\), + where \(L_V\) and \(L_T\) are the lengths of the visual and textual sequences respectively, + the modality separated operation \(\phi\) is defined as: + + \[\widetilde{H}^{l-1} = \text{LNV}(\phi(H^{l-1}, M, 0)) + \text{LNT}(\phi(H^{l-1}, M, 1))\] + + Here, \(\phi\) is the modality separated operation, \(M\) indicates the modality (0 for visual, 1 for language), + and \(\text{LNV}\) and \(\text{LNT}\) are layer normalizations for visual and language features respectively. + + The query, key, and value projections are formulated as follows: + + - Query Projection: + \[Q^l = H^{l-1} W_Q^l\] + + - Key Projection: + \[K^l = \phi(\widetilde{H}^{l-1}, M, 0) W_{K0}^l + \phi(\widetilde{H}^{l-1}, M, 1) W_{K1}^l\] + + - Value Projection: + \[V^l = \phi(H^{l-1}, M, 0) W_{V0}^l + \phi(H^{l-1}, M, 1) W_{V1}^l\] + + The attention context features for the \(l\)-th layer are computed as: + + \[C^l = \text{Softmax}\left(\frac{Q^l K^{l \top}}{\sqrt{d}}\right) V^l\] + + Where \(Q^l\), \(K^l\), and \(V^l\) are the query, key, and value projections respectively, and \(d\) is the dimension of the head. + """ + + def __init__(self, module_provider, num_multiway=2): + super(MultiwayNetwork, self).__init__() + + self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)]) + + def forward(self, hidden_states, multiway_indices): + if len(self.multiway) == 1: + return self.multiway[0](hidden_states) + + output_hidden_states = torch.empty_like(hidden_states) + + for idx, subway in enumerate(self.multiway): + local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True) + hidden = hidden_states[local_indices].unsqueeze(1).contiguous() + if hidden.numel(): + output = subway(hidden) + if isinstance(output, tuple): + output = output[0] + output = output.squeeze(1) + output_hidden_states[local_indices] = output + + return output_hidden_states.contiguous() + + +class MultiwayAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MPLUGDocOwlConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = MultiwayNetwork( + module_provider=partial( + nn.Linear, + in_features=self.hidden_size, + out_features=self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + ) + self.v_proj = MultiwayNetwork( + module_provider=partial( + nn.Linear, + in_features=self.hidden_size, + out_features=self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = MPLUGDocOwlRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = MPLUGDocOwlLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = MPLUGDocOwlDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + modality_indicators: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj( + hidden_states, + ) + key_states = self.k_proj(hidden_states, modality_indicators) + value_states = self.v_proj(hidden_states, modality_indicators) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # cos, sin = self.rotary_emb(value_states, position_ids) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +MPLUGDocOwl_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MPLUGDocOwlConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +class MPLUGDocOwlDecoderLayer(nn.Module): + def __init__(self, config: MPLUGDocOwlConfig, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = MultiwayAttention(config=config) + self.layer_idx = layer_idx + self.mlp = MPLUGDocOwlMLP(config) + self.input_layernorm = MultiwayNetwork( + module_provider=partial(MPLUGDocOwlRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps) + ) + self.post_attention_layernorm = MultiwayNetwork( + module_provider=partial(MPLUGDocOwlRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps) + ) + + def forward( + self, + hidden_states: torch.Tensor, + modality_indicators: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + modality_indicators (torch.Tensor): A tensor of 1s and 0s indicating which module to apply to each part of hidden_states. 1 - image, 0 - text embeddings. + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states, modality_indicators) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + modality_indicators=modality_indicators, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states, modality_indicators) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@add_start_docstrings( + "The bare MPLUGDocOwl Model outputting raw hidden-states without any specific head on top.", + MPLUGDocOwl_START_DOCSTRING, +) +class MPLUGDocOwlPreTrainedLanguageModel(PreTrainedModel): + config_class = MPLUGDocOwlConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MPLUGDocOwlDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = False + _supports_cache_class = True + _supports_static_cache = True + _supports_sdpa = False + + +MPLUGDocOwl_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare MPLUGDocOwl Model outputting raw hidden-states without any specific head on top.", + MPLUGDocOwl_START_DOCSTRING, +) +class MPLUGDocOwlLanguageModel(MPLUGDocOwlPreTrainedLanguageModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MPLUGDocOwlDecoderLayer`] + + Args: + config: MPLUGDocOwlConfig + """ + + def __init__(self, config: MPLUGDocOwlConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MPLUGDocOwlDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = MPLUGDocOwlRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MPLUGDocOwl_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + modality_indicators: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + + attention_mask = _prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + + else: + layer_outputs = decoder_layer( + hidden_states, + modality_indicators=modality_indicators, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@add_start_docstrings( + """The MPLUGDOCOWL model which consists of a vision backbone and a language model.""", + MPLUGDocOwl_START_DOCSTRING, +) +class MPLUGDocOwlForCausalLM(MPLUGDocOwlPreTrainedLanguageModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MPLUGDocOwlLanguageModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MPLUGDocOwl_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + modality_indicators: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + `Union[Tuple, CausalLMOutputWithPast]`: A `Tuple` containing various elements depending on the configuration + (`config`) and inputs, or a `CausalLMOutputWithPast` if `return_dict=True` is passed or set in the configuration. + The `Tuple` can contain: + - `loss` (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + - `logits` (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + - `past_key_values` (`List[torch.FloatTensor]`, *optional*, returned when `use_cache=True` is passed or set in the configuration): + Contains pre-computed hidden-states (key and values in the attention blocks) as computed in the previous forward pass. + Can be used to speed up sequential decoding. + - `hidden_states` (`List[torch.FloatTensor]`, *optional*, returned when `output_hidden_states=True` is passed or set in the configuration): + Contains the hidden-states of the model at the output of each layer plus the initial embedding outputs. + - `attentions` (`List[torch.FloatTensor]`, *optional*, returned when `output_attentions=True` is passed or set in the configuration): + Contains the attention weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ``` + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + modality_indicators=modality_indicators, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +class MPLUGDocOwlHReducer(MPLUGDocOwlPreTrainedModel): + r""" + MPLUGDocOwlHReducer is a spatial-aware vision-to-text module designed for Visual Document Understanding. + This component processes high-resolution text-rich images by reducing the visual sequence length while + preserving spatial information. It uses a convolutional layer followed by a fully connected layer to align + visual features with language embeddings. + + Unlike other popular vision-to-text modules such as MLPs or cross-attention modules with learnable queries, + the H-Reducer is specifically designed to handle high-resolution images efficiently without losing spatial + coherence. See the paper https://arxiv.org/pdf/2403.12895 for more details. + + Attributes: + config (Config): Model configuration containing hyperparameters for the language model and hreducer. + + Methods: + __init__(config): + Initializes the MPLUGDocOwlHReducer with the given configuration. + forward(encoder_hidden_states=None): + Processes the encoder hidden states to reduce visual feature length and align them with language embeddings. + """ + + def __init__(self, config): + r""" + Initializes the MPLUGDocOwlHReducer with the given configuration. + + Args: + config (Config): Model configuration containing various hyperparameters. + """ + + super().__init__(config) + self.config = config + self.conv_shape = ( + int(self.config.hreducer_conv_shape.split("x")[0]), + int(self.config.hreducer_conv_shape.split("x")[1]), + ) + self.layer_norm = torch.nn.LayerNorm(self.config.hreducer_hidden_size, eps=self.config.hreducer_layer_norm) + self.conv_patch = self.conv_shape[0] * self.conv_shape[1] + self.reducer_before = torch.nn.Sequential( + nn.Conv2d( + self.config.hreducer_hidden_size, + self.conv_patch * self.config.hreducer_hidden_size, + kernel_size=self.conv_shape, + stride=self.conv_shape, + bias=True, + ), + nn.GELU(), + ) + ## reduce visual feature length with a conv layer + self.reducer = nn.Conv2d( + self.config.hreducer_hidden_size, + self.config.hreducer_hidden_size, + kernel_size=self.conv_shape, + stride=self.conv_shape, + bias=True, + ) + ## align visual features with language embedding with fc + self.visual_fc = torch.nn.Linear(self.config.hreducer_hidden_size, config.text_config.hidden_size) + self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, config.text_config.hidden_size)) + self.gradient_checkpointing = False + self.post_init() + + def forward(self, encoder_hidden_states=None): + r""" + Processes the encoder hidden states to reduce visual feature length and align them with language embeddings. + + Args: + encoder_hidden_states (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size), optional): + Batch size is the number of all images (global+crop) in a batch. + Sequence of hidden-states at the output of the last layer of the encoder. + + Returns: + torch.FloatTensor: The processed sequence output with reduced visual feature length and aligned with language embeddings. + + Example: + >>> config = Config() # Assuming Config is already defined + >>> model = MPLUGDocOwlHReducer(config) + >>> encoder_hidden_states = torch.randn(batch_size, sequence_length, hidden_size) # Example tensor + >>> output = model.forward(encoder_hidden_states) + """ + + # Remove the first cls token + encoder_hidden_states = encoder_hidden_states[ + :, 1:, : + ] # Shape: (batch_size, sequence_length - 1, hidden_size) + + # B - batch_size, L - sequence_length, C - hidden_size + batch_size, seq_len, hidden_size = encoder_hidden_states.shape + + # Calculate height assuming seq_len is a square number + height = int(torch.sqrt(torch.tensor(seq_len))) + + # Transpose and reshape encoder hidden states + encoder_hidden_states = encoder_hidden_states.transpose( + 2, 1 + ) # Shape: (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.view( + batch_size, hidden_size, height, height + ) # Shape: (batch_size, hidden_size, height, height) + + # Apply reducer (e.g., a convolution) + reduced_states = self.reducer_before( + encoder_hidden_states + ) # Shape: (batch_size, reduced_depth, height, width_reduced) + + # B - batch_size, reduced_depth - reduced depth dimension, height - height, width_reduced - reduced width + batch_size, reduced_depth, height, width_reduced = reduced_states.shape + + # Number of patches in width + num_patches = self.conv_patch + + # New depth dimension + depth = reduced_depth // num_patches + + # Reshape reduced states + reduced_states = reduced_states.view( + batch_size, num_patches, depth, height, width_reduced + ) # Shape: (batch_size, num_patches, depth, height, width_reduced) + reduced_states = reduced_states.permute( + 0, 2, 3, 4, 1 + ) # Shape: (batch_size, depth, height, width_reduced, num_patches) + reduced_states = reduced_states.reshape( + batch_size, depth, height, width_reduced * num_patches + ) # Shape: (batch_size, depth, height, width) + + # Apply final reducer (e.g., a convolution) + sequence_output = self.reducer(reduced_states) # Shape: (batch_size, final_depth, final_height, final_width) + + # Flatten and transpose to (batch_size, seq_length_reduced, final_depth) + sequence_output = sequence_output.flatten(2).transpose( + 1, 2 + ) # Shape: (batch_size, seq_length_reduced, final_depth) + sequence_output = sequence_output.transpose( + 0, 1 + ).contiguous() # Shape: (seq_length_reduced, batch_size, final_depth) + + # Apply final fully connected layer + sequence_output = self.visual_fc(sequence_output) # Shape: (seq_length_reduced, batch_size, final_hidden_size) + sequence_output = sequence_output.transpose( + 0, 1 + ).contiguous() # Shape: (batch_size, seq_length_reduced, final_hidden_size) + + # Concatenate end-of-sequence token + sequence_output = torch.cat( + [sequence_output, self.vit_eos.repeat(batch_size, 1, 1)], dim=1 + ) # Shape: (batch_size, seq_length_reduced + 1, final_hidden_size) + + return sequence_output + + +@add_start_docstrings( + """The MPLUGDOCOWL model which consists of a vision backbone and a language model.""", + MPLUGDocOwl_START_DOCSTRING, +) +class MPLUGDocOwlForConditionalGeneration(MPLUGDocOwlPreTrainedModel): + def __init__(self, config: MPLUGDocOwlConfig): + super().__init__(config) + + self.vision_tower = MPLUGDocOwlVisionModel(config.vision_config) + self.multi_modal_projector = MPLUGDocOwlHReducer(config) + self.vocab_size = config.text_config.vocab_size + + self.language_model = MPLUGDocOwlForCausalLM(config.text_config) + + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + @add_start_docstrings_to_model_forward(MPLUGDOCOWL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MPLUGDocOwlCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + patch_positions: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, MPLUGDocOwlCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + + Returns: + `Union[Tuple, MPLUGDocOwlCausalLMOutputWithPast]`: A tuple containing the output logits, and optionally the loss if `labels` is provided, or an MPLUGDocOwlCausalLMOutputWithPast object with the following attributes: + - loss (optional): `torch.FloatTensor` of shape `(1,)` if `labels` is provided. + - logits: `torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`. + - past_key_values (optional): list of `torch.FloatTensor` containing pre-computed hidden-states (key and values in the attention blocks) that can be used to speed up sequential decoding. + - hidden_states (optional): list of `torch.FloatTensor` (one for the output of each layer + output embedding). + - attentions (optional): list of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, MPLUGDocOwlForConditionalGeneration + + >>> model = MPLUGDocOwlForConditionalGeneration.from_pretrained("danaaubakirova/mplugdocowl1.5-Chat-hf") + >>> processor = AutoProcessor.from_pretrained("danaaubakirova/mplugdocowl1.5-Chat-hf") + + >>> prompt = "What's the content of the image? ASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "USER: What's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" + ``` + + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # modality indicators are like token-type-ids and denote `1` for positions where image_embeddings are + batch_size, seq_len, _ = inputs_embeds.shape + modality_indicators = torch.zeros((batch_size, seq_len), device=inputs_embeds.device) + + if pixel_values is not None: + image_outputs = self.vision_tower(pixel_values, output_hidden_states=False).last_hidden_state + image_features = self.multi_modal_projector(encoder_hidden_states=image_outputs) + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + modality_indicators[input_ids == self.config.image_token_index] = 1 + + outputs = self.language_model( + attention_mask=attention_mask, + modality_indicators=modality_indicators, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MPLUGDocOwlCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + pixel_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + # Use pixel values if we are in pre-fill stage or generation w/o cache + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + def _reorder_cache(self, *args, **kwargs): + return self.language_model._reorder_cache(*args, **kwargs) diff --git a/src/transformers/models/mplugdocowl/processing_mplugdocowl.py b/src/transformers/models/mplugdocowl/processing_mplugdocowl.py new file mode 100644 index 00000000000000..de0841964b1989 --- /dev/null +++ b/src/transformers/models/mplugdocowl/processing_mplugdocowl.py @@ -0,0 +1,256 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for MPLUGDocOwl. +""" + +from typing import Dict, List, Optional, Union + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import TensorType + + +class MPLUGDocOwlProcessor(ProcessorMixin): + r""" + Constructs a MPLUGDocOwl processor which wraps a MPLUGDocOwl image processor and a MPLUGDocOwl tokenizer into a single processor. + + [`MPLUGDocOwlProcessor`] offers all the functionalities of [`MPLUGDocOwlImageProcessor`] and [`AutoTokenizer`]. See the + [`~MPLUGDocOwlProcessor.__call__`] and [`~MPLUGDocOwlProcessor.decode`] for more information. + + Args: + image_processor ([`MPLUGDocOwlImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`AutoTokenizer`], *optional*): + The tokenizer is a required input. + num_image_tokens (`int`, *optional*, defaults to 257): + The sequence length of image embeddings after the HReducer module. + image_token (`str`, *optional*, defaults to ""): + The string form of the token corresponding to the special `image` token used as a placeholder. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template", "num_image_tokens", "image_token"] + image_processor_class = "MPLUGDocOwlImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__( + self, + image_processor=None, + tokenizer=None, + chat_template=None, + num_image_tokens=257, + image_token="", + **kwargs, + ): + self.num_image_tokens = num_image_tokens + self.image_token = image_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def generate_text_with_placeholders( + self, text, patch_positions, anchor_max, num_patches, add_textual_crop_indicator + ): + """ + Generates a text string with placeholders for images and optional textual crop indicators. + + Parameters: + - text (str): The input text containing tokens where image placeholders should be inserted. + - patch_positions (numpy.ndarray): Array of patch positions indicating the location of cropped images. + - anchor_max (int): The maximum anchor value used to identify global images. + - num_patches (int): The number of patches (or cropped images) to be represented in the text. + - add_textual_crop_indicator (bool): Flag indicating whether to add textual crop indicators in the output. + + Returns: + - str: The generated text with appropriate image placeholders and optional crop indicators. + """ + media_token = "" + if media_token not in text: + raise ValueError("The prompt must contain the media token ''") + text_list = text.split(media_token) + text = "USER: " + image_token_count = 0 + + for next_text in text_list[1:]: + if add_textual_crop_indicator: + # Generate image placeholders with interleaved textual crop indicator + for patch_pos in patch_positions.tolist(): + if patch_pos[0] == anchor_max and patch_pos[1] == anchor_max: + text += "" + else: + row_col = f"row{patch_pos[0]}_col{patch_pos[1]}" + text += f"" + else: + # Generate successive image placeholders for an image, 1 crop img == 1 + text += "" * num_patches + + text += next_text + image_token_count += 1 + + text += " ASSISTANT:" + return text + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + add_textual_crop_indicator: bool = True, + padding: Union[bool, str, PaddingStrategy] = True, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length=None, + do_rescale: bool = True, + do_convert_rgb: bool = True, + do_resize: bool = True, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = (0.48145466, 0.4578275, 0.40821073), + image_std: Optional[Union[float, List[float]]] = (0.26862954, 0.26130258, 0.27577711), + size: Dict[str, int] = {"width": 448, "height": 448}, + do_anchor_resize: bool = True, + do_shape_adaptive_cropping: bool = True, + do_add_global_image: bool = True, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to AutoTokenizer's [`~AutoTokenizer.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + MPLUGDocOwlImageProcessor's [`~MPLUGDocOwlImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + images (ImageInput, optional): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array, or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], optional): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). + add_textual_crop_indicator (bool, optional): + Whether to add a textual crop indicator to the images. Defaults to True. + padding (Union[bool, str, PaddingStrategy], optional): + Select a strategy to pad the returned sequences. Defaults to True. + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence is provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (Union[bool, str, TruncationStrategy], optional): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + max_length (int, optional): + Maximum length of the returned list and optionally padding length. + do_rescale (bool, optional): + Whether to rescale the image. Defaults to True. + do_convert_rgb (bool, optional): + Whether to convert the image to RGB. Defaults to True. + do_resize (bool, optional): + Whether to resize the image. Defaults to True. + do_normalize (bool, optional): + Whether to normalize the image. Defaults to None. + image_mean (Optional[Union[float, List[float]]], optional): + The mean values for image normalization. Defaults to (0.48145466, 0.4578275, 0.40821073). + image_std (Optional[Union[float, List[float]]], optional): + The standard deviation values for image normalization. Defaults to (0.26862954, 0.26130258, 0.27577711). + size (Dict[str, int], optional): + A dictionary specifying the desired width and height for resizing. Defaults to {"width": 448, "height": 448}. + do_anchor_resize (bool, optional): + Whether to resize the image based on the specified anchor. Defaults to True. + do_shape_adaptive_cropping (bool, optional): + Whether to do a shape adaptive cropping of the input image. Should be only called if the `do_anchor_resize` is True. Defaults to True. + do_add_global_image (bool, optional): + Whether to add the global image to the image input. Defaults to True. + return_tensors (Optional[Union[str, TensorType]], optional): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. Defaults to TensorType.PYTORCH. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + + if images is not None: + pixel_values = self.image_processor( + images, + do_rescale=do_rescale, + do_convert_rgb=do_convert_rgb, + do_shape_adaptive_cropping=do_shape_adaptive_cropping, + do_resize=do_resize, + do_normalize=do_normalize, + return_tensors=return_tensors, + image_mean=image_mean, + image_std=image_std, + size=size, + do_anchor_resize=do_anchor_resize, + do_add_global_image=do_add_global_image, + ) + else: + pixel_values = None + # text preprocessing + patch_positions = pixel_values["patch_positions"] + num_patches = pixel_values["num_patches"] + anchor_max = pixel_values["anchor_max"] + + if not isinstance(text, list): + text = [text] + + texts = [ + self.generate_text_with_placeholders(txt, patch_pos, anch_max, n_patches, add_textual_crop_indicator) + for txt, patch_pos, anch_max, n_patches in zip(text, patch_positions, anchor_max, num_patches) + ] + + prompt_strings = [] + for sample in texts: + sample = sample.replace(self.image_token, self.image_token * self.num_image_tokens) + prompt_strings.append(sample) + + text_inputs = self.tokenizer( + prompt_strings, + return_tensors=return_tensors, + padding=padding, + truncation=truncation, + max_length=max_length, + ) + + return BatchFeature( + data={**text_inputs, "pixel_values": pixel_values["pixel_values"], "patch_positions": patch_positions} + ) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to AutoTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index de739c6e70044a..9920aa7b42a96f 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6009,6 +6009,69 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class MPLUGDocOwlAttention(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MPLUGDocOwlForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MPLUGDocOwlForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MPLUGDocOwlHReducer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MPLUGDocOwlLanguageModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MPLUGDocOwlPreTrainedLanguageModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MPLUGDocOwlPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MPLUGDocOwlVisionModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MPLUGDocOwlVisionTransformer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MPNetForMaskedLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_vision_objects.py b/src/transformers/utils/dummy_vision_objects.py index 19f8dc1b1d9c9e..28f937a19a0a84 100644 --- a/src/transformers/utils/dummy_vision_objects.py +++ b/src/transformers/utils/dummy_vision_objects.py @@ -436,6 +436,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["vision"]) +class MPLUGDocOwlImageProcessor(metaclass=DummyObject): + _backends = ["vision"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["vision"]) + + class NougatImageProcessor(metaclass=DummyObject): _backends = ["vision"] diff --git a/tests/models/mplugdocowl/__init__.py b/tests/models/mplugdocowl/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/mplugdocowl/test_modeling_mplugdocowl.py b/tests/models/mplugdocowl/test_modeling_mplugdocowl.py new file mode 100644 index 00000000000000..5d57e8e2cc936f --- /dev/null +++ b/tests/models/mplugdocowl/test_modeling_mplugdocowl.py @@ -0,0 +1,400 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch MPLUGDocOwl model.""" + +import gc +import unittest + +import requests +from parameterized import parameterized + +from transformers import ( + MPLUGDocOwlConfig, + MPLUGDocOwlForConditionalGeneration, + MPLUGDocOwlProcessor, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + require_torch, + require_torch_sdpa, + require_vision, + slow, + torch_device, +) + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch +else: + is_torch_greater_or_equal_than_2_0 = False + +if is_vision_available(): + from PIL import Image + + +class MPLUGDocOwlVisionText2TextModelTester: + def __init__( + self, + parent, + ignore_index=-100, + image_token_index=0, + projector_hidden_act="gelu", + seq_length=7, + vision_feature_select_strategy="default", + hreducer_hidden_size=32, + hreducer_initializer_range=0.02, + hreducer_layer_norm=1e-6, + hreducer_conv_shape="1x2", + vision_feature_layer=-1, + text_config={ + "model_type": "llama", + "seq_length": 7, + "is_training": True, + "use_input_mask": True, + "use_token_type_ids": False, + "use_labels": True, + "vocab_size": 99, + "hidden_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "type_vocab_size": 16, + "type_sequence_label_size": 2, + "initializer_range": 0.02, + "num_labels": 3, + "num_choices": 4, + "pad_token_id": 0, + }, + is_training=True, + vision_config={ + "image_size": 30, + "patch_size": 2, + "num_channels": 3, + "is_training": True, + "hidden_size": 32, + "projection_dim": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "dropout": 0.1, + "attention_dropout": 0.1, + "initializer_range": 0.02, + }, + ): + self.parent = parent + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.text_config = text_config + self.vision_config = vision_config + self.seq_length = seq_length + self.hreducer_hidden_size = hreducer_hidden_size + self.hreducer_initializer_range = hreducer_initializer_range + self.hreducer_layer_norm = hreducer_layer_norm + self.hreducer_conv_shape = hreducer_conv_shape + + self.num_hidden_layers = text_config["num_hidden_layers"] + self.vocab_size = text_config["vocab_size"] + self.hidden_size = text_config["hidden_size"] + self.num_attention_heads = text_config["num_attention_heads"] + self.is_training = is_training + + self.batch_size = 3 + self.num_channels = 3 + self.image_size = 336 + self.encoder_seq_length = 112 + + def get_config(self): + return MPLUGDocOwlConfig( + hreducer_conv_shape=self.hreducer_conv_shape, + hreducer_hidden_size=self.hreducer_hidden_size, + hreducer_initializer_range=self.hreducer_initializer_range, + hreducer_layer_norm=self.hreducer_layer_norm, + text_config=self.text_config, + vision_config=self.vision_config, + ignore_index=self.ignore_index, + image_token_index=self.image_token_index, + projector_hidden_act=self.projector_hidden_act, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor( + [ + self.batch_size, + self.vision_config["num_channels"], + self.vision_config["image_size"], + self.vision_config["image_size"], + ] + ) + config = self.get_config() + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 + attention_mask = input_ids.ne(1).to(torch_device) + # we are giving 3 images let's make sure we pass in 3 image tokens + input_ids[:, 1] = config.image_token_index + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + def create_and_check_mplugdocowl_model_fp16_forward(self, config, input_ids, pixel_values, attention_mask): + model = MPLUGDocOwlForConditionalGeneration(config=config) + model.to(torch_device) + model.eval() + with torch.autocast(device_type="cuda", dtype=torch.float16): + logits = model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values.to(torch.bfloat16), + return_dict=True, + )["logits"] + self.parent.assertFalse(torch.isnan(logits).any().item()) + + +@require_torch +class MPLUGDocOwlForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase): + """ + Model tester for `MPLUGDocOwlForConditionalGeneration`. + """ + + all_model_classes = (MPLUGDocOwlForConditionalGeneration,) if is_torch_available() else () + test_pruning = False + test_head_masking = False + test_attention_outputs = False + test_torchscript = False + + def setUp(self): + self.model_tester = MPLUGDocOwlVisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=MPLUGDocOwlConfig, has_text_modality=False) + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="input_embeds cannot be passed in without input_ids") + def test_inputs_embeds(): + pass + + @require_torch_sdpa + @slow + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + self.skipTest(reason="This model does not support SDPA") + + @unittest.skip(reason="MPLUGDocOwl1.5 does not use feedforward chunking.") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="Compile not yet supported in MPLUGDocOwl1.5") + def test_sdpa_can_compile_dynamic(self): + pass + + @unittest.skip(reason="Compile not yet supported in MPLUGDocOwl1.5") + def test_sdpa_can_dispatch_on_flash(self): + pass + + @unittest.skip(reason="Hidden_states is tested in individual model tests") + def test_hidden_states_output(self): + pass + + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + + # Ensure all parameters are initialized to 0.0 or 1.0 + for name, param in model.named_parameters(): + if "embeddings" not in name and param.requires_grad: + # Explicitly initialize parameters + with torch.no_grad(): + param.fill_(0.0) # or param.fill_(1.0) based on your requirements + + # Calculate the rounded mean of the parameter data + param_mean = ((param.data.mean() * 1e9).round() / 1e9).item() + + # Check if the mean is either 0.0 or 1.0 + try: + self.assertIn( + param_mean, + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized: found {param_mean}, expected 0.0 or 1.0", + ) + except AssertionError as e: + print(f"Initialization error: {e}") + raise + + @unittest.skip( + reason="MPLUGDocOwlVisionModel does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Thus, cannot be created with no checkpoint." + ) + def test_from_pretrained_no_checkpoint(self): + pass + + +@require_vision +@require_torch +class MPLUGDocOwlForConditionalGenerationIntegrationTest(unittest.TestCase): + def setUp(self): + self.processor = MPLUGDocOwlProcessor.from_pretrained("danaaubakirova/mplugdocowl1.5-Chat-hf") + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + @slow + def test_small_model_integration_test(self): + model = MPLUGDocOwlForConditionalGeneration.from_pretrained( + "danaaubakirova/mplugdocowl1.5-Chat-hf", load_in_4bit=False + ) + + prompt = "What's the value of the Very well bar in the 65+ age group? Answer the question with detailed explanation." + raw_image = Image.open( + requests.get( + "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/test_image.png", + stream=True, + ).raw + ) + inputs = self.processor(prompt, raw_image, return_tensors="pt") + + output = model.generate(**inputs, max_new_tokens=500) + EXPECTED_DECODED_TEXT = """ 68%\nIn the image, which appears to be a chart from a Pew Research Center report, the bar representing the percentage of Republicans and Republican leaners who believe "very well" describes how fights for what they believe in describe Trump is at 68% for the 65+ age group.""" + + self.assertEqual( + self.processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + def test_small_model_integration_test_single(self): + # Let' s make sure we test the preprocessing to replace what is used + model = MPLUGDocOwlForConditionalGeneration.from_pretrained( + "danaaubakirova/mplugdocowl1.5-Chat-hf", load_in_4bit=False + ) + + prompt = "What is the name of the movie in the poster? Provide detailed explanation." + raw_image = Image.open( + requests.get( + "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/examples_Rebecca_(1939_poster)_Small.jpeg", + stream=True, + ).raw + ) + inputs = self.processor(prompt, raw_image, return_tensors="pt", do_add_global_image=True) + output = model.generate(**inputs, max_new_tokens=500) + EXPECTED_DECODED_TEXT = 'Rebecca\nThe name of the movie in the poster is "Rebecca," as indicated by the large title at the top of the poster. The poster also includes the names of the stars, Laurence Olivier and Joan Fontaine, suggesting that they are the lead actors in the film. The poster features a classic Hollywood style with a focus on the two main characters and the title.' # fmt: skip + self.assertEqual( + self.processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + def test_small_model_integration_test_mplugdocowl_single(self): + # Let' s make sure we test the preprocessing to replace what is used + model_id = "danaaubakirova/mplugdocowl1.5-Chat-hf" + + model = MPLUGDocOwlForConditionalGeneration.from_pretrained( + "danaaubakirova/mplugdocowl1.5-Chat-hf", load_in_4bit=False + ) + processor = MPLUGDocOwlProcessor.from_pretrained(model_id) + + prompt = "Recognize text in the image." + raw_image = Image.open( + requests.get( + "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/test_image.tif", + stream=True, + ).raw + ) + + inputs = processor(prompt, raw_image, return_tensors="pt") # .to(torch_device, torch.float16) + + output = model.generate(**inputs, max_new_tokens=500, do_sample=False) + + EXPECTED_DECODED_TEXT = "PHILIP MORRIS MANAGEMENT CORP." + self.assertEqual( + processor.decode(output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + # @require_bitsandbytes + def test_small_model_integration_test_llama_batched(self): + # Let' s make sure we test the preprocessing to replace what is used + model_id = "danaaubakirova/mplugdocowl1.5-Chat-hf" + + model = MPLUGDocOwlForConditionalGeneration.from_pretrained( + "danaaubakirova/mplugdocowl1.5-Chat-hf", load_in_4bit=False + ) + processor = MPLUGDocOwlProcessor.from_pretrained(model_id) + + prompts = [ + "What is the name of the movie in the poster? Provide detailed explanation.", + "What is unusual about this image? Provide detailed explanation.", + ] + image1 = Image.open( + requests.get( + "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/examples_Rebecca_(1939_poster)_Small.jpeg", + stream=True, + ).raw + ) + image2 = Image.open( + requests.get( + "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/extreme_ironing.jpg", + stream=True, + ).raw + ) + + inputs = processor(text=prompts, images=[image1, image2], return_tensors="pt") + + output = model.generate(**inputs, max_new_tokens=512, do_sample=False, use_cache=True) + + EXPECTED_DECODED_TEXT = [ + 'USER: What is the name of the movie in the poster? Provide detailed explanation. ASSISTANT: Rebecca\nThe name of the movie in the poster is "Rebecca," as indicated by the large title at the top of the poster. The poster also includes the names of the stars, Laurence Olivier and Joan Fontaine, suggesting that they are the lead actors in the film. The poster features a classic Hollywood style with a focus on the two main characters and the title.', + "USER: What is unusual about this image? Provide detailed explanation. ASSISTANT:\nThe unusual aspect of this image is that the man is ironing clothes on the back of a taxi, which is not a common sight. It is not typical to see someone ironing on the back of a vehicle, especially in an urban setting where such activities are generally not practical due to the lack of space and the potential for disruption to traffic. The presence of a taxi with a man ironing on its back adds an element of surprise and novelty to the scene.", + ] + self.assertEqual( + processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) diff --git a/utils/check_repo.py b/utils/check_repo.py index 293089ccb662b4..705cd21ab7d2f9 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -128,6 +128,12 @@ "SeamlessM4TCodeHifiGan", # Building part of bigger (tested) model. "SeamlessM4TTextToUnitForConditionalGeneration", # Building part of bigger (tested) model. "ChameleonVQVAE", # VQVAE here is used only for encoding (discretizing) and is tested as part of bigger model + "MPLUGDocOwlHReducer", # Building part of bigger (tested) model. + "MPLUGDocOwlAttention", # Building part of bigger (tested) model. + "MPLUGDocOwlForCausalLM", # Building part of bigger (tested) model. + "MPLUGDocOwlLanguageModel", # Building part of bigger (tested) model. + "MPLUGDocOwlVisionModel", # Building part of bigger (tested) model. + "MPLUGDocOwlVisionTransformer", # Building part of bigger (tested) model. ] # Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't @@ -321,6 +327,12 @@ "SiglipVisionModel", "SiglipTextModel", "ChameleonVQVAE", # no autoclass for VQ-VAE models + "MPLUGDocOwlHReducer", + "MPLUGDocOwlAttention", + "MPLUGDocOwlForCausalLM", + "MPLUGDocOwlLanguageModel", + "MPLUGDocOwlVisionModel", + "MPLUGDocOwlVisionTransformer", ] # DO NOT edit this list!