From 81c25de0948cf46a285d0dc7e18ee4597701b410 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 3 Sep 2024 20:39:48 +0000 Subject: [PATCH] Add validate images and test processing utils --- src/transformers/processing_utils.py | 54 ++++++++- tests/utils/test_processing_utils.py | 165 +++++++++++++++++++++++++++ 2 files changed, 218 insertions(+), 1 deletion(-) create mode 100644 tests/utils/test_processing_utils.py diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 60085fd00705b0..e8ec9684abff34 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -27,7 +27,7 @@ import numpy as np from .dynamic_module_utils import custom_object_save -from .image_utils import ChannelDimension, is_vision_available +from .image_utils import ChannelDimension, is_vision_available, valid_images if is_vision_available(): @@ -993,6 +993,58 @@ def apply_chat_template( ) +def _validate_images_text_input_order(images, text): + """ + For backward compatibility: reverse the order of `images` and `text` inputs if they are swapped. + This method should only be called for processors where `images` and `text` have been swapped for uniformization purposes. + Note that this method assumes that two `None` inputs are valid inputs. If this is not the case, it should be handled + in the processor's `__call__` method before calling this method. + """ + + def _is_valid_text_input_for_processor(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... not empty + return False + elif isinstance(t[0], (str, int)): + # ... list of strings or int (for encoded inputs) + return True + elif isinstance(t[0], (list, tuple)): + # ... list of list of strings or int (for list of encoded inputs) + if isinstance(t[0][0], (str, int)): + return True + elif isinstance(t[0][0], (list, tuple)): + # ... list of list of list of int (for list of list of encoded inputs) + return isinstance(t[0][0][0], int) + return False + + def _is_valid(input, validator): + return validator(input) or input is None + + images_is_valid = _is_valid(images, valid_images) + images_is_text = _is_valid_text_input_for_processor(images) if not images_is_valid else False + + text_is_valid = _is_valid(text, _is_valid_text_input_for_processor) + text_is_images = valid_images(text) if not text_is_valid else False + # Handle cases where both inputs are valid + if images_is_valid and text_is_valid: + return images, text + + # Handle cases where inputs need to and can be swapped + if (images is None and text_is_images) or (text is None and images_is_text) or (images_is_text and text_is_images): + logger.warning_once( + "You may have used the wrong order for inputs. `images` should be passed before `text`. " + "The `images` and `text` inputs will be swapped. This behavior will be deprecated in transformers v4.47." + ) + return text, images + + raise ValueError("Invalid input type. Check that `images` and/or `text` are valid inputs.") + + ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub) if ProcessorMixin.push_to_hub.__doc__ is not None: ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format( diff --git a/tests/utils/test_processing_utils.py b/tests/utils/test_processing_utils.py new file mode 100644 index 00000000000000..7ece8a45875312 --- /dev/null +++ b/tests/utils/test_processing_utils.py @@ -0,0 +1,165 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from transformers import is_torch_available, is_vision_available +from transformers.processing_utils import _validate_images_text_input_order +from transformers.testing_utils import require_torch, require_vision + + +if is_vision_available(): + import PIL + +if is_torch_available(): + import torch + + +@require_vision +class ProcessingUtilTester(unittest.TestCase): + def test_validate_images_text_input_order(self): + # text string and PIL images inputs + images = PIL.Image.new("RGB", (224, 224)) + text = "text" + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + + # text list of string and numpy images inputs + images = np.random.rand(224, 224, 3) + text = ["text1", "text2"] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertTrue(np.array_equal(valid_images, images)) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertTrue(np.array_equal(valid_images, images)) + self.assertEqual(valid_text, text) + + # text nested list of string and list of pil images inputs + images = [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))] + text = [["text1", "text2, text3"], ["text3", "text4"]] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + + # pretokenized text and list of numpy images inputs + images = [np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)] + text = list(range(10)) + print(type(text)) + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertTrue(np.array_equal(valid_images[0], images[0])) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertTrue(np.array_equal(valid_images[0], images[0])) + self.assertEqual(valid_text, text) + + # list of pretokenized text and nested list of numpy images inputs + images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]] + text = [list(range(10)), list(range(5))] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertTrue(np.array_equal(valid_images[0][0], images[0][0])) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertTrue(np.array_equal(valid_images[0][0], images[0][0])) + self.assertEqual(valid_text, text) + + # nested list of pretokenized text and nested list of PIL images inputs + images = [ + [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))], + [PIL.Image.new("RGB", (224, 224))], + ] + text = [[list(range(10)), list(range(5))], [list(range(10))]] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + + # None images + images = None + text = "text" + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(images, None) + self.assertEqual(text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(images, None) + self.assertEqual(text, text) + + # None text + images = PIL.Image.new("RGB", (224, 224)) + text = None + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(images, images) + self.assertEqual(text, None) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(images, images) + self.assertEqual(text, None) + + # incorrect inputs + images = "text" + text = "text" + with self.assertRaises(ValueError): + _validate_images_text_input_order(images=images, text=text) + + @require_torch + def test_validate_images_text_input_order_torch(self): + # text string and torch images inputs + images = torch.rand(224, 224, 3) + text = "text" + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertTrue(torch.equal(valid_images, images)) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertTrue(torch.equal(valid_images, images)) + self.assertEqual(valid_text, text) + + # text list of string and list of torch images inputs + images = [torch.rand(224, 224, 3), torch.rand(224, 224, 3)] + text = ["text1", "text2"] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertTrue(torch.equal(valid_images[0], images[0])) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertTrue(torch.equal(valid_images[0], images[0])) + self.assertEqual(valid_text, text)