From bfe6445d0608276a99ef2ba9a104cd1bf4af755c Mon Sep 17 00:00:00 2001 From: MnCSSJ4x Date: Thu, 1 Aug 2024 19:22:29 +0530 Subject: [PATCH 1/8] Adds uniform processing to paligemma. --- .../models/paligemma/processing_paligemma.py | 69 ++++++++----------- 1 file changed, 30 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index 3d0ece60c367e4..3c31680417da20 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -17,19 +17,21 @@ """ import logging -from typing import List, Optional, Union +from typing import List, Union + +try: + from typing import Unpack +except ImportError: + pass from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image -from ...processing_utils import ProcessorMixin +from ...processing_utils import ProcessingKwargs, ProcessorMixin from ...tokenization_utils_base import ( AddedToken, - PaddingStrategy, PreTokenizedInput, TextInput, - TruncationStrategy, ) -from ...utils import TensorType logger = logging.getLogger(__name__) @@ -72,6 +74,18 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token): """ return f"{image_token * image_seq_len}{bos_token}{prompt}\n" +class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "tokenize_newline_separately": True, + "suffix": None, + }, + "image_kwargs": { + "do_convert_rgb": None, + "do_thumbnail": None, + "do_align_long_axis": None, + }, + } class PaliGemmaProcessor(ProcessorMixin): r""" @@ -124,25 +138,8 @@ def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, images: ImageInput = None, - tokenize_newline_separately: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length=None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, - do_resize: bool = None, - do_normalize: bool = None, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821 - input_data_format: Optional[ - Union[str, "ChannelDimension"] # noqa: F821 - ] = None, - resample: "PILImageResampling" = None, # noqa: F821 - do_convert_rgb: bool = None, - do_thumbnail: bool = None, - do_align_long_axis: bool = None, - do_rescale: bool = None, - suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + video =None, + **kwargs: Unpack[PaliGemmaProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -216,7 +213,7 @@ def __call__( - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **labels** -- Labels compatible with training if `suffix` is not None """ - + suffix = kwargs["text_kwargs"]["suffix"] return_token_type_ids = True if suffix is not None else False if images is None: @@ -253,27 +250,21 @@ def __call__( pixel_values = self.image_processor( images, - do_resize=do_resize, - do_normalize=do_normalize, - return_tensors=return_tensors, - image_mean=image_mean, - image_std=image_std, - input_data_format=input_data_format, - data_format=data_format, - resample=resample, - do_convert_rgb=do_convert_rgb, + **kwargs["image_kwargs"], )["pixel_values"] + max_length = kwargs.get("max_length", None) if max_length is not None: max_length += self.image_seq_length # max_length has to account for the image tokens - + output_kwargs = self._merge_kwargs( + PaliGemmaProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) inputs = self.tokenizer( input_strings, text_pair=suffix, - return_tensors=return_tensors, - padding=padding, - max_length=max_length, - truncation=truncation, + **output_kwargs['text_kwargs'], return_token_type_ids=return_token_type_ids, ) From ea95bafbe070e0b6bc4e1c4aa548147e595a4b8f Mon Sep 17 00:00:00 2001 From: MnCSSJ4x Date: Mon, 5 Aug 2024 15:42:23 +0530 Subject: [PATCH 2/8] Added specific args and updated call. --- .../models/paligemma/processing_paligemma.py | 54 ++++++++++++++----- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index 3c31680417da20..15e4bcf573237f 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -17,7 +17,7 @@ """ import logging -from typing import List, Union +from typing import List, Optional, Union try: @@ -26,7 +26,7 @@ pass from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image -from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs from ...tokenization_utils_base import ( AddedToken, PreTokenizedInput, @@ -74,19 +74,45 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token): """ return f"{image_token * image_seq_len}{bos_token}{prompt}\n" + +class PaliGemmaTextKwargs(TextKwargs): + suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None + + +class PaliGemmaImagesKwargs(ImagesKwargs): + do_convert_rgb: Optional[bool] = None + do_thumbnail: Optional[bool] = None + do_align_long_axis: Optional[bool] = None + do_rescale: Optional[bool] = None + + class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: PaliGemmaTextKwargs + image_kwargs: PaliGemmaImagesKwargs _defaults = { "text_kwargs": { - "tokenize_newline_separately": True, - "suffix": None, + "tokenize_newline_separately": True, # Not Available in Default + "suffix": None, # Not Available in Default + "padding": False, + "truncation": None, + "max_length": None, }, "image_kwargs": { - "do_convert_rgb": None, - "do_thumbnail": None, - "do_align_long_axis": None, + "do_resize": None, + "do_normalize": None, + "image_mean": None, + "image_std": None, + "data_format": "channels_first", + "input_data_format": None, + "resample": None, + "do_convert_rgb": None, # Not Available in Default + "do_thumbnail": None, # Not Available in Default + "do_align_long_axis": None, # Not Available in Default + "do_rescale": None, }, } + class PaliGemmaProcessor(ProcessorMixin): r""" Constructs a PaliGemma processor which wraps a PaliGemma image processor and a PaliGemma tokenizer into a single processor. @@ -138,7 +164,7 @@ def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, images: ImageInput = None, - video =None, + video=None, **kwargs: Unpack[PaliGemmaProcessorKwargs], ) -> BatchFeature: """ @@ -213,7 +239,12 @@ def __call__( - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **labels** -- Labels compatible with training if `suffix` is not None """ - suffix = kwargs["text_kwargs"]["suffix"] + output_kwargs = self._merge_kwargs( + PaliGemmaProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + suffix = output_kwargs["text_kwargs"]["suffix"] return_token_type_ids = True if suffix is not None else False if images is None: @@ -252,8 +283,7 @@ def __call__( images, **kwargs["image_kwargs"], )["pixel_values"] - - max_length = kwargs.get("max_length", None) + max_length = output_kwargs.get("max_length", None) if max_length is not None: max_length += self.image_seq_length # max_length has to account for the image tokens output_kwargs = self._merge_kwargs( @@ -264,7 +294,7 @@ def __call__( inputs = self.tokenizer( input_strings, text_pair=suffix, - **output_kwargs['text_kwargs'], + **output_kwargs["text_kwargs"], return_token_type_ids=return_token_type_ids, ) From 3c89e17d6612d6909a1122aab7013608fc39f788 Mon Sep 17 00:00:00 2001 From: MnCSSJ4x Date: Mon, 5 Aug 2024 15:58:46 +0530 Subject: [PATCH 3/8] Removed none defaults. --- .../models/paligemma/processing_paligemma.py | 23 ++++--------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index 15e4bcf573237f..e4dca5698be27e 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -76,14 +76,14 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token): class PaliGemmaTextKwargs(TextKwargs): - suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None + suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] class PaliGemmaImagesKwargs(ImagesKwargs): - do_convert_rgb: Optional[bool] = None - do_thumbnail: Optional[bool] = None - do_align_long_axis: Optional[bool] = None - do_rescale: Optional[bool] = None + do_convert_rgb: Optional[bool] + do_thumbnail: Optional[bool] + do_align_long_axis: Optional[bool] + do_rescale: Optional[bool] class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False): @@ -92,23 +92,10 @@ class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { "tokenize_newline_separately": True, # Not Available in Default - "suffix": None, # Not Available in Default "padding": False, - "truncation": None, - "max_length": None, }, "image_kwargs": { - "do_resize": None, - "do_normalize": None, - "image_mean": None, - "image_std": None, "data_format": "channels_first", - "input_data_format": None, - "resample": None, - "do_convert_rgb": None, # Not Available in Default - "do_thumbnail": None, # Not Available in Default - "do_align_long_axis": None, # Not Available in Default - "do_rescale": None, }, } From d62f00b9381e6ff674381bc49407092e8862f624 Mon Sep 17 00:00:00 2001 From: Monjoy Narayan Choudhury <77499007+MnCSSJ4x@users.noreply.github.com> Date: Wed, 14 Aug 2024 20:45:53 +0530 Subject: [PATCH 4/8] Update src/transformers/models/paligemma/processing_paligemma.py Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> --- src/transformers/models/paligemma/processing_paligemma.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index e4dca5698be27e..758e78cbe8c14a 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -20,10 +20,10 @@ from typing import List, Optional, Union -try: +if sys.version_info >= (3, 11): from typing import Unpack -except ImportError: - pass +else: + from typing_extensions import Unpack from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs From d5a5eed575ac623d5fd25cd70c65a39a3fc8568b Mon Sep 17 00:00:00 2001 From: Monjoy Narayan Choudhury <77499007+MnCSSJ4x@users.noreply.github.com> Date: Wed, 14 Aug 2024 20:46:19 +0530 Subject: [PATCH 5/8] Update src/transformers/models/paligemma/processing_paligemma.py Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> --- src/transformers/models/paligemma/processing_paligemma.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index 758e78cbe8c14a..86aebc320077e8 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -270,9 +270,6 @@ def __call__( images, **kwargs["image_kwargs"], )["pixel_values"] - max_length = output_kwargs.get("max_length", None) - if max_length is not None: - max_length += self.image_seq_length # max_length has to account for the image tokens output_kwargs = self._merge_kwargs( PaliGemmaProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, From b51e05f6547919658b9d3279b85c60e691b43948 Mon Sep 17 00:00:00 2001 From: Monjoy Narayan Choudhury <77499007+MnCSSJ4x@users.noreply.github.com> Date: Wed, 14 Aug 2024 20:46:28 +0530 Subject: [PATCH 6/8] Update src/transformers/models/paligemma/processing_paligemma.py Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> --- src/transformers/models/paligemma/processing_paligemma.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index 86aebc320077e8..3cd68fa33bf527 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -270,11 +270,8 @@ def __call__( images, **kwargs["image_kwargs"], )["pixel_values"] - output_kwargs = self._merge_kwargs( - PaliGemmaProcessorKwargs, - tokenizer_init_kwargs=self.tokenizer.init_kwargs, - **kwargs, - ) + if output_kwargs["text_kwargs"].get("max_length", None) is not None: + output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length inputs = self.tokenizer( input_strings, text_pair=suffix, From 580ef32047103329d930146590eb1effb736d03e Mon Sep 17 00:00:00 2001 From: MnCSSJ4x Date: Wed, 14 Aug 2024 21:13:15 +0530 Subject: [PATCH 7/8] Add test for processing_paligemma.py --- .../models/paligemma/processing_paligemma.py | 4 +- .../paligemma/test_processing_paligemma.py | 209 ++++++++++++++++++ 2 files changed, 211 insertions(+), 2 deletions(-) create mode 100644 tests/models/paligemma/test_processing_paligemma.py diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index 3cd68fa33bf527..bd1d05d5c9574c 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -17,9 +17,9 @@ """ import logging +import sys from typing import List, Optional, Union - if sys.version_info >= (3, 11): from typing import Unpack else: @@ -91,7 +91,7 @@ class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False): image_kwargs: PaliGemmaImagesKwargs _defaults = { "text_kwargs": { - "tokenize_newline_separately": True, # Not Available in Default + "tokenize_newline_separately": True, "padding": False, }, "image_kwargs": { diff --git a/tests/models/paligemma/test_processing_paligemma.py b/tests/models/paligemma/test_processing_paligemma.py new file mode 100644 index 00000000000000..3af8415cf33c77 --- /dev/null +++ b/tests/models/paligemma/test_processing_paligemma.py @@ -0,0 +1,209 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import shutil +import tempfile +import unittest + +import numpy as np + +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from PIL import Image + + from transformers import ( + PaliGemmaProcessor, + is_vision_available, + AutoProcessor, + ) + + +@require_vision +class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = PaliGemmaProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-pt-224") + processor.save_pretrained(self.tmpdirname) + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + + # Some kwargs tests are overriden from common tests to handle shortest_edge + # and size_divisor behaviour + + @require_torch + @require_vision + def test_image_processor_defaults_preserved_by_image_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component( + "image_processor", + crop_size={"shortest_edge": 234, "longest_edge": 234}, + ) + tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + self.assertEqual(len(inputs["pixel_values"][0][0]), 234) + + @require_torch + @require_vision + def test_structured_kwargs_nested_from_dict(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"crop_size": {"shortest_edge": 214}, "data_format": "channels_first"}, + "text_kwargs": {"padding": "max_length", "max_length": 76, "tokenize_newline_separately": False}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 76) + + @require_torch + @require_vision + def test_kwargs_overrides_default_image_processor_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor", crop_size={"shortest_edge": 234}) + tokenizer = self.get_component("tokenizer", max_length=117) + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor(text=input_str, images=image_input, crop_size={"shortest_edge": 224}) + self.assertEqual(len(inputs["pixel_values"][0][0]), 224) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + crop_size={"shortest_edge": 214}, + padding="longest", + max_length=76, + ) + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 6) + + @require_torch + @require_vision + def test_unstructured_kwargs(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + crop_size={"shortest_edge": 214}, + padding="max_length", + max_length=76, + ) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + self.assertEqual(len(inputs["input_ids"][0]), 76) + + @require_torch + @require_vision + def test_structured_kwargs_nested(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + if not tokenizer.pad_token: + tokenizer.pad_token = "[TEST_PAD]" + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + # Define the kwargs for each modality + all_kwargs = { + "common_kwargs": {"return_tensors": "pt"}, + "images_kwargs": {"crop_size": {"shortest_edge": 214}, "data_format": "channels_first"}, + "text_kwargs": {"padding": "max_length", "max_length": 76, "tokenize_newline_separately": False}, + } + + inputs = processor(text=input_str, images=image_input, **all_kwargs) + self.skip_processor_without_typed_kwargs(processor) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + self.assertEqual(len(inputs["input_ids"][0]), 76) \ No newline at end of file From 77baadcb235cd4a7a13ea19ec3a7e2aabc156477 Mon Sep 17 00:00:00 2001 From: MnCSSJ4x Date: Wed, 14 Aug 2024 21:14:27 +0530 Subject: [PATCH 8/8] Update style using make style. --- src/transformers/models/paligemma/processing_paligemma.py | 1 + tests/models/paligemma/test_processing_paligemma.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index bd1d05d5c9574c..3cfd66c1ef98fb 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -20,6 +20,7 @@ import sys from typing import List, Optional, Union + if sys.version_info >= (3, 11): from typing import Unpack else: diff --git a/tests/models/paligemma/test_processing_paligemma.py b/tests/models/paligemma/test_processing_paligemma.py index 3af8415cf33c77..c5977709a3ed63 100644 --- a/tests/models/paligemma/test_processing_paligemma.py +++ b/tests/models/paligemma/test_processing_paligemma.py @@ -27,9 +27,9 @@ from PIL import Image from transformers import ( + AutoProcessor, PaliGemmaProcessor, is_vision_available, - AutoProcessor, ) @@ -206,4 +206,4 @@ def test_structured_kwargs_nested(self): self.skip_processor_without_typed_kwargs(processor) self.assertEqual(inputs["pixel_values"].shape[2], 214) - self.assertEqual(len(inputs["input_ids"][0]), 76) \ No newline at end of file + self.assertEqual(len(inputs["input_ids"][0]), 76)