Skip to content

Commit

Permalink
Uniformize kwargs for Paligemma processor and update docs (huggingfac…
Browse files Browse the repository at this point in the history
…e#33571)

* Uniformize paligemma processor

* nit
  • Loading branch information
yonigozlan authored and itazap committed Sep 20, 2024
1 parent f629df5 commit 1789e57
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 74 deletions.
4 changes: 2 additions & 2 deletions docs/source/en/model_doc/paligemma.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ processor = AutoProcessor.from_pretrained(model_id)
prompt = "What is on the flower?"
image_file = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors="pt")
inputs = processor(raw_image, prompt, return_tensors="pt")
output = model.generate(**inputs, max_new_tokens=20)

print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):])
Expand All @@ -53,7 +53,7 @@ print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):])
```python
prompt = "What is on the flower?"
answer = "a bee"
inputs = processor(text=prompt, images=raw_image, suffix=answer, return_tensors="pt")
inputs = processor(images=raw_image, text=prompt, suffix=answer, return_tensors="pt")
```

## Resources
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/paligemma/modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def forward(
>>> url = "https://huggingface.co/gv-hf/PaliGemma-test-224px-hf/resolve/main/cow_beach_1.png"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_length=30)
Expand Down
114 changes: 52 additions & 62 deletions src/transformers/models/paligemma/processing_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,19 @@

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image
from ...processing_utils import ProcessorMixin
from ...processing_utils import (
ImagesKwargs,
ProcessingKwargs,
ProcessorMixin,
TextKwargs,
Unpack,
_validate_images_text_input_order,
)
from ...tokenization_utils_base import (
AddedToken,
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from ...utils import TensorType


logger = logging.getLogger(__name__)
Expand All @@ -38,6 +42,27 @@
EXTRA_TOKENS = [f"<loc{i:0>4}>" for i in range(1024)] + [f"<seg{i:0>3}>" for i in range(128)]


class PaliGemmaTextKwargs(TextKwargs):
suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]


class PaliGemmaImagesKwargs(ImagesKwargs):
do_convert_rgb: Optional[bool]


class PaliGemmaProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: PaliGemmaTextKwargs
images_kwargs: PaliGemmaImagesKwargs
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {
"data_format": "channels_first",
},
}


# Copied from transformers.models.idefics2.processing_idefics2.is_url
def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http")
Expand Down Expand Up @@ -122,27 +147,11 @@ def __init__(

def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
tokenize_newline_separately: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
do_resize: bool = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
data_format: Optional["ChannelDimension"] = "channels_first", # noqa: F821
input_data_format: Optional[
Union[str, "ChannelDimension"] # noqa: F821
] = None,
resample: "PILImageResampling" = None, # noqa: F821
do_convert_rgb: bool = None,
do_thumbnail: bool = None,
do_align_long_axis: bool = None,
do_rescale: bool = None,
suffix: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[PaliGemmaProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand Down Expand Up @@ -171,29 +180,14 @@ def __call__(
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
tokenize_newline_separately (`bool`, defaults to `True`):
Adds a separately tokenized '\n' at the end of the prompt.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
Expand All @@ -216,6 +210,15 @@ def __call__(
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **labels** -- Labels compatible with training if `suffix` is not None
"""
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)

output_kwargs = self._merge_kwargs(
PaliGemmaProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
suffix = output_kwargs["text_kwargs"].pop("suffix", None)

return_token_type_ids = True if suffix is not None else False

Expand Down Expand Up @@ -251,30 +254,17 @@ def __call__(
for prompt in text
]

pixel_values = self.image_processor(
images,
do_resize=do_resize,
do_normalize=do_normalize,
return_tensors=return_tensors,
image_mean=image_mean,
image_std=image_std,
input_data_format=input_data_format,
data_format=data_format,
resample=resample,
do_convert_rgb=do_convert_rgb,
)["pixel_values"]

if max_length is not None:
max_length += self.image_seq_length # max_length has to account for the image tokens
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]

# max_length has to account for the image tokens
if output_kwargs["text_kwargs"].get("max_length", None) is not None:
output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length

inputs = self.tokenizer(
input_strings,
text_pair=suffix,
return_tensors=return_tensors,
padding=padding,
max_length=max_length,
truncation=truncation,
return_token_type_ids=return_token_type_ids,
**output_kwargs["text_kwargs"],
)

return_data = {**inputs, "pixel_values": pixel_values}
Expand Down
18 changes: 9 additions & 9 deletions tests/models/paligemma/test_modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def test_small_model_integration_test(self):
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(text=prompt, images=raw_image, return_tensors="pt")
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt")
EXPECTED_INPUT_IDS = torch.tensor([[257152] * 256 + [2, 108]])
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))

Expand All @@ -360,7 +360,7 @@ def test_small_model_integration_test_paligemma_VQA(self):
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(text=prompt, images=raw_image, return_tensors="pt").to(torch.float16)
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(torch.float16)

output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "answer en Where is the cow standing?\nbeach" # fmt: skip
Expand All @@ -382,7 +382,7 @@ def test_small_model_integration_test_paligemma_empty_prompt(self):
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(text=prompt, images=raw_image, return_tensors="pt").to(torch.float16)
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(torch.float16)

output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "\ncow on the beach" # fmt: skip
Expand Down Expand Up @@ -412,7 +412,7 @@ def test_small_model_integration_test_paligemma_batched(self):
)
image2 = image1

inputs = self.processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True)
inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)

output = model.generate(**inputs, max_new_tokens=20)

Expand Down Expand Up @@ -443,7 +443,7 @@ def test_small_model_integration_test_paligemma_batched_bf16(self):
image2 = image1

inputs = (
self.processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True)
self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
.to(torch.bfloat16)
.to(torch_device)
)
Expand Down Expand Up @@ -475,7 +475,7 @@ def test_small_model_integration_test_paligemma_batched_f16(self):
image2 = image1

inputs = (
self.processor(text=prompts, images=[image1, image2], return_tensors="pt", padding=True)
self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
.to(torch.float16)
.to(torch_device)
)
Expand Down Expand Up @@ -504,7 +504,7 @@ def test_integration_detection_bug(self):
).raw
)

inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(torch_device)
inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(torch.bfloat16).to(torch_device)

output = model.generate(**inputs, max_new_tokens=20)

Expand All @@ -528,8 +528,8 @@ def test_paligemma_index_error_bug(self):

raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(
text=prompt,
images=raw_image,
text=prompt,
return_tensors="pt",
).to(torch.float16)

Expand Down Expand Up @@ -561,7 +561,7 @@ def test_paligemma_finetuning_with_suffixes_bf16(self):
image2 = image1

inputs = (
self.processor(text=prompts, suffix=suffixes, images=[image1, image2], return_tensors="pt", padding=True)
self.processor(images=[image1, image2], text=prompts, suffix=suffixes, return_tensors="pt", padding=True)
.to(torch.bfloat16)
.to(torch_device)
)
Expand Down
89 changes: 89 additions & 0 deletions tests/models/paligemma/test_processor_paligemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import shutil
import tempfile
import unittest

from transformers import GemmaTokenizer
from transformers.testing_utils import get_tests_dir, require_torch, require_vision
from transformers.utils import is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from transformers import (
PaliGemmaProcessor,
SiglipImageProcessor,
is_vision_available,
)

SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")


@require_vision
class PaliGemmaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = PaliGemmaProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
image_processor.image_seq_length = 0
tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True)
processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer)
processor.save_pretrained(self.tmpdirname)

def tearDown(self):
shutil.rmtree(self.tmpdirname)

@require_torch
@require_vision
def test_image_seq_length(self):
input_str = "lower newer"
image_input = self.prepare_image_inputs()
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=112, padding="max_length")
image_processor.image_seq_length = 14
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
inputs = processor(
text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length"
)
self.assertEqual(len(inputs["input_ids"][0]), 112 + 14)

@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)

self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 10)

0 comments on commit 1789e57

Please sign in to comment.