Skip to content

Commit

Permalink
Add ProcessingUtilTester
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Aug 28, 2024
1 parent 445e472 commit 1263a21
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 7 deletions.
1 change: 0 additions & 1 deletion src/transformers/models/llava/processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def __call__(
raise ValueError("You have to specify at least one of `images` or `text`.")

# check if images and text inputs are reversed for BC
text, images = images, text
images, text = _validate_images_text_input_order(images, text)

output_kwargs = self._merge_kwargs(
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,12 +1011,13 @@ def _is_valid_text_input_for_processor(t):
# ... list of strings or int (for encoded inputs)
return True
elif isinstance(t[0], (list, tuple)):
# ... list of list of strings or int (for encoded inputs)
return isinstance(t[0][0], (str, int))
else:
return False
else:
return False
# ... list of list of strings or int (for list of encoded inputs)
if isinstance(t[0][0], (str, int)):
return True
elif isinstance(t[0][0], (list, tuple)):
# ... list of list of list of int (for list of list of encoded inputs)
return isinstance(t[0][0][0], int)
return False

def _is_valid(input, validator):
return validator(input) or input is None
Expand Down
162 changes: 162 additions & 0 deletions tests/utils/test_processing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import PIL

from transformers import is_torch_available
from transformers.processing_utils import _validate_images_text_input_order
from transformers.testing_utils import require_torch


if is_torch_available():
import torch


class ProcessingUtilTester(unittest.TestCase):
def test_validate_images_text_input_order(self):
# text string and PIL images inputs
images = PIL.Image.new("RGB", (224, 224))
text = "text"
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)

# text list of string and numpy images inputs
images = np.random.rand(224, 224, 3)
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertIsInstance(valid_images, np.ndarray)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertIsInstance(valid_images, np.ndarray)
self.assertEqual(valid_text, text)

# text nested list of string and list of pil images inputs
images = [PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))]
text = [["text1", "text2, text3"], ["text3", "text4"]]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)

# pretokenized text and list of numpy images inputs
images = [np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)]
text = list(range(10))
print(type(text))
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertIsInstance(valid_images[0], np.ndarray)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertIsInstance(valid_images[0], np.ndarray)
self.assertEqual(valid_text, text)

# list of pretokenized text and nested list of numpy images inputs
images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]]
text = [list(range(10)), list(range(5))]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertIsInstance(valid_images[0][0], np.ndarray)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertIsInstance(valid_images[0][0], np.ndarray)
self.assertEqual(valid_text, text)

# nested list of pretokenized text and nested list of PIL images inputs
images = [
[PIL.Image.new("RGB", (224, 224)), PIL.Image.new("RGB", (224, 224))],
[PIL.Image.new("RGB", (224, 224))],
]
text = [[list(range(10)), list(range(5))], [list(range(10))]]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(valid_images, images)
self.assertEqual(valid_text, text)

# None images
images = None
text = "text"
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(images, None)
self.assertEqual(text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(images, None)
self.assertEqual(text, text)

# None text
images = PIL.Image.new("RGB", (224, 224))
text = None
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertEqual(images, images)
self.assertEqual(text, None)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertEqual(images, images)
self.assertEqual(text, None)

# incorrect inputs
images = "text"
text = "text"
with self.assertRaises(ValueError):
_validate_images_text_input_order(images=images, text=text)

@require_torch
def test_validate_images_text_input_order_torch(self):
# text string and torch images inputs
images = torch.rand(224, 224, 3)
text = "text"
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertIsInstance(valid_images, torch.Tensor)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertIsInstance(valid_images, torch.Tensor)
self.assertEqual(valid_text, text)

# text list of string and list of torch images inputs
images = [torch.rand(224, 224, 3), torch.rand(224, 224, 3)]
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertIsInstance(valid_images[0], torch.Tensor)
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertIsInstance(valid_images[0], torch.Tensor)
self.assertEqual(valid_text, text)

0 comments on commit 1263a21

Please sign in to comment.