diff --git a/docs/source/en/model_doc/paligemma.md b/docs/source/en/model_doc/paligemma.md index 48debe593f97a9..41d785bba29dba 100644 --- a/docs/source/en/model_doc/paligemma.md +++ b/docs/source/en/model_doc/paligemma.md @@ -41,7 +41,7 @@ processor = AutoProcessor.from_pretrained(model_id) prompt = "What is on the flower?" image_file = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true" raw_image = Image.open(requests.get(image_file, stream=True).raw) -inputs = processor(prompt, raw_image, return_tensors="pt") +inputs = processor(raw_image, prompt, return_tensors="pt") output = model.generate(**inputs, max_new_tokens=20) print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):]) @@ -53,7 +53,7 @@ print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):]) ```python prompt = "What is on the flower?" answer = "a bee" -inputs = processor(text=prompt, images=raw_image, suffix=answer, return_tensors="pt") +inputs = processor(images=raw_image, text=prompt, suffix=answer, return_tensors="pt") ``` ## Resources diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 4e456b9f08e213..48fffb6b428df7 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -443,7 +443,7 @@ def forward( >>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(**inputs, max_length=30) diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index 3d0ece60c367e4..4457b6fe957bf3 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -21,15 +21,19 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image -from ...processing_utils import ProcessorMixin +from ...processing_utils import ( + ImagesKwargs, + ProcessingKwargs, + ProcessorMixin, + TextKwargs, + Unpack, + _validate_images_text_input_order, +) from ...tokenization_utils_base import ( AddedToken, - PaddingStrategy, PreTokenizedInput, TextInput, - TruncationStrategy, ) -from ...utils import TensorType logger = logging.getLogger(__name__) @@ -38,6 +42,27 @@ EXTRA_TOKENS = [f"4}>" for i in range(1024)] + [f"3}>" for i in range(128)] +class PaliGemmaTextKwargs(TextKwargs): + suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] + + +class PaliGemmaImagesKwargs(ImagesKwargs): + do_convert_rgb: Optional[bool] + + +class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: PaliGemmaTextKwargs + images_kwargs: PaliGemmaImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": { + "data_format": "channels_first", + }, + } + + # Copied from transformers.models.idefics2.processing_idefics2.is_url def is_url(val) -> bool: return isinstance(val, str) and val.startswith("http") @@ -122,27 +147,11 @@ def __init__( def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, images: ImageInput = None, - tokenize_newline_separately: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length=None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, - do_resize: bool = None, - do_normalize: bool = None, - image_mean: Optional[Union[float, List[float]]] = None, - image_std: Optional[Union[float, List[float]]] = None, - data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821 - input_data_format: Optional[ - Union[str, "ChannelDimension"] # noqa: F821 - ] = None, - resample: "PILImageResampling" = None, # noqa: F821 - do_convert_rgb: bool = None, - do_thumbnail: bool = None, - do_align_long_axis: bool = None, - do_rescale: bool = None, - suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[PaliGemmaProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -171,29 +180,14 @@ def __call__( Args: - text (`str`, `List[str]`, `List[List[str]]`): - The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings - (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set - `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. - tokenize_newline_separately (`bool`, defaults to `True`): - Adds a separately tokenized '\n' at the end of the prompt. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: @@ -216,6 +210,15 @@ def __call__( - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. - **labels** -- Labels compatible with training if `suffix` is not None """ + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + output_kwargs = self._merge_kwargs( + PaliGemmaProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + suffix = output_kwargs["text_kwargs"].pop("suffix", None) return_token_type_ids = True if suffix is not None else False @@ -251,30 +254,17 @@ def __call__( for prompt in text ] - pixel_values = self.image_processor( - images, - do_resize=do_resize, - do_normalize=do_normalize, - return_tensors=return_tensors, - image_mean=image_mean, - image_std=image_std, - input_data_format=input_data_format, - data_format=data_format, - resample=resample, - do_convert_rgb=do_convert_rgb, - )["pixel_values"] - - if max_length is not None: - max_length += self.image_seq_length # max_length has to account for the image tokens + pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] + + # max_length has to account for the image tokens + if output_kwargs["text_kwargs"].get("max_length", None) is not None: + output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length inputs = self.tokenizer( input_strings, text_pair=suffix, - return_tensors=return_tensors, - padding=padding, - max_length=max_length, - truncation=truncation, return_token_type_ids=return_token_type_ids, + **output_kwargs["text_kwargs"], ) return_data = {**inputs, "pixel_values": pixel_values} diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index d592205443e1c2..3918292133b406 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -337,7 +337,7 @@ def test_small_model_integration_test(self): "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" ) raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = self.processor(text=prompt, images=raw_image, return_tensors="pt") + inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt") EXPECTED_INPUT_IDS = torch.tensor([[257152] * 256 + [2, 108]]) self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) @@ -360,7 +360,7 @@ def test_small_model_integration_test_paligemma_VQA(self): "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" ) raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = self.processor(text=prompt, images=raw_image, return_tensors="pt").to(torch.float16) + inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(torch.float16) output = model.generate(**inputs, max_new_tokens=900, do_sample=False) EXPECTED_DECODED_TEXT = "answer en Where is the cow standing?\nbeach" # fmt: skip @@ -382,7 +382,7 @@ def test_small_model_integration_test_paligemma_empty_prompt(self): "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png" ) raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = self.processor(text=prompt, images=raw_image, return_tensors="pt").to(torch.float16) + inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(torch.float16) output = model.generate(**inputs, max_new_tokens=900, do_sample=False) EXPECTED_DECODED_TEXT = "\ncow on the beach" # fmt: skip @@ -412,7 +412,7 @@ def test_small_model_integration_test_paligemma_batched(self): ) image2 = image1 - inputs = self.processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True) + inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True) output = model.generate(**inputs, max_new_tokens=20) @@ -443,7 +443,7 @@ def test_small_model_integration_test_paligemma_batched_bf16(self): image2 = image1 inputs = ( - self.processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True) + self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True) .to(torch.bfloat16) .to(torch_device) ) @@ -475,7 +475,7 @@ def test_small_model_integration_test_paligemma_batched_f16(self): image2 = image1 inputs = ( - self.processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True) + self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True) .to(torch.float16) .to(torch_device) ) @@ -504,7 +504,7 @@ def test_integration_detection_bug(self): ).raw ) - inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(torch_device) + inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(torch.bfloat16).to(torch_device) output = model.generate(**inputs, max_new_tokens=20) @@ -528,8 +528,8 @@ def test_paligemma_index_error_bug(self): raw_image = Image.open(requests.get(image_file, stream=True).raw) inputs = self.processor( - text=prompt, images=raw_image, + text=prompt, return_tensors="pt", ).to(torch.float16) @@ -561,7 +561,7 @@ def test_paligemma_finetuning_with_suffixes_bf16(self): image2 = image1 inputs = ( - self.processor(text=prompts, suffix=suffixes, images=[image1, image2], return_tensors="pt", padding=True) + self.processor(images=[image1, image2], text=prompts, suffix=suffixes, return_tensors="pt", padding=True) .to(torch.bfloat16) .to(torch_device) ) diff --git a/tests/models/paligemma/test_processor_paligemma.py b/tests/models/paligemma/test_processor_paligemma.py new file mode 100644 index 00000000000000..47810f1832416f --- /dev/null +++ b/tests/models/paligemma/test_processor_paligemma.py @@ -0,0 +1,89 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import tempfile +import unittest + +from transformers import GemmaTokenizer +from transformers.testing_utils import get_tests_dir, require_torch, require_vision +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import ( + PaliGemmaProcessor, + SiglipImageProcessor, + is_vision_available, + ) + +SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") + + +@require_vision +class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = PaliGemmaProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384") + image_processor.image_seq_length = 0 + tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True) + processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer) + processor.save_pretrained(self.tmpdirname) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + @require_torch + @require_vision + def test_image_seq_length(self): + input_str = "lower newer" + image_input = self.prepare_image_inputs() + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer", max_length=112, padding="max_length") + image_processor.image_seq_length = 14 + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + inputs = processor( + text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length" + ) + self.assertEqual(len(inputs["input_ids"][0]), 112 + 14) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + size={"height": 214, "width": 214}, + padding="longest", + max_length=76, + ) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 10)