From d8e65c0a385e0fb5c81810592b342f126a731c52 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Fri, 16 Aug 2024 15:38:35 +0800 Subject: [PATCH 01/12] uniformize kwargs of Chameleon --- .../models/chameleon/processing_chameleon.py | 81 ++++++++++--------- 1 file changed, 41 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index 1480808336d14e..2cac2d4bcb986a 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -16,13 +16,36 @@ Processor class for Chameleon. """ +import sys from typing import List, Optional, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +class ChameleonTextKwargs(TextKwargs, total=False): + return_for_text_completion: bool + + +class ChameleonProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: ChameleonTextKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "stride": 0, + "return_for_text_completion": False, + }, + "common_kwargs": { + "return_tensors": "pt", + }, + } class ChameleonProcessor(ProcessorMixin): @@ -57,13 +80,9 @@ def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, ima def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - images: ImageInput = None, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length: int = None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, - return_for_text_completion: bool = False, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + images: Optional[ImageInput] = None, + **kwargs: Unpack[ChameleonProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -80,26 +99,6 @@ def __call__( images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. Both channels-first and channels-last formats are supported. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -114,6 +113,15 @@ def __call__( text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): raise TypeError("Invalid input text. Please provide a string, or a list of strings") + if text is None and images is None: + raise ValueError("You must provide either text or images") + + output_kwargs = self._merge_kwargs( + ChameleonProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False) # Replace the image token with the expanded image token sequence prompt_strings = [] @@ -124,19 +132,12 @@ def __call__( sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode prompt_strings.append(sample) - data = self.tokenizer( - prompt_strings, - return_tensors=return_tensors, - padding=padding, - truncation=truncation, - max_length=max_length, - ) + data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) if images is not None: - pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] - data["pixel_values"] = pixel_values + data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"] - return BatchFeature(data=data, tensor_type=return_tensors) + return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"]) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama def batch_decode(self, *args, **kwargs): From a2f71e653c7bdd0c54f3c67605600108387823ab Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Fri, 16 Aug 2024 15:41:52 +0800 Subject: [PATCH 02/12] fix linter nit --- src/transformers/models/chameleon/processing_chameleon.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index 2cac2d4bcb986a..a039101f56d71e 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -24,6 +24,7 @@ from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs from ...tokenization_utils_base import PreTokenizedInput, TextInput + if sys.version_info >= (3, 11): from typing import Unpack else: From 15955135fb23c52b38195135e9c47f43369f22ec Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Fri, 16 Aug 2024 17:38:16 +0800 Subject: [PATCH 03/12] rm stride default --- src/transformers/models/chameleon/processing_chameleon.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index a039101f56d71e..d999267ef1fa9c 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -40,7 +40,6 @@ class ChameleonProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { "padding": False, - "stride": 0, "return_for_text_completion": False, }, "common_kwargs": { From 47169c1a2978b2c8b8709fe43f512c186d8ffb7f Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Fri, 16 Aug 2024 17:43:30 +0800 Subject: [PATCH 04/12] add tests for chameleon processor --- .../models/chameleon/test_processor_chameleon.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 tests/models/chameleon/test_processor_chameleon.py diff --git a/tests/models/chameleon/test_processor_chameleon.py b/tests/models/chameleon/test_processor_chameleon.py new file mode 100644 index 00000000000000..74314e3d4c1e95 --- /dev/null +++ b/tests/models/chameleon/test_processor_chameleon.py @@ -0,0 +1,16 @@ +import tempfile +import unittest + +from transformers import ChameleonProcessor + +from ...test_processing_common import ProcessorTesterMixin + + +class ChameleonProcessorTest(ProcessorTesterMixin, unittest.TestCase): + from_pretrained_id = "leloy/Anole-7b-v0.1-hf" + processor_class = ChameleonProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + processor = self.processor_class.from_pretrained(self.from_pretrained_id) + processor.save_pretrained(self.tmpdirname) From 0d25ae6e29dc027df6032fa71864ddbd6c59bbac Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Fri, 16 Aug 2024 18:07:33 +0800 Subject: [PATCH 05/12] fix tests --- .../chameleon/test_processor_chameleon.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/models/chameleon/test_processor_chameleon.py b/tests/models/chameleon/test_processor_chameleon.py index 74314e3d4c1e95..1efeaa5339d304 100644 --- a/tests/models/chameleon/test_processor_chameleon.py +++ b/tests/models/chameleon/test_processor_chameleon.py @@ -2,6 +2,7 @@ import unittest from transformers import ChameleonProcessor +from transformers.models.auto.processing_auto import processor_class_from_name from ...test_processing_common import ProcessorTesterMixin @@ -10,6 +11,22 @@ class ChameleonProcessorTest(ProcessorTesterMixin, unittest.TestCase): from_pretrained_id = "leloy/Anole-7b-v0.1-hf" processor_class = ChameleonProcessor + def get_component(self, attribute, **kwargs): + assert attribute in self.processor_class.attributes + component_class_name = getattr(self.processor_class, f"{attribute}_class") + if isinstance(component_class_name, tuple): + if "_fast" in component_class_name[0]: + component_class_name = component_class_name[0] + else: + component_class_name = component_class_name[1] + + component_class = processor_class_from_name(component_class_name) + component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa + if attribute == "tokenizer" and not component.pad_token: + component.pad_token = "[TEST_PAD]" + + return component + def setUp(self): self.tmpdirname = tempfile.mkdtemp() processor = self.processor_class.from_pretrained(self.from_pretrained_id) From 0ee73cd5f684383471ba2bed226f3663ee1de660 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Mon, 19 Aug 2024 18:33:57 +0800 Subject: [PATCH 06/12] add comment on get_component --- tests/models/chameleon/test_processor_chameleon.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/chameleon/test_processor_chameleon.py b/tests/models/chameleon/test_processor_chameleon.py index 1efeaa5339d304..18a517399b8605 100644 --- a/tests/models/chameleon/test_processor_chameleon.py +++ b/tests/models/chameleon/test_processor_chameleon.py @@ -13,6 +13,10 @@ class ChameleonProcessorTest(ProcessorTesterMixin, unittest.TestCase): def get_component(self, attribute, **kwargs): assert attribute in self.processor_class.attributes + if attribute != "tokenizer": + return super().get_component(attribute, **kwargs) + # We use the fast tokenizer by default as the slow tokenizer expects the vocab file to be present in the loading directory. + # This vocab file is neither in the official repo nor does it get saved when we save the processor in `setUp` below. component_class_name = getattr(self.processor_class, f"{attribute}_class") if isinstance(component_class_name, tuple): if "_fast" in component_class_name[0]: From 0178af6d21010c26a707710dba00c3e9f74196d8 Mon Sep 17 00:00:00 2001 From: Franz Louis Cesista Date: Wed, 21 Aug 2024 01:09:11 +0800 Subject: [PATCH 07/12] rm Chameleon's slow tokenizer --- .../models/auto/tokenization_auto.py | 5 +---- .../convert_chameleon_weights_to_hf.py | 12 +++++------ .../models/chameleon/processing_chameleon.py | 2 +- .../chameleon/test_processor_chameleon.py | 21 ------------------- 4 files changed, 8 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index e735579108d857..c677fcaea337fc 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -109,10 +109,7 @@ ("canine", ("CanineTokenizer", None)), ( "chameleon", - ( - "LlamaTokenizer" if is_sentencepiece_available() else None, - "LlamaTokenizerFast" if is_tokenizers_available() else None, - ), + (None, "LlamaTokenizerFast" if is_tokenizers_available() else None), ), ("chinese_clip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ( diff --git a/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py b/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py index 1aebeb0f0bb711..ff45c9b597e0b4 100644 --- a/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py +++ b/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py @@ -24,7 +24,7 @@ from transformers import ( ChameleonConfig, - ChameleonForCausalLM, + ChameleonForConditionalGeneration, ChameleonImageProcessor, ChameleonProcessor, ) @@ -49,10 +49,10 @@ Thereafter, models can be loaded via: ```py -from transformers import ChameleonForCausalLM, LlamaTokenizer +from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast -model = ChameleonForCausalLM.from_pretrained("/output/path") -tokenizer = LlamaTokenizer.from_pretrained("/output/path") +model = ChameleonForConditionalGeneration.from_pretrained("/output/path") +tokenizer = LlamaTokenizerFast.from_pretrained("/output/path") ``` Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions @@ -372,7 +372,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim): vocabulary_map=vocabulary_map, ) with init_empty_weights(): - model = ChameleonForCausalLM(config) + model = ChameleonForConditionalGeneration(config) model.load_state_dict(state_dict, assign=True, strict=False) model.save_pretrained(model_path, safe_serialization=True) @@ -397,7 +397,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim): # taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl print("Loading the checkpoint in a Chameleon model...") print("*" * 100) - model = ChameleonForCausalLM.from_pretrained( + model = ChameleonForConditionalGeneration.from_pretrained( model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto" ) processor = ChameleonProcessor.from_pretrained(model_path) diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index d999267ef1fa9c..dfab8a5dbaf2c7 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -68,7 +68,7 @@ class ChameleonProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") + tokenizer_class = "LlamaTokenizerFast" image_processor_class = "ChameleonImageProcessor" def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""): diff --git a/tests/models/chameleon/test_processor_chameleon.py b/tests/models/chameleon/test_processor_chameleon.py index 18a517399b8605..74314e3d4c1e95 100644 --- a/tests/models/chameleon/test_processor_chameleon.py +++ b/tests/models/chameleon/test_processor_chameleon.py @@ -2,7 +2,6 @@ import unittest from transformers import ChameleonProcessor -from transformers.models.auto.processing_auto import processor_class_from_name from ...test_processing_common import ProcessorTesterMixin @@ -11,26 +10,6 @@ class ChameleonProcessorTest(ProcessorTesterMixin, unittest.TestCase): from_pretrained_id = "leloy/Anole-7b-v0.1-hf" processor_class = ChameleonProcessor - def get_component(self, attribute, **kwargs): - assert attribute in self.processor_class.attributes - if attribute != "tokenizer": - return super().get_component(attribute, **kwargs) - # We use the fast tokenizer by default as the slow tokenizer expects the vocab file to be present in the loading directory. - # This vocab file is neither in the official repo nor does it get saved when we save the processor in `setUp` below. - component_class_name = getattr(self.processor_class, f"{attribute}_class") - if isinstance(component_class_name, tuple): - if "_fast" in component_class_name[0]: - component_class_name = component_class_name[0] - else: - component_class_name = component_class_name[1] - - component_class = processor_class_from_name(component_class_name) - component = component_class.from_pretrained(self.tmpdirname, **kwargs) # noqa - if attribute == "tokenizer" and not component.pad_token: - component.pad_token = "[TEST_PAD]" - - return component - def setUp(self): self.tmpdirname = tempfile.mkdtemp() processor = self.processor_class.from_pretrained(self.from_pretrained_id) From 50826302b156073d6ad2f8811a35ff78065caa0e Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Mon, 23 Sep 2024 18:56:09 +0000 Subject: [PATCH 08/12] add check order images text + nit --- .../models/chameleon/processing_chameleon.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index dfab8a5dbaf2c7..97c43898b6c915 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -16,21 +16,14 @@ Processor class for Chameleon. """ -import sys from typing import List, Optional, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order from ...tokenization_utils_base import PreTokenizedInput, TextInput -if sys.version_info >= (3, 11): - from typing import Unpack -else: - from typing_extensions import Unpack - - class ChameleonTextKwargs(TextKwargs, total=False): return_for_text_completion: bool @@ -80,8 +73,10 @@ def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, ima def __call__( self, - text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, images: Optional[ImageInput] = None, + text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + audio=None, + videos=None, **kwargs: Unpack[ChameleonProcessorKwargs], ) -> BatchFeature: """ @@ -92,13 +87,13 @@ def __call__( of the above two methods for more information. Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. text (`str`, `List[str]`, `List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): - The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch - tensor. Both channels-first and channels-last formats are supported. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -109,6 +104,8 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) if isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): From 272ff5c05426a11332a0338158faaea5ef751320 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 24 Sep 2024 01:28:53 +0000 Subject: [PATCH 09/12] update docs and tests --- docs/source/en/model_doc/chameleon.md | 12 ++++++------ .../models/chameleon/modeling_chameleon.py | 2 +- tests/models/chameleon/test_modeling_chameleon.py | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/source/en/model_doc/chameleon.md b/docs/source/en/model_doc/chameleon.md index 28ec01ad615871..eb12bd80e0a615 100644 --- a/docs/source/en/model_doc/chameleon.md +++ b/docs/source/en/model_doc/chameleon.md @@ -19,7 +19,7 @@ rendered properly in your Markdown viewer. ## Overview The Chameleon model was proposed in [Chameleon: Mixed-Modal Early-Fusion Foundation Models -](https://arxiv.org/abs/2405.09818v1) by META AI Chameleon Team. Chameleon is a Vision-Language Model that use vector quantization to tokenize images which enables the model to generate multimodal output. The model takes images and texts as input, including an interleaved format, and generates textual response. Image generation module is not released yet. +](https://arxiv.org/abs/2405.09818v1) by META AI Chameleon Team. Chameleon is a Vision-Language Model that use vector quantization to tokenize images which enables the model to generate multimodal output. The model takes images and texts as input, including an interleaved format, and generates textual response. Image generation module is not released yet. The abstract from the paper is the following: @@ -61,7 +61,7 @@ The original code can be found [here](https://github.com/facebookresearch/chamel ### Single image inference -Chameleon is a gated model so make sure to have access and login to Hugging Face Hub using a token. +Chameleon is a gated model so make sure to have access and login to Hugging Face Hub using a token. Here's how to load the model and perform inference in half-precision (`torch.bfloat16`): ```python @@ -78,7 +78,7 @@ url = 'http://images.cocodataset.org/val2017/000000039769.jpg' image = Image.open(requests.get(url, stream=True).raw) prompt = "What do you see in this image?" -inputs = processor(prompt, image, return_tensors="pt").to(model.device, dtype=torch.bfloat16) +inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, dtype=torch.bfloat16) # autoregressively complete prompt output = model.generate(**inputs, max_new_tokens=50) @@ -117,7 +117,7 @@ prompts = [ # We can simply feed images in the order they have to be used in the text prompt # Each "" token uses one image leaving the next for the subsequent "" tokens -inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16) +inputs = processor(images=[image_stop, image_cats, image_snowman], text=prompts, padding=True, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16) # Generate generate_ids = model.generate(**inputs, max_new_tokens=50) @@ -152,8 +152,8 @@ from transformers import ChameleonForConditionalGeneration model_id = "facebook/chameleon-7b" model = ChameleonForConditionalGeneration.from_pretrained( - model_id, - torch_dtype=torch.bfloat16, + model_id, + torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, attn_implementation="flash_attention_2" ).to(0) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index c631181f00c59e..c4eb1eade6e2f7 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1568,7 +1568,7 @@ def forward( >>> image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw) >>> image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw) - >>> inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, torch.bfloat16) + >>> inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16) >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False) >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0] diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py index 16e0a548e6dc47..00e3ad40a57652 100644 --- a/tests/models/chameleon/test_modeling_chameleon.py +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -350,7 +350,7 @@ def test_flash_attn_2_generate_padding_right(self): processor.tokenizer.padding_side = "right" - inputs = processor(texts, return_tensors="pt", padding=True).to(0) + inputs = processor(text=texts, return_tensors="pt", padding=True).to(0) output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) output_native = processor.tokenizer.batch_decode(output_native) @@ -392,7 +392,7 @@ def test_model_7b(self): ) prompt = "Describe what do you see here and tell me about the history behind it?" - inputs = processor(prompt, images=image, return_tensors="pt").to(model.device, torch.float16) + inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, torch.float16) # greedy generation outputs EXPECTED_TEXT_COMPLETION = ['Describe what do you see here and tell me about the history behind it?The image depicts a star map, with a bright blue line extending across the center of the image. The line is labeled "390 light years" and is accompanied by a small black and'] # fmt: skip @@ -420,7 +420,7 @@ def test_model_7b_batched(self): "What constellation is this image showing?", ] - inputs = processor(prompts, images=[image, image_2], padding=True, return_tensors="pt").to( + inputs = processor(images=[image, image_2], text=prompts, padding=True, return_tensors="pt").to( model.device, torch.float16 ) @@ -450,7 +450,7 @@ def test_model_7b_multi_image(self): ) prompt = "What do these two images have in common?" - inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, torch.float16) + inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.float16) # greedy generation outputs EXPECTED_TEXT_COMPLETION = ['What do these two images have in common?The two images show a connection between two things that are not necessarily related. The first image shows a group of stars, while the second image shows a network of lines connecting two points. The connection between'] # fmt: skip From 71b9a06f2e30c26fd79499e2ae79075c09699a69 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 24 Sep 2024 14:28:06 +0000 Subject: [PATCH 10/12] Fix LlamaTokenizer tests --- .../models/auto/tokenization_auto.py | 5 ++- .../models/chameleon/processing_chameleon.py | 9 ++++- .../chameleon/test_processor_chameleon.py | 33 +++++++++++++++++-- 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index c677fcaea337fc..e735579108d857 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -109,7 +109,10 @@ ("canine", ("CanineTokenizer", None)), ( "chameleon", - (None, "LlamaTokenizerFast" if is_tokenizers_available() else None), + ( + "LlamaTokenizer" if is_sentencepiece_available() else None, + "LlamaTokenizerFast" if is_tokenizers_available() else None, + ), ), ("chinese_clip", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ( diff --git a/src/transformers/models/chameleon/processing_chameleon.py b/src/transformers/models/chameleon/processing_chameleon.py index 97c43898b6c915..2d699c8f663a61 100644 --- a/src/transformers/models/chameleon/processing_chameleon.py +++ b/src/transformers/models/chameleon/processing_chameleon.py @@ -61,7 +61,7 @@ class ChameleonProcessor(ProcessorMixin): """ attributes = ["image_processor", "tokenizer"] - tokenizer_class = "LlamaTokenizerFast" + tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") image_processor_class = "ChameleonImageProcessor" def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = ""): @@ -94,6 +94,13 @@ def __call__( The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: diff --git a/tests/models/chameleon/test_processor_chameleon.py b/tests/models/chameleon/test_processor_chameleon.py index 74314e3d4c1e95..50ce2ea3b6722f 100644 --- a/tests/models/chameleon/test_processor_chameleon.py +++ b/tests/models/chameleon/test_processor_chameleon.py @@ -1,16 +1,43 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch chameleon model.""" + import tempfile import unittest -from transformers import ChameleonProcessor +from transformers import ChameleonProcessor, LlamaTokenizer +from transformers.testing_utils import get_tests_dir +from transformers.utils import is_vision_available from ...test_processing_common import ProcessorTesterMixin +if is_vision_available(): + from transformers import ChameleonImageProcessor, is_vision_available + + +SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") + + class ChameleonProcessorTest(ProcessorTesterMixin, unittest.TestCase): - from_pretrained_id = "leloy/Anole-7b-v0.1-hf" processor_class = ChameleonProcessor def setUp(self): self.tmpdirname = tempfile.mkdtemp() - processor = self.processor_class.from_pretrained(self.from_pretrained_id) + + image_processor = ChameleonImageProcessor.from_pretrained("facebook/chameleon-7b") + tokenizer = LlamaTokenizer.from_pretrained("facebook/chameleon-7b", vocab_file=SAMPLE_VOCAB) + processor = self.processor_class(image_processor=image_processor, tokenizer=tokenizer) processor.save_pretrained(self.tmpdirname) From c3bcb7a7a6cd1f0fc359857a6aecca28dbf91359 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 24 Sep 2024 14:35:46 +0000 Subject: [PATCH 11/12] fix gated repo access --- tests/models/chameleon/test_processor_chameleon.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/models/chameleon/test_processor_chameleon.py b/tests/models/chameleon/test_processor_chameleon.py index 50ce2ea3b6722f..97ce171e21b368 100644 --- a/tests/models/chameleon/test_processor_chameleon.py +++ b/tests/models/chameleon/test_processor_chameleon.py @@ -36,8 +36,9 @@ class ChameleonProcessorTest(ProcessorTesterMixin, unittest.TestCase): def setUp(self): self.tmpdirname = tempfile.mkdtemp() - - image_processor = ChameleonImageProcessor.from_pretrained("facebook/chameleon-7b") - tokenizer = LlamaTokenizer.from_pretrained("facebook/chameleon-7b", vocab_file=SAMPLE_VOCAB) + image_processor = ChameleonImageProcessor() + tokenizer = LlamaTokenizer(vocab_file=SAMPLE_VOCAB) + tokenizer.pad_token_id = 0 + tokenizer.sep_token_id = 1 processor = self.processor_class(image_processor=image_processor, tokenizer=tokenizer) processor.save_pretrained(self.tmpdirname) From f00aeb7a68bf95c3b21ff8e02940876348389d8c Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 25 Sep 2024 15:15:06 +0000 Subject: [PATCH 12/12] fix wrong import --- tests/models/chameleon/test_processor_chameleon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/chameleon/test_processor_chameleon.py b/tests/models/chameleon/test_processor_chameleon.py index 97ce171e21b368..0bf2c2ddf2b4b6 100644 --- a/tests/models/chameleon/test_processor_chameleon.py +++ b/tests/models/chameleon/test_processor_chameleon.py @@ -25,7 +25,7 @@ if is_vision_available(): - from transformers import ChameleonImageProcessor, is_vision_available + from transformers import ChameleonImageProcessor SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")