Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validate images and text inputs order util for processors and test_processing_utils #33285

Merged
merged 5 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np

from .dynamic_module_utils import custom_object_save
from .image_utils import ChannelDimension, is_vision_available
from .image_utils import ChannelDimension, is_vision_available, valid_images


if is_vision_available():
Expand Down Expand Up @@ -993,6 +993,54 @@ def apply_chat_template(
)


def _validate_images_text_input_order(images, text):
"""
For backward compatibility: reverse the order of `images` and `text` inputs if they are swapped.
This method should only be called for processors where `images` and `text` have been swapped for uniformization purposes.
Note that this method assumes that two `None` inputs are valid inputs. If this is not the case, it should be handled
in the processor's `__call__` method before calling this method.
"""

def _is_valid_text_input_for_processor(t):
if isinstance(t, str):
# Strings are fine
return True
elif isinstance(t, (list, tuple)):
# List are fine as long as they are...
if len(t) == 0:
# ... not empty
return False
elif isinstance(t[0], str):
# ... list of strings
return True
elif isinstance(t[0], (list, tuple)):
# ... list of list of strings
return isinstance(t[0][0], str)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be re-written to be recursive

return False

def _is_valid(input, validator):
return validator(input) or input is None

images_is_valid = _is_valid(images, valid_images)
images_is_text = _is_valid_text_input_for_processor(images) if not images_is_valid else False

text_is_valid = _is_valid(text, _is_valid_text_input_for_processor)
text_is_images = valid_images(text) if not text_is_valid else False
# Handle cases where both inputs are valid
if images_is_valid and text_is_valid:
return images, text

# Handle cases where inputs need to and can be swapped
if (images is None and text_is_images) or (text is None and images_is_text) or (images_is_text and text_is_images):
logger.warning_once(
"You may have used the wrong order for inputs. `images` should be passed before `text`. "
"The `images` and `text` inputs will be swapped. This behavior will be deprecated in transformers v4.47."
)
return text, images

raise ValueError("Invalid input type. Check that `images` and/or `text` are valid inputs.")


ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
if ProcessorMixin.push_to_hub.__doc__ is not None:
ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
Expand Down
164 changes: 164 additions & 0 deletions tests/utils/test_processing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

from transformers import is_torch_available, is_vision_available
from transformers.processing_utils import _validate_images_text_input_order
from transformers.testing_utils import require_torch, require_vision


if is_vision_available():
import PIL

if is_torch_available():
import torch


@require_vision
class ProcessingUtilTester(unittest.TestCase):
def test_validate_images_text_input_order(self):
# text string and PIL images inputs
images = PIL.Image.new("RGB", (224, 224))
text = "text"
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)

# text list of string and numpy images inputs
images = np.random.rand(224, 224, 3)
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(np.array_equal(valid_images, images))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(np.array_equal(valid_images, images))
self.assertEqual(valid_text, text)

# text nested list of string and list of pil images inputs
images = [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))]
text = [["text1", "text2, text3"], ["text3", "text4"]]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)

# list of strings and list of numpy images inputs
images = [np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)]
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(np.array_equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(np.array_equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)

# list of strings and nested list of numpy images inputs
images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]]
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(np.array_equal(valid_images[0][0], images[0][0]))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(np.array_equal(valid_images[0][0], images[0][0]))
self.assertEqual(valid_text, text)

# nested list of strings and nested list of PIL images inputs
images = [
[PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))],
[PIL.Image.new("RGB", (224, 224))],
]
text = [["text1", "text2, text3"], ["text3", "text4"]]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)

# None images
images = None
text = "text"
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(images, None)
self.assertEqual(text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(images, None)
self.assertEqual(text, text)

# None text
images = PIL.Image.new("RGB", (224, 224))
text = None
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(images, images)
self.assertEqual(text, None)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(images, images)
self.assertEqual(text, None)

# incorrect inputs
images = "text"
text = "text"
with self.assertRaises(ValueError):
_validate_images_text_input_order(images=images, text=text)

@require_torch
def test_validate_images_text_input_order_torch(self):
# text string and torch images inputs
images = torch.rand(224, 224, 3)
text = "text"
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(torch.equal(valid_images, images))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(torch.equal(valid_images, images))
self.assertEqual(valid_text, text)

# text list of string and list of torch images inputs
images = [torch.rand(224, 224, 3), torch.rand(224, 224, 3)]
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertTrue(torch.equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertTrue(torch.equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)
Loading