Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniformize model processors (models *with* special arg names) #32841

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
91 changes: 53 additions & 38 deletions src/transformers/models/clipseg/processing_clipseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,24 @@
Image/Text processor class for CLIPSeg
"""

import sys
import warnings
from typing import List, Optional, Union

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import PreTokenizedInput, TextInput


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack


class CLIPSegProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}


class CLIPSegProcessor(ProcessorMixin):
Expand All @@ -39,6 +53,8 @@ class CLIPSegProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "ViTImageProcessor"
tokenizer_class = ("CLIPTokenizer", "CLIPTokenizerFast")
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
optional_call_args = ["visual_prompt"]

def __init__(self, image_processor=None, tokenizer=None, **kwargs):
feature_extractor = None
Expand All @@ -58,7 +74,18 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):

super().__init__(image_processor, tokenizer)

def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=None, **kwargs):
def __call__(
self,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
images: Optional[ImageInput] = None,
Comment on lines +79 to +80
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another thing we're doing now is swap the arg order, so that it is image, text, audio, videos. And that needs another deprecation cycle...

BTW, i am quite out of the loop, do we need this order-swapping for pipeline @yonigozlan ?

# The following is to capture `visual_prompt` argument that may be passed as a positional argument.
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
# This behavior is only needed for backward compatibility and will be removed in future versions.
*args,
audio=None,
videos=None,
**kwargs: Unpack[CLIPSegProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode
Expand All @@ -79,56 +106,44 @@ def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=No
NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape
(C, H, W), where C is a number of channels, H and W are image height and width.

return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:

- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.

Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:

- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None` and `visual_prompt` is `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
`None`) and `visual_prompt` is `None`.
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **conditional_pixel_values** -- Conditional pixel values to be fed to a model. Returned when `visual_prompt` is not `None`.
"""

output_kwargs = self._merge_kwargs(
CLIPSegProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
**self.prepare_and_validate_optional_call_args(*args),
)

visual_prompt = output_kwargs["images_kwargs"].pop("visual_prompt", None)

if text is None and visual_prompt is None and images is None:
raise ValueError("You have to specify either text, visual prompt or images.")

if text is not None and visual_prompt is not None:
raise ValueError("You have to specify exactly one type of prompt. Either text or visual prompt.")

data = {}
if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)

text_features = self.tokenizer(text, **output_kwargs["text_kwargs"])
data.update(text_features)
if visual_prompt is not None:
prompt_features = self.image_processor(visual_prompt, return_tensors=return_tensors, **kwargs)

prompt_image_features = self.image_processor(visual_prompt, **output_kwargs["images_kwargs"])
data["conditional_pixel_values"] = prompt_image_features.pixel_values
if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)

if visual_prompt is not None and images is not None:
encoding = {
"pixel_values": image_features.pixel_values,
"conditional_pixel_values": prompt_features.pixel_values,
}
return encoding
elif text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
elif text is not None:
return encoding
elif visual_prompt is not None:
encoding = {
"conditional_pixel_values": prompt_features.pixel_values,
}
return encoding
else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
data["pixel_values"] = image_features.pixel_values

return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors"))

