Skip to content

Commit

Permalink
Uniformize kwargs for Idefics/2 processors (#32568)
Browse files Browse the repository at this point in the history
* Add uniformize idefics processor kwargs and tests

* Uniformize idefics2 processor kwargs

* add image_processor tests idefics

* add BC args order change idefics2 processor and update doc

* Add support for multiple images per prompt in image-text-to-text mode idefics

* Fix processor input args in idefics tests

* improve test processing common, remove unnecessary tests, update process uniformization

* fix doctrings idefics

* fix tests processors idefics/2
  • Loading branch information
yonigozlan authored Oct 3, 2024
1 parent b0c5660 commit 074aa3b
Show file tree
Hide file tree
Showing 6 changed files with 409 additions and 160 deletions.
162 changes: 103 additions & 59 deletions src/transformers/models/idefics/processing_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,21 @@
Processor class for IDEFICS.
"""

from typing import Callable, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union
from urllib.parse import urlparse

from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
from ...processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
TextKwargs,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_tf_available, is_torch_available
from ...utils.deprecation import deprecate_kwarg


if is_torch_available():
Expand All @@ -34,6 +42,32 @@
IMAGE_TOKEN = "<image>"


class IdeficsImagesKwargs(ImagesKwargs, total=False):
transform: Optional[Callable]
image_size: Optional[Dict[str, int]]
image_mean: Optional[Union[float, List[float]]]
image_std: Optional[Union[float, List[float]]]


class IdeficsTextKwargs(TextKwargs, total=False):
add_eos_token: Optional[bool]
add_end_of_utterance_token: Optional[bool]


class IdeficsProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: IdeficsTextKwargs
images_kwargs: IdeficsImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": False,
"padding": "longest",
"add_eos_token": False,
},
"images_kwargs": {},
"common_kwargs": {"return_tensors": "pt"},
}


# copied from m4.training.packing
def incremental_to_binary_attention_mask(incremental_mask, return_tensors, num_classes=-1):
# Set elements >= num_classes to -1
Expand Down Expand Up @@ -199,52 +233,32 @@ def __init__(self, image_processor, tokenizer=None, image_size=224, add_end_of_u
else False
)

@deprecate_kwarg(old_name="prompts", version="5.0.0", new_name="text", raise_if_both_names=True)
def __call__(
self,
prompts: Union[List[TextInput], List[List[TextInput]]],
padding: Union[bool, str, PaddingStrategy] = "longest",
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
transform: Callable = None,
add_eos_token=False,
add_end_of_utterance_token=None,
debug=False,
return_tensors="pt",
) -> BatchEncoding:
images=None,
text: Union[
TextInput,
PreTokenizedInput,
List[TextInput],
List[PreTokenizedInput],
List[List[TextInput]],
List[List[PreTokenizedInput]],
] = None,
audio=None,
videos=None,
**kwargs: Unpack[IdeficsProcessorKwargs],
) -> BatchFeature:
"""This method takes batched or non-batched prompts made of text and images and converts them into prompts that
the model was trained on and prepares the image pixel values for the model to process.
Args:
prompts (`Union[List[TextInput], [List[List[TextInput]]]]`):
images (`Union[PIL.Image, str, List[PIL.Image], List[str]]`):
either a single image or a batched list of images - can be passed in when text contains only text prompts,
in order to use the image-text-to-text behavior.
text (`Union[List[TextInput], [List[List[TextInput]]]]`):
either a single prompt or a batched list of prompts - see the detailed description immediately after
the end of the arguments doc section.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `"longest"`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'`: No padding. This will raise an error if the input sequences are of different
lengths.
Note: Unlike most processors, which set padding=`False` by default, `IdeficsProcessor` sets `padding="longest"`
by default. See https://github.com/huggingface/transformers/pull/29449#pullrequestreview-1925576061 for why.
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
transform (`Callable`, *optional*):
A custom transform function that accepts a single image can be passed for training. For example,
`torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific
set of transforms will be applied to the images
add_eos_token (`bool`, *optional*, defaults to `False`):
Adds `eos_token` at the end of the final prompt if True`
add_end_of_utterance_token (`bool`, *optional*)
Whether to automatically add `<end_of_utterance>` after each prompt's text input (unless followed by an
image). If `None` the tokenizer will be checked instead and if this token is found in
`additional_special_tokens` then the value will be `True`.
debug (`bool`, *optional*, defaults to `False`):
`True` value will help debug prompt generation by dumping useful information
return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`):
The type of tensors to return. Can be one of:
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
Expand All @@ -255,7 +269,7 @@ def __call__(
Detailed explanation:
Each entry in `prompts` is either a text to be passed as is or an image that will be processed.
Each entry in `text` is either a text to be passed as is or an image that will be processed.
An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved.
Expand All @@ -279,7 +293,7 @@ def __call__(
"Describe this image.\nAssistant:",
]
inputs = processor(prompts, return_tensors="pt")
inputs = processor(text=prompts, return_tensors="pt")
generated_ids = model.generate(**inputs, max_length=100)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
```
Expand Down Expand Up @@ -311,18 +325,55 @@ def __call__(
transforms.Normalize(mean=self.image_mean, std=self.image_std),
]
)
inputs = processor(prompts, transform=image_transform, return_tensors="pt")
inputs = processor(text=prompts, transform=image_transform, return_tensors="pt")
```
In order to help debug prompt generation enable `debug=True` which will show you what's happening.
"""
if images is None and text is None:
raise ValueError("You need to specify either `text` or `images` and `text`.")
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)

if images is None:
# assuming the user wants to use the old behavior with prompts as the only argument
prompts = text
elif text is not None:
# Assuming image-text-to-text behavior:
# Check if batched images are provided
if not isinstance(images, (list, tuple)):
images = [images]
if isinstance(text, str):
text = [text]
# Check if batched images and text are in the correct format
if isinstance(text, (list, tuple)) and len(text) != len(images):
raise ValueError(
"When providing both images and text arguments, the number of text prompts should be the same as the number of images."
"If you want to have several images per prompt, images should be nested as such: images=[[img1, img2], [img3, img4], ...] for text=[prompt1, prompt2, ...]."
)
# Check that only text is present in the prompts
if not all(isinstance(i, str) for i in text):
raise ValueError("When using the image-text-to-text behavior, the prompts should only contain text.")
if isinstance(images[0], (list, tuple)):
# if nested images, nest text as well
text = [[i] for i in text]
prompts = list(zip(images, text))

output_kwargs = self._merge_kwargs(
IdeficsProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

add_eos_token = output_kwargs["text_kwargs"].pop("add_eos_token", False)
add_end_of_utterance_token = output_kwargs["text_kwargs"].pop("add_end_of_utterance_token", None)

# if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
if add_end_of_utterance_token is None:
add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token
# turn non-batched prompts into batched
if not any(isinstance(i, list) for i in prompts):
if not any(isinstance(i, (list, tuple)) for i in prompts):
prompts = [prompts]

fake_token = "<fake_token_around_image>"
Expand Down Expand Up @@ -371,21 +422,14 @@ def image_tokens(last_was_image):
if add_eos_token:
full_text += self.tokenizer.eos_token

if debug is True:
print(f"{full_text=}")

image_objects = self.image_processor(image_objects, transform=transform, return_tensors=return_tensors)
image_objects = self.image_processor(image_objects, **output_kwargs["images_kwargs"])

all_prompts.append(full_text)
all_images.append(image_objects)

text_encoding = self.tokenizer(
text=all_prompts,
add_special_tokens=False,
padding=padding,
truncation=truncation,
max_length=max_length,
)
# For BC
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", "pt")
text_encoding = self.tokenizer(all_prompts, **output_kwargs["text_kwargs"])
all_texts = text_encoding["input_ids"]
all_attention_masks = text_encoding["attention_mask"]

Expand All @@ -398,12 +442,12 @@ def image_tokens(last_was_image):
output_images = []
output_attention_masks = []

for text, attention_mask, images in zip(all_texts, all_attention_masks, all_images):
padded_input_ids = text
for text_single, attention_mask, extracted_images in zip(all_texts, all_attention_masks, all_images):
padded_input_ids = text_single
image_count = padded_input_ids.count(self.image_token_id)
local_max_num_images = min(image_count, max_num_images)

current_images = images[:local_max_num_images]
current_images = extracted_images[:local_max_num_images]

if len(current_images) > 0:
if return_tensors == "pt":
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1584,7 +1584,7 @@ def forward(
... "In which city is that bridge located?<image>",
... ]
>>> images = [[image1, image2], [image3]]
>>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to("cuda")
>>> inputs = processor(images=images, text=prompts, padding=True, return_tensors="pt").to("cuda")
>>> # Generate
>>> generated_ids = model.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_new_tokens=20)
Expand Down
Loading

0 comments on commit 074aa3b

Please sign in to comment.