Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniformize kwargs for image-text-to-text processors #32544

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
083b4bd
uniformize FUYU processor kwargs
yonigozlan Aug 7, 2024
425ed0e
Uniformize instructblip processor kwargs
yonigozlan Aug 7, 2024
dda5c5d
Fix processor kwargs and tests Fuyu, InstructBlip, Kosmos2
yonigozlan Aug 7, 2024
c2806f4
Uniformize llava_next processor
yonigozlan Aug 8, 2024
0c30bcb
Fix save_load test for processor with chat_template only as extra ini…
yonigozlan Aug 8, 2024
9538867
Fix import Unpack
yonigozlan Aug 8, 2024
d8311d2
Fix Fuyu Processor import
yonigozlan Aug 8, 2024
770eb38
Fix FuyuProcessor import
yonigozlan Aug 8, 2024
20c1e6e
Fix FuyuProcessor
yonigozlan Aug 8, 2024
b887c6d
Add defaults for specific kwargs kosmos2
yonigozlan Aug 9, 2024
325ce26
Fix Udop to return BatchFeature instead of BatchEncoding and uniformi…
yonigozlan Aug 9, 2024
a7fcb8b
Add tests processor Udop
yonigozlan Aug 9, 2024
8a296f5
remove Copied from in processing Udop as change of input orders cause…
yonigozlan Aug 9, 2024
f1be841
Fix overwrite tests kwargs processors
yonigozlan Aug 12, 2024
58b70a1
Add warnings and BC for changes in processor inputs order, change doc…
yonigozlan Aug 12, 2024
99f6673
Fix processing test fuyu
yonigozlan Aug 13, 2024
e0fecb5
remove unnecessary pad_token check in instructblip ProcessorTest
yonigozlan Aug 13, 2024
2172b9d
Fix BC tests and cleanup
yonigozlan Aug 13, 2024
3557693
FIx imports fuyu
yonigozlan Aug 13, 2024
228acee
Uniformize Pix2Struct
yonigozlan Aug 14, 2024
ce68136
Fix wrong name for FuyuProcessorKwargs
yonigozlan Aug 14, 2024
5db32c9
Fix slow tests reversed inputs align fuyu llava-next, change udop war…
yonigozlan Aug 15, 2024
6b95100
Fix wrong logging import udop
yonigozlan Aug 15, 2024
120a370
Add check images text input order
yonigozlan Sep 9, 2024
5657519
Fix copies
yonigozlan Sep 9, 2024
7b4bcb5
change text pair handling when positional arg
yonigozlan Sep 10, 2024
41b5d4c
rebase on main, fix imports in test_processing_common
yonigozlan Sep 13, 2024
7420705
remove optional args and udop uniformization from this PR
yonigozlan Sep 16, 2024
e6ceb28
fix failing tests
yonigozlan Sep 16, 2024
ac91b1e
remove unnecessary test, fix processing utils and test processing common
yonigozlan Sep 20, 2024
aecbd9e
cleanup Unpack
yonigozlan Sep 20, 2024
0e75363
cleanup
yonigozlan Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/align.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
candidate_labels = ["an image of a cat", "an image of a dog"]

inputs = processor(text=candidate_labels, images=image, return_tensors="pt")
inputs = processor(images=image ,text=candidate_labels, return_tensors="pt")

with torch.no_grad():
outputs = model(**inputs)
Expand Down
12 changes: 6 additions & 6 deletions docs/source/en/model_doc/fuyu.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ rendered properly in your Markdown viewer.

## Overview

The Fuyu model was created by [ADEPT](https://www.adept.ai/blog/fuyu-8b), and authored by Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar.
The Fuyu model was created by [ADEPT](https://www.adept.ai/blog/fuyu-8b), and authored by Rohan Bavishi, Erich Elsen, Curtis Hawthorne, Maxwell Nye, Augustus Odena, Arushi Somani, Sağnak Taşırlar.

The authors introduced Fuyu-8B, a decoder-only multimodal model based on the classic transformers architecture, with query and key normalization. A linear encoder is added to create multimodal embeddings from image inputs.
The authors introduced Fuyu-8B, a decoder-only multimodal model based on the classic transformers architecture, with query and key normalization. A linear encoder is added to create multimodal embeddings from image inputs.

By treating image tokens like text tokens and using a special image-newline character, the model knows when an image line ends. Image positional embeddings are removed. This avoids the need for different training phases for various image resolutions. With 8 billion parameters and licensed under CC-BY-NC, Fuyu-8B is notable for its ability to handle both text and images, its impressive context size of 16K, and its overall performance.

<Tip warning={true}>

The `Fuyu` models were trained using `bfloat16`, but the original inference uses `float16` The checkpoints uploaded on the hub use `torch_dtype = 'float16'` which will be
used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`.
used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`.

The `dtype` of the online weights is mostly irrelevant, unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online) then it will be cast to the default `dtype` of `torch` (becomes `torch.float32`). Users should specify the `torch_dtype` they want, and if they don't it will be `torch.float32`.

Expand Down Expand Up @@ -56,7 +56,7 @@ tar -xvf 8b_base_model_release.tar
```
Then, model can be loaded via:

```py
```py
from transformers import FuyuConfig, FuyuForCausalLM
model_config = FuyuConfig()
model = FuyuForCausalLM(model_config).from_pretrained('/output/path')
Expand All @@ -81,7 +81,7 @@ text_prompt = "Generate a coco-style caption.\\n"

bus_image_url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png"
bus_image_pil = Image.open(io.BytesIO(requests.get(bus_image_url).content))
inputs_to_model = processor(text=text_prompt, images=bus_image_pil)
inputs_to_model = processor(images=bus_image_pil, text=text_prompt)
yonigozlan marked this conversation as resolved.
Show resolved Hide resolved


```
Expand All @@ -90,7 +90,7 @@ This model was contributed by [Molbap](https://huggingface.co/Molbap).
The original code can be found [here](https://github.com/persimmon-ai-labs/adept-inference).

- Fuyu uses a `sentencepiece` based tokenizer, with a `Unigram` model. It supports bytefallback, which is only available in `tokenizers==0.14.0` for the fast tokenizer.
The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece.
The `LlamaTokenizer` is used as it is a standard wrapper around sentencepiece.

- The authors suggest to use the following prompt for image captioning: `f"Generate a coco-style caption.\\n"`

Expand Down
10 changes: 5 additions & 5 deletions docs/source/en/model_doc/llava_next.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ import requests

processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")

model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to("cuda:0")

# prepare image and text prompt, using the appropriate prompt template
Expand All @@ -150,7 +150,7 @@ conversation = [
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
inputs = processor(image, prompt, return_tensors="pt").to("cuda:0")

# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=100)
Expand Down Expand Up @@ -222,7 +222,7 @@ prompts = [prompt_1, prompt_2]

# We can simply feed images in the order they have to be used in the text prompt
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens
inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device)
inputs = processor(images=[image_stop, image_cats, image_snowman], text=prompts, padding=True, return_tensors="pt").to(model.device)

# Generate
generate_ids = model.generate(**inputs, max_new_tokens=30)
Expand Down Expand Up @@ -256,8 +256,8 @@ First make sure to install flash-attn. Refer to the [original repository of Flas
from transformers import LlavaNextForConditionalGeneration

model = LlavaNextForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
use_flash_attention_2=True
).to(0)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/align/modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,7 +1575,7 @@ def forward(
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> inputs = processor(
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
... images=image, text=["a photo of a cat", "a photo of a dog"], return_tensors="pt", padding=True
... )

>>> outputs = model(**inputs)
Expand Down
17 changes: 8 additions & 9 deletions src/transformers/models/align/processing_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@
from typing import List, Union

from ...image_utils import ImageInput
from ...processing_utils import (
ProcessingKwargs,
ProcessorMixin,
Unpack,
)
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput


Expand Down Expand Up @@ -76,8 +72,8 @@ def __init__(self, image_processor, tokenizer):

def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[AlignProcessorKwargs],
Expand All @@ -90,13 +86,13 @@ def __call__(
to the doctsring of the above two methods for more information.

Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
Expand All @@ -114,6 +110,9 @@ def __call__(
"""
if text is None and images is None:
raise ValueError("You must specify either text or images.")
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)

output_kwargs = self._merge_kwargs(
AlignProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/fuyu/modeling_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def forward(
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> prompt = "Generate a coco-style caption.\n"

>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> outputs = model(**inputs)

>>> generated_ids = model.generate(**inputs, max_new_tokens=7)
Expand Down
91 changes: 46 additions & 45 deletions src/transformers/models/fuyu/processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@

import numpy as np

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, TruncationStrategy
from ...utils import TensorType, is_torch_available, logging, requires_backends
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_available, logging, requires_backends


if is_torch_available():
Expand All @@ -49,6 +50,24 @@
BEGINNING_OF_ANSWER_STRING = "<0x04>" # <boa>


class FuyuProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"return_attention_mask": True,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": False,
"return_length": False,
"verbose": True,
},
"images_kwargs": {},
}


def full_unpacked_stream_to_tensor(
all_bi_tokens_to_place: List[int],
full_unpacked_stream: List["torch.Tensor"],
Expand Down Expand Up @@ -452,23 +471,11 @@ def get_sample_encoding(

def __call__(
self,
text=None,
images=None,
add_special_tokens: bool = True,
return_attention_mask: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
images: ImageInput = None,
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[FuyuProcessorKwargs],
) -> "FuyuBatchFeature":
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand All @@ -478,13 +485,13 @@ def __call__(
of the above two methods for more information.

Args:
images (`PIL.Image.Image`, `List[PIL.Image.Image]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `List[PIL.Image.Image]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.

Returns:
[`FuyuBatchEncoding`]: A [`FuyuBatchEncoding`] with the following fields:
Expand All @@ -498,31 +505,24 @@ def __call__(
requires_backends(self, ["torch"])

# --- Check input validity ---
if not return_attention_mask:
raise ValueError("`return_attention_mask=False` is not supported for this model.")
if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be None.")
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)

output_kwargs = self._merge_kwargs(
FuyuProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if not output_kwargs["text_kwargs"].setdefault("return_attention_mask", True):
raise ValueError("`return_attention_mask=False` is not supported for this model.")

if text is not None and images is None:
logger.warning("You are processing a text with no associated image. Make sure it is intended.")
self.current_processor = self.tokenizer
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
text_encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
return text_encoding

if text is None and images is not None:
Expand All @@ -537,7 +537,8 @@ def __call__(
# --- Preprocess images using self.image_processor ---

# FIXME - We hard code "pt" here because the rest of the processing assumes torch tensors
image_encoding = self.image_processor.preprocess(images, return_tensors="pt")
output_kwargs["images_kwargs"]["return_tensors"] = "pt"
image_encoding = self.image_processor.preprocess(images, **output_kwargs["images_kwargs"])
batch_images = image_encoding["images"]
image_unpadded_heights = image_encoding["image_unpadded_heights"]
image_unpadded_widths = image_encoding["image_unpadded_widths"]
Expand Down Expand Up @@ -568,7 +569,7 @@ def __call__(
)
all_encodings.append(sample_encoding)
batch_encoding = self._left_pad_inputs_with_attention_mask(
model_inputs=all_encodings, return_attention_mask=return_attention_mask
model_inputs=all_encodings, return_attention_mask=True
yonigozlan marked this conversation as resolved.
Show resolved Hide resolved
)
return FuyuBatchFeature(data=batch_encoding)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,15 @@
from ...image_transforms import center_to_corners_format
from ...image_utils import AnnotationFormat, ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils import TensorType, is_torch_available


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack

from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils import TensorType, is_torch_available


if is_torch_available():
import torch
Expand Down
Loading
Loading