def batch_decode(self, *args, **kwargs):
"""
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/donut/image_processing_donut.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def pad_image(
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
size = get_size_dict(size)
output_height, output_width = size["height"], size["width"]
input_height, input_width = get_image_size(image, channel_dim=input_data_format)

Expand Down Expand Up @@ -232,6 +233,7 @@ def thumbnail(
The channel dimension format of the input image. If not provided, it will be inferred.
"""
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
size = get_size_dict(size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not clear why we needed these changes, was this causing CI failure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are utils for old processors that don't support the new image size format yet

we might as well add these here since (1) they help w/ backwards compatibility, (2) make the image-text-to-text pipeline easier to implement, & (3) they just revert to a no-op if size already follows the new image size format

output_height, output_width = size["height"], size["width"]

# We always resize to the smallest of either the input or output size.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/fuyu/image_processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
pad,
resize,
Expand Down Expand Up @@ -344,6 +344,7 @@ def pad_image(
The channel dimension format of the input image. If not provided, it will be inferred.
"""
image_height, image_width = get_image_size(image, input_data_format)
size = get_size_dict(size)
target_height, target_width = size["height"], size["width"]
padding_top = 0
padding_left = 0
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/nougat/image_processing_nougat.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def pad_image(
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
size = get_size_dict(size)
output_height, output_width = size["height"], size["width"]
input_height, input_width = get_image_size(image, channel_dim=input_data_format)

Expand Down Expand Up @@ -292,6 +293,7 @@ def thumbnail(
The channel dimension format of the input image. If not provided, it will be inferred.
"""
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
size = get_size_dict(size)
output_height, output_width = size["height"], size["width"]

# We always resize to the smallest of either the input or output size.
Expand Down
179 changes: 88 additions & 91 deletions src/transformers/models/nougat/processing_nougat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,45 @@
Processor class for Nougat.
"""

from typing import Dict, List, Optional, Union
import sys
from typing import List, Optional, Union

from transformers.tokenization_utils_base import PreTokenizedInput, TextInput, TruncationStrategy
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import PreTokenizedInput, TextInput

from ...processing_utils import ProcessorMixin
from ...utils import PaddingStrategy, TensorType

if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack


class NougatImagesKwargs(ImagesKwargs, total=False):
do_crop_margin: Optional[bool]
do_thumbnail: Optional[bool]
do_align_long_axis: Optional[bool]


class NougatProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: NougatImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"is_split_into_words": False,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_length": False,
"verbose": True,
},
"images_kwargs": {
"data_format": "channels_first",
},
}


class NougatProcessor(ProcessorMixin):
Expand All @@ -39,104 +72,68 @@ class NougatProcessor(ProcessorMixin):
"""

attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
image_processor_class = "NougatImageProcessor"
tokenizer_class = "NougatTokenizerFast"

def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor

def __call__(
self,
images=None,
text=None,
do_crop_margin: bool = None,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: "PILImageResampling" = None, # noqa: F821
do_thumbnail: bool = None,
do_align_long_axis: bool = None,
do_pad: bool = None,
do_rescale: bool = None,
rescale_factor: Union[int, float] = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
input_data_format: Optional[Union[str, "ChannelDimension"]] = None, # noqa: F821
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair_target: Optional[
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
):
images: Optional[ImageInput] = None,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio=None,
videos=None,
**kwargs: Unpack[NougatProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several text(s) and image(s). This method forwards the `text` and
`kwargs` arguments to NougatTokenizerFast's [`~NougatTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
NougatImageProcessor's [`~NougatImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.

Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
`List[torch.Tensor]`, *optional*):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`, *optional*):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).

Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **labels** -- List of token ids to be fed to a model. Returned when both `text` and `images` are not `None`.
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None` and `images` is `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if images is None and text is None:
raise ValueError("You need to specify either an `images` or `text` input to process.")

if images is not None:
inputs = self.image_processor(
images,
do_crop_margin=do_crop_margin,
do_resize=do_resize,
size=size,
resample=resample,
do_thumbnail=do_thumbnail,
do_align_long_axis=do_align_long_axis,
do_pad=do_pad,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
return_tensors=return_tensors,
data_format=data_format,
input_data_format=input_data_format,
)
output_kwargs = self._merge_kwargs(
NougatProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# Temporary fix for "paddding_side" in init_kwargs
_ = output_kwargs["text_kwargs"].pop("padding_side", None)

data = {}
if text is not None:
encodings = self.tokenizer(
text,
text_pair=text_pair,
text_target=text_target,
text_pair_target=text_pair_target,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
)

if text is None:
return inputs
elif images is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
return inputs
text_features = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
data.update(text_features)
if images is not None:
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
data.update(image_features)
if "input_ids" in data:
data["labels"] = data.pop("input_ids")
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors"))

def batch_decode(self, *args, **kwargs):
"""
Expand Down
Loading