Skip to content

Commit

Permalink
Use non nested images and batched text Idefics2/3 (#34222)
Browse files Browse the repository at this point in the history
* add support for non nested images and add tests

* add tests error scenario

* fix style

* added single and no image to error tests
  • Loading branch information
yonigozlan authored Oct 25, 2024
1 parent 3d99f17 commit 940a6bd
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
isinstance(images, (list, tuple))
and len(images) > 0
and isinstance(images[0], (list, tuple))
and len(images[0]) > 0
and is_valid_image(images[0][0])
):
pass
Expand Down
17 changes: 16 additions & 1 deletion src/transformers/models/idefics2/processing_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Processor class for IDEFICS2.
"""

from itertools import accumulate
from typing import TYPE_CHECKING, List, Optional, Union

from ...feature_extraction_utils import BatchFeature
Expand Down Expand Up @@ -218,7 +219,21 @@ def __call__(
if is_image_or_image_url(images):
images = [[images]]
elif isinstance(images, list) and is_image_or_image_url(images[0]):
images = [images]
if text is not None:
if sum(n_images_in_text) != len(images):
raise ValueError(
f"The total number of {image_token} tokens in the prompts should be the same as the number of images passed."
f" Found {sum(n_images_in_text)} {image_token} tokens and {len(images)} images."
)
# Reorganize the images to match the prompts
cumsum_images_in_text = [0] + list(accumulate(n_images_in_text))
images = [
images[cumsum_images_in_text[i] : cumsum_images_in_text[i + 1]]
for i in range(len(n_images_in_text))
]
else:
images = [images]

elif (
not isinstance(images, list)
and not isinstance(images[0], list)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/idefics3/image_processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,11 @@ def get_resize_output_image_size(
def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
"""
Convert a single image or a list of images to a list of numpy arrays.
Args:
images (`ImageInput`):
A single image or a list of images.
Returns:
A list of numpy arrays.
"""
Expand All @@ -168,6 +170,7 @@ def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
isinstance(images, (list, tuple))
and len(images) > 0
and isinstance(images[0], (list, tuple))
and len(images[0]) > 0
and is_valid_image(images[0][0])
):
pass
Expand Down
38 changes: 26 additions & 12 deletions src/transformers/models/idefics3/processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import re
from itertools import accumulate
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from ...feature_extraction_utils import BatchFeature
Expand Down Expand Up @@ -241,11 +242,31 @@ def __call__(
n_images_in_images = []
inputs = BatchFeature()

if text is not None:
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
n_images_in_text = [sample.count(self.image_token.content) for sample in text]

if images is not None:
if is_image_or_image_url(images):
images = [[images]]
elif isinstance(images, list) and is_image_or_image_url(images[0]):
images = [images]
if text is not None:
if sum(n_images_in_text) != len(images):
raise ValueError(
f"The total number of {self.image_token.content} tokens in the prompts should be the same as the number of images passed."
f" Found {sum(n_images_in_text)} {self.image_token.content} tokens and {len(images)} images."
)
# Reorganize the images to match the prompts
cumsum_images_in_text = [0] + list(accumulate(n_images_in_text))
images = [
images[cumsum_images_in_text[i] : cumsum_images_in_text[i + 1]]
for i in range(len(n_images_in_text))
]
else:
images = [images]
elif (
not isinstance(images, list)
and not isinstance(images[0], list)
Expand All @@ -263,10 +284,10 @@ def __call__(
inputs.update(image_inputs)

if text is not None:
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")
if n_images_in_images != n_images_in_text:
raise ValueError(
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
)

image_rows = inputs.pop("rows", [[0] * len(text)])
image_cols = inputs.pop("cols", [[0] * len(text)])
Expand All @@ -277,8 +298,6 @@ def __call__(

prompt_strings = []
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
n_images_in_text.append(sample.count(image_token))

# Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len`
image_prompt_strings = []
for n_rows, n_cols in zip(sample_rows, sample_cols):
Expand All @@ -305,11 +324,6 @@ def __call__(
text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"])
inputs.update(text_inputs)

if n_images_in_images != n_images_in_text:
raise ValueError(
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
)

return inputs

def batch_decode(self, *args, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
isinstance(images, (list, tuple))
and len(images) > 0
and isinstance(images[0], (list, tuple))
and len(images[0]) > 0
and is_valid_image(images[0][0])
):
pass
Expand Down
77 changes: 67 additions & 10 deletions tests/models/idefics2/test_processor_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,73 @@ def test_add_special_tokens_processor(self):
self.assertEqual(inputs["input_ids"], expected_input_ids)
# fmt: on

def test_non_nested_images_with_batched_text(self):
processor = self.get_processor()
processor.image_processor.do_image_splitting = False

image_str = "<image>"
text_str_1 = "In this image, we see"
text_str_2 = "bla, bla"

text = [
image_str + text_str_1,
text_str_2 + image_str + image_str,
]
images = [self.image1, self.image2, self.image3]

inputs = processor(text=text, images=images, padding=True)

self.assertEqual(inputs["pixel_values"].shape, (2, 2, 3, 767, 980))
self.assertEqual(inputs["pixel_attention_mask"].shape, (2, 2, 767, 980))

def test_process_interleaved_images_prompts_image_error(self):
processor = self.get_processor()

text = [
"This is a test sentence.",
"In this other sentence we try some good things",
]
images = [[self.image1], [self.image2]]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)
images = [[self.image1], []]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)

text = [
"This is a test sentence.<image>",
"In this other sentence we try some good things<image>",
]
images = [[self.image1], [self.image2, self.image3]]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)
images = [[], [self.image2]]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)
images = [self.image1, self.image2, self.image3]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)
images = [self.image1]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)

text = [
"This is a test sentence.",
"In this other sentence we try some good things<image>",
]
images = [[self.image1], []]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)
images = [[], [self.image2]]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)
images = [self.image1, self.image2]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)
images = [self.image1]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)

def test_apply_chat_template(self):
# Message contains content which a mix of lists with images and image urls and string
messages = [
Expand Down Expand Up @@ -275,13 +342,3 @@ def prepare_text_inputs(self, batch_size: Optional[int] = None):
return ["lower newer <image>", "<image> upper older longer string"] + ["<image> lower newer"] * (
batch_size - 2
)

# Override as PixtralProcessor needs nested images to work properly with batched inputs
@require_vision
def prepare_image_inputs(self, batch_size: Optional[int] = None):
"""This function prepares a list of PIL images for testing"""
if batch_size is None:
return super().prepare_image_inputs()
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
return [[super().prepare_image_inputs()]] * batch_size
79 changes: 69 additions & 10 deletions tests/models/idefics3/test_processor_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,74 @@ def test_add_special_tokens_processor(self):
self.assertEqual(inputs["input_ids"], expected_input_ids)
# fmt: on

def test_non_nested_images_with_batched_text(self):
processor = self.get_processor()
processor.image_processor.do_image_splitting = False

image_str = "<image>"
text_str_1 = "In this image, we see"
text_str_2 = "In this image, we see"

text = [
image_str + text_str_1,
image_str + image_str + text_str_2,
]
images = [self.image1, self.image2, self.image3]

inputs = processor(text=text, images=images, padding=True)

self.assertEqual(np.array(inputs["pixel_values"]).shape, (2, 2, 3, 364, 364))
self.assertEqual(np.array(inputs["pixel_attention_mask"]).shape, (2, 2, 364, 364))

# Copied from tests.models.idefics2.test_processor_idefics2.Idefics2ProcessorTest.test_process_interleaved_images_prompts_image_error
def test_process_interleaved_images_prompts_image_error(self):
processor = self.get_processor()

text = [
"This is a test sentence.",
"In this other sentence we try some good things",
]
images = [[self.image1], [self.image2]]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)
images = [[self.image1], []]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)

text = [
"This is a test sentence.<image>",
"In this other sentence we try some good things<image>",
]
images = [[self.image1], [self.image2, self.image3]]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)
images = [[], [self.image2]]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)
images = [self.image1, self.image2, self.image3]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)
images = [self.image1]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)

text = [
"This is a test sentence.",
"In this other sentence we try some good things<image>",
]
images = [[self.image1], []]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)
images = [[], [self.image2]]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)
images = [self.image1, self.image2]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)
images = [self.image1]
with self.assertRaises(ValueError):
processor(text=text, images=images, padding=True)

def test_apply_chat_template(self):
# Message contains content which a mix of lists with images and image urls and string
messages = [
Expand Down Expand Up @@ -299,16 +367,7 @@ def prepare_text_inputs(self, batch_size: Optional[int] = None):
batch_size - 2
)

# Override as Idefics3Processor needs nested images to work properly with batched inputs
@require_vision
def prepare_image_inputs(self, batch_size: Optional[int] = None):
"""This function prepares a list of PIL images for testing"""
if batch_size is None:
return super().prepare_image_inputs()
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
return [[super().prepare_image_inputs()]] * batch_size

# Override tests as inputs_ids padded dimension is the second one but not the last one
@require_vision
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs(self):
Expand Down

0 comments on commit 940a6bd

Please sign in to comment.