Skip to content

Commit 5c36526

Browse files
yonigozlanBernardZach
authored andcommitted
Uniformize kwargs for Idefics/2 processors (huggingface#32568)
* Add uniformize idefics processor kwargs and tests * Uniformize idefics2 processor kwargs * add image_processor tests idefics * add BC args order change idefics2 processor and update doc * Add support for multiple images per prompt in image-text-to-text mode idefics * Fix processor input args in idefics tests * improve test processing common, remove unnecessary tests, update process uniformization * fix doctrings idefics * fix tests processors idefics/2
1 parent 6d56edd commit 5c36526

File tree

6 files changed

+409
-160
lines changed

6 files changed

+409
-160
lines changed

src/transformers/models/idefics/processing_idefics.py

Lines changed: 103 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,21 @@
1616
Processor class for IDEFICS.
1717
"""
1818

19-
from typing import Callable, List, Optional, Union
19+
from typing import Callable, Dict, List, Optional, Union
2020
from urllib.parse import urlparse
2121

2222
from ...feature_extraction_utils import BatchFeature
23-
from ...processing_utils import ProcessorMixin
24-
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
23+
from ...processing_utils import (
24+
ImagesKwargs,
25+
ProcessingKwargs,
26+
ProcessorMixin,
27+
TextKwargs,
28+
Unpack,
29+
_validate_images_text_input_order,
30+
)
31+
from ...tokenization_utils_base import PreTokenizedInput, TextInput
2532
from ...utils import is_tf_available, is_torch_available
33+
from ...utils.deprecation import deprecate_kwarg
2634

2735

2836
if is_torch_available():
@@ -34,6 +42,32 @@
3442
IMAGE_TOKEN = "<image>"
3543

3644

45+
class IdeficsImagesKwargs(ImagesKwargs, total=False):
46+
transform: Optional[Callable]
47+
image_size: Optional[Dict[str, int]]
48+
image_mean: Optional[Union[float, List[float]]]
49+
image_std: Optional[Union[float, List[float]]]
50+
51+
52+
class IdeficsTextKwargs(TextKwargs, total=False):
53+
add_eos_token: Optional[bool]
54+
add_end_of_utterance_token: Optional[bool]
55+
56+
57+
class IdeficsProcessorKwargs(ProcessingKwargs, total=False):
58+
text_kwargs: IdeficsTextKwargs
59+
images_kwargs: IdeficsImagesKwargs
60+
_defaults = {
61+
"text_kwargs": {
62+
"add_special_tokens": False,
63+
"padding": "longest",
64+
"add_eos_token": False,
65+
},
66+
"images_kwargs": {},
67+
"common_kwargs": {"return_tensors": "pt"},
68+
}
69+
70+
3771
# copied from m4.training.packing
3872
def incremental_to_binary_attention_mask(incremental_mask, return_tensors, num_classes=-1):
3973
# Set elements >= num_classes to -1
@@ -199,52 +233,32 @@ def __init__(self, image_processor, tokenizer=None, image_size=224, add_end_of_u
199233
else False
200234
)
201235

236+
@deprecate_kwarg(old_name="prompts", version="5.0.0", new_name="text", raise_if_both_names=True)
202237
def __call__(
203238
self,
204-
prompts: Union[List[TextInput], List[List[TextInput]]],
205-
padding: Union[bool, str, PaddingStrategy] = "longest",
206-
truncation: Union[bool, str, TruncationStrategy] = None,
207-
max_length: Optional[int] = None,
208-
transform: Callable = None,
209-
add_eos_token=False,
210-
add_end_of_utterance_token=None,
211-
debug=False,
212-
return_tensors="pt",
213-
) -> BatchEncoding:
239+
images=None,
240+
text: Union[
241+
TextInput,
242+
PreTokenizedInput,
243+
List[TextInput],
244+
List[PreTokenizedInput],
245+
List[List[TextInput]],
246+
List[List[PreTokenizedInput]],
247+
] = None,
248+
audio=None,
249+
videos=None,
250+
**kwargs: Unpack[IdeficsProcessorKwargs],
251+
) -> BatchFeature:
214252
"""This method takes batched or non-batched prompts made of text and images and converts them into prompts that
215253
the model was trained on and prepares the image pixel values for the model to process.
216254
217255
Args:
218-
prompts (`Union[List[TextInput], [List[List[TextInput]]]]`):
256+
images (`Union[PIL.Image, str, List[PIL.Image], List[str]]`):
257+
either a single image or a batched list of images - can be passed in when text contains only text prompts,
258+
in order to use the image-text-to-text behavior.
259+
text (`Union[List[TextInput], [List[List[TextInput]]]]`):
219260
either a single prompt or a batched list of prompts - see the detailed description immediately after
220261
the end of the arguments doc section.
221-
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `"longest"`):
222-
Select a strategy to pad the returned sequences (according to the model's padding side and padding
223-
index) among:
224-
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
225-
sequence if provided).
226-
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
227-
acceptable input length for the model if that argument is not provided.
228-
- `False` or `'do_not_pad'`: No padding. This will raise an error if the input sequences are of different
229-
lengths.
230-
Note: Unlike most processors, which set padding=`False` by default, `IdeficsProcessor` sets `padding="longest"`
231-
by default. See https://github.com/huggingface/transformers/pull/29449#pullrequestreview-1925576061 for why.
232-
max_length (`int`, *optional*):
233-
Maximum length of the returned list and optionally padding length (see above).
234-
truncation (`bool`, *optional*):
235-
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
236-
transform (`Callable`, *optional*):
237-
A custom transform function that accepts a single image can be passed for training. For example,
238-
`torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific
239-
set of transforms will be applied to the images
240-
add_eos_token (`bool`, *optional*, defaults to `False`):
241-
Adds `eos_token` at the end of the final prompt if True`
242-
add_end_of_utterance_token (`bool`, *optional*)
243-
Whether to automatically add `<end_of_utterance>` after each prompt's text input (unless followed by an
244-
image). If `None` the tokenizer will be checked instead and if this token is found in
245-
`additional_special_tokens` then the value will be `True`.
246-
debug (`bool`, *optional*, defaults to `False`):
247-
`True` value will help debug prompt generation by dumping useful information
248262
return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`):
249263
The type of tensors to return. Can be one of:
250264
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
@@ -255,7 +269,7 @@ def __call__(
255269
256270
Detailed explanation:
257271
258-
Each entry in `prompts` is either a text to be passed as is or an image that will be processed.
272+
Each entry in `text` is either a text to be passed as is or an image that will be processed.
259273
260274
An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved.
261275
@@ -279,7 +293,7 @@ def __call__(
279293
"Describe this image.\nAssistant:",
280294
]
281295
282-
inputs = processor(prompts, return_tensors="pt")
296+
inputs = processor(text=prompts, return_tensors="pt")
283297
generated_ids = model.generate(**inputs, max_length=100)
284298
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
285299
```
@@ -311,18 +325,55 @@ def __call__(
311325
transforms.Normalize(mean=self.image_mean, std=self.image_std),
312326
]
313327
)
314-
inputs = processor(prompts, transform=image_transform, return_tensors="pt")
328+
inputs = processor(text=prompts, transform=image_transform, return_tensors="pt")
315329
```
316330
317331
In order to help debug prompt generation enable `debug=True` which will show you what's happening.
318332
319333
"""
334+
if images is None and text is None:
335+
raise ValueError("You need to specify either `text` or `images` and `text`.")
336+
# check if images and text inputs are reversed for BC
337+
images, text = _validate_images_text_input_order(images, text)
338+
339+
if images is None:
340+
# assuming the user wants to use the old behavior with prompts as the only argument
341+
prompts = text
342+
elif text is not None:
343+
# Assuming image-text-to-text behavior:
344+
# Check if batched images are provided
345+
if not isinstance(images, (list, tuple)):
346+
images = [images]
347+
if isinstance(text, str):
348+
text = [text]
349+
# Check if batched images and text are in the correct format
350+
if isinstance(text, (list, tuple)) and len(text) != len(images):
351+
raise ValueError(
352+
"When providing both images and text arguments, the number of text prompts should be the same as the number of images."
353+
"If you want to have several images per prompt, images should be nested as such: images=[[img1, img2], [img3, img4], ...] for text=[prompt1, prompt2, ...]."
354+
)
355+
# Check that only text is present in the prompts
356+
if not all(isinstance(i, str) for i in text):
357+
raise ValueError("When using the image-text-to-text behavior, the prompts should only contain text.")
358+
if isinstance(images[0], (list, tuple)):
359+
# if nested images, nest text as well
360+
text = [[i] for i in text]
361+
prompts = list(zip(images, text))
362+
363+
output_kwargs = self._merge_kwargs(
364+
IdeficsProcessorKwargs,
365+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
366+
**kwargs,
367+
)
368+
369+
add_eos_token = output_kwargs["text_kwargs"].pop("add_eos_token", False)
370+
add_end_of_utterance_token = output_kwargs["text_kwargs"].pop("add_end_of_utterance_token", None)
320371

321372
# if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
322373
if add_end_of_utterance_token is None:
323374
add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token
324375
# turn non-batched prompts into batched
325-
if not any(isinstance(i, list) for i in prompts):
376+
if not any(isinstance(i, (list, tuple)) for i in prompts):
326377
prompts = [prompts]
327378

328379
fake_token = "<fake_token_around_image>"
@@ -371,21 +422,14 @@ def image_tokens(last_was_image):
371422
if add_eos_token:
372423
full_text += self.tokenizer.eos_token
373424

374-
if debug is True:
375-
print(f"{full_text=}")
376-
377-
image_objects = self.image_processor(image_objects, transform=transform, return_tensors=return_tensors)
425+
image_objects = self.image_processor(image_objects, **output_kwargs["images_kwargs"])
378426

379427
all_prompts.append(full_text)
380428
all_images.append(image_objects)
381429

382-
text_encoding = self.tokenizer(
383-
text=all_prompts,
384-
add_special_tokens=False,
385-
padding=padding,
386-
truncation=truncation,
387-
max_length=max_length,
388-
)
430+
# For BC
431+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", "pt")
432+
text_encoding = self.tokenizer(all_prompts, **output_kwargs["text_kwargs"])
389433
all_texts = text_encoding["input_ids"]
390434
all_attention_masks = text_encoding["attention_mask"]
391435

@@ -398,12 +442,12 @@ def image_tokens(last_was_image):
398442
output_images = []
399443
output_attention_masks = []
400444

401-
for text, attention_mask, images in zip(all_texts, all_attention_masks, all_images):
402-
padded_input_ids = text
445+
for text_single, attention_mask, extracted_images in zip(all_texts, all_attention_masks, all_images):
446+
padded_input_ids = text_single
403447
image_count = padded_input_ids.count(self.image_token_id)
404448
local_max_num_images = min(image_count, max_num_images)
405449

406-
current_images = images[:local_max_num_images]
450+
current_images = extracted_images[:local_max_num_images]
407451

408452
if len(current_images) > 0:
409453
if return_tensors == "pt":

src/transformers/models/idefics2/modeling_idefics2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1584,7 +1584,7 @@ def forward(
15841584
... "In which city is that bridge located?<image>",
15851585
... ]
15861586
>>> images = [[image1, image2], [image3]]
1587-
>>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to("cuda")
1587+
>>> inputs = processor(images=images, text=prompts, padding=True, return_tensors="pt").to("cuda")
15881588
15891589
>>> # Generate
15901590
>>> generated_ids = model.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_new_tokens=20)

0 commit comments

Comments
 (0)