Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds uniform processing kwargs to paligemma. #32377

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
86 changes: 47 additions & 39 deletions src/transformers/models/paligemma/processing_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,19 @@
import logging
from typing import List, Optional, Union


try:
from typing import Unpack
except ImportError:
pass
MnCSSJ4x marked this conversation as resolved.
Show resolved Hide resolved
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image
from ...processing_utils import ProcessorMixin
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs
from ...tokenization_utils_base import (
AddedToken,
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from ...utils import TensorType


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -73,6 +75,31 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token):
return f"{image_token * image_seq_len}{bos_token}{prompt}\n"


class PaliGemmaTextKwargs(TextKwargs):
suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]


class PaliGemmaImagesKwargs(ImagesKwargs):
do_convert_rgb: Optional[bool]
do_thumbnail: Optional[bool]
do_align_long_axis: Optional[bool]
do_rescale: Optional[bool]


class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: PaliGemmaTextKwargs
image_kwargs: PaliGemmaImagesKwargs
_defaults = {
"text_kwargs": {
"tokenize_newline_separately": True, # Not Available in Default
"padding": False,
},
"image_kwargs": {
"data_format": "channels_first",
},
}


class PaliGemmaProcessor(ProcessorMixin):
r"""
Constructs a PaliGemma processor which wraps a PaliGemma image processor and a PaliGemma tokenizer into a single processor.
Expand Down Expand Up @@ -124,25 +151,8 @@ def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
Comment on lines 153 to 154
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two inputs should be reversed and support for backward compatibility should be added. This should be similar to what is needed for Fuyu:

if (
text is not None
and not isinstance(text[0], str)
or images is not None
and (isinstance(images, str) or (isinstance(images, (list, tuple)) and isinstance(images[0], str)))
):
warnings.warn(
"It looks like you are passing the inputs in the wrong order. You should pass the images input first and the text input second."
"Images and text inputs will be swapped."
)
images, text = text, images

tokenize_newline_separately: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
do_resize: bool = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
input_data_format: Optional[
Union[str, "ChannelDimension"] # noqa: F821
] = None,
resample: "PILImageResampling" = None, # noqa: F821
do_convert_rgb: bool = None,
do_thumbnail: bool = None,
do_align_long_axis: bool = None,
do_rescale: bool = None,
suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
video=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can also advertise None audio kwarg here!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

audio=None Is still needed here for API consistency, even if this model doesn't support the audio modality.

Suggested change
video=None,
audio = None,
video=None,

**kwargs: Unpack[PaliGemmaProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand Down Expand Up @@ -216,7 +226,12 @@ def __call__(
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **labels** -- Labels compatible with training if `suffix` is not None
"""

output_kwargs = self._merge_kwargs(
PaliGemmaProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
suffix = output_kwargs["text_kwargs"]["suffix"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If suffix is not specified as a kwargs, this will cause an error. Better to use:
suffix = output_kwargs["text_kwargs"].pop("suffix", None)

return_token_type_ids = True if suffix is not None else False

if images is None:
Expand Down Expand Up @@ -253,27 +268,20 @@ def __call__(

pixel_values = self.image_processor(
images,
do_resize=do_resize,
do_normalize=do_normalize,
return_tensors=return_tensors,
image_mean=image_mean,
image_std=image_std,
input_data_format=input_data_format,
data_format=data_format,
resample=resample,
do_convert_rgb=do_convert_rgb,
**kwargs["image_kwargs"],
)["pixel_values"]

max_length = output_kwargs.get("max_length", None)
if max_length is not None:
max_length += self.image_seq_length # max_length has to account for the image tokens
MnCSSJ4x marked this conversation as resolved.
Show resolved Hide resolved

output_kwargs = self._merge_kwargs(
PaliGemmaProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
MnCSSJ4x marked this conversation as resolved.
Show resolved Hide resolved
inputs = self.tokenizer(
input_strings,
text_pair=suffix,
return_tensors=return_tensors,
padding=padding,
max_length=max_length,
truncation=truncation,
**output_kwargs["text_kwargs"],
return_token_type_ids=return_token_type_ids,
Comment on lines +279 to 280
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**output_kwargs["text_kwargs"],
return_token_type_ids=return_token_type_ids,
return_token_type_ids=return_token_type_ids,
**output_kwargs["text_kwargs"],

)

Expand Down