Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds uniform processing kwargs to paligemma. #32377

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
85 changes: 44 additions & 41 deletions src/transformers/models/paligemma/processing_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@
"""

import logging
import sys
from typing import List, Optional, Union


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image
from ...processing_utils import ProcessorMixin
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs
from ...tokenization_utils_base import (
AddedToken,
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from ...utils import TensorType


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -73,6 +76,31 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token):
return f"{image_token * image_seq_len}{bos_token}{prompt}\n"


class PaliGemmaTextKwargs(TextKwargs):
suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]


class PaliGemmaImagesKwargs(ImagesKwargs):
do_convert_rgb: Optional[bool]
do_thumbnail: Optional[bool]
do_align_long_axis: Optional[bool]
do_rescale: Optional[bool]


class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: PaliGemmaTextKwargs
image_kwargs: PaliGemmaImagesKwargs
_defaults = {
"text_kwargs": {
"tokenize_newline_separately": True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like tokenize_newline_separately is not use anywhere, and it is not a default text_kwargs, so it might be best to remove it entirely?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's not used anymore and is not needed - iiuc do_thumbnail, do_align_long_axis and do_rescale neither (FYI, they are not used here)
+1 for removing it

"padding": False,
},
"image_kwargs": {
"data_format": "channels_first",
},
}


class PaliGemmaProcessor(ProcessorMixin):
r"""
Constructs a PaliGemma processor which wraps a PaliGemma image processor and a PaliGemma tokenizer into a single processor.
Expand Down Expand Up @@ -124,25 +152,8 @@ def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
Comment on lines 153 to 154
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two inputs should be reversed and support for backward compatibility should be added. This should be similar to what is needed for Fuyu:

if (
text is not None
and not isinstance(text[0], str)
or images is not None
and (isinstance(images, str) or (isinstance(images, (list, tuple)) and isinstance(images[0], str)))
):
warnings.warn(
"It looks like you are passing the inputs in the wrong order. You should pass the images input first and the text input second."
"Images and text inputs will be swapped."
)
images, text = text, images

tokenize_newline_separately: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
do_resize: bool = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
input_data_format: Optional[
Union[str, "ChannelDimension"] # noqa: F821
] = None,
resample: "PILImageResampling" = None, # noqa: F821
do_convert_rgb: bool = None,
do_thumbnail: bool = None,
do_align_long_axis: bool = None,
do_rescale: bool = None,
suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
video=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can also advertise None audio kwarg here!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

audio=None Is still needed here for API consistency, even if this model doesn't support the audio modality.

Suggested change
video=None,
audio = None,
video=None,

**kwargs: Unpack[PaliGemmaProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand Down Expand Up @@ -216,7 +227,12 @@ def __call__(
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **labels** -- Labels compatible with training if `suffix` is not None
"""

output_kwargs = self._merge_kwargs(
PaliGemmaProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
suffix = output_kwargs["text_kwargs"]["suffix"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If suffix is not specified as a kwargs, this will cause an error. Better to use:
suffix = output_kwargs["text_kwargs"].pop("suffix", None)

return_token_type_ids = True if suffix is not None else False

if images is None:
Expand Down Expand Up @@ -253,27 +269,14 @@ def __call__(

pixel_values = self.image_processor(
images,
do_resize=do_resize,
do_normalize=do_normalize,
return_tensors=return_tensors,
image_mean=image_mean,
image_std=image_std,
input_data_format=input_data_format,
data_format=data_format,
resample=resample,
do_convert_rgb=do_convert_rgb,
**kwargs["image_kwargs"],
)["pixel_values"]

if max_length is not None:
max_length += self.image_seq_length # max_length has to account for the image tokens

if output_kwargs["text_kwargs"].get("max_length", None) is not None:
output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length
inputs = self.tokenizer(
input_strings,
text_pair=suffix,
return_tensors=return_tensors,
padding=padding,
max_length=max_length,
truncation=truncation,
**output_kwargs["text_kwargs"],
return_token_type_ids=return_token_type_ids,
Comment on lines +279 to 280
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**output_kwargs["text_kwargs"],
return_token_type_ids=return_token_type_ids,
return_token_type_ids=return_token_type_ids,
**output_kwargs["text_kwargs"],

)

Expand Down
209 changes: 209 additions & 0 deletions tests/models/paligemma/test_processing_paligemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest

import numpy as np

from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from PIL import Image

from transformers import (
AutoProcessor,
PaliGemmaProcessor,
is_vision_available,
)


@require_vision
class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = PaliGemmaProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-pt-224")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will indeed cause a gated repo issue. you could rebuild a processor without using this repo, something like

Suggested change
processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-pt-224")
image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True)
processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer)

Where
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
as it is done for test_tokenization_gemma.py.

Not sure if that's the nicest way to fix this though, any idea @zucchini-nlp @molbap ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CI token can be updated so that it can read this repo: in test_modeling.py there is

        self.processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-pt-224")

so that should not be an issue already - any idea @ydshieh here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, what is the issue here? This repo seem to be public no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
I suppose it is, with a license to accept?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, are we talking about

google/paligemma-3b-pt-224

or

google/siglip-so400m-patch14-384

but both are accessible even if I am using a firefox private window

processor.save_pretrained(self.tmpdirname)

def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer

def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor

def tearDown(self):
shutil.rmtree(self.tmpdirname)

def prepare_image_inputs(self):
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
or a list of PyTorch tensors if one specifies torchify=True.
"""

image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]

image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]

return image_inputs
Comment on lines +54 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm noticing more of this function across the repo and it's identical in 17 places, I think we can move it to processing_utils.py at some point and save some loc, same remark for above helper functions!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this one works, I usually use it for image processor tests? from tests.test_image_processing_common import prepare_image_inputs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! Let's move. I also have my own personal agenda to remove the "numpify" and "torchify" arguments which are confusing, clash and inconsistent so would be a good opportunity for that


# Some kwargs tests are overriden from common tests to handle shortest_edge
# and size_divisor behaviour

@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component(
"image_processor",
crop_size={"shortest_edge": 234, "longest_edge": 234},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may not be needed anymore as the base tests were changed recently, same for other tests. Please fetch and rebase on upstream main :)

)
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)

@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")

image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()

# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"crop_size": {"shortest_edge": 214}, "data_format": "channels_first"},
"text_kwargs": {"padding": "max_length", "max_length": 76, "tokenize_newline_separately": False},
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)

@require_torch
@require_vision
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", crop_size={"shortest_edge": 234})
tokenizer = self.get_component("tokenizer", max_length=117)
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input, crop_size={"shortest_edge": 224})
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)

@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"shortest_edge": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 6)

@require_torch
@require_vision
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"shortest_edge": 214},
padding="max_length",
max_length=76,
)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)

@require_torch
@require_vision
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"crop_size": {"shortest_edge": 214}, "data_format": "channels_first"},
"text_kwargs": {"padding": "max_length", "max_length": 76, "tokenize_newline_separately": False},
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)