Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniformize model processors (models w/o special arg names) #32845

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions src/transformers/models/altclip/processing_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,23 @@
Image/Text processor class for AltCLIP
"""

import sys
import warnings
from typing import List, Optional, Union

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack


class AltCLIPProcessingKwargs(ProcessingKwargs, total=False):
_defaults = {}


class AltCLIPProcessor(ProcessorMixin):
Expand Down Expand Up @@ -59,7 +72,12 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):

super().__init__(image_processor, tokenizer)

def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
def __call__(
self,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
images: Optional[ImageInput] = None,
**kwargs: Unpack[AltCLIPProcessingKwargs],
):
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to XLMRobertaTokenizerFast's [`~XLMRobertaTokenizerFast.__call__`] if `text` is not
Expand Down Expand Up @@ -97,19 +115,27 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")

output_kwargs = self._merge_kwargs(
leloykun marked this conversation as resolved.
Show resolved Hide resolved
AltCLIPProcessingKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])

if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])

if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
elif text is not None:
return encoding
else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
return BatchEncoding(
data=dict(**image_features), tensor_type=output_kwargs["common_kwargs"]["return_tensors"]
)

def batch_decode(self, *args, **kwargs):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def get_resize_output_image_size(
new_width = scale * new_width

new_height, new_width = int(new_height + 0.5), int(new_width + 0.5)
new_height = new_height // size_divisor * size_divisor
new_width = new_width // size_divisor * size_divisor
new_height = max(1, new_height // size_divisor) * size_divisor
new_width = max(1, new_width // size_divisor) * size_divisor

return new_height, new_width

Expand Down Expand Up @@ -238,9 +238,7 @@ def resize(
The channel dimension format of the input image. If not provided, it will be inferred.
"""
size = get_size_dict(size, default_to_square=False)
if "shortest_edge" not in size:
raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}")
shorter = size["shortest_edge"]
shorter = size["shortest_edge"] if "shortest_edge" in size else min(size["height"], size["width"])
longer = int(1333 / 800 * shorter)
output_size = get_resize_output_image_size(
image, shorter=shorter, longer=longer, size_divisor=size_divisor, input_data_format=input_data_format
Expand Down
81 changes: 41 additions & 40 deletions src/transformers/models/chameleon/processing_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,36 @@
Processor class for Chameleon.
"""

import sys
from typing import List, Optional, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack


class ChameleonTextKwargs(TextKwargs, total=False):
return_for_text_completion: bool


class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: ChameleonTextKwargs
_defaults = {
"text_kwargs": {
"padding": False,
"return_for_text_completion": False,
},
"common_kwargs": {
"return_tensors": "pt",
},
}


class ChameleonProcessor(ProcessorMixin):
Expand Down Expand Up @@ -57,13 +80,9 @@ def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, ima

def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: int = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
return_for_text_completion: bool = False,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
images: Optional[ImageInput] = None,
**kwargs: Unpack[ChameleonProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand All @@ -80,26 +99,6 @@ def __call__(
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:

- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.

Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
Expand All @@ -114,6 +113,15 @@ def __call__(
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
if text is None and images is None:
raise ValueError("You must provide either text or images")

output_kwargs = self._merge_kwargs(
ChameleonProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False)

# Replace the image token with the expanded image token sequence
prompt_strings = []
Expand All @@ -124,19 +132,12 @@ def __call__(
sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
prompt_strings.append(sample)

data = self.tokenizer(
prompt_strings,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])

if images is not None:
pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
data["pixel_values"] = pixel_values
data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]

return BatchFeature(data=data, tensor_type=return_tensors)
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"])

# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
Expand Down
89 changes: 43 additions & 46 deletions src/transformers/models/flava/processing_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,41 @@
Image/Text processor class for FLAVA
"""

import sys
import warnings
from typing import List, Optional, Union

from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack


class FlavaImagesKwargs(ImagesKwargs, total=False):
return_image_mask: Optional[bool]
return_codebook_pixels: Optional[bool]


class FlavaProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: FlavaImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"truncation": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_length": False,
"verbose": True,
},
}


class FlavaProcessor(ProcessorMixin):
Expand Down Expand Up @@ -64,23 +92,7 @@ def __call__(
self,
images: Optional[ImageInput] = None,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_image_mask: Optional[bool] = None,
return_codebook_pixels: Optional[bool] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
**kwargs: Unpack[FlavaProcessorKwargs],
):
"""
This method uses [`FlavaImageProcessor.__call__`] method to prepare image(s) for the model, and
Expand All @@ -92,41 +104,26 @@ def __call__(
if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")

output_kwargs = self._merge_kwargs(
leloykun marked this conversation as resolved.
Show resolved Hide resolved
FlavaProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if text is not None:
encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
if images is not None:
image_features = self.image_processor(
images,
return_image_mask=return_image_mask,
return_codebook_pixels=return_codebook_pixels,
return_tensors=return_tensors,
**kwargs,
)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])

if text is not None and images is not None:
encoding.update(image_features)
return encoding
elif text is not None:
return encoding
else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
return BatchEncoding(
data=dict(**image_features), tensor_type=output_kwargs["common_kwargs"]["return_tensors"]
)

def batch_decode(self, *args, **kwargs):
"""
Expand Down
Loading