Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniformize kwargs for Idefics/2 processors #32568

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 111 additions & 61 deletions src/transformers/models/idefics/processing_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,23 @@
Processor class for IDEFICS.
"""

from typing import Callable, List, Optional, Union
import sys
import warnings
from typing import Callable, Dict, List, Optional, Union
from urllib.parse import urlparse

from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils import is_tf_available, is_torch_available
from ...utils.deprecation import deprecate_kwarg


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack

if is_torch_available():
import torch

Expand All @@ -34,6 +42,32 @@
IMAGE_TOKEN = "<image>"


class IdeficsImagesKwargs(ImagesKwargs, total=False):
transform: Optional[Callable]
image_size: Optional[Dict[str, int]]
image_mean: Optional[Union[float, List[float]]]
image_std: Optional[Union[float, List[float]]]


class IdeficsTextKwargs(TextKwargs, total=False):
add_eos_token: Optional[bool]
add_end_of_utterance_token: Optional[bool]


class IdeficsProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: IdeficsTextKwargs
images_kwargs: IdeficsImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": False,
"padding": "longest",
"add_eos_token": False,
},
"images_kwargs": {},
"common_kwargs": {"return_tensors": "pt"},
}


# copied from m4.training.packing
def incremental_to_binary_attention_mask(incremental_mask, return_tensors, num_classes=-1):
# Set elements >= num_classes to -1
Expand Down Expand Up @@ -199,63 +233,40 @@ def __init__(self, image_processor, tokenizer=None, image_size=224, add_end_of_u
else False
)

@deprecate_kwarg(old_name="prompts", version="5.0.0", new_name="text", raise_if_both_names=True)
def __call__(
self,
prompts: Union[List[TextInput], List[List[TextInput]]],
padding: Union[bool, str, PaddingStrategy] = "longest",
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
transform: Callable = None,
add_eos_token=False,
add_end_of_utterance_token=None,
debug=False,
return_tensors="pt",
images=None,
text: Union[
TextInput,
PreTokenizedInput,
List[TextInput],
List[PreTokenizedInput],
List[List[TextInput]],
List[List[PreTokenizedInput]],
] = None,
audio=None,
videos=None,
**kwargs: Unpack[IdeficsProcessorKwargs],
) -> BatchEncoding:
"""This method takes batched or non-batched prompts made of text and images and converts them into prompts that
the model was trained on and prepares the image pixel values for the model to process.

Args:
prompts (`Union[List[TextInput], [List[List[TextInput]]]]`):
text (`Union[List[TextInput], [List[List[TextInput]]]]`):
either a single prompt or a batched list of prompts - see the detailed description immediately after
the end of the arguments doc section.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `"longest"`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'`: No padding. This will raise an error if the input sequences are of different
lengths.
Note: Unlike most processors, which set padding=`False` by default, `IdeficsProcessor` sets `padding="longest"`
by default. See https://github.com/huggingface/transformers/pull/29449#pullrequestreview-1925576061 for why.
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
transform (`Callable`, *optional*):
A custom transform function that accepts a single image can be passed for training. For example,
`torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific
set of transforms will be applied to the images
add_eos_token (`bool`, *optional*, defaults to `False`):
Adds `eos_token` at the end of the final prompt if True`
add_end_of_utterance_token (`bool`, *optional*)
Whether to automatically add `<end_of_utterance>` after each prompt's text input (unless followed by an
image). If `None` the tokenizer will be checked instead and if this token is found in
`additional_special_tokens` then the value will be `True`.
debug (`bool`, *optional*, defaults to `False`):
`True` value will help debug prompt generation by dumping useful information
return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`):
The type of tensors to return. Can be one of:
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
images (`Union[PIL.Image, str, List[PIL.Image], List[str]]`):
either a single image or a batched list of images - can be passed in when text contains only text prompts,
in order to use the image-text-to-text behavior.

Returns:
a dict with entries: `input_ids`, `attention_mask`, `pixel_values`, `image_attention_mask` which can be
directly passed to `model.generate`

Detailed explanation:

Each entry in `prompts` is either a text to be passed as is or an image that will be processed.
Each entry in `text` is either a text to be passed as is or an image that will be processed.

An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved.

Expand All @@ -279,7 +290,7 @@ def __call__(
"Describe this image.\nAssistant:",
]

inputs = processor(prompts, return_tensors="pt")
inputs = processor(text=prompts, return_tensors="pt")
generated_ids = model.generate(**inputs, max_length=100)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
```
Expand Down Expand Up @@ -311,18 +322,64 @@ def __call__(
transforms.Normalize(mean=self.image_mean, std=self.image_std),
]
)
inputs = processor(prompts, transform=image_transform, return_tensors="pt")
inputs = processor(text=prompts, transform=image_transform, return_tensors="pt")
```

In order to help debug prompt generation enable `debug=True` which will show you what's happening.

"""
if images is None and text is None:
raise ValueError("You need to specify either `text` or `images` and `text`.")
# for BC
if text is None:
# if the user didn't specify text=text in the call, we assume they want to use the old behavior
# with text (previously prompts) as a first argument
warnings.warn(
"The use of `text` as the first argument will be deprecated in the future. `images` is now the first argument."
"The first given argument will be considered as `prompts` in the old behavior.",
)
text = images
images = None
if images is None:
# assuming the user wants to use the old behavior with prompts as the only argument
prompts = text
elif text is not None:
# Assuming image-text-to-text behavior:
# Check if batched images are provided
if not isinstance(images, (list, tuple)):
images = [images]
if isinstance(text, str):
text = [text]
# Check if batched images and text are in the correct format
if isinstance(text, (list, tuple)) and len(text) != len(images):
raise ValueError(
"When providing both images and text arguments, the number of text prompts should be the same as the number of images."
"If you want to have several images per prompt, images should be nested as such: images=[[img1, img2], [img3, img4], ...] for text=[prompt1, prompt2, ...]."
)
# Check that only text is present in the prompts
if not all(isinstance(i, str) for i in text):
raise ValueError("When using the image-text-to-text behavior, the prompts should only contain text.")
if isinstance(images[0], (list, tuple)):
# if nested images, nest text as well
text = [[i] for i in text]
prompts = list(zip(images, text))

output_kwargs = self._merge_kwargs(
IdeficsProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# Temporary fix for "paddding_side" in init_kwargs
_ = output_kwargs["text_kwargs"].pop("padding_side", None)

add_eos_token = output_kwargs["text_kwargs"].pop("add_eos_token", False)
add_end_of_utterance_token = output_kwargs["text_kwargs"].pop("add_end_of_utterance_token", None)

# if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
if add_end_of_utterance_token is None:
add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token
# turn non-batched prompts into batched
if not any(isinstance(i, list) for i in prompts):
if not any(isinstance(i, (list, tuple)) for i in prompts):
prompts = [prompts]

fake_token = "<fake_token_around_image>"
Expand Down Expand Up @@ -371,21 +428,14 @@ def image_tokens(last_was_image):
if add_eos_token:
full_text += self.tokenizer.eos_token

if debug is True:
print(f"{full_text=}")

image_objects = self.image_processor(image_objects, transform=transform, return_tensors=return_tensors)
image_objects = self.image_processor(image_objects, **output_kwargs["images_kwargs"])

all_prompts.append(full_text)
all_images.append(image_objects)

text_encoding = self.tokenizer(
text=all_prompts,
add_special_tokens=False,
padding=padding,
truncation=truncation,
max_length=max_length,
)
# For BC
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", "pt")
yonigozlan marked this conversation as resolved.
Show resolved Hide resolved
text_encoding = self.tokenizer(all_prompts, **output_kwargs["text_kwargs"])
all_texts = text_encoding["input_ids"]
all_attention_masks = text_encoding["attention_mask"]

Expand All @@ -398,12 +448,12 @@ def image_tokens(last_was_image):
output_images = []
output_attention_masks = []

for text, attention_mask, images in zip(all_texts, all_attention_masks, all_images):
padded_input_ids = text
for text_single, attention_mask, extracted_images in zip(all_texts, all_attention_masks, all_images):
padded_input_ids = text_single
image_count = padded_input_ids.count(self.image_token_id)
local_max_num_images = min(image_count, max_num_images)

current_images = images[:local_max_num_images]
current_images = extracted_images[:local_max_num_images]

if len(current_images) > 0:
if return_tensors == "pt":
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,7 +1558,7 @@ def forward(
... "In which city is that bridge located?<image>",
... ]
>>> images = [[image1, image2], [image3]]
>>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to("cuda")
>>> inputs = processor(images=images, text=prompts, padding=True, return_tensors="pt").to("cuda")

>>> # Generate
>>> generated_ids = model.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_new_tokens=20)
Expand Down
Loading
Loading