From 9230d78e76611cfa38c845213021aeb185362d10 Mon Sep 17 00:00:00 2001 From: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Date: Wed, 4 Sep 2024 13:50:31 -0400 Subject: [PATCH] Add validate images and text inputs order util for processors and test_processing_utils (#33285) * Add validate images and test processing utils * Remove encoded text from possible inputs in tests * Removed encoded inputs as valid in processing_utils * change text input check to be recursive * change text check to all element of lists and not just the first one in recursive checks --- src/transformers/processing_utils.py | 46 +++++++- tests/utils/test_processing_utils.py | 164 +++++++++++++++++++++++++++ 2 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 tests/utils/test_processing_utils.py diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 60085fd00705b0..59a1ba98c73fdc 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -27,7 +27,7 @@ import numpy as np from .dynamic_module_utils import custom_object_save -from .image_utils import ChannelDimension, is_vision_available +from .image_utils import ChannelDimension, is_vision_available, valid_images if is_vision_available(): @@ -993,6 +993,50 @@ def apply_chat_template( ) +def _validate_images_text_input_order(images, text): + """ + For backward compatibility: reverse the order of `images` and `text` inputs if they are swapped. + This method should only be called for processors where `images` and `text` have been swapped for uniformization purposes. + Note that this method assumes that two `None` inputs are valid inputs. If this is not the case, it should be handled + in the processor's `__call__` method before calling this method. + """ + + def _is_valid_text_input_for_processor(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... not empty + return False + for t_s in t: + return _is_valid_text_input_for_processor(t_s) + return False + + def _is_valid(input, validator): + return validator(input) or input is None + + images_is_valid = _is_valid(images, valid_images) + images_is_text = _is_valid_text_input_for_processor(images) if not images_is_valid else False + + text_is_valid = _is_valid(text, _is_valid_text_input_for_processor) + text_is_images = valid_images(text) if not text_is_valid else False + # Handle cases where both inputs are valid + if images_is_valid and text_is_valid: + return images, text + + # Handle cases where inputs need to and can be swapped + if (images is None and text_is_images) or (text is None and images_is_text) or (images_is_text and text_is_images): + logger.warning_once( + "You may have used the wrong order for inputs. `images` should be passed before `text`. " + "The `images` and `text` inputs will be swapped. This behavior will be deprecated in transformers v4.47." + ) + return text, images + + raise ValueError("Invalid input type. Check that `images` and/or `text` are valid inputs.") + + ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub) if ProcessorMixin.push_to_hub.__doc__ is not None: ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format( diff --git a/tests/utils/test_processing_utils.py b/tests/utils/test_processing_utils.py new file mode 100644 index 00000000000000..cf0e66cd7bd08f --- /dev/null +++ b/tests/utils/test_processing_utils.py @@ -0,0 +1,164 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from transformers import is_torch_available, is_vision_available +from transformers.processing_utils import _validate_images_text_input_order +from transformers.testing_utils import require_torch, require_vision + + +if is_vision_available(): + import PIL + +if is_torch_available(): + import torch + + +@require_vision +class ProcessingUtilTester(unittest.TestCase): + def test_validate_images_text_input_order(self): + # text string and PIL images inputs + images = PIL.Image.new("RGB", (224, 224)) + text = "text" + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + + # text list of string and numpy images inputs + images = np.random.rand(224, 224, 3) + text = ["text1", "text2"] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertTrue(np.array_equal(valid_images, images)) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertTrue(np.array_equal(valid_images, images)) + self.assertEqual(valid_text, text) + + # text nested list of string and list of pil images inputs + images = [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))] + text = [["text1", "text2, text3"], ["text3", "text4"]] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + + # list of strings and list of numpy images inputs + images = [np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)] + text = ["text1", "text2"] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertTrue(np.array_equal(valid_images[0], images[0])) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertTrue(np.array_equal(valid_images[0], images[0])) + self.assertEqual(valid_text, text) + + # list of strings and nested list of numpy images inputs + images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]] + text = ["text1", "text2"] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertTrue(np.array_equal(valid_images[0][0], images[0][0])) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertTrue(np.array_equal(valid_images[0][0], images[0][0])) + self.assertEqual(valid_text, text) + + # nested list of strings and nested list of PIL images inputs + images = [ + [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))], + [PIL.Image.new("RGB", (224, 224))], + ] + text = [["text1", "text2, text3"], ["text3", "text4"]] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + + # None images + images = None + text = "text" + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(images, None) + self.assertEqual(text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(images, None) + self.assertEqual(text, text) + + # None text + images = PIL.Image.new("RGB", (224, 224)) + text = None + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(images, images) + self.assertEqual(text, None) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(images, images) + self.assertEqual(text, None) + + # incorrect inputs + images = "text" + text = "text" + with self.assertRaises(ValueError): + _validate_images_text_input_order(images=images, text=text) + + @require_torch + def test_validate_images_text_input_order_torch(self): + # text string and torch images inputs + images = torch.rand(224, 224, 3) + text = "text" + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertTrue(torch.equal(valid_images, images)) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertTrue(torch.equal(valid_images, images)) + self.assertEqual(valid_text, text) + + # text list of string and list of torch images inputs + images = [torch.rand(224, 224, 3), torch.rand(224, 224, 3)] + text = ["text1", "text2"] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertTrue(torch.equal(valid_images[0], images[0])) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertTrue(torch.equal(valid_images[0], images[0])) + self.assertEqual(valid_text, text)