From bf2817eeeaa11d46bad961138cc72d2ae6751a39 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 18 Jul 2024 12:21:39 -0400 Subject: [PATCH 01/13] Add template image-text-to-text pipeline --- docs/source/en/model_doc/auto.md | 4 + src/transformers/__init__.py | 4 + src/transformers/models/auto/__init__.py | 4 + src/transformers/models/auto/modeling_auto.py | 30 ++++ src/transformers/pipelines/__init__.py | 13 ++ .../pipelines/image_text_to_text.py | 144 ++++++++++++++++++ src/transformers/pipelines/image_to_text.py | 9 +- src/transformers/utils/dummy_pt_objects.py | 10 ++ utils/update_metadata.py | 1 + 9 files changed, 218 insertions(+), 1 deletion(-) create mode 100644 src/transformers/pipelines/image_text_to_text.py diff --git a/docs/source/en/model_doc/auto.md b/docs/source/en/model_doc/auto.md index ab42c24d83e82d..a98b105417370d 100644 --- a/docs/source/en/model_doc/auto.md +++ b/docs/source/en/model_doc/auto.md @@ -381,3 +381,7 @@ The following auto classes are available for the following multimodal tasks. ### FlaxAutoModelForVision2Seq [[autodoc]] FlaxAutoModelForVision2Seq + +### AutoModelForImageTextToText + +[[autodoc]] AutoModelForImageTextToText \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9108367f35b321..eb3ac1174d9daa 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1353,6 +1353,7 @@ "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "MODEL_FOR_IMAGE_MAPPING", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", + "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING", "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING", "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", "MODEL_FOR_KEYPOINT_DETECTION_MAPPING", @@ -1394,6 +1395,7 @@ "AutoModelForDocumentQuestionAnswering", "AutoModelForImageClassification", "AutoModelForImageSegmentation", + "AutoModelForImageTextToText", "AutoModelForImageToImage", "AutoModelForInstanceSegmentation", "AutoModelForKeypointDetection", @@ -6056,6 +6058,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, MODEL_FOR_KEYPOINT_DETECTION_MAPPING, @@ -6097,6 +6100,7 @@ AutoModelForDocumentQuestionAnswering, AutoModelForImageClassification, AutoModelForImageSegmentation, + AutoModelForImageTextToText, AutoModelForImageToImage, AutoModelForInstanceSegmentation, AutoModelForKeypointDetection, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 3bb2b8e9d4c199..2ee0541a1a71b8 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -74,6 +74,7 @@ "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING", "MODEL_FOR_VISION_2_SEQ_MAPPING", + "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING", "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING", "MODEL_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING", @@ -119,6 +120,7 @@ "AutoModelWithLMHead", "AutoModelForZeroShotImageClassification", "AutoModelForZeroShotObjectDetection", + "AutoModelForImageTextToText", ] try: @@ -238,6 +240,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, MODEL_FOR_IMAGE_TO_IMAGE_MAPPING, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, MODEL_FOR_KEYPOINT_DETECTION_MAPPING, @@ -279,6 +282,7 @@ AutoModelForDocumentQuestionAnswering, AutoModelForImageClassification, AutoModelForImageSegmentation, + AutoModelForImageTextToText, AutoModelForImageToImage, AutoModelForInstanceSegmentation, AutoModelForKeypointDetection, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index d096abf4342614..b4145fefc9249e 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -719,6 +719,26 @@ ] ) +MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict( + [ + ("blip", "BlipForConditionalGeneration"), + ("blip-2", "Blip2ForConditionalGeneration"), + ("fuyu", "FuyuForCausalLM"), + ("git", "GitForCausalLM"), + ("idefics", "IdeficsForVisionText2Text"), + ("instructblip", "InstructBlipForConditionalGeneration"), + ("kosmos-2", "Kosmos2ForConditionalGeneration"), + ("llava", "LlavaForConditionalGeneration"), + ("pix2struct", "Pix2StructForConditionalGeneration"), + ("udop", "UdopForConditionalGeneration"), + ("vipllava", "VipLlavaForConditionalGeneration"), + ("vision-encoder-decoder", "VisionEncoderDecoderModel"), + ("idefics2", "Idefics2ForConditionalGeneration"), + ("llava_next", "LlavaNextForConditionalGeneration"), + ("paligemma", "PaliGemmaForConditionalGeneration"), + ] +) + MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( [ # Model for Masked LM mapping @@ -1371,6 +1391,9 @@ CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES ) MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) +MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES +) MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES ) @@ -1665,6 +1688,13 @@ class AutoModelForVision2Seq(_BaseAutoModelClass): AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling") +class AutoModelForImageTextToText(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING + + +AutoModelForImageTextToText = auto_class_update(AutoModelForImageTextToText, head_doc="image-text-to-text modeling") + + class AutoModelForAudioClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 9bc0a1cf8b4677..4ebfb2a086a4f8 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -65,6 +65,7 @@ from .image_classification import ImageClassificationPipeline from .image_feature_extraction import ImageFeatureExtractionPipeline from .image_segmentation import ImageSegmentationPipeline +from .image_text_to_text import ImageTextToTextPipeline from .image_to_image import ImageToImagePipeline from .image_to_text import ImageToTextPipeline from .mask_generation import MaskGenerationPipeline @@ -117,6 +118,7 @@ AutoModelForDocumentQuestionAnswering, AutoModelForImageClassification, AutoModelForImageSegmentation, + AutoModelForImageTextToText, AutoModelForMaskedLM, AutoModelForMaskGeneration, AutoModelForObjectDetection, @@ -382,6 +384,17 @@ }, "type": "multimodal", }, + "image-text-to-text": { + "impl": ImageTextToTextPipeline, + "tf": (), + "pt": (AutoModelForImageTextToText,) if is_torch_available() else (), + "default": { + "model": { + "pt": ("Salesforce/blip-image-captioning-base", "89b09ea"), + } + }, + "type": "multimodal", + }, "object-detection": { "impl": ObjectDetectionPipeline, "tf": (), diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py new file mode 100644 index 00000000000000..60587d02fc0e43 --- /dev/null +++ b/src/transformers/pipelines/image_text_to_text.py @@ -0,0 +1,144 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +from ..utils import ( + add_end_docstrings, + is_torch_available, + is_vision_available, + logging, + requires_backends, +) +from .base import Pipeline, build_pipeline_init_args + + +if is_vision_available(): + from PIL import Image + + from ..image_utils import load_image + + +if is_torch_available(): + from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES + +logger = logging.get_logger(__name__) + + +@add_end_docstrings(build_pipeline_init_args(has_processor=True)) +class ImageTextToTextPipeline(Pipeline): + """ + Image-text-to-text pipeline using an `AutoModelForImageTextToText`. This pipeline generates text given an image and text. + + Example: + + ```python + >>> from transformers import pipeline + + >>> pipe = pipeline(task="image-text-to-text", model="Salesforce/blip-image-captioning-base") + >>> pipe("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", text="A photo of") + [{'generated_text': 'a photo of two birds'}] + ``` + + Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) + + This image-text to text pipeline can currently be loaded from pipeline() using the following task identifier: + "image-text-to-text". + + See the list of available models on + [huggingface.co/models](https://huggingface.co/models?pipeline_tag=image-text-to-text). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + requires_backends(self, "vision") + self.check_model_type(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES) + + def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, text=None, timeout=None): + forward_kwargs = {} + preprocess_params = {} + + if text is not None: + preprocess_params["text"] = text + if timeout is not None: + preprocess_params["timeout"] = timeout + + if generate_kwargs is not None: + forward_kwargs["generate_kwargs"] = generate_kwargs + if max_new_tokens is not None: + if "generate_kwargs" not in forward_kwargs: + forward_kwargs["generate_kwargs"] = {} + if "max_new_tokens" in forward_kwargs["generate_kwargs"]: + raise ValueError( + "'max_new_tokens' is defined twice, once in 'generate_kwargs' and once as a direct parameter," + " please use only one" + ) + forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens + + return preprocess_params, forward_kwargs, {} + + def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs): + """ + Generate a text given text and the image(s) passed as inputs. + + Args: + images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`): + The pipeline handles three types of images: + + - A string containing a HTTP(s) link pointing to an image + - A string containing a local path to an image + - An image loaded in PIL directly + + The pipeline accepts either a single image or a batch of images. + + text (`str`): + The text to be used as a prompt for the generation. + + max_new_tokens (`int`, *optional*): + The amount of maximum tokens to generate. By default it will use `generate` default. + + generate_kwargs (`Dict`, *optional*): + Pass it to send all of these arguments directly to `generate` allowing full control of this function. + timeout (`float`, *optional*, defaults to None): + The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and + the call may block forever. + + Return: + A list or a list of list of `dict`: Each result comes as a dictionary with the following key: + + - **generated_text** (`str`) -- The generated text. + """ + return super().__call__(images, **kwargs) + + def preprocess(self, image=None, text=None, timeout=None): + pass + + def _forward(self, model_inputs, generate_kwargs=None): + if generate_kwargs is None: + generate_kwargs = {} + + model_outputs = self.model.generate(**model_inputs, **generate_kwargs) + return model_outputs + + def postprocess(self, model_outputs): + records = [] + generated_texts = self.processor.batch_decode( + model_outputs, + skip_special_tokens=True, + ) + + records = [{"generated_text": text} for text in generated_texts] + + return records \ No newline at end of file diff --git a/src/transformers/pipelines/image_to_text.py b/src/transformers/pipelines/image_to_text.py index 88dce8e591ae41..ed99dfc497df0d 100644 --- a/src/transformers/pipelines/image_to_text.py +++ b/src/transformers/pipelines/image_to_text.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import List, Union from ..utils import ( @@ -96,7 +97,7 @@ def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs): """ - Assign labels to the image(s) passed as inputs. + Generate text based on the image(s) passed as inputs. Args: images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`): @@ -128,6 +129,12 @@ def preprocess(self, image, prompt=None, timeout=None): image = load_image(image, timeout=timeout) if prompt is not None: + warnings.warn( + "Passing `prompt` to the `image-to-text` pipeline is deprecated and will be removed in version 4.45" + " of ๐Ÿค— Transformers. Use the `image-text-to-text` pipeline instead", + FutureWarning, + ) + if not isinstance(prompt, str): raise ValueError( f"Received an invalid text input, got - {type(prompt)} - but expected a single string. " diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index de739c6e70044a..6b1244005db9db 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -668,6 +668,9 @@ def __init__(self, *args, **kwargs): MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = None +MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = None + + MODEL_FOR_IMAGE_TO_IMAGE_MAPPING = None @@ -835,6 +838,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class AutoModelForImageTextToText(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class AutoModelForImageToImage(metaclass=DummyObject): _backends = ["torch"] diff --git a/utils/update_metadata.py b/utils/update_metadata.py index 1806eb3f03df5a..799a59a9a59a4a 100755 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -110,6 +110,7 @@ "AutoModelForVisualQuestionAnswering", ), ("image-to-text", "MODEL_FOR_FOR_VISION_2_SEQ_MAPPING_NAMES", "AutoModelForVision2Seq"), + ("image-text-to-text", "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES", "AutoModelForImageTextToText"), ( "zero-shot-image-classification", "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES", From ab0139ecd02e98067440572b14e6b40cfb3e50a1 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 18 Jul 2024 16:35:35 +0000 Subject: [PATCH 02/13] fix style --- docs/source/ja/model_doc/auto.md | 4 ++++ src/transformers/models/auto/modeling_auto.py | 6 +++--- src/transformers/pipelines/image_text_to_text.py | 4 +--- src/transformers/pipelines/image_to_text.py | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/docs/source/ja/model_doc/auto.md b/docs/source/ja/model_doc/auto.md index d4baaf70e6fd48..492c46c79ea905 100644 --- a/docs/source/ja/model_doc/auto.md +++ b/docs/source/ja/model_doc/auto.md @@ -368,3 +368,7 @@ AutoModel.register(NewModelConfig, NewModel) ### FlaxAutoModelForVision2Seq [[autodoc]] FlaxAutoModelForVision2Seq + +### AutoModelForImageTextToText + +[[autodoc]] AutoModelForImageTextToText diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index b4145fefc9249e..9ca99b13a71046 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -726,16 +726,16 @@ ("fuyu", "FuyuForCausalLM"), ("git", "GitForCausalLM"), ("idefics", "IdeficsForVisionText2Text"), + ("idefics2", "Idefics2ForConditionalGeneration"), ("instructblip", "InstructBlipForConditionalGeneration"), ("kosmos-2", "Kosmos2ForConditionalGeneration"), ("llava", "LlavaForConditionalGeneration"), + ("llava_next", "LlavaNextForConditionalGeneration"), + ("paligemma", "PaliGemmaForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"), ("udop", "UdopForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"), ("vision-encoder-decoder", "VisionEncoderDecoderModel"), - ("idefics2", "Idefics2ForConditionalGeneration"), - ("llava_next", "LlavaNextForConditionalGeneration"), - ("paligemma", "PaliGemmaForConditionalGeneration"), ] ) diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index 60587d02fc0e43..9cdb8780ab170c 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -28,8 +28,6 @@ if is_vision_available(): from PIL import Image - from ..image_utils import load_image - if is_torch_available(): from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES @@ -141,4 +139,4 @@ def postprocess(self, model_outputs): records = [{"generated_text": text} for text in generated_texts] - return records \ No newline at end of file + return records diff --git a/src/transformers/pipelines/image_to_text.py b/src/transformers/pipelines/image_to_text.py index ed99dfc497df0d..a34e5e1a7e190b 100644 --- a/src/transformers/pipelines/image_to_text.py +++ b/src/transformers/pipelines/image_to_text.py @@ -134,7 +134,7 @@ def preprocess(self, image, prompt=None, timeout=None): " of ๐Ÿค— Transformers. Use the `image-text-to-text` pipeline instead", FutureWarning, ) - + if not isinstance(prompt, str): raise ValueError( f"Received an invalid text input, got - {type(prompt)} - but expected a single string. " From 91de7146ddf771ab693fc126f8bb305c844de5c0 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 23 Jul 2024 15:57:41 +0000 Subject: [PATCH 03/13] Add processor handling --- src/transformers/pipelines/__init__.py | 25 ++++++++++++++++ src/transformers/pipelines/base.py | 14 ++++++++- .../pipelines/image_text_to_text.py | 30 ++++++++++++++++++- 3 files changed, 67 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 4ebfb2a086a4f8..8c1fdfd4115604 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -24,6 +24,7 @@ from ..dynamic_module_utils import get_class_from_dynamic_module from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..image_processing_utils import BaseImageProcessor +from ..processing_utils import ProcessorMixin from ..models.auto.configuration_auto import AutoConfig from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING, AutoImageProcessor @@ -569,6 +570,7 @@ def pipeline( tokenizer: Optional[Union[str, PreTrainedTokenizer, "PreTrainedTokenizerFast"]] = None, feature_extractor: Optional[Union[str, PreTrainedFeatureExtractor]] = None, image_processor: Optional[Union[str, BaseImageProcessor]] = None, + processor: Optional[ProcessorMixin] = None, framework: Optional[str] = None, revision: Optional[str] = None, use_fast: bool = True, @@ -920,6 +922,7 @@ def pipeline( load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None + load_processor = type(model_config) in PROCESSOR_MAPPING or processor is not None # If `model` (instance of `PretrainedModel` instead of `str`) is passed (and/or same for config), while # `image_processor` or `feature_extractor` is `None`, the loading will fail. This happens particularly for some @@ -1082,6 +1085,28 @@ def pipeline( if not is_pyctcdecode_available(): logger.warning("Try to install `pyctcdecode`: `pip install pyctcdecode") + if load_processor: + # Try to infer processor from model or config name (if provided as str) + if processor is None: + if isinstance(model_name, str): + processor = model_name + elif isinstance(config, str): + processor = config + elif load_image_processor or load_feature_extractor: + pass + else: + # Impossible to guess what is the right processor here + raise Exception( + "Impossible to guess which processor to use. " + "Please provide a ProcessorMixin class or a path/identifier " + "to a pretrained processor." + ) + + # Instantiate processor if needed + if isinstance(processor, (str, tuple)): + processor = AutoProcessor.from_pretrained(processor, _from_pipeline=task, **hub_kwargs, **model_kwargs) + + if task == "translation" and model.config.task_specific_params: for key in model.config.task_specific_params: if key.startswith("translation"): diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 09f77402a143af..a9904fafcd4b17 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -31,6 +31,7 @@ from ..dynamic_module_utils import custom_object_save from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..image_processing_utils import BaseImageProcessor +from ..processing_utils import ProcessorMixin from ..modelcard import ModelCard from ..models.auto.configuration_auto import AutoConfig from ..tokenization_utils import PreTrainedTokenizer @@ -711,6 +712,7 @@ def build_pipeline_init_args( has_tokenizer: bool = False, has_feature_extractor: bool = False, has_image_processor: bool = False, + has_processor: bool = False, supports_binary_output: bool = True, ) -> str: docstring = r""" @@ -733,6 +735,11 @@ def build_pipeline_init_args( image_processor ([`BaseImageProcessor`]): The image processor that will be used by the pipeline to encode data for the model. This object inherits from [`BaseImageProcessor`].""" + if has_processor: + docstring += r""" + processor ([`ProcessorMixin`]): + The processor that will be used by the pipeline to encode data for the model. This object inherits from + [`ProcessorMixin`].""" docstring += r""" modelcard (`str` or [`ModelCard`], *optional*): Model card attributed to the model for this pipeline. @@ -769,7 +776,7 @@ def build_pipeline_init_args( PIPELINE_INIT_ARGS = build_pipeline_init_args( - has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, supports_binary_output=True + has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, has_processor=True, supports_binary_output=True ) @@ -808,6 +815,7 @@ def __init__( tokenizer: Optional[PreTrainedTokenizer] = None, feature_extractor: Optional[PreTrainedFeatureExtractor] = None, image_processor: Optional[BaseImageProcessor] = None, + processor:Optional[ProcessorMixin] = None, modelcard: Optional[ModelCard] = None, framework: Optional[str] = None, task: str = "", @@ -825,6 +833,7 @@ def __init__( self.tokenizer = tokenizer self.feature_extractor = feature_extractor self.image_processor = image_processor + self.processor = processor self.modelcard = modelcard self.framework = framework @@ -990,6 +999,9 @@ def save_pretrained( if self.image_processor is not None: self.image_processor.save_pretrained(save_directory, **kwargs) + if self.processor is not None: + self.processor.save_pretrained(save_directory, **kwargs) + if self.modelcard is not None: self.modelcard.save_pretrained(save_directory) diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index 9cdb8780ab170c..549a2cb1b8860f 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -28,6 +28,8 @@ if is_vision_available(): from PIL import Image + from ..image_utils import load_image + if is_torch_available(): from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES @@ -121,7 +123,33 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag return super().__call__(images, **kwargs) def preprocess(self, image=None, text=None, timeout=None): - pass + if image is not None: + image = load_image(image, timeout=timeout) + + model_type = self.model.config.model_type + + kwargs = {} + + # if model_type == "pix2struct": + # kwargs = {"add_special_tokens": False} + + # if model_type == "idefics": + # model_inputs = self.processor(text, return_tensors=self.framework, **kwargs) + model_inputs = self.processor(images=image, text=text, return_tensors=self.framework, **kwargs) + + # if model_type == "git": + # # remove EOS token from input_ids and attention_mask + # model_inputs["input_ids"] = model_inputs["input_ids"][:, :-1] + # model_inputs["attention_mask"] = model_inputs["attention_mask"][:, :-1] + + # if model_type == "vision-encoder-decoder" and self.processor.__class__.__name__ == "DonutProcessor": + # model_inputs["decoder_input_ids"] = self.processor.tokenizer( + # text, + # add_special_tokens=False, + # return_tensors=self.framework, + # ).input_ids + + return model_inputs def _forward(self, model_inputs, generate_kwargs=None): if generate_kwargs is None: From 693f0cd35ca48ee1577bb65d008dd6382aae3116 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 24 Jul 2024 16:14:20 +0000 Subject: [PATCH 04/13] Standardise processing donut, git, idefics, pix2struct --- .../models/auto/processing_auto.py | 1 + .../models/donut/processing_donut.py | 18 ++++++++-- src/transformers/models/git/processing_git.py | 11 +++++- .../models/idefics/processing_idefics.py | 35 ++++++++++++++++--- .../pix2struct/processing_pix2struct.py | 9 +++++ src/transformers/pipelines/__init__.py | 8 ++++- .../pipelines/image_text_to_text.py | 9 +++-- 7 files changed, 80 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index 1ab136a1e74ca7..e261072873c360 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -98,6 +98,7 @@ ("video_llava", "VideoLlavaProcessor"), ("vilt", "ViltProcessor"), ("vipllava", "LlavaProcessor"), + ("vision-encoder-decoder", "DonutProcessor"), ("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"), ("wav2vec2", "Wav2Vec2Processor"), ("wav2vec2-bert", "Wav2Vec2Processor"), diff --git a/src/transformers/models/donut/processing_donut.py b/src/transformers/models/donut/processing_donut.py index daf6e7d1dfe4ab..935d1ec46df8cc 100644 --- a/src/transformers/models/donut/processing_donut.py +++ b/src/transformers/models/donut/processing_donut.py @@ -71,6 +71,13 @@ def __call__(self, *args, **kwargs): [`~DonutTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information. """ # For backward compatibility + legacy = kwargs.pop("legacy", True) + print("legacy: ", legacy) + if legacy: + warnings.warn( + "The use of legacy will be deprecated in the future. Please use the new processing behavior by setting legacy=False." + ) + if self._in_target_context_manager: return self.current_processor(*args, **kwargs) @@ -85,7 +92,11 @@ def __call__(self, *args, **kwargs): if images is not None: inputs = self.image_processor(images, *args, **kwargs) - if text is not None: + if text is not None and images is None: + encodings = self.tokenizer(text, **kwargs) + elif text is not None: + if not legacy: + kwargs.update({"add_special_tokens": False}) encodings = self.tokenizer(text, **kwargs) if text is None: @@ -93,7 +104,10 @@ def __call__(self, *args, **kwargs): elif images is None: return encodings else: - inputs["labels"] = encodings["input_ids"] + if not legacy: + inputs["decoder_input_ids"] = encodings["input_ids"] + else: + inputs["labels"] = encodings["input_ids"] return inputs def batch_decode(self, *args, **kwargs): diff --git a/src/transformers/models/git/processing_git.py b/src/transformers/models/git/processing_git.py index 98649c644e728c..3ffeb077e1dfdc 100644 --- a/src/transformers/models/git/processing_git.py +++ b/src/transformers/models/git/processing_git.py @@ -16,6 +16,7 @@ Image/Text processor class for GIT """ +import warnings from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import BatchEncoding @@ -42,7 +43,7 @@ def __init__(self, image_processor, tokenizer): super().__init__(image_processor, tokenizer) self.current_processor = self.image_processor - def __call__(self, text=None, images=None, return_tensors=None, **kwargs): + def __call__(self, text=None, images=None, return_tensors=None, legacy=True, **kwargs): """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode @@ -76,6 +77,11 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + if legacy: + warnings.warn( + "The use of legacy will be deprecated in the future. Please use the new processing behavior by setting legacy=False." + ) + tokenizer_kwargs, image_processor_kwargs = {}, {} if kwargs: tokenizer_kwargs = {k: v for k, v in kwargs.items() if k not in self.image_processor._valid_processor_keys} @@ -94,6 +100,9 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs): if text is not None and images is not None: encoding["pixel_values"] = image_features.pixel_values + if not legacy: + encoding["input_ids"] = encoding["input_ids"][:, :-1] + encoding["attention_mask"] = encoding["attention_mask"][:, :-1] return encoding elif text is not None: return encoding diff --git a/src/transformers/models/idefics/processing_idefics.py b/src/transformers/models/idefics/processing_idefics.py index 8e9e196764f923..a0d841f7430538 100644 --- a/src/transformers/models/idefics/processing_idefics.py +++ b/src/transformers/models/idefics/processing_idefics.py @@ -18,11 +18,13 @@ from typing import Callable, List, Optional, Union from urllib.parse import urlparse +import warnings from ...feature_extraction_utils import BatchFeature from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy +from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy, PreTokenizedInput from ...utils import is_tf_available, is_torch_available +from ...utils import TensorType if is_torch_available(): @@ -201,15 +203,18 @@ def __init__(self, image_processor, tokenizer=None, image_size=224, add_end_of_u def __call__( self, - prompts: Union[List[TextInput], List[List[TextInput]]], - padding: Union[bool, str, PaddingStrategy] = "longest", + images=None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = None, max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + legacy = True, + prompts: Optional[Union[List[TextInput], List[List[TextInput]]]] = None, transform: Callable = None, add_eos_token=False, add_end_of_utterance_token=None, debug=False, - return_tensors="pt", ) -> BatchEncoding: """This method takes batched or non-batched prompts made of text and images and converts them into prompts that the model was trained on and prepares the image pixel values for the model to process. @@ -318,11 +323,31 @@ def __call__( """ + if legacy: + warnings.warn( + "The use of legacy will be deprecated in the future. Please use the new processing behavior by setting legacy=False." + ) + if prompts is None: + # if the user didn't specify prompts=prompts in the call, we assume they want to use the old behavior with prompts as a first argument + prompts = images + elif prompts is None and (images is not None and text is not None): + # Assuming image-text-to-text behavior: one prompt for all images + # Check if batched images are provided + if not isinstance(images, (list, tuple)): + images = [images] + # Check if batched text is provided + if isinstance(text, (list, tuple)) and len(text) > 1: + raise ValueError("When using the image-text-to-text behavior, a single prompt should be given.") + text_batched = [text] * len(images) + prompts = list(zip(images, text_batched)) + + + # if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it if add_end_of_utterance_token is None: add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token # turn non-batched prompts into batched - if not any(isinstance(i, list) for i in prompts): + if not any(isinstance(i, (list, tuple)) for i in prompts): prompts = [prompts] fake_token = "" diff --git a/src/transformers/models/pix2struct/processing_pix2struct.py b/src/transformers/models/pix2struct/processing_pix2struct.py index 269fa8c62fb205..098a1f3e443dc5 100644 --- a/src/transformers/models/pix2struct/processing_pix2struct.py +++ b/src/transformers/models/pix2struct/processing_pix2struct.py @@ -16,6 +16,7 @@ Processor class for Pix2Struct. """ +import warnings from typing import List, Optional, Union from ...processing_utils import ProcessorMixin @@ -65,6 +66,7 @@ def __call__( return_length: bool = False, verbose: bool = True, return_tensors: Optional[Union[str, TensorType]] = None, + legacy = True, **kwargs, ) -> BatchEncoding: """ @@ -73,6 +75,11 @@ def __call__( Please refer to the docstring of the above two methods for more information. """ + if legacy: + warnings.warn( + "The use of legacy will be deprecated in the future. Please use the new processing behavior by setting legacy=False." + ) + if images is None and text is None: raise ValueError("You have to specify either images or text.") @@ -111,6 +118,8 @@ def __call__( ) if text is not None and not self.image_processor.is_vqa: + if not legacy: + add_special_tokens = False text_encoding = self.tokenizer( text=text, add_special_tokens=add_special_tokens, diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 8c1fdfd4115604..bdac21ee559800 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -29,6 +29,7 @@ from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING, AutoImageProcessor from ..models.auto.modeling_auto import AutoModelForDepthEstimation, AutoModelForImageToImage +from ..models.auto.processing_auto import PROCESSOR_MAPPING, AutoProcessor from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer from ..tokenization_utils import PreTrainedTokenizer from ..utils import ( @@ -921,8 +922,10 @@ def pipeline( hub_kwargs["_commit_hash"] = model.config._commit_hash load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None + print("type(model_config)", type(model_config)) load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None load_processor = type(model_config) in PROCESSOR_MAPPING or processor is not None + print("load_processor", load_processor) # If `model` (instance of `PretrainedModel` instead of `str`) is passed (and/or same for config), while # `image_processor` or `feature_extractor` is `None`, the loading will fail. This happens particularly for some @@ -1006,7 +1009,7 @@ def pipeline( tokenizer_identifier = tokenizer tokenizer_kwargs = model_kwargs.copy() tokenizer_kwargs.pop("torch_dtype", None) - + print("tokenizer_identifier", tokenizer_identifier) tokenizer = AutoTokenizer.from_pretrained( tokenizer_identifier, use_fast=use_fast, _from_pipeline=task, **hub_kwargs, **tokenizer_kwargs ) @@ -1129,6 +1132,9 @@ def pipeline( if image_processor is not None: kwargs["image_processor"] = image_processor + if processor is not None: + kwargs["processor"] = processor + if device is not None: kwargs["device"] = device diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index 549a2cb1b8860f..f88c45b14f06ed 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -128,14 +128,19 @@ def preprocess(self, image=None, text=None, timeout=None): model_type = self.model.config.model_type - kwargs = {} + kwargs = {"legacy": False} # if model_type == "pix2struct": # kwargs = {"add_special_tokens": False} # if model_type == "idefics": # model_inputs = self.processor(text, return_tensors=self.framework, **kwargs) - model_inputs = self.processor(images=image, text=text, return_tensors=self.framework, **kwargs) + # temporary while waiting for uniformized processors + try: + model_inputs = self.processor(images=image, text=text, return_tensors=self.framework, **kwargs) + except TypeError: + kwargs = {} + model_inputs = self.processor(images=image, text=text, return_tensors=self.framework, **kwargs) # if model_type == "git": # # remove EOS token from input_ids and attention_mask From 68616ecfdc671ac7ac37c5d37ba9cedd7666d12d Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 24 Jul 2024 22:55:11 +0000 Subject: [PATCH 05/13] Fix style and minor issues for idefics and udop --- .../models/auto/processing_auto.py | 1 + src/transformers/models/git/processing_git.py | 1 + .../models/idefics/processing_idefics.py | 24 ++++++++-------- .../pix2struct/processing_pix2struct.py | 2 +- src/transformers/pipelines/__init__.py | 3 +- src/transformers/pipelines/base.py | 10 +++++-- .../pipelines/image_text_to_text.py | 28 ------------------- 7 files changed, 24 insertions(+), 45 deletions(-) diff --git a/src/transformers/models/auto/processing_auto.py b/src/transformers/models/auto/processing_auto.py index e261072873c360..fd2a0c7f01e0a2 100644 --- a/src/transformers/models/auto/processing_auto.py +++ b/src/transformers/models/auto/processing_auto.py @@ -93,6 +93,7 @@ ("trocr", "TrOCRProcessor"), ("tvlt", "TvltProcessor"), ("tvp", "TvpProcessor"), + ("udop", "UdopProcessor"), ("unispeech", "Wav2Vec2Processor"), ("unispeech-sat", "Wav2Vec2Processor"), ("video_llava", "VideoLlavaProcessor"), diff --git a/src/transformers/models/git/processing_git.py b/src/transformers/models/git/processing_git.py index 3ffeb077e1dfdc..583e5ea83c9b75 100644 --- a/src/transformers/models/git/processing_git.py +++ b/src/transformers/models/git/processing_git.py @@ -17,6 +17,7 @@ """ import warnings + from ...processing_utils import ProcessorMixin from ...tokenization_utils_base import BatchEncoding diff --git a/src/transformers/models/idefics/processing_idefics.py b/src/transformers/models/idefics/processing_idefics.py index a0d841f7430538..375ac6c626d83f 100644 --- a/src/transformers/models/idefics/processing_idefics.py +++ b/src/transformers/models/idefics/processing_idefics.py @@ -16,15 +16,14 @@ Processor class for IDEFICS. """ +import warnings from typing import Callable, List, Optional, Union from urllib.parse import urlparse -import warnings from ...feature_extraction_utils import BatchFeature from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy, PreTokenizedInput -from ...utils import is_tf_available, is_torch_available -from ...utils import TensorType +from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import TensorType, is_tf_available, is_torch_available if is_torch_available(): @@ -209,7 +208,7 @@ def __call__( truncation: Union[bool, str, TruncationStrategy] = None, max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, - legacy = True, + legacy=True, prompts: Optional[Union[List[TextInput], List[List[TextInput]]]] = None, transform: Callable = None, add_eos_token=False, @@ -335,13 +334,16 @@ def __call__( # Check if batched images are provided if not isinstance(images, (list, tuple)): images = [images] + if not isinstance(text, (list, tuple)): + text = [text] * len(images) # Check if batched text is provided - if isinstance(text, (list, tuple)) and len(text) > 1: - raise ValueError("When using the image-text-to-text behavior, a single prompt should be given.") - text_batched = [text] * len(images) - prompts = list(zip(images, text_batched)) - - + print("images: ", images) + print("text: ", text) + if isinstance(text, (list, tuple)) and len(text) != len(images): + raise ValueError( + "When using the image-text-to-text behavior, the number of prompts should be the same as the number of images." + ) + prompts = list(zip(images, text)) # if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it if add_end_of_utterance_token is None: diff --git a/src/transformers/models/pix2struct/processing_pix2struct.py b/src/transformers/models/pix2struct/processing_pix2struct.py index 098a1f3e443dc5..f8a8e35c9c8fb3 100644 --- a/src/transformers/models/pix2struct/processing_pix2struct.py +++ b/src/transformers/models/pix2struct/processing_pix2struct.py @@ -66,7 +66,7 @@ def __call__( return_length: bool = False, verbose: bool = True, return_tensors: Optional[Union[str, TensorType]] = None, - legacy = True, + legacy=True, **kwargs, ) -> BatchEncoding: """ diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index bdac21ee559800..9605e20f76e48c 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -24,13 +24,13 @@ from ..dynamic_module_utils import get_class_from_dynamic_module from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..image_processing_utils import BaseImageProcessor -from ..processing_utils import ProcessorMixin from ..models.auto.configuration_auto import AutoConfig from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING, AutoImageProcessor from ..models.auto.modeling_auto import AutoModelForDepthEstimation, AutoModelForImageToImage from ..models.auto.processing_auto import PROCESSOR_MAPPING, AutoProcessor from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer +from ..processing_utils import ProcessorMixin from ..tokenization_utils import PreTrainedTokenizer from ..utils import ( CONFIG_NAME, @@ -1109,7 +1109,6 @@ def pipeline( if isinstance(processor, (str, tuple)): processor = AutoProcessor.from_pretrained(processor, _from_pipeline=task, **hub_kwargs, **model_kwargs) - if task == "translation" and model.config.task_specific_params: for key in model.config.task_specific_params: if key.startswith("translation"): diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index a9904fafcd4b17..fd9f2ab5f140cb 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -31,9 +31,9 @@ from ..dynamic_module_utils import custom_object_save from ..feature_extraction_utils import PreTrainedFeatureExtractor from ..image_processing_utils import BaseImageProcessor -from ..processing_utils import ProcessorMixin from ..modelcard import ModelCard from ..models.auto.configuration_auto import AutoConfig +from ..processing_utils import ProcessorMixin from ..tokenization_utils import PreTrainedTokenizer from ..utils import ( ModelOutput, @@ -776,7 +776,11 @@ def build_pipeline_init_args( PIPELINE_INIT_ARGS = build_pipeline_init_args( - has_tokenizer=True, has_feature_extractor=True, has_image_processor=True, has_processor=True, supports_binary_output=True + has_tokenizer=True, + has_feature_extractor=True, + has_image_processor=True, + has_processor=True, + supports_binary_output=True, ) @@ -815,7 +819,7 @@ def __init__( tokenizer: Optional[PreTrainedTokenizer] = None, feature_extractor: Optional[PreTrainedFeatureExtractor] = None, image_processor: Optional[BaseImageProcessor] = None, - processor:Optional[ProcessorMixin] = None, + processor: Optional[ProcessorMixin] = None, modelcard: Optional[ModelCard] = None, framework: Optional[str] = None, task: str = "", diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index f88c45b14f06ed..1b57caae992b93 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -106,14 +106,6 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag text (`str`): The text to be used as a prompt for the generation. - max_new_tokens (`int`, *optional*): - The amount of maximum tokens to generate. By default it will use `generate` default. - - generate_kwargs (`Dict`, *optional*): - Pass it to send all of these arguments directly to `generate` allowing full control of this function. - timeout (`float`, *optional*, defaults to None): - The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and - the call may block forever. Return: A list or a list of list of `dict`: Each result comes as a dictionary with the following key: @@ -126,34 +118,14 @@ def preprocess(self, image=None, text=None, timeout=None): if image is not None: image = load_image(image, timeout=timeout) - model_type = self.model.config.model_type - kwargs = {"legacy": False} - # if model_type == "pix2struct": - # kwargs = {"add_special_tokens": False} - - # if model_type == "idefics": - # model_inputs = self.processor(text, return_tensors=self.framework, **kwargs) - # temporary while waiting for uniformized processors try: model_inputs = self.processor(images=image, text=text, return_tensors=self.framework, **kwargs) except TypeError: kwargs = {} model_inputs = self.processor(images=image, text=text, return_tensors=self.framework, **kwargs) - # if model_type == "git": - # # remove EOS token from input_ids and attention_mask - # model_inputs["input_ids"] = model_inputs["input_ids"][:, :-1] - # model_inputs["attention_mask"] = model_inputs["attention_mask"][:, :-1] - - # if model_type == "vision-encoder-decoder" and self.processor.__class__.__name__ == "DonutProcessor": - # model_inputs["decoder_input_ids"] = self.processor.tokenizer( - # text, - # add_special_tokens=False, - # return_tensors=self.framework, - # ).input_ids - return model_inputs def _forward(self, model_inputs, generate_kwargs=None): From 1e4c72add2a674738c07ae0d9bd2f8ef02de291a Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 25 Jul 2024 16:09:50 +0000 Subject: [PATCH 06/13] Add post processing for image-text-to-text models --- .../models/blip/processing_blip.py | 14 +++++++++++ .../models/blip_2/processing_blip_2.py | 14 +++++++++++ .../models/fuyu/processing_fuyu.py | 24 +++++++++++++++++++ src/transformers/models/git/processing_git.py | 14 +++++++++++ .../models/idefics/processing_idefics.py | 14 +++++++++++ .../models/idefics2/processing_idefics2.py | 14 +++++++++++ .../instructblip/processing_instructblip.py | 14 +++++++++++ .../models/kosmos2/processing_kosmos2.py | 15 ++++++++++++ .../models/llava/processing_llava.py | 14 +++++++++++ .../llava_next/processing_llava_next.py | 14 +++++++++++ .../models/paligemma/processing_paligemma.py | 14 +++++++++++ .../pix2struct/processing_pix2struct.py | 14 +++++++++++ .../models/udop/processing_udop.py | 14 +++++++++++ .../pipelines/image_text_to_text.py | 24 ++++++++++++------- 14 files changed, 209 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/blip/processing_blip.py b/src/transformers/models/blip/processing_blip.py index cd96b46ab1d26f..babc11a43c21f0 100644 --- a/src/transformers/models/blip/processing_blip.py +++ b/src/transformers/models/blip/processing_blip.py @@ -144,6 +144,20 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names diff --git a/src/transformers/models/blip_2/processing_blip_2.py b/src/transformers/models/blip_2/processing_blip_2.py index 2d526a17ba68ba..8f0f256604129b 100644 --- a/src/transformers/models/blip_2/processing_blip_2.py +++ b/src/transformers/models/blip_2/processing_blip_2.py @@ -148,6 +148,20 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + @property # Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names def model_input_names(self): diff --git a/src/transformers/models/fuyu/processing_fuyu.py b/src/transformers/models/fuyu/processing_fuyu.py index 6b542ba3378e67..109b9b3bb9e3d3 100644 --- a/src/transformers/models/fuyu/processing_fuyu.py +++ b/src/transformers/models/fuyu/processing_fuyu.py @@ -681,6 +681,30 @@ def tokens_to_points(tokens, original_size): return results + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-processes the output of `FuyuForConditionalGeneration` to only return the text output. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + containing the token ids of the generated sequences. + + Returns: + `List[str]`: The decoded text output. + """ + boa = self.tokenizer.vocab[BEGINNING_OF_ANSWER_STRING] + # get boa index for each outputted sequence tensor + # start all generated sequences from the beginning of the answer token, pad to have consistent length + unpadded_output_sequences = [seq[(seq == boa).nonzero(as_tuple=True)[0] + 1 :] for seq in generated_outputs] + max_len = max(len(seq) for seq in unpadded_output_sequences) + # convert to torch and pad sequences + padded_output_sequences = torch.full((len(unpadded_output_sequences), max_len), self.pad_token_id) + for i, seq in enumerate(unpadded_output_sequences): + padded_output_sequences[i, : len(seq)] = torch.tensor(seq) + + return self.batch_decode(padded_output_sequences, skip_special_tokens=True) + def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please diff --git a/src/transformers/models/git/processing_git.py b/src/transformers/models/git/processing_git.py index 583e5ea83c9b75..f6e2fe29b97df3 100644 --- a/src/transformers/models/git/processing_git.py +++ b/src/transformers/models/git/processing_git.py @@ -124,6 +124,20 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + @property def model_input_names(self): return ["input_ids", "attention_mask", "pixel_values"] diff --git a/src/transformers/models/idefics/processing_idefics.py b/src/transformers/models/idefics/processing_idefics.py index 375ac6c626d83f..74cf29c85c3547 100644 --- a/src/transformers/models/idefics/processing_idefics.py +++ b/src/transformers/models/idefics/processing_idefics.py @@ -513,6 +513,20 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names diff --git a/src/transformers/models/idefics2/processing_idefics2.py b/src/transformers/models/idefics2/processing_idefics2.py index 2e14118144baaa..385cb8d3a1d84c 100644 --- a/src/transformers/models/idefics2/processing_idefics2.py +++ b/src/transformers/models/idefics2/processing_idefics2.py @@ -246,6 +246,20 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names diff --git a/src/transformers/models/instructblip/processing_instructblip.py b/src/transformers/models/instructblip/processing_instructblip.py index adebd22178efb2..9d23927c84d17e 100644 --- a/src/transformers/models/instructblip/processing_instructblip.py +++ b/src/transformers/models/instructblip/processing_instructblip.py @@ -149,6 +149,20 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + @property # Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names def model_input_names(self): diff --git a/src/transformers/models/kosmos2/processing_kosmos2.py b/src/transformers/models/kosmos2/processing_kosmos2.py index 6d1cce14b186fe..d16d5f4eedb83d 100644 --- a/src/transformers/models/kosmos2/processing_kosmos2.py +++ b/src/transformers/models/kosmos2/processing_kosmos2.py @@ -403,6 +403,21 @@ def post_process_generation(self, text, cleanup_and_extract=True): return clean_text_and_extract_entities_with_bboxes(caption) return caption + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=True) + return [self.post_process_generation(text, cleanup_and_extract=False) for text in generated_texts] + @property # Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names def model_input_names(self): diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index a563b1cb82e788..bb9a44b13a9f89 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -129,6 +129,20 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + @property # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names def model_input_names(self): diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index 7664b7954308b3..98984a82457f95 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -133,6 +133,20 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + @property # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names def model_input_names(self): diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index 3d0ece60c367e4..1d63e094ee1dd1 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -300,6 +300,20 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + @property # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma def model_input_names(self): diff --git a/src/transformers/models/pix2struct/processing_pix2struct.py b/src/transformers/models/pix2struct/processing_pix2struct.py index f8a8e35c9c8fb3..c195d3f5c23160 100644 --- a/src/transformers/models/pix2struct/processing_pix2struct.py +++ b/src/transformers/models/pix2struct/processing_pix2struct.py @@ -165,6 +165,20 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names diff --git a/src/transformers/models/udop/processing_udop.py b/src/transformers/models/udop/processing_udop.py index 2902541d6f5b46..745ff444603de7 100644 --- a/src/transformers/models/udop/processing_udop.py +++ b/src/transformers/models/udop/processing_udop.py @@ -198,6 +198,20 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + @property # Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.model_input_names def model_input_names(self): diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index 1b57caae992b93..bc3313b752e8a0 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -69,9 +69,11 @@ def __init__(self, *args, **kwargs): def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, text=None, timeout=None): forward_kwargs = {} preprocess_params = {} + post_process_params = {} if text is not None: preprocess_params["text"] = text + post_process_params["text"] = text if timeout is not None: preprocess_params["timeout"] = timeout @@ -87,7 +89,7 @@ def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, text=N ) forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens - return preprocess_params, forward_kwargs, {} + return preprocess_params, forward_kwargs, post_process_params def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs): """ @@ -135,13 +137,19 @@ def _forward(self, model_inputs, generate_kwargs=None): model_outputs = self.model.generate(**model_inputs, **generate_kwargs) return model_outputs - def postprocess(self, model_outputs): + def postprocess(self, model_outputs, text=None): records = [] - generated_texts = self.processor.batch_decode( - model_outputs, - skip_special_tokens=True, - ) - - records = [{"generated_text": text} for text in generated_texts] + generated_texts = self.processor.post_process_image_text_to_text(model_outputs) + print("generated_texts", generated_texts) + # cleanup the generated text + generated_texts = [text.strip() for text in generated_texts] + print("text", text) + if text is not None: + # remove the input text from the generated text if the generated text starts with the input text + generated_texts = [ + text_generated[len(text) :].strip() if text_generated.startswith(text) else text_generated + for text_generated in generated_texts + ] + records = [{"input_text": text, "generated_text": generated_text} for generated_text in generated_texts] return records From 86f3467fdddc1789555d302f58cabbd615f36f56 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Thu, 25 Jul 2024 22:19:15 +0000 Subject: [PATCH 07/13] Add support for batched text and images, and for chat templates with images --- src/transformers/pipelines/base.py | 7 +- .../pipelines/image_text_to_text.py | 127 ++++++++++++++---- 2 files changed, 110 insertions(+), 24 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index fd9f2ab5f140cb..5392472d1fddc3 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -1198,7 +1198,12 @@ def get_iterator( logger.info("Disabling tokenizer parallelism, we're using DataLoader multithreading already") os.environ["TOKENIZERS_PARALLELISM"] = "false" # TODO hack by collating feature_extractor and image_processor - feature_extractor = self.feature_extractor if self.feature_extractor is not None else self.image_processor + if self.feature_extractor is not None: + feature_extractor = self.feature_extractor + elif self.image_processor is not None: + feature_extractor = self.image_processor + elif self.processor is not None: + feature_extractor = self.processor collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn(self.tokenizer, feature_extractor) dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn) model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size) diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index bc3313b752e8a0..c848840d84928f 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Dict, List, Union from ..utils import ( add_end_docstrings, @@ -37,6 +37,38 @@ logger = logging.get_logger(__name__) +class Chat: + """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats + to this format because the rest of the pipeline code tends to assume that lists of messages are + actually a batch of samples rather than messages in the same conversation.""" + + def __init__(self, messages: Dict, images: Union[str, List[str], "Image.Image", List["Image.Image"]]): + for message in messages: + if not ("role" in message and "content" in message): + raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") + if count_images_in_chat(messages) != len(images): + raise ValueError("The number of images should be the same as the number of images in the chat.") + + self.messages = messages + self.images = images + + +class ImageText: + """This class is intended to just be used internally in this pipeline and not exposed to users. We used this class + as the base pipeline does not support multiple inputs, so we need to convert multiple inputs to a single input.""" + + def __init__(self, images: List, text: Union[str, List[str]]): + self.images = images + self.text = text + + +def count_images_in_chat(chat): + num_images = 0 + for message in chat: + num_images += sum(1 for content in message["content"] if content.get("type") == "image") + return num_images + + @add_end_docstrings(build_pipeline_init_args(has_processor=True)) class ImageTextToTextPipeline(Pipeline): """ @@ -71,9 +103,6 @@ def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, text=N preprocess_params = {} post_process_params = {} - if text is not None: - preprocess_params["text"] = text - post_process_params["text"] = text if timeout is not None: preprocess_params["timeout"] = timeout @@ -91,7 +120,7 @@ def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, text=N return preprocess_params, forward_kwargs, post_process_params - def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs): + def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs): """ Generate a text given text and the image(s) passed as inputs. @@ -105,28 +134,72 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag The pipeline accepts either a single image or a batch of images. - text (`str`): - The text to be used as a prompt for the generation. - + text (str, List[str], `List[Dict[str, Union[str, PIL.Image]]]`): + The text to be used for generation. If a list of strings is passed, the length of the list should be the + same as the number of images. Text can also follow the chat format: a list of dictionaries where each + dictionary represents a message in a conversation. Each dictionary should have two keys: 'role' and + 'content'. 'role' should be one of 'user', 'system' or 'assistant'. 'content' should be a dictionary + containing the text of the message and the type of the message. The type of the message can be either + 'text' or 'image'. If the type is 'image', no text is needed. Return: A list or a list of list of `dict`: Each result comes as a dictionary with the following key: - **generated_text** (`str`) -- The generated text. """ - return super().__call__(images, **kwargs) + text = kwargs.pop("text") + + if images is None or text is None: + raise ValueError("You have to specify both `images` and `text`") + + if not isinstance(images, (list, tuple)): + images = [images] - def preprocess(self, image=None, text=None, timeout=None): - if image is not None: - image = load_image(image, timeout=timeout) + if isinstance(text, (list, tuple, text) if is_torch_available() else (list, tuple)) and isinstance( + text[0], (list, tuple, dict) + ): + # We have one or more prompts in list-of-dicts format, so this is chat mode + if isinstance(text[0], dict): + return super().__call__(Chat(text, images), **kwargs) + else: + chats = [Chat(chat, image) for chat, image in zip(text, images)] # ๐Ÿˆ ๐Ÿˆ ๐Ÿˆ + return super().__call__(chats, **kwargs) + if isinstance(text, str): + text = [text] * len(images) + if len(images) != len(text): + raise ValueError("The number of images and text should be the same.") + + return super().__call__([ImageText(image, text_single) for image, text_single in zip(images, text)], **kwargs) + + def preprocess(self, inputs=None, timeout=None): kwargs = {"legacy": False} + images = inputs.images + if isinstance(inputs, Chat): + kwargs["chats"] = inputs.messages + text = self.processor.apply_chat_template( + inputs.messages, + add_generation_prompt=True, + return_tensors=self.framework, + **kwargs, + ) + else: + text = inputs.text + + if not isinstance(images, (list, tuple)): + images = load_image(images, timeout=timeout) + else: + images = [load_image(image, timeout=timeout) for image in images] try: - model_inputs = self.processor(images=image, text=text, return_tensors=self.framework, **kwargs) + kwargs["padding"] = True + model_inputs = self.processor(images=images, text=text, return_tensors=self.framework, **kwargs) except TypeError: kwargs = {} - model_inputs = self.processor(images=image, text=text, return_tensors=self.framework, **kwargs) + kwargs["padding"] = True + model_inputs = self.processor(images=images, text=text, return_tensors=self.framework, **kwargs) + + model_inputs["text"] = text return model_inputs @@ -134,22 +207,30 @@ def _forward(self, model_inputs, generate_kwargs=None): if generate_kwargs is None: generate_kwargs = {} + input_text = model_inputs.pop("text") model_outputs = self.model.generate(**model_inputs, **generate_kwargs) - return model_outputs + return {"outputs": model_outputs, "input_text": input_text} - def postprocess(self, model_outputs, text=None): + def postprocess(self, model_outputs): records = [] - generated_texts = self.processor.post_process_image_text_to_text(model_outputs) - print("generated_texts", generated_texts) + input_text = model_outputs["input_text"] + outputs = model_outputs["outputs"] + generated_texts = self.processor.post_process_image_text_to_text(outputs) # cleanup the generated text generated_texts = [text.strip() for text in generated_texts] - print("text", text) - if text is not None: + if isinstance(input_text, str): + input_text = [input_text] + if input_text is not None: # remove the input text from the generated text if the generated text starts with the input text generated_texts = [ - text_generated[len(text) :].strip() if text_generated.startswith(text) else text_generated - for text_generated in generated_texts + text_generated[len(input_text[i]) :].strip() + if text_generated.startswith(input_text[i]) + else text_generated + for i, text_generated in enumerate(generated_texts) ] - records = [{"input_text": text, "generated_text": generated_text} for generated_text in generated_texts] + records = [ + {"input_text": input_text[i], "generated_text": generated_text} + for i, generated_text in enumerate(generated_texts) + ] return records From 469f825a9b89d707440da2e3cb8c078aaec7c6ca Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 26 Jul 2024 16:39:05 +0000 Subject: [PATCH 08/13] Add support for multiple images per prompt using token, standardize not returning input prompt in the generated text. --- .../pipelines/image_text_to_text.py | 108 ++++++++++++++---- 1 file changed, 84 insertions(+), 24 deletions(-) diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index c848840d84928f..2166f653e13b7d 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -33,9 +33,12 @@ if is_torch_available(): from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES + from .pt_utils import KeyDataset logger = logging.get_logger(__name__) +IMAGE_TOKEN = "" + class Chat: """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats @@ -57,7 +60,7 @@ class ImageText: """This class is intended to just be used internally in this pipeline and not exposed to users. We used this class as the base pipeline does not support multiple inputs, so we need to convert multiple inputs to a single input.""" - def __init__(self, images: List, text: Union[str, List[str]]): + def __init__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], text: Union[str, List[str]]): self.images = images self.text = text @@ -72,7 +75,7 @@ def count_images_in_chat(chat): @add_end_docstrings(build_pipeline_init_args(has_processor=True)) class ImageTextToTextPipeline(Pipeline): """ - Image-text-to-text pipeline using an `AutoModelForImageTextToText`. This pipeline generates text given an image and text. + Image-text-to-text pipeline. This pipeline generates text given an image and text. Example: @@ -98,7 +101,16 @@ def __init__(self, *args, **kwargs): requires_backends(self, "vision") self.check_model_type(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES) - def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, text=None, timeout=None): + def _sanitize_parameters( + self, + max_new_tokens=None, + generate_kwargs=None, + text=None, + truncation=None, + padding=None, + max_length=None, + timeout=None, + ): forward_kwargs = {} preprocess_params = {} post_process_params = {} @@ -106,8 +118,18 @@ def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, text=N if timeout is not None: preprocess_params["timeout"] = timeout + if truncation is not None: + preprocess_params["truncation"] = truncation + + if padding is not None: + preprocess_params["padding"] = padding + + if max_length is not None: + preprocess_params["max_length"] = max_length + if generate_kwargs is not None: forward_kwargs["generate_kwargs"] = generate_kwargs + if max_new_tokens is not None: if "generate_kwargs" not in forward_kwargs: forward_kwargs["generate_kwargs"] = {} @@ -125,7 +147,7 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag Generate a text given text and the image(s) passed as inputs. Args: - images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`): + images (`str`, `List[str]`, `PIL.Image or `List[PIL.Image]`): The pipeline handles three types of images: - A string containing a HTTP(s) link pointing to an image @@ -146,8 +168,10 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag A list or a list of list of `dict`: Each result comes as a dictionary with the following key: - **generated_text** (`str`) -- The generated text. + - **input_text** (`str`) -- The input text. """ text = kwargs.pop("text") + batch_size = kwargs.get("batch_size", 1) if images is None or text is None: raise ValueError("You have to specify both `images` and `text`") @@ -155,7 +179,7 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag if not isinstance(images, (list, tuple)): images = [images] - if isinstance(text, (list, tuple, text) if is_torch_available() else (list, tuple)) and isinstance( + if isinstance(text, (list, tuple, KeyDataset) if is_torch_available() else (list, tuple)) and isinstance( text[0], (list, tuple, dict) ): # We have one or more prompts in list-of-dicts format, so this is chat mode @@ -167,14 +191,51 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag if isinstance(text, str): text = [text] * len(images) + if not isinstance(text[0], str): + raise ValueError("The pipeline does not support nested lists of prompts.") + + # Check number of IMAGE_TOKEN token in each text + num_images_in_text = [text_single.count(IMAGE_TOKEN) for text_single in text] + if sum(num_images_in_text) > 0: + if any(num > 1 for num in num_images_in_text) and batch_size > 1: + raise ValueError( + "The pipeline does not support multiple images for a single prompt with batch_size > 1." + ) + # Check if already nested images and consistency + if isinstance(images[0], (list, tuple)): + if len(images) != len(text): + raise ValueError("The number of nested image groups and prompts should be the same.") + num_images_in_images = [len(image) for image in images] + if num_images_in_text != num_images_in_images: + raise ValueError( + f"The number of images in each nested image group should be the same as the number of {IMAGE_TOKEN} tokens in the corresponding prompt." + ) + elif sum(num_images_in_text) != len(images): + raise ValueError( + f"The total number of {IMAGE_TOKEN} tokens in the prompts should be the same as the number of images passed." + ) + else: + # Reorganize the images to match the prompts + images_reorganized = [] + for num_images in num_images_in_text: + images_reorganized.append(images[:num_images]) + images = images[num_images:] + images = images_reorganized + # After reorganizing, these should be the same if len(images) != len(text): raise ValueError("The number of images and text should be the same.") return super().__call__([ImageText(image, text_single) for image, text_single in zip(images, text)], **kwargs) - def preprocess(self, inputs=None, timeout=None): - kwargs = {"legacy": False} + def preprocess(self, inputs=None, truncation=None, padding="longest", max_length=None, timeout=None): + kwargs = { + "legacy": False, + "truncation": truncation, + "padding": padding, + "max_length": max_length, + } images = inputs.images + if isinstance(inputs, Chat): kwargs["chats"] = inputs.messages text = self.processor.apply_chat_template( @@ -192,11 +253,9 @@ def preprocess(self, inputs=None, timeout=None): images = [load_image(image, timeout=timeout) for image in images] try: - kwargs["padding"] = True model_inputs = self.processor(images=images, text=text, return_tensors=self.framework, **kwargs) except TypeError: - kwargs = {} - kwargs["padding"] = True + kwargs.pop("legacy", None) model_inputs = self.processor(images=images, text=text, return_tensors=self.framework, **kwargs) model_inputs["text"] = text @@ -206,28 +265,29 @@ def preprocess(self, inputs=None, timeout=None): def _forward(self, model_inputs, generate_kwargs=None): if generate_kwargs is None: generate_kwargs = {} - input_text = model_inputs.pop("text") model_outputs = self.model.generate(**model_inputs, **generate_kwargs) - return {"outputs": model_outputs, "input_text": input_text} + return {"outputs": model_outputs, "input_text": input_text, "input_ids": model_inputs["input_ids"]} def postprocess(self, model_outputs): - records = [] input_text = model_outputs["input_text"] + input_text = [input_text] if isinstance(input_text, str) else input_text outputs = model_outputs["outputs"] + inputs_id = model_outputs["input_ids"] + + # Decode inputs and outputs the same way to remove input text from generated text if present generated_texts = self.processor.post_process_image_text_to_text(outputs) - # cleanup the generated text + decoded_inputs = self.processor.post_process_image_text_to_text(inputs_id) generated_texts = [text.strip() for text in generated_texts] - if isinstance(input_text, str): - input_text = [input_text] - if input_text is not None: - # remove the input text from the generated text if the generated text starts with the input text - generated_texts = [ - text_generated[len(input_text[i]) :].strip() - if text_generated.startswith(input_text[i]) - else text_generated - for i, text_generated in enumerate(generated_texts) - ] + decoded_inputs = [text.strip() for text in decoded_inputs] + # Remove the input text from the generated text if the generated text starts with the input text + generated_texts = [ + text_generated[len(decoded_inputs[i]) :].strip() + if text_generated.startswith(decoded_inputs[i]) + else text_generated + for i, text_generated in enumerate(generated_texts) + ] + records = [ {"input_text": input_text[i], "generated_text": generated_text} for i, generated_text in enumerate(generated_texts) From 116ea5719be04f95fdf431a6e81f41eca481c91a Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 26 Jul 2024 18:15:01 +0000 Subject: [PATCH 09/13] Fix unbounded variables issue --- src/transformers/models/idefics/processing_idefics.py | 2 +- src/transformers/pipelines/base.py | 2 ++ src/transformers/pipelines/image_text_to_text.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/idefics/processing_idefics.py b/src/transformers/models/idefics/processing_idefics.py index 74cf29c85c3547..b3bddb74d59f40 100644 --- a/src/transformers/models/idefics/processing_idefics.py +++ b/src/transformers/models/idefics/processing_idefics.py @@ -207,7 +207,7 @@ def __call__( padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = None, max_length: Optional[int] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + return_tensors: Optional[Union[str, TensorType]] = "pt", legacy=True, prompts: Optional[Union[List[TextInput], List[List[TextInput]]]] = None, transform: Callable = None, diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 5392472d1fddc3..ce176af04dc06e 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -1204,6 +1204,8 @@ def get_iterator( feature_extractor = self.image_processor elif self.processor is not None: feature_extractor = self.processor + else: + feature_extractor = None collate_fn = no_collate_fn if batch_size == 1 else pad_collate_fn(self.tokenizer, feature_extractor) dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=collate_fn) model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size) diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index 2166f653e13b7d..f0966a6f6788dd 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -75,7 +75,7 @@ def count_images_in_chat(chat): @add_end_docstrings(build_pipeline_init_args(has_processor=True)) class ImageTextToTextPipeline(Pipeline): """ - Image-text-to-text pipeline. This pipeline generates text given an image and text. + Image-text-to-text pipeline using an `AutoModelForImageTextToText`. This pipeline generates text given an image and text. Example: From bb3594eabc1af07a6de0c0fc1b50226890bbaa9c Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 26 Jul 2024 21:26:58 +0000 Subject: [PATCH 10/13] Add tests and fix small issues --- .../models/donut/processing_donut.py | 15 +- .../pipelines/image_text_to_text.py | 21 +-- .../test_pipelines_image_text_to_text.py | 140 ++++++++++++++++++ tests/test_pipeline_mixin.py | 2 + 4 files changed, 168 insertions(+), 10 deletions(-) create mode 100644 tests/pipelines/test_pipelines_image_text_to_text.py diff --git a/src/transformers/models/donut/processing_donut.py b/src/transformers/models/donut/processing_donut.py index 935d1ec46df8cc..6e38414eb83f9a 100644 --- a/src/transformers/models/donut/processing_donut.py +++ b/src/transformers/models/donut/processing_donut.py @@ -72,7 +72,6 @@ def __call__(self, *args, **kwargs): """ # For backward compatibility legacy = kwargs.pop("legacy", True) - print("legacy: ", legacy) if legacy: warnings.warn( "The use of legacy will be deprecated in the future. Please use the new processing behavior by setting legacy=False." @@ -194,6 +193,20 @@ def token2json(self, tokens, is_inner_value=False, added_vocab=None): else: return [] if is_inner_value else {"text_sequence": tokens} + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + @property def feature_extractor_class(self): warnings.warn( diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index f0966a6f6788dd..0ce5cd3276bb31 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -142,7 +142,12 @@ def _sanitize_parameters( return preprocess_params, forward_kwargs, post_process_params - def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs): + def __call__( + self, + images: Union[str, List[str], List[List[str]], "Image.Image", List["Image.Image"], List[List["Image.Image"]]], + text: Union[str, List[str], List[dict]], + **kwargs, + ): """ Generate a text given text and the image(s) passed as inputs. @@ -170,12 +175,8 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag - **generated_text** (`str`) -- The generated text. - **input_text** (`str`) -- The input text. """ - text = kwargs.pop("text") batch_size = kwargs.get("batch_size", 1) - if images is None or text is None: - raise ValueError("You have to specify both `images` and `text`") - if not isinstance(images, (list, tuple)): images = [images] @@ -227,7 +228,7 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag return super().__call__([ImageText(image, text_single) for image, text_single in zip(images, text)], **kwargs) - def preprocess(self, inputs=None, truncation=None, padding="longest", max_length=None, timeout=None): + def preprocess(self, inputs=None, truncation=None, padding=False, max_length=None, timeout=None): kwargs = { "legacy": False, "truncation": truncation, @@ -237,7 +238,7 @@ def preprocess(self, inputs=None, truncation=None, padding="longest", max_length images = inputs.images if isinstance(inputs, Chat): - kwargs["chats"] = inputs.messages + # kwargs["chats"] = inputs.messages text = self.processor.apply_chat_template( inputs.messages, add_generation_prompt=True, @@ -246,7 +247,6 @@ def preprocess(self, inputs=None, truncation=None, padding="longest", max_length ) else: text = inputs.text - if not isinstance(images, (list, tuple)): images = load_image(images, timeout=timeout) else: @@ -266,8 +266,11 @@ def _forward(self, model_inputs, generate_kwargs=None): if generate_kwargs is None: generate_kwargs = {} input_text = model_inputs.pop("text") + input_ids = ( + model_inputs["input_ids"] if "input_ids" in model_inputs else model_inputs["decoder_input_ids"] + ) # for decoder-only models model_outputs = self.model.generate(**model_inputs, **generate_kwargs) - return {"outputs": model_outputs, "input_text": input_text, "input_ids": model_inputs["input_ids"]} + return {"outputs": model_outputs, "input_text": input_text, "input_ids": input_ids} def postprocess(self, model_outputs): input_text = model_outputs["input_text"] diff --git a/tests/pipelines/test_pipelines_image_text_to_text.py b/tests/pipelines/test_pipelines_image_text_to_text.py new file mode 100644 index 00000000000000..9fe31ea91ba40a --- /dev/null +++ b/tests/pipelines/test_pipelines_image_text_to_text.py @@ -0,0 +1,140 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, is_vision_available +from transformers.pipelines import ImageTextToTextPipeline, pipeline +from transformers.testing_utils import ( + is_pipeline_test, + require_torch, + require_vision, + slow, +) + +from .test_pipelines_common import ANY + + +if is_vision_available(): + from PIL import Image +else: + + class Image: + @staticmethod + def open(*args, **kwargs): + pass + + +@is_pipeline_test +@require_vision +class ImageTextToTextPipelineTests(unittest.TestCase): + model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING + + def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): + pipe = ImageTextToTextPipeline( + model=model, tokenizer=tokenizer, image_processor=processor, torch_dtype=torch_dtype + ) + examples = { + "images": [ + Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"), + "./tests/fixtures/tests_samples/COCO/000000039769.png", + ], + "text": [" This is a ", " Here I see a "], + } + return pipe, examples + + def run_pipeline_test(self, pipe, examples): + outputs = pipe(examples) + self.assertEqual( + outputs, + [ + [{"input_text": ANY(str), "generated_text": ANY(str)}], + [{"input_text": ANY(str), "generated_text": ANY(str)}], + ], + ) + + @require_torch + def test_small_model_pt_token(self): + pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf") + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" + text = " What this is? Assistant: This is" + + outputs = pipe(image, text=text, max_new_tokens=20) + self.assertEqual( + outputs, + [ + [ + { + "input_text": " What this is? Assistant: This is", + "generated_text": "a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable", + } + ] + ], + ) + + outputs = pipe([image, image], text=[text, text], max_new_tokens=20) + self.assertEqual( + outputs, + [ + [ + { + "input_text": " What this is? Assistant: This is", + "generated_text": "a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable", + } + ], + [ + { + "input_text": " What this is? Assistant: This is", + "generated_text": "a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable", + } + ], + ], + ) + + @require_torch + def test_consistent_batching_behaviour(self): + pipe = pipeline("image-text-to-text", model="microsoft/kosmos-2-patch14-224") + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" + prompt = "a photo of" + + outputs = pipe([image, image], text=[prompt, prompt], max_new_tokens=20) + outputs_batched = pipe([image, image], text=[prompt, prompt], max_new_tokens=20, batch_size=2) + self.assertEqual(outputs, outputs_batched) + + @slow + @require_torch + def test_model_pt_chat_template(self): + pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf") + image_ny = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + image_chicago = "https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Whatโ€™s the difference between these two images?"}, + {"type": "image"}, + {"type": "image"}, + ], + } + ] + outputs = pipe([image_ny, image_chicago], max_new_tokens=20, text=messages) + self.assertEqual( + outputs, + [ + { + "input_text": "<|im_start|>user\n\nWhatโ€™s the difference between these two images?<|im_end|>\n<|im_start|>assistant\n", + "generated_text": "The first image shows a statue of the Statue of Liberty in the foreground, while the second image shows", + } + ], + ) diff --git a/tests/test_pipeline_mixin.py b/tests/test_pipeline_mixin.py index 8a0ca08e8dabec..08b0a2b3b3e5f2 100644 --- a/tests/test_pipeline_mixin.py +++ b/tests/test_pipeline_mixin.py @@ -40,6 +40,7 @@ from .pipelines.test_pipelines_image_classification import ImageClassificationPipelineTests from .pipelines.test_pipelines_image_feature_extraction import ImageFeatureExtractionPipelineTests from .pipelines.test_pipelines_image_segmentation import ImageSegmentationPipelineTests +from .pipelines.test_pipelines_image_text_to_text import ImageTextToTextPipeline from .pipelines.test_pipelines_image_to_image import ImageToImagePipelineTests from .pipelines.test_pipelines_image_to_text import ImageToTextPipelineTests from .pipelines.test_pipelines_mask_generation import MaskGenerationPipelineTests @@ -73,6 +74,7 @@ "image-segmentation": {"test": ImageSegmentationPipelineTests}, "image-to-image": {"test": ImageToImagePipelineTests}, "image-to-text": {"test": ImageToTextPipelineTests}, + "image-text-to-text": {"test": ImageTextToTextPipeline}, "mask-generation": {"test": MaskGenerationPipelineTests}, "object-detection": {"test": ObjectDetectionPipelineTests}, "question-answering": {"test": QAPipelineTests}, From a07bbc8f243a8df90e30373e27a05d2f83c51348 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Mon, 29 Jul 2024 14:42:28 +0000 Subject: [PATCH 11/13] Fix Paligemma prompts with image token --- src/transformers/models/paligemma/processing_paligemma.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index 1d63e094ee1dd1..ae8c9fc80c01d6 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -17,6 +17,7 @@ """ import logging +import warnings from typing import List, Optional, Union from ...feature_extraction_utils import BatchFeature @@ -70,6 +71,11 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token): image_seq_len (`int`): The length of the image sequence. image_token (`str`): The image token. """ + if image_token in prompt: + warnings.warn( + f"The image token {image_token} is already present in the prompt. This may lead to unexpected behavior." + ) + prompt = prompt.replace(image_token, "") return f"{image_token * image_seq_len}{bos_token}{prompt}\n" From 7c14518c637fc6615a7e759a266ac943b2432a34 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Mon, 29 Jul 2024 14:50:20 +0000 Subject: [PATCH 12/13] Fix automated model pipeline test --- tests/pipelines/test_pipelines_image_text_to_text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_image_text_to_text.py b/tests/pipelines/test_pipelines_image_text_to_text.py index 9fe31ea91ba40a..779248e7a23671 100644 --- a/tests/pipelines/test_pipelines_image_text_to_text.py +++ b/tests/pipelines/test_pipelines_image_text_to_text.py @@ -55,7 +55,7 @@ def get_test_pipeline(self, model, tokenizer, processor, torch_dtype="float32"): return pipe, examples def run_pipeline_test(self, pipe, examples): - outputs = pipe(examples) + outputs = pipe(examples.get("images"), text=examples.get("text"), max_new_tokens=20) self.assertEqual( outputs, [ From 5ed61454f68b5db5af7c492586aac0f9f6625758 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 6 Aug 2024 14:20:14 +0000 Subject: [PATCH 13/13] Change legacy arg to kwargs --- src/transformers/models/git/processing_git.py | 3 ++- src/transformers/models/idefics/processing_idefics.py | 4 ++-- src/transformers/models/pix2struct/processing_pix2struct.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/git/processing_git.py b/src/transformers/models/git/processing_git.py index f6e2fe29b97df3..97ac541850049b 100644 --- a/src/transformers/models/git/processing_git.py +++ b/src/transformers/models/git/processing_git.py @@ -44,7 +44,7 @@ def __init__(self, image_processor, tokenizer): super().__init__(image_processor, tokenizer) self.current_processor = self.image_processor - def __call__(self, text=None, images=None, return_tensors=None, legacy=True, **kwargs): + def __call__(self, text=None, images=None, return_tensors=None, **kwargs): """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` and `kwargs` arguments to BertTokenizerFast's [`~BertTokenizerFast.__call__`] if `text` is not `None` to encode @@ -78,6 +78,7 @@ def __call__(self, text=None, images=None, return_tensors=None, legacy=True, **k `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + legacy = kwargs.pop("legacy", True) if legacy: warnings.warn( "The use of legacy will be deprecated in the future. Please use the new processing behavior by setting legacy=False." diff --git a/src/transformers/models/idefics/processing_idefics.py b/src/transformers/models/idefics/processing_idefics.py index b3bddb74d59f40..b9ae341da7c063 100644 --- a/src/transformers/models/idefics/processing_idefics.py +++ b/src/transformers/models/idefics/processing_idefics.py @@ -208,12 +208,12 @@ def __call__( truncation: Union[bool, str, TruncationStrategy] = None, max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = "pt", - legacy=True, prompts: Optional[Union[List[TextInput], List[List[TextInput]]]] = None, transform: Callable = None, add_eos_token=False, add_end_of_utterance_token=None, debug=False, + **kwargs, ) -> BatchEncoding: """This method takes batched or non-batched prompts made of text and images and converts them into prompts that the model was trained on and prepares the image pixel values for the model to process. @@ -321,7 +321,7 @@ def __call__( In order to help debug prompt generation enable `debug=True` which will show you what's happening. """ - + legacy = kwargs.pop("legacy", True) if legacy: warnings.warn( "The use of legacy will be deprecated in the future. Please use the new processing behavior by setting legacy=False." diff --git a/src/transformers/models/pix2struct/processing_pix2struct.py b/src/transformers/models/pix2struct/processing_pix2struct.py index c195d3f5c23160..b610d9fd245488 100644 --- a/src/transformers/models/pix2struct/processing_pix2struct.py +++ b/src/transformers/models/pix2struct/processing_pix2struct.py @@ -66,7 +66,6 @@ def __call__( return_length: bool = False, verbose: bool = True, return_tensors: Optional[Union[str, TensorType]] = None, - legacy=True, **kwargs, ) -> BatchEncoding: """ @@ -75,6 +74,8 @@ def __call__( Please refer to the docstring of the above two methods for more information. """ + legacy = kwargs.pop("legacy", True) + print("legacy: ", legacy) if legacy: warnings.warn( "The use of legacy will be deprecated in the future. Please use the new processing behavior by setting legacy=False."