diff --git a/docs/source/en/model_doc/chameleon.md b/docs/source/en/model_doc/chameleon.md index 28ec01ad615871..eb12bd80e0a615 100644 --- a/docs/source/en/model_doc/chameleon.md +++ b/docs/source/en/model_doc/chameleon.md @@ -19,7 +19,7 @@ rendered properly in your Markdown viewer. ## Overview The Chameleon model was proposed in [Chameleon: Mixed-Modal Early-Fusion Foundation Models -](https://arxiv.org/abs/2405.09818v1) by META AI Chameleon Team. Chameleon is a Vision-Language Model that use vector quantization to tokenize images which enables the model to generate multimodal output. The model takes images and texts as input, including an interleaved format, and generates textual response. Image generation module is not released yet. +](https://arxiv.org/abs/2405.09818v1) by META AI Chameleon Team. Chameleon is a Vision-Language Model that use vector quantization to tokenize images which enables the model to generate multimodal output. The model takes images and texts as input, including an interleaved format, and generates textual response. Image generation module is not released yet. The abstract from the paper is the following: @@ -61,7 +61,7 @@ The original code can be found [here](https://github.com/facebookresearch/chamel ### Single image inference -Chameleon is a gated model so make sure to have access and login to Hugging Face Hub using a token. +Chameleon is a gated model so make sure to have access and login to Hugging Face Hub using a token. Here's how to load the model and perform inference in half-precision (`torch.bfloat16`): ```python @@ -78,7 +78,7 @@ url = 'http://images.cocodataset.org/val2017/000000039769.jpg' image = Image.open(requests.get(url, stream=True).raw) prompt = "What do you see in this image?" -inputs = processor(prompt, image, return_tensors="pt").to(model.device, dtype=torch.bfloat16) +inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, dtype=torch.bfloat16) # autoregressively complete prompt output = model.generate(**inputs, max_new_tokens=50) @@ -117,7 +117,7 @@ prompts = [ # We can simply feed images in the order they have to be used in the text prompt # Each "" token uses one image leaving the next for the subsequent "" tokens -inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16) +inputs = processor(images=[image_stop, image_cats, image_snowman], text=prompts, padding=True, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16) # Generate generate_ids = model.generate(**inputs, max_new_tokens=50) @@ -152,8 +152,8 @@ from transformers import ChameleonForConditionalGeneration model_id = "facebook/chameleon-7b" model = ChameleonForConditionalGeneration.from_pretrained( - model_id, - torch_dtype=torch.bfloat16, + model_id, + torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, attn_implementation="flash_attention_2" ).to(0) diff --git a/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py b/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py index 1aebeb0f0bb711..ff45c9b597e0b4 100644 --- a/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py +++ b/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py @@ -24,7 +24,7 @@ from transformers import ( ChameleonConfig, - ChameleonForCausalLM, + ChameleonForConditionalGeneration, ChameleonImageProcessor, ChameleonProcessor, ) @@ -49,10 +49,10 @@ Thereafter, models can be loaded via: ```py -from transformers import ChameleonForCausalLM, LlamaTokenizer +from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast -model = ChameleonForCausalLM.from_pretrained("/output/path") -tokenizer = LlamaTokenizer.from_pretrained("/output/path") +model = ChameleonForConditionalGeneration.from_pretrained("/output/path") +tokenizer = LlamaTokenizerFast.from_pretrained("/output/path") ``` Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions @@ -372,7 +372,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim): vocabulary_map=vocabulary_map, ) with init_empty_weights(): - model = ChameleonForCausalLM(config) + model = ChameleonForConditionalGeneration(config) model.load_state_dict(state_dict, assign=True, strict=False) model.save_pretrained(model_path, safe_serialization=True) @@ -397,7 +397,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim): # taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl print("Loading the checkpoint in a Chameleon model...") print("*" * 100) - model = ChameleonForCausalLM.from_pretrained( + model = ChameleonForConditionalGeneration.from_pretrained( model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto" ) processor = ChameleonProcessor.from_pretrained(model_path) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index c631181f00c59e..c4eb1eade6e2f7 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1568,7 +1568,7 @@ def forward( >>> image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw) >>> image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw) - >>> inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, torch.bfloat16) + >>> inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16) >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index 1480808336d14e..2d699c8f663a61 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -20,9 +20,25 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +class ChameleonTextKwargs(TextKwargs, total=False): + return_for_text_completion: bool + + +class ChameleonProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: ChameleonTextKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "return_for_text_completion": False, + }, + "common_kwargs": { + "return_tensors": "pt", + }, + } class ChameleonProcessor(ProcessorMixin): @@ -57,13 +73,11 @@ def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, ima def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - images: ImageInput = None, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: int = None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, - return_for_text_completion: bool = False, + images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio=None, + videos=None, + **kwargs: Unpack[ChameleonProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -73,26 +87,13 @@ def __call__( of the above two methods for more information. Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. text (`str`, `List[str]`, `List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): - The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch - tensor. Both channels-first and channels-last formats are supported. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: @@ -110,10 +111,21 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) if isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): raise TypeError("Invalid input text. Please provide a string, or a list of strings") + if text is None and images is None: + raise ValueError("You must provide either text or images") + + output_kwargs = self._merge_kwargs( + ChameleonProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False) # Replace the image token with the expanded image token sequence prompt_strings = [] @@ -124,19 +136,12 @@ def __call__( sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode prompt_strings.append(sample) - data = self.tokenizer( - prompt_strings, - return_tensors=return_tensors, - padding=padding, - truncation=truncation, - max_length=max_length, - ) + data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) if images is not None: - pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] - data["pixel_values"] = pixel_values + data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] - return BatchFeature(data=data, tensor_type=return_tensors) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"]) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py index 16e0a548e6dc47..00e3ad40a57652 100644 --- a/tests/models/chameleon/test_modeling_chameleon.py +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -350,7 +350,7 @@ def test_flash_attn_2_generate_padding_right(self): processor.tokenizer.padding_side = "right" - inputs = processor(texts, return_tensors="pt", padding=True).to(0) + inputs = processor(text=texts, return_tensors="pt", padding=True).to(0) output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) output_native = processor.tokenizer.batch_decode(output_native) @@ -392,7 +392,7 @@ def test_model_7b(self): ) prompt = "Describe what do you see here and tell me about the history behind it?" - inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.float16) + inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, torch.float16) # greedy generation outputs EXPECTED_TEXT_COMPLETION = ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue line extending across the center of the image. The line is labeled "390 light years" and is accompanied by a small black and'] # fmt: skip @@ -420,7 +420,7 @@ def test_model_7b_batched(self): "What constellation is this image showing?", ] - inputs = processor(prompts, images=[image, image_2], padding=True, return_tensors="pt").to( + inputs = processor(images=[image, image_2], text=prompts, padding=True, return_tensors="pt").to( model.device, torch.float16 ) @@ -450,7 +450,7 @@ def test_model_7b_multi_image(self): ) prompt = "What do these two images have in common?" - inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, torch.float16) + inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.float16) # greedy generation outputs EXPECTED_TEXT_COMPLETION = ['What do these two images have in common?The two images show a connection between two things that are not necessarily related. The first image shows a group of stars, while the second image shows a network of lines connecting two points. The connection between'] # fmt: skip diff --git a/tests/models/chameleon/test_processor_chameleon.py b/tests/models/chameleon/test_processor_chameleon.py new file mode 100644 index 00000000000000..0bf2c2ddf2b4b6 --- /dev/null +++ b/tests/models/chameleon/test_processor_chameleon.py @@ -0,0 +1,44 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch chameleon model.""" + +import tempfile +import unittest + +from transformers import ChameleonProcessor, LlamaTokenizer +from transformers.testing_utils import get_tests_dir +from transformers.utils import is_vision_available + +from ...test_processing_common import ProcessorTesterMixin + + +if is_vision_available(): + from transformers import ChameleonImageProcessor + + +SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") + + +class ChameleonProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = ChameleonProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = ChameleonImageProcessor() + tokenizer = LlamaTokenizer(vocab_file=SAMPLE_VOCAB) + tokenizer.pad_token_id = 0 + tokenizer.sep_token_id = 1 + processor = self.processor_class(image_processor=image_processor, tokenizer=tokenizer) + processor.save_pretrained(self.tmpdirname)