Skip to content

Commit

Permalink
Fix test processing utils
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Sep 3, 2024
1 parent 1263a21 commit 146a50a
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions tests/utils/test_processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@
import unittest

import numpy as np
import PIL

from transformers import is_torch_available
from transformers import is_torch_available, is_vision_available
from transformers.processing_utils import _validate_images_text_input_order
from transformers.testing_utils import require_torch
from transformers.testing_utils import require_torch, require_vision


if is_vision_available():
import PIL

if is_torch_available():
import torch


@require_vision
class ProcessingUtilTester(unittest.TestCase):
def test_validate_images_text_input_order(self):
# text string and PIL images inputs
Expand All @@ -46,11 +49,11 @@ def test_validate_images_text_input_order(self):
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertIsInstance(valid_images, np.ndarray)
self.assertTrue(np.array_equal(valid_images, images))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertIsInstance(valid_images, np.ndarray)
self.assertTrue(np.array_equal(valid_images, images))
self.assertEqual(valid_text, text)

# text nested list of string and list of pil images inputs
Expand All @@ -71,23 +74,23 @@ def test_validate_images_text_input_order(self):
print(type(text))
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertIsInstance(valid_images[0], np.ndarray)
self.assertTrue(np.array_equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertIsInstance(valid_images[0], np.ndarray)
self.assertTrue(np.array_equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)

# list of pretokenized text and nested list of numpy images inputs
images = [[np.random.rand(224, 224, 3), np.random.rand(224, 224, 3)], [np.random.rand(224, 224, 3)]]
text = [list(range(10)), list(range(5))]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertIsInstance(valid_images[0][0], np.ndarray)
self.assertTrue(np.array_equal(valid_images[0][0], images[0][0]))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertIsInstance(valid_images[0][0], np.ndarray)
self.assertTrue(np.array_equal(valid_images[0][0], images[0][0]))
self.assertEqual(valid_text, text)

# nested list of pretokenized text and nested list of PIL images inputs
Expand Down Expand Up @@ -142,21 +145,21 @@ def test_validate_images_text_input_order_torch(self):
text = "text"
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertIsInstance(valid_images, torch.Tensor)
self.assertTrue(torch.equal(valid_images, images))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertIsInstance(valid_images, torch.Tensor)
self.assertTrue(torch.equal(valid_images, images))
self.assertEqual(valid_text, text)

# text list of string and list of torch images inputs
images = [torch.rand(224, 224, 3), torch.rand(224, 224, 3)]
text = ["text1", "text2"]
# test correct text and images order
valid_images, valid_text = _validate_images_text_input_order(images=images, text=text)
self.assertIsInstance(valid_images[0], torch.Tensor)
self.assertTrue(torch.equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)
# test incorrect text and images order
valid_images, valid_text = _validate_images_text_input_order(images=text, text=images)
self.assertIsInstance(valid_images[0], torch.Tensor)
self.assertTrue(torch.equal(valid_images[0], images[0]))
self.assertEqual(valid_text, text)

0 comments on commit 146a50a

Please sign in to comment.