diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 0f272f481c5d77..dc74bceb4c020a 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -129,7 +129,6 @@ def __call__( raise ValueError("You have to specify at least one of `images` or `text`.") # check if images and text inputs are reversed for BC - text, images = images, text images, text = _validate_images_text_input_order(images, text) output_kwargs = self._merge_kwargs( diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 28b023f2153ad1..8ce88e01421e2d 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -1011,12 +1011,13 @@ def _is_valid_text_input_for_processor(t): # ... list of strings or int (for encoded inputs) return True elif isinstance(t[0], (list, tuple)): - # ... list of list of strings or int (for encoded inputs) - return isinstance(t[0][0], (str, int)) - else: - return False - else: - return False + # ... list of list of strings or int (for list of encoded inputs) + if isinstance(t[0][0], (str, int)): + return True + elif isinstance(t[0][0], (list, tuple)): + # ... list of list of list of int (for list of list of encoded inputs) + return isinstance(t[0][0][0], int) + return False def _is_valid(input, validator): return validator(input) or input is None diff --git a/tests/utils/test_processing_utils.py b/tests/utils/test_processing_utils.py new file mode 100644 index 00000000000000..77dc710f497f61 --- /dev/null +++ b/tests/utils/test_processing_utils.py @@ -0,0 +1,162 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import PIL + +from transformers import is_torch_available +from transformers.processing_utils import _validate_images_text_input_order +from transformers.testing_utils import require_torch + + +if is_torch_available(): + import torch + + +class ProcessingUtilTester(unittest.TestCase): + def test_validate_images_text_input_order(self): + # text string and PIL images inputs + images = PIL.Image.new("RGB", (224, 224)) + text = "text" + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + + # text list of string and numpy images inputs + images = np.random.rand(224, 224, 3) + text = ["text1", "text2"] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertIsInstance(valid_images, np.ndarray) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertIsInstance(valid_images, np.ndarray) + self.assertEqual(valid_text, text) + + # text nested list of string and list of pil images inputs + images = [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))] + text = [["text1", "text2, text3"], ["text3", "text4"]] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + + # pretokenized text and list of numpy images inputs + images = [np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)] + text = list(range(10)) + print(type(text)) + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertIsInstance(valid_images[0], np.ndarray) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertIsInstance(valid_images[0], np.ndarray) + self.assertEqual(valid_text, text) + + # list of pretokenized text and nested list of numpy images inputs + images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]] + text = [list(range(10)), list(range(5))] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertIsInstance(valid_images[0][0], np.ndarray) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertIsInstance(valid_images[0][0], np.ndarray) + self.assertEqual(valid_text, text) + + # nested list of pretokenized text and nested list of PIL images inputs + images = [ + [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))], + [PIL.Image.new("RGB", (224, 224))], + ] + text = [[list(range(10)), list(range(5))], [list(range(10))]] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(valid_images, images) + self.assertEqual(valid_text, text) + + # None images + images = None + text = "text" + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(images, None) + self.assertEqual(text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(images, None) + self.assertEqual(text, text) + + # None text + images = PIL.Image.new("RGB", (224, 224)) + text = None + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertEqual(images, images) + self.assertEqual(text, None) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertEqual(images, images) + self.assertEqual(text, None) + + # incorrect inputs + images = "text" + text = "text" + with self.assertRaises(ValueError): + _validate_images_text_input_order(images=images, text=text) + + @require_torch + def test_validate_images_text_input_order_torch(self): + # text string and torch images inputs + images = torch.rand(224, 224, 3) + text = "text" + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertIsInstance(valid_images, torch.Tensor) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertIsInstance(valid_images, torch.Tensor) + self.assertEqual(valid_text, text) + + # text list of string and list of torch images inputs + images = [torch.rand(224, 224, 3), torch.rand(224, 224, 3)] + text = ["text1", "text2"] + # test correct text and images order + valid_images, valid_text = _validate_images_text_input_order(images=images, text=text) + self.assertIsInstance(valid_images[0], torch.Tensor) + self.assertEqual(valid_text, text) + # test incorrect text and images order + valid_images, valid_text = _validate_images_text_input_order(images=text, text=images) + self.assertIsInstance(valid_images[0], torch.Tensor) + self.assertEqual(valid_text, text)