Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validate images and text inputs order util for processors and test_processing_utils #33285

Merged
merged 5 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np

from .dynamic_module_utils import custom_object_save
from .image_utils import ChannelDimension, is_vision_available
from .image_utils import ChannelDimension, is_vision_available, valid_images


if is_vision_available():
Expand Down Expand Up @@ -993,6 +993,50 @@ def apply_chat_template(
)


def _validate_images_text_input_order(images, text):
"""
For backward compatibility: reverse the order of `images` and `text` inputs if they are swapped.
This method should only be called for processors where `images` and `text` have been swapped for uniformization purposes.
Note that this method assumes that two `None` inputs are valid inputs. If this is not the case, it should be handled
in the processor's `__call__` method before calling this method.
"""

def _is_valid_text_input_for_processor(t):
if isinstance(t, str):
# Strings are fine
return True
elif isinstance(t, (list, tuple)):
# List are fine as long as they are...
if len(t) == 0:
# ... not empty
return False
for t_s in t:
return _is_valid_text_input_for_processor(t_s)
return False

def _is_valid(input, validator):
return validator(input) or input is None

images_is_valid = _is_valid(images, valid_images)
images_is_text = _is_valid_text_input_for_processor(images) if not images_is_valid else False

text_is_valid = _is_valid(text, _is_valid_text_input_for_processor)
text_is_images = valid_images(text) if not text_is_valid else False
# Handle cases where both inputs are valid
if images_is_valid and text_is_valid:
return images, text

# Handle cases where inputs need to and can be swapped
if (images is None and text_is_images) or (text is None and images_is_text) or (images_is_text and text_is_images):
logger.warning_once(
"You may have used the wrong order for inputs. `images` should be passed before `text`. "
"The `images` and `text` inputs will be swapped. This behavior will be deprecated in transformers v4.47."
)
return text, images

raise ValueError("Invalid input type. Check that `images` and/or `text` are valid inputs.")


ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
if ProcessorMixin.push_to_hub.__doc__ is not None:
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
Expand Down
164 changes: 164 additions & 0 deletions tests/utils/test_processing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

from transformers import is_torch_available, is_vision_available
from transformers.processing_utils import _validate_images_text_input_order
from transformers.testing_utils import require_torch, require_vision


if is_vision_available():
import PIL

if is_torch_available():
import torch


@require_vision
class ProcessingUtilTester(unittest.TestCase):
def test_validate_images_text_input_order(self):
# text string and PIL images inputs
images = PIL.Image.new("RGB", (224, 224))
text = "text"
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)

# text list of string and numpy images inputs
images = np.random.rand(224, 224, 3)
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(np.array_equal(valid_images, images))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(np.array_equal(valid_images, images))
self.assertEqual(valid_text, text)

# text nested list of string and list of pil images inputs
images = [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))]
text = [["text1", "text2, text3"], ["text3", "text4"]]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)

# list of strings and list of numpy images inputs
images = [np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)]
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(np.array_equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(np.array_equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)

# list of strings and nested list of numpy images inputs
images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]]
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(np.array_equal(valid_images[0][0], images[0][0]))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(np.array_equal(valid_images[0][0], images[0][0]))
self.assertEqual(valid_text, text)

# nested list of strings and nested list of PIL images inputs
images = [
[PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))],
[PIL.Image.new("RGB", (224, 224))],
]
text = [["text1", "text2, text3"], ["text3", "text4"]]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)

# None images
images = None
text = "text"
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(images, None)
self.assertEqual(text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(images, None)
self.assertEqual(text, text)

# None text
images = PIL.Image.new("RGB", (224, 224))
text = None
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(images, images)
self.assertEqual(text, None)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(images, images)
self.assertEqual(text, None)

# incorrect inputs
images = "text"
text = "text"
with self.assertRaises(ValueError):
_validate_images_text_input_order(images=images, text=text)

@require_torch
def test_validate_images_text_input_order_torch(self):
# text string and torch images inputs
images = torch.rand(224, 224, 3)
text = "text"
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(torch.equal(valid_images, images))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(torch.equal(valid_images, images))
self.assertEqual(valid_text, text)

# text list of string and list of torch images inputs
images = [torch.rand(224, 224, 3), torch.rand(224, 224, 3)]
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(torch.equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(torch.equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)
Loading