From 35ad460377748c2a359613d276923a86c4fd8fc7 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Mon, 16 Sep 2024 23:22:11 +0000 Subject: [PATCH 1/4] add uniformized pixtral and kwargs --- src/transformers/models/pixtral/__init__.py | 3 +- .../models/pixtral/processing_pixtral.py | 74 ++++--- src/transformers/processing_utils.py | 20 +- .../models/pixtral/test_processor_pixtral.py | 198 ++++++++++++++++-- tests/test_processing_common.py | 2 + 5 files changed, 239 insertions(+), 58 deletions(-) diff --git a/src/transformers/models/pixtral/__init__.py b/src/transformers/models/pixtral/__init__.py index e09ed8e60127dd..8c32b8750b0358 100644 --- a/src/transformers/models/pixtral/__init__.py +++ b/src/transformers/models/pixtral/__init__.py @@ -43,7 +43,8 @@ if TYPE_CHECKING: - from .configuration_pixtral import PixtralProcessor, PixtralVisionConfig + from .configuration_pixtral import PixtralVisionConfig + from .processing_pixtral import PixtralProcessor try: if not is_torch_available(): diff --git a/src/transformers/models/pixtral/processing_pixtral.py b/src/transformers/models/pixtral/processing_pixtral.py index 9362703c8aa6da..1b07aa02771dc9 100644 --- a/src/transformers/models/pixtral/processing_pixtral.py +++ b/src/transformers/models/pixtral/processing_pixtral.py @@ -16,18 +16,36 @@ Processor class for Pixtral. """ -from typing import List, Optional, Union +import sys +from typing import List, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image, load_image -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends +from ...processing_utils import ProcessingKwargs, ProcessorMixin, _validate_images_text_input_order +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + logger = logging.get_logger(__name__) +class PixtralProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": {}, + "common_kwargs": { + "return_tensors": "pt", + }, + } + + # Copied from transformers.models.idefics2.processing_idefics2.is_url def is_url(val) -> bool: return isinstance(val, str) and val.startswith("http") @@ -143,12 +161,11 @@ def __init__( def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, images: ImageInput = None, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length=None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[PixtralProcessorKwargs], ) -> BatchMixFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -158,26 +175,13 @@ def __call__( of the above two methods for more information. Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. text (`str`, `List[str]`, `List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): - The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch - tensor. Both channels-first and channels-last formats are supported. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: @@ -195,6 +199,15 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + output_kwargs = self._merge_kwargs( + PixtralProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: if is_image_or_image_url(images): images = [[images]] @@ -209,7 +222,7 @@ def __call__( "Invalid input images. Please provide a single image or a list of images or a list of list of images." ) images = [[load_image(im) for im in sample] for sample in images] - image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors=return_tensors) + image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"]) else: image_inputs = {} @@ -246,16 +259,9 @@ def __call__( while "" in sample: replace_str = replace_strings.pop(0) sample = sample.replace("", replace_str, 1) - prompt_strings.append(sample) - text_inputs = self.tokenizer( - prompt_strings, - return_tensors=return_tensors, - padding=padding, - truncation=truncation, - max_length=max_length, - ) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) return BatchMixFeature(data={**text_inputs, **image_inputs}) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index ee28c01189b439..f73e8d24cbcd9c 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -27,7 +27,7 @@ import numpy as np from .dynamic_module_utils import custom_object_save -from .image_utils import ChannelDimension, is_vision_available, valid_images +from .image_utils import ChannelDimension, is_valid_image, is_vision_available if is_vision_available(): @@ -1003,6 +1003,20 @@ def _validate_images_text_input_order(images, text): in the processor's `__call__` method before calling this method. """ + def is_url(val) -> bool: + return isinstance(val, str) and val.startswith("http") + + def _is_valid_images_input_for_processor(imgs): + # If we have an list of images, make sure every image is valid + if isinstance(imgs, (list, tuple)): + for img in imgs: + if not _is_valid_images_input_for_processor(img): + return False + # If not a list of tuple, we have been given a single image or batched tensor of images + elif not (is_valid_image(imgs) or is_url(imgs)): + return False + return True + def _is_valid_text_input_for_processor(t): if isinstance(t, str): # Strings are fine @@ -1019,11 +1033,11 @@ def _is_valid_text_input_for_processor(t): def _is_valid(input, validator): return validator(input) or input is None - images_is_valid = _is_valid(images, valid_images) + images_is_valid = _is_valid(images, _is_valid_images_input_for_processor) images_is_text = _is_valid_text_input_for_processor(images) if not images_is_valid else False text_is_valid = _is_valid(text, _is_valid_text_input_for_processor) - text_is_images = valid_images(text) if not text_is_valid else False + text_is_images = _is_valid_images_input_for_processor(text) if not text_is_valid else False # Handle cases where both inputs are valid if images_is_valid and text_is_valid: return images, text diff --git a/tests/models/pixtral/test_processor_pixtral.py b/tests/models/pixtral/test_processor_pixtral.py index b70cab1c074480..04aa3ee8a38b4e 100644 --- a/tests/models/pixtral/test_processor_pixtral.py +++ b/tests/models/pixtral/test_processor_pixtral.py @@ -11,14 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import shutil +import tempfile import unittest import requests import torch -from transformers.testing_utils import require_vision +from transformers.testing_utils import ( + require_torch, + require_vision, +) from transformers.utils import is_vision_available +from ...test_processing_common import ProcessorTesterMixin + if is_vision_available(): from PIL import Image @@ -27,7 +34,7 @@ @require_vision -class PixtralProcessorTest(unittest.TestCase): +class PixtralProcessorTest(ProcessorTesterMixin, unittest.TestCase): processor_class = PixtralProcessor @classmethod @@ -40,15 +47,20 @@ def setUpClass(cls): cls.image_2 = Image.open(requests.get(cls.url_2, stream=True).raw) def setUp(self): - super().setUp() + self.tmpdirname = tempfile.mkdtemp() # FIXME - just load the processor directly from the checkpoint tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/pixtral-12b") image_processor = PixtralImageProcessor() - self.processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor) + processor = PixtralProcessor(tokenizer=tokenizer, image_processor=image_processor) + processor.save_pretrained(self.tmpdirname) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) @unittest.skip("No chat template was set for this model (yet)") def test_chat_template(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) expected_prompt = "USER: [IMG]\nWhat is shown in this image? ASSISTANT:" messages = [ @@ -60,11 +72,12 @@ def test_chat_template(self): ], }, ] - formatted_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) + formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) self.assertEqual(expected_prompt, formatted_prompt) @unittest.skip("No chat template was set for this model (yet)") def test_image_token_filling(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) # Important to check with non square image image = torch.randint(0, 2, (3, 500, 316)) expected_image_tokens = 1526 @@ -79,8 +92,8 @@ def test_image_token_filling(self): ], }, ] - inputs = self.processor( - text=[self.processor.apply_chat_template(messages)], + inputs = processor( + text=[processor.apply_chat_template(messages)], images=[image], return_tensors="pt", ) @@ -88,14 +101,15 @@ def test_image_token_filling(self): self.assertEqual(expected_image_tokens, image_tokens) def test_processor_with_single_image(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) prompt_string = "USER: [IMG]\nWhat's the content of the image? ASSISTANT:" # Make small for checking image token expansion - self.processor.image_processor.size = {"longest_edge": 30} - self.processor.image_processor.patch_size = {"height": 2, "width": 2} + processor.image_processor.size = {"longest_edge": 30} + processor.image_processor.patch_size = {"height": 2, "width": 2} # Test passing in an image - inputs_image = self.processor(text=prompt_string, images=self.image_0, return_tensors="pt") + inputs_image = processor(text=prompt_string, images=self.image_0, return_tensors="pt") self.assertIn("input_ids", inputs_image) self.assertTrue(len(inputs_image["input_ids"]) == 1) self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) @@ -115,7 +129,7 @@ def test_processor_with_single_image(self): # fmt: on # Test passing in a url - inputs_url = self.processor(text=prompt_string, images=self.url_0, return_tensors="pt") + inputs_url = processor(text=prompt_string, images=self.url_0, return_tensors="pt") self.assertIn("input_ids", inputs_url) self.assertTrue(len(inputs_url["input_ids"]) == 1) self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) @@ -135,14 +149,15 @@ def test_processor_with_single_image(self): # fmt: on def test_processor_with_multiple_images_single_list(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) prompt_string = "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:" # Make small for checking image token expansion - self.processor.image_processor.size = {"longest_edge": 30} - self.processor.image_processor.patch_size = {"height": 2, "width": 2} + processor.image_processor.size = {"longest_edge": 30} + processor.image_processor.patch_size = {"height": 2, "width": 2} # Test passing in an image - inputs_image = self.processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt") + inputs_image = processor(text=prompt_string, images=[self.image_0, self.image_1], return_tensors="pt") self.assertIn("input_ids", inputs_image) self.assertTrue(len(inputs_image["input_ids"]) == 1) self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) @@ -162,7 +177,7 @@ def test_processor_with_multiple_images_single_list(self): # fmt: on # Test passing in a url - inputs_url = self.processor(text=prompt_string, images=[self.url_0, self.url_1], return_tensors="pt") + inputs_url = processor(text=prompt_string, images=[self.url_0, self.url_1], return_tensors="pt") self.assertIn("input_ids", inputs_url) self.assertTrue(len(inputs_url["input_ids"]) == 1) self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) @@ -181,19 +196,20 @@ def test_processor_with_multiple_images_single_list(self): # fmt: on def test_processor_with_multiple_images_multiple_lists(self): + processor = self.processor_class.from_pretrained(self.tmpdirname) prompt_string = [ "USER: [IMG][IMG]\nWhat's the difference between these two images? ASSISTANT:", "USER: [IMG]\nWhat's the content of the image? ASSISTANT:", ] - self.processor.tokenizer.pad_token = "" + processor.tokenizer.pad_token = "" image_inputs = [[self.image_0, self.image_1], [self.image_2]] # Make small for checking image token expansion - self.processor.image_processor.size = {"longest_edge": 30} - self.processor.image_processor.patch_size = {"height": 2, "width": 2} + processor.image_processor.size = {"longest_edge": 30} + processor.image_processor.patch_size = {"height": 2, "width": 2} # Test passing in an image - inputs_image = self.processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True) + inputs_image = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True) self.assertIn("input_ids", inputs_image) self.assertTrue(len(inputs_image["input_ids"]) == 2) self.assertIsInstance(inputs_image["input_ids"], torch.Tensor) @@ -213,7 +229,7 @@ def test_processor_with_multiple_images_multiple_lists(self): # fmt: on # Test passing in a url - inputs_url = self.processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True) + inputs_url = processor(text=prompt_string, images=image_inputs, return_tensors="pt", padding=True) self.assertIn("input_ids", inputs_url) self.assertTrue(len(inputs_url["input_ids"]) == 2) self.assertIsInstance(inputs_url["input_ids"], torch.Tensor) @@ -231,3 +247,145 @@ def test_processor_with_multiple_images_multiple_lists(self): [21510, 1058, 1032, 10, 10, 12, 10, 10, 13, 10, 10, 12, 10, 10, 13, 1010, 7493, 1681, 1278, 6592, 2396, 2576, 2295, 8061, 1063, 1349, 4290, 16002, 41150, 1058] ) # fmt: on + + # Override all tests requiring shape as returning tensor batches is not supported by PixtralProcessor + + @require_torch + @require_vision + def test_image_processor_defaults_preserved_by_image_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", size={"height": 240, "width": 240}) + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + # Added dimension by pixtral image processor + self.assertEqual(len(inputs["pixel_values"][0][0][0][0]), 240) + + @require_torch + @require_vision + def test_kwargs_overrides_default_image_processor_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", size={"height": 400, "width": 400}) + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input, size={"height": 240, "width": 240}) + self.assertEqual(len(inputs["pixel_values"][0][0][0][0]), 240) + + @require_torch + @require_vision + def test_structured_kwargs_nested(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"size": {"height": 240, "width": 240}}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertEqual(inputs["pixel_values"][0][0].shape[-1], 240) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + @require_torch + @require_vision + def test_structured_kwargs_nested_from_dict(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"size": {"height": 240, "width": 240}}, + "text_kwargs": {"padding": "max_length", "max_length": 76}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.assertEqual(inputs["pixel_values"][0][0].shape[-1], 240) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + @require_torch + @require_vision + def test_unstructured_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + size={"height": 240, "width": 240}, + padding="max_length", + max_length=76, + ) + + self.assertEqual(inputs["pixel_values"][0][0].shape[-1], 240) + self.assertEqual(len(inputs["input_ids"][0]), 76) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + # images needs to be nested to detect multiple prompts + image_input = [self.prepare_image_inputs()] * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + size={"height": 240, "width": 240}, + padding="longest", + max_length=76, + ) + + self.assertEqual(inputs["pixel_values"][0][0].shape[-1], 240) + self.assertEqual(len(inputs["input_ids"][0]), 4) diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index b8ca7a6d6733fe..f3111f5d57851f 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -64,6 +64,8 @@ def get_component(self, attribute, **kwargs): component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa if attribute == "tokenizer" and not component.pad_token: component.pad_token = "[TEST_PAD]" + if component.pad_token_id is None: + component.pad_token_id = 0 return component From 9e6cda35897c609da8c44219f54a055009465188 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Mon, 16 Sep 2024 23:30:44 +0000 Subject: [PATCH 2/4] update doc --- docs/source/en/model_doc/pixtral.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/model_doc/pixtral.md b/docs/source/en/model_doc/pixtral.md index 8df2bf5af5f9ca..1c610d19b681c4 100644 --- a/docs/source/en/model_doc/pixtral.md +++ b/docs/source/en/model_doc/pixtral.md @@ -24,7 +24,7 @@ The Pixtral model was released by the Mistral AI team on [Vllm](https://github.c Tips: - Pixtral is a multimodal model, the main contribution is the 2d ROPE on the images, and support for arbitrary image size (the images are not padded together nor are they resized) -- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders. +- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders. - The format for one or mulitple prompts is the following: ``` "[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]" @@ -35,7 +35,7 @@ This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts) Here is an example of how to run it: -```python +```python from transformers import LlavaForConditionalGeneration, AutoProcessor from PIL import Image @@ -51,7 +51,7 @@ IMG_URLS = [ ] PROMPT = "[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]" -inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda") +inputs = processor(images=IMG_URLS, text=PROMPT, return_tensors="pt").to("cuda") generate_ids = model.generate(**inputs, max_new_tokens=500) ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] From b87a0d9623584a0f31213e9af9af9404778ae0cf Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 17 Sep 2024 13:28:31 +0000 Subject: [PATCH 3/4] fix _validate_images_text_input_order --- src/transformers/processing_utils.py | 4 ++-- tests/utils/test_processing_utils.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index f73e8d24cbcd9c..55dadb832d3141 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1034,10 +1034,10 @@ def _is_valid(input, validator): return validator(input) or input is None images_is_valid = _is_valid(images, _is_valid_images_input_for_processor) - images_is_text = _is_valid_text_input_for_processor(images) if not images_is_valid else False + images_is_text = _is_valid_text_input_for_processor(images) text_is_valid = _is_valid(text, _is_valid_text_input_for_processor) - text_is_images = _is_valid_images_input_for_processor(text) if not text_is_valid else False + text_is_images = _is_valid_images_input_for_processor(text) # Handle cases where both inputs are valid if images_is_valid and text_is_valid: return images, text diff --git a/tests/utils/test_processing_utils.py b/tests/utils/test_processing_utils.py index cf0e66cd7bd08f..f669da25385fab 100644 --- a/tests/utils/test_processing_utils.py +++ b/tests/utils/test_processing_utils.py @@ -80,6 +80,18 @@ def test_validate_images_text_input_order(self): self.assertTrue(np.array_equal(valid_images[0], images[0])) self.assertEqual(valid_text, text) + # list of strings and list of url images inputs + images = ["https://url1", "https://url2"] + text = ["text1", "text2"] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + # list of strings and nested list of numpy images inputs images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]] text = ["text1", "text2"] From 50847494d77ec5f050f4511960a11f4d2960bb1b Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 17 Sep 2024 14:50:13 +0000 Subject: [PATCH 4/4] nit --- src/transformers/processing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 55dadb832d3141..ddf40c4fe6d442 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1012,7 +1012,7 @@ def _is_valid_images_input_for_processor(imgs): for img in imgs: if not _is_valid_images_input_for_processor(img): return False - # If not a list of tuple, we have been given a single image or batched tensor of images + # If not a list or tuple, we have been given a single image or batched tensor of images elif not (is_valid_image(imgs) or is_url(imgs)): return False return True