Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds uniform processing kwargs to paligemma. #32377

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
69 changes: 30 additions & 39 deletions src/transformers/models/paligemma/processing_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@
"""

import logging
from typing import List, Optional, Union
from typing import List, Union


try:
from typing import Unpack
except ImportError:
pass
MnCSSJ4x marked this conversation as resolved.
Show resolved Hide resolved
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image
from ...processing_utils import ProcessorMixin
from ...processing_utils import ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import (
AddedToken,
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from ...utils import TensorType


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -72,6 +74,18 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token):
"""
return f"{image_token * image_seq_len}{bos_token}{prompt}\n"

class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"tokenize_newline_separately": True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like tokenize_newline_separately is not use anywhere, and it is not a default text_kwargs, so it might be best to remove it entirely?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's not used anymore and is not needed - iiuc do_thumbnail, do_align_long_axis and do_rescale neither (FYI, they are not used here)
+1 for removing it

"suffix": None,
MnCSSJ4x marked this conversation as resolved.
Show resolved Hide resolved
},
"image_kwargs": {
"do_convert_rgb": None,
"do_thumbnail": None,
"do_align_long_axis": None,
},
MnCSSJ4x marked this conversation as resolved.
Show resolved Hide resolved
}

class PaliGemmaProcessor(ProcessorMixin):
r"""
Expand Down Expand Up @@ -124,25 +138,8 @@ def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
Comment on lines 153 to 154
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two inputs should be reversed and support for backward compatibility should be added. This should be similar to what is needed for Fuyu:

if (
text is not None
and not isinstance(text[0], str)
or images is not None
and (isinstance(images, str) or (isinstance(images, (list, tuple)) and isinstance(images[0], str)))
):
warnings.warn(
"It looks like you are passing the inputs in the wrong order. You should pass the images input first and the text input second."
"Images and text inputs will be swapped."
)
images, text = text, images

tokenize_newline_separately: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
do_resize: bool = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
input_data_format: Optional[
Union[str, "ChannelDimension"] # noqa: F821
] = None,
resample: "PILImageResampling" = None, # noqa: F821
do_convert_rgb: bool = None,
do_thumbnail: bool = None,
do_align_long_axis: bool = None,
do_rescale: bool = None,
suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
video =None,
**kwargs: Unpack[PaliGemmaProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand Down Expand Up @@ -216,7 +213,7 @@ def __call__(
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **labels** -- Labels compatible with training if `suffix` is not None
"""

suffix = kwargs["text_kwargs"]["suffix"]
return_token_type_ids = True if suffix is not None else False

if images is None:
Expand Down Expand Up @@ -253,27 +250,21 @@ def __call__(

pixel_values = self.image_processor(
images,
do_resize=do_resize,
do_normalize=do_normalize,
return_tensors=return_tensors,
image_mean=image_mean,
image_std=image_std,
input_data_format=input_data_format,
data_format=data_format,
resample=resample,
do_convert_rgb=do_convert_rgb,
**kwargs["image_kwargs"],
)["pixel_values"]

max_length = kwargs.get("max_length", None)
MnCSSJ4x marked this conversation as resolved.
Show resolved Hide resolved
if max_length is not None:
max_length += self.image_seq_length # max_length has to account for the image tokens

output_kwargs = self._merge_kwargs(
PaliGemmaProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
MnCSSJ4x marked this conversation as resolved.
Show resolved Hide resolved
inputs = self.tokenizer(
input_strings,
text_pair=suffix,
return_tensors=return_tensors,
padding=padding,
max_length=max_length,
truncation=truncation,
**output_kwargs['text_kwargs'],
return_token_type_ids=return_token_type_ids,
)

Expand Down