Skip to content

Commit

Permalink
Add validate images and test processing utils
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Sep 3, 2024
1 parent ecd61c6 commit 81c25de
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 1 deletion.
54 changes: 53 additions & 1 deletion src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np

from .dynamic_module_utils import custom_object_save
from .image_utils import ChannelDimension, is_vision_available
from .image_utils import ChannelDimension, is_vision_available, valid_images


if is_vision_available():
Expand Down Expand Up @@ -993,6 +993,58 @@ def apply_chat_template(
)


def _validate_images_text_input_order(images, text):
"""
For backward compatibility: reverse the order of `images` and `text` inputs if they are swapped.
This method should only be called for processors where `images` and `text` have been swapped for uniformization purposes.
Note that this method assumes that two `None` inputs are valid inputs. If this is not the case, it should be handled
in the processor's `__call__` method before calling this method.
"""

def _is_valid_text_input_for_processor(t):
if isinstance(t, str):
# Strings are fine
return True
elif isinstance(t, (list, tuple)):
# List are fine as long as they are...
if len(t) == 0:
# ... not empty
return False
elif isinstance(t[0], (str, int)):
# ... list of strings or int (for encoded inputs)
return True
elif isinstance(t[0], (list, tuple)):
# ... list of list of strings or int (for list of encoded inputs)
if isinstance(t[0][0], (str, int)):
return True
elif isinstance(t[0][0], (list, tuple)):
# ... list of list of list of int (for list of list of encoded inputs)
return isinstance(t[0][0][0], int)
return False

def _is_valid(input, validator):
return validator(input) or input is None

images_is_valid = _is_valid(images, valid_images)
images_is_text = _is_valid_text_input_for_processor(images) if not images_is_valid else False

text_is_valid = _is_valid(text, _is_valid_text_input_for_processor)
text_is_images = valid_images(text) if not text_is_valid else False
# Handle cases where both inputs are valid
if images_is_valid and text_is_valid:
return images, text

# Handle cases where inputs need to and can be swapped
if (images is None and text_is_images) or (text is None and images_is_text) or (images_is_text and text_is_images):
logger.warning_once(
"You may have used the wrong order for inputs. `images` should be passed before `text`. "
"The `images` and `text` inputs will be swapped. This behavior will be deprecated in transformers v4.47."
)
return text, images

raise ValueError("Invalid input type. Check that `images` and/or `text` are valid inputs.")


ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
if ProcessorMixin.push_to_hub.__doc__ is not None:
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
Expand Down
165 changes: 165 additions & 0 deletions tests/utils/test_processing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

from transformers import is_torch_available, is_vision_available
from transformers.processing_utils import _validate_images_text_input_order
from transformers.testing_utils import require_torch, require_vision


if is_vision_available():
import PIL

if is_torch_available():
import torch


@require_vision
class ProcessingUtilTester(unittest.TestCase):
def test_validate_images_text_input_order(self):
# text string and PIL images inputs
images = PIL.Image.new("RGB", (224, 224))
text = "text"
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)

# text list of string and numpy images inputs
images = np.random.rand(224, 224, 3)
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(np.array_equal(valid_images, images))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(np.array_equal(valid_images, images))
self.assertEqual(valid_text, text)

# text nested list of string and list of pil images inputs
images = [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))]
text = [["text1", "text2, text3"], ["text3", "text4"]]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)

# pretokenized text and list of numpy images inputs
images = [np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)]
text = list(range(10))
print(type(text))
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(np.array_equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(np.array_equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)

# list of pretokenized text and nested list of numpy images inputs
images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]]
text = [list(range(10)), list(range(5))]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(np.array_equal(valid_images[0][0], images[0][0]))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(np.array_equal(valid_images[0][0], images[0][0]))
self.assertEqual(valid_text, text)

# nested list of pretokenized text and nested list of PIL images inputs
images = [
[PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))],
[PIL.Image.new("RGB", (224, 224))],
]
text = [[list(range(10)), list(range(5))], [list(range(10))]]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)

# None images
images = None
text = "text"
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(images, None)
self.assertEqual(text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(images, None)
self.assertEqual(text, text)

# None text
images = PIL.Image.new("RGB", (224, 224))
text = None
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(images, images)
self.assertEqual(text, None)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(images, images)
self.assertEqual(text, None)

# incorrect inputs
images = "text"
text = "text"
with self.assertRaises(ValueError):
_validate_images_text_input_order(images=images, text=text)

@require_torch
def test_validate_images_text_input_order_torch(self):
# text string and torch images inputs
images = torch.rand(224, 224, 3)
text = "text"
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(torch.equal(valid_images, images))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(torch.equal(valid_images, images))
self.assertEqual(valid_text, text)

# text list of string and list of torch images inputs
images = [torch.rand(224, 224, 3), torch.rand(224, 224, 3)]
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(torch.equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(torch.equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)

0 comments on commit 81c25de

Please sign in to comment.