Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniformize kwargs for chameleon processor #32181

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,7 @@
("canine", ("CanineTokenizer", None)),
(
"chameleon",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
(None, "LlamaTokenizerFast" if is_tokenizers_available() else None),
),
("chinese_clip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from transformers import (
ChameleonConfig,
ChameleonForCausalLM,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw @zucchini-nlp we might need to increase prio for this PR because of this

I have this change in my other PR too, but I forgot we haven't merged it yet

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I was out for a while. Yes, I think some other contributor also reported the issue and wanted to open a PR to fix the conversion script. Feel free to open a PR if there isn't any, as this issue isn't at all related to processor kwargs

ChameleonForConditionalGeneration,
ChameleonImageProcessor,
ChameleonProcessor,
)
Expand All @@ -49,10 +49,10 @@
Thereafter, models can be loaded via:

```py
from transformers import ChameleonForCausalLM, LlamaTokenizer
from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast

model = ChameleonForCausalLM.from_pretrained("/output/path")
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
model = ChameleonForConditionalGeneration.from_pretrained("/output/path")
tokenizer = LlamaTokenizerFast.from_pretrained("/output/path")
```

Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
Expand Down Expand Up @@ -372,7 +372,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim):
vocabulary_map=vocabulary_map,
)
with init_empty_weights():
model = ChameleonForCausalLM(config)
model = ChameleonForConditionalGeneration(config)

model.load_state_dict(state_dict, assign=True, strict=False)
model.save_pretrained(model_path, safe_serialization=True)
Expand All @@ -397,7 +397,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim):
# taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl
print("Loading the checkpoint in a Chameleon model...")
print("*" * 100)
model = ChameleonForCausalLM.from_pretrained(
model = ChameleonForConditionalGeneration.from_pretrained(
model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto"
)
processor = ChameleonProcessor.from_pretrained(model_path)
Expand Down
83 changes: 42 additions & 41 deletions src/transformers/models/chameleon/processing_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,36 @@
Processor class for Chameleon.
"""

import sys
from typing import List, Optional, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack


class ChameleonTextKwargs(TextKwargs, total=False):
return_for_text_completion: bool


class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: ChameleonTextKwargs
_defaults = {
"text_kwargs": {
"padding": False,
"return_for_text_completion": False,
},
"common_kwargs": {
"return_tensors": "pt",
},
}


class ChameleonProcessor(ProcessorMixin):
Expand All @@ -45,7 +68,7 @@ class ChameleonProcessor(ProcessorMixin):
"""

attributes = ["image_processor", "tokenizer"]
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
tokenizer_class = "LlamaTokenizerFast"
image_processor_class = "ChameleonImageProcessor"

def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"):
Expand All @@ -57,13 +80,9 @@ def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, ima

def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: int = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
return_for_text_completion: bool = False,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
images: Optional[ImageInput] = None,
leloykun marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: Unpack[ChameleonProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand All @@ -80,26 +99,6 @@ def __call__(
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:

- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.

Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
Expand All @@ -114,6 +113,15 @@ def __call__(
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
if text is None and images is None:
raise ValueError("You must provide either text or images")

output_kwargs = self._merge_kwargs(
ChameleonProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False)

# Replace the image token with the expanded image token sequence
prompt_strings = []
Expand All @@ -124,19 +132,12 @@ def __call__(
sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
prompt_strings.append(sample)

data = self.tokenizer(
prompt_strings,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])

if images is not None:
pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
data["pixel_values"] = pixel_values
data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]

return BatchFeature(data=data, tensor_type=return_tensors)
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"])

# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
Expand Down
16 changes: 16 additions & 0 deletions tests/models/chameleon/test_processor_chameleon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import tempfile
import unittest

from transformers import ChameleonProcessor

from ...test_processing_common import ProcessorTesterMixin


class ChameleonProcessorTest(ProcessorTesterMixin, unittest.TestCase):
from_pretrained_id = "leloy/Anole-7b-v0.1-hf"
processor_class = ChameleonProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
processor = self.processor_class.from_pretrained(self.from_pretrained_id)
processor.save_pretrained(self.tmpdirname)
43 changes: 27 additions & 16 deletions tests/test_processing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
@require_torch
class ProcessorTesterMixin:
processor_class = None
text_data_arg_name = "input_ids"
images_data_arg_name = "pixel_values"

def prepare_processor_dict(self):
return {}
Expand Down Expand Up @@ -136,14 +138,14 @@ def test_tokenizer_defaults_preserved_by_kwargs(self):
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input, return_tensors="pt")
self.assertEqual(len(inputs["input_ids"][0]), 117)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 117)

@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", size=(234, 234))
image_processor = self.get_component("image_processor", size=(234, 234), crop_size=(234, 234))
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
Expand All @@ -153,7 +155,7 @@ def test_image_processor_defaults_preserved_by_image_kwargs(self):
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
self.assertEqual(len(inputs[self.images_data_arg_name][0][0]), 234)

@require_vision
@require_torch
Expand All @@ -171,7 +173,7 @@ def test_kwargs_overrides_default_tokenizer_kwargs(self):
inputs = processor(
text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length"
)
self.assertEqual(len(inputs["input_ids"][0]), 112)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 112)

@require_torch
@require_vision
Expand All @@ -187,8 +189,8 @@ def test_kwargs_overrides_default_image_processor_kwargs(self):
input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input, size=[224, 224])
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)
inputs = processor(text=input_str, images=image_input, size=[224, 224], crop_size=(224, 224))
self.assertEqual(len(inputs[self.images_data_arg_name][0][0]), 224)

@require_torch
@require_vision
Expand All @@ -208,12 +210,13 @@ def test_unstructured_kwargs(self):
images=image_input,
return_tensors="pt",
size={"height": 214, "width": 214},
crop_size={"height": 214, "width": 214},
padding="max_length",
max_length=76,
)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
self.assertEqual(inputs[self.images_data_arg_name].shape[2], 214)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 76)

@require_torch
@require_vision
Expand All @@ -233,13 +236,14 @@ def test_unstructured_kwargs_batched(self):
images=image_input,
return_tensors="pt",
size={"height": 214, "width": 214},
crop_size={"height": 214, "width": 214},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yonigozlan i think you removed crop_size from common tests and it had smth to do with some image processors accepting/not accepting certain kwargs?

Copy link
Member

@yonigozlan yonigozlan Aug 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp Yes but actually it would be nice to have both here. @molbap had some CI tests crash because crop_size was removed here and the image_processor had do_center_crop set to True by default which canceled out size. Having both would handle cases where either do_center_crop is set to True in the image_processor by default, or crop_size is not supported by the image_processor.
So I am for keeping this and merging this PR before some other kwargs uniformization PRs

padding="longest",
max_length=76,
)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(inputs[self.images_data_arg_name].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 6)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 6)

@require_torch
@require_vision
Expand All @@ -260,6 +264,7 @@ def test_doubly_passed_kwargs(self):
images=image_input,
images_kwargs={"size": {"height": 222, "width": 222}},
size={"height": 214, "width": 214},
crop_size={"height": 214, "width": 214},
)

@require_torch
Expand All @@ -279,16 +284,19 @@ def test_structured_kwargs_nested(self):
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 214, "width": 214}},
"images_kwargs": {
"size": {"height": 214, "width": 214},
"crop_size": {"height": 214, "width": 214},
},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(inputs[self.images_data_arg_name].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 76)

@require_torch
@require_vision
Expand All @@ -307,14 +315,17 @@ def test_structured_kwargs_nested_from_dict(self):
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 214, "width": 214}},
"images_kwargs": {
"size": {"height": 214, "width": 214},
"crop_size": {"height": 214, "width": 214},
},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(inputs[self.images_data_arg_name].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 76)


class MyProcessor(ProcessorMixin):
Expand Down