From a8b895c59b45b25adfdd934956b2106943846163 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 16 Aug 2024 20:34:34 +0000 Subject: [PATCH 1/9] Uniformize kwargs for LlaVa and update docs --- .../models/llava/modeling_llava.py | 2 +- .../models/llava/processing_llava.py | 61 +++++++++++++------ tests/models/llava/test_modeling_llava.py | 18 +++--- tests/models/llava/test_processor_llava.py | 56 ++++++++++++++++- 4 files changed, 107 insertions(+), 30 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 9ad19ccee72228..eb1c55341b0784 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -405,7 +405,7 @@ def forward( >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(**inputs, max_new_tokens=15) diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 678724ae95be41..76659710226a4f 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -16,18 +16,33 @@ Processor class for Llava. """ -from typing import List, Optional, Union +import sys +from typing import List, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, get_image_size, to_numpy_array -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType, logging +from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + logger = logging.get_logger(__name__) +class LlavaProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": {}, + } + + class LlavaProcessor(ProcessorMixin): r""" Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor. @@ -73,12 +88,11 @@ def __init__( def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, images: ImageInput = None, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length=None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[LlavaProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -125,8 +139,27 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + if images is None and text is None: + raise ValueError("You have to specify at least images or text.") + # check if images and text inputs are reversed for BC + if ( + text is not None + and not isinstance(text[0], str) + or images is not None + and (isinstance(images, str) or (isinstance(images, (list, tuple)) and isinstance(images[0], str))) + ): + logger.warning_once( + "It looks like you are passing the inputs in the wrong order. You should pass the images input first and the text input second." + "Images and text inputs will be swapped." + ) + images, text = text, images + output_kwargs = self._merge_kwargs( + LlavaProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) if images is not None: - image_inputs = self.image_processor(images, return_tensors=return_tensors) + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) else: image_inputs = {} @@ -158,13 +191,7 @@ def __call__( "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." ) - text_inputs = self.tokenizer( - prompt_strings, - return_tensors=return_tensors, - padding=padding, - truncation=truncation, - max_length=max_length, - ) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) return BatchFeature(data={**text_inputs, **image_inputs}) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 2fed802b5a2fb3..6a6450554965c2 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -274,7 +274,7 @@ def test_small_model_integration_test(self): prompt = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:" image_file = "https://llava-vl.github.io/static/images/view.jpg" raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = self.processor(prompt, raw_image, return_tensors="pt") + inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt") EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) @@ -299,7 +299,7 @@ def test_small_model_integration_test_llama_single(self): prompt = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT:" image_file = "https://llava-vl.github.io/static/images/view.jpg" raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) output = model.generate(**inputs, max_new_tokens=900, do_sample=False) EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip @@ -325,7 +325,7 @@ def test_small_model_integration_test_llama_batched(self): image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) - inputs = processor(prompts, images=[image1, image2], return_tensors="pt", padding=True) + inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True) output = model.generate(**inputs, max_new_tokens=20) @@ -349,7 +349,7 @@ def test_small_model_integration_test_batch(self): image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) - inputs = self.processor(prompts, images=[image1, image2], return_tensors="pt", padding=True) + inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True) output = model.generate(**inputs, max_new_tokens=20) @@ -381,7 +381,7 @@ def test_small_model_integration_test_llama_batched_regression(self): image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) - inputs = processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True) + inputs = processor(images=[image1, image2, image1], text=prompts, return_tensors="pt", padding=True) output = model.generate(**inputs, max_new_tokens=20) @@ -409,8 +409,8 @@ def test_batched_generation(self): image2 = Image.open(requests.get(url2, stream=True).raw) inputs = processor( - text=[prompt1, prompt2, prompt3], images=[image1, image2, image1, image2], + text=[prompt1, prompt2, prompt3], return_tensors="pt", padding=True, ).to(torch_device) @@ -444,7 +444,7 @@ def test_llava_index_error_bug(self): image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) # Make sure that `generate` works _ = model.generate(**inputs, max_new_tokens=20) @@ -554,13 +554,13 @@ def test_expansion_in_processing(self): # check processing with expansion of inputs processor.vision_feature_select_strategy = "default" processor.patch_size = 14 - inputs_expanded = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) self.assertTrue(inputs_expanded.input_ids.shape[-1] == 593) # check processing without expansion of inputs (legacy behavior) processor.vision_feature_select_strategy = None processor.patch_size = None - inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) self.assertTrue(inputs.input_ids.shape[-1] == 18) # generate exactly 20 tokens diff --git a/tests/models/llava/test_processor_llava.py b/tests/models/llava/test_processor_llava.py index 54c1b4674cbcef..79493665f9d862 100644 --- a/tests/models/llava/test_processor_llava.py +++ b/tests/models/llava/test_processor_llava.py @@ -11,18 +11,42 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import shutil +import tempfile import unittest -from transformers.testing_utils import require_vision +from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_vision_available +from ...test_processing_common import ProcessorTesterMixin + if is_vision_available(): - from transformers import AutoTokenizer, LlavaProcessor + from transformers import AutoProcessor, AutoTokenizer, CLIPImageProcessor, LlamaTokenizerFast, LlavaProcessor @require_vision -class LlavaProcessorTest(unittest.TestCase): +class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = LlavaProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = CLIPImageProcessor(do_center_crop=False) + tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b") + + processor = LlavaProcessor(image_processor, tokenizer) + + processor.save_pretrained(self.tmpdirname) + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + def test_can_load_various_tokenizers(self): for checkpoint in ["Intel/llava-gemma-2b", "llava-hf/llava-1.5-7b-hf"]: processor = LlavaProcessor.from_pretrained(checkpoint) @@ -45,3 +69,29 @@ def test_chat_template(self): formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) self.assertEqual(expected_prompt, formatted_prompt) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + text=input_str, + images=image_input, + return_tensors="pt", + size={"height": 214, "width": 214}, + padding="longest", + max_length=76, + ) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 5) From 937054de09b0efd8192a82df1d9e0cfde1df0bb5 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 16 Aug 2024 20:40:17 +0000 Subject: [PATCH 2/9] Change order of processor inputs in docstring --- src/transformers/models/llava/processing_llava.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 76659710226a4f..bbb57d808d5143 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -102,13 +102,13 @@ def __call__( of the above two methods for more information. Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. text (`str`, `List[str]`, `List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): - The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch - tensor. Both channels-first and channels-last formats are supported. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): Select a strategy to pad the returned sequences (according to the model's padding side and padding index) among: From 31e5132f46fdd714adc4c18d22d49b6c5f496bc0 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Mon, 19 Aug 2024 23:07:19 +0000 Subject: [PATCH 3/9] Improve BC support for reversed images and text inputs --- .../models/llava/processing_llava.py | 14 +---- src/transformers/processing_utils.py | 57 ++++++++++++++++++- tests/models/llava/test_modeling_llava.py | 2 +- tests/models/llava/test_processor_llava.py | 2 +- 4 files changed, 59 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index bbb57d808d5143..11cba7fe41893f 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -141,18 +141,10 @@ def __call__( """ if images is None and text is None: raise ValueError("You have to specify at least images or text.") + # check if images and text inputs are reversed for BC - if ( - text is not None - and not isinstance(text[0], str) - or images is not None - and (isinstance(images, str) or (isinstance(images, (list, tuple)) and isinstance(images[0], str))) - ): - logger.warning_once( - "It looks like you are passing the inputs in the wrong order. You should pass the images input first and the text input second." - "Images and text inputs will be swapped." - ) - images, text = text, images + images, text = self._check_reversed_images_text(images, text) + output_kwargs = self._merge_kwargs( LlavaProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index ee28c01189b439..3ecf3eba10359b 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -102,8 +102,6 @@ class TextKwargs(TypedDict, total=False): Whether or not to return the lengths of the encoded inputs. verbose (`bool`, *optional*): Whether or not to print more information and warnings. - padding_side (`str`, *optional*): - The side on which padding will be applied. """ add_special_tokens: Optional[bool] @@ -120,7 +118,6 @@ class TextKwargs(TypedDict, total=False): return_offsets_mapping: Optional[bool] return_length: Optional[bool] verbose: Optional[bool] - padding_side: Optional[str] class ImagesKwargs(TypedDict, total=False): @@ -832,6 +829,60 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg output_kwargs[modality].update(output_kwargs["common_kwargs"]) return output_kwargs + def _check_reversed_images_text(self, images, text): + """ + For backward compatibility: reverse the order of `images` and `text` inputs if they are swapped. + This method should only be called for processors where `images` and `text` have been swapped for uniformization purposes. + Note that this method assumes that two `None` inputs are valid inputs. If this is not the case, it should be handled + in the processor's `__call__` method before calling this method. + """ + + def _is_valid_text_input(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... empty + return True + elif isinstance(t[0], str): + # ... list of strings + return True + elif isinstance(t[0], (list, tuple)): + # ... list with an empty list or with a list of strings + return len(t[0]) == 0 or isinstance(t[0][0], str) + else: + return False + else: + return False + + def _is_valid_or_convertible(input, validator, converter): + is_valid = validator(input) or input is None + is_convertible = converter(input) if not is_valid else False + return is_valid, is_convertible + + images_is_valid, images_is_text = _is_valid_or_convertible(images, valid_images, _is_valid_text_input) + text_is_valid, text_is_images = _is_valid_or_convertible(text, _is_valid_text_input, valid_images) + + # Handle cases where both inputs are valid + if images_is_valid and text_is_valid: + return images, text + + # Handle cases where inputs need to and can be swapped + if ( + (images is None and text_is_images) + or (text is None and images_is_text) + or (images_is_text and text_is_images) + ): + logger.warning_once( + "You may have used the wrong order for inputs. `images` should be passed before `text`. " + "The `images` and `text` inputs will be swapped." + ) + return text, images + + raise ValueError("Invalid input type. Check that `images` and/or `text` are valid inputs.") + @classmethod def from_pretrained( cls, diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 6a6450554965c2..712250e4d903b4 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -510,7 +510,7 @@ def test_generation_no_images(self): processor = AutoProcessor.from_pretrained(model_id) # Prepare inputs with no images - inputs = processor("Hello, I am", return_tensors="pt").to(torch_device) + inputs = processor(text="Hello, I am", return_tensors="pt").to(torch_device) # Make sure that `generate` works _ = model.generate(**inputs, max_new_tokens=20) diff --git a/tests/models/llava/test_processor_llava.py b/tests/models/llava/test_processor_llava.py index 79493665f9d862..ab9904a6c59f27 100644 --- a/tests/models/llava/test_processor_llava.py +++ b/tests/models/llava/test_processor_llava.py @@ -84,8 +84,8 @@ def test_unstructured_kwargs_batched(self): input_str = ["lower newer", "upper older longer string"] image_input = self.prepare_image_inputs() * 2 inputs = processor( - text=input_str, images=image_input, + text=input_str, return_tensors="pt", size={"height": 214, "width": 214}, padding="longest", From 7097e665e2b51818b9d81370a8cb6105be6c57f6 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Mon, 19 Aug 2024 23:34:24 +0000 Subject: [PATCH 4/9] cleanup llava processor call docstring --- .../models/llava/processing_llava.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 11cba7fe41893f..11da217f09a83f 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -109,26 +109,6 @@ def __call__( The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors of a particular framework. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return NumPy `np.ndarray` objects. - - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: From c345dc8fb35efd3707f10777d3709be0f439b365 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 20 Aug 2024 15:17:43 +0000 Subject: [PATCH 5/9] Add encoded inputs as valid text inputs in reverse input check, add deprecation version in warning --- src/transformers/processing_utils.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 3ecf3eba10359b..4797c1b412757c 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -837,21 +837,21 @@ def _check_reversed_images_text(self, images, text): in the processor's `__call__` method before calling this method. """ - def _is_valid_text_input(t): + def _is_valid_text_input_for_processor(t): if isinstance(t, str): # Strings are fine return True elif isinstance(t, (list, tuple)): # List are fine as long as they are... if len(t) == 0: - # ... empty - return True - elif isinstance(t[0], str): - # ... list of strings + # ... not empty + return False + elif isinstance(t[0], (str, int)): + # ... list of strings or int (for encoded inputs) return True elif isinstance(t[0], (list, tuple)): - # ... list with an empty list or with a list of strings - return len(t[0]) == 0 or isinstance(t[0][0], str) + # ... list of list of strings or int (for encoded inputs) + return isinstance(t[0][0], (str, int)) else: return False else: @@ -862,8 +862,12 @@ def _is_valid_or_convertible(input, validator, converter): is_convertible = converter(input) if not is_valid else False return is_valid, is_convertible - images_is_valid, images_is_text = _is_valid_or_convertible(images, valid_images, _is_valid_text_input) - text_is_valid, text_is_images = _is_valid_or_convertible(text, _is_valid_text_input, valid_images) + images_is_valid, images_is_text = _is_valid_or_convertible( + images, valid_images, _is_valid_text_input_for_processor + ) + text_is_valid, text_is_images = _is_valid_or_convertible( + text, _is_valid_text_input_for_processor, valid_images + ) # Handle cases where both inputs are valid if images_is_valid and text_is_valid: @@ -877,7 +881,7 @@ def _is_valid_or_convertible(input, validator, converter): ): logger.warning_once( "You may have used the wrong order for inputs. `images` should be passed before `text`. " - "The `images` and `text` inputs will be swapped." + "The `images` and `text` inputs will be swapped. This behavior will be deprecated in transformers v4.47." ) return text, images From 04918e7f56a666c6a07a709c3fd988c45b41059f Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 20 Aug 2024 17:33:33 +0000 Subject: [PATCH 6/9] Put function check reversed images text outside base processor class --- .../models/llava/processing_llava.py | 4 +- src/transformers/processing_utils.py | 58 ------------------- 2 files changed, 2 insertions(+), 60 deletions(-) diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 11da217f09a83f..59ad219fca059c 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -21,7 +21,7 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, get_image_size, to_numpy_array -from ...processing_utils import ProcessingKwargs, ProcessorMixin +from ...processing_utils import ProcessingKwargs, ProcessorMixin, _check_reversed_images_text_for_vlms from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import logging @@ -123,7 +123,7 @@ def __call__( raise ValueError("You have to specify at least images or text.") # check if images and text inputs are reversed for BC - images, text = self._check_reversed_images_text(images, text) + images, text = _check_reversed_images_text_for_vlms(images, text) output_kwargs = self._merge_kwargs( LlavaProcessorKwargs, diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 4797c1b412757c..ef7911217727fd 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -829,64 +829,6 @@ class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwarg output_kwargs[modality].update(output_kwargs["common_kwargs"]) return output_kwargs - def _check_reversed_images_text(self, images, text): - """ - For backward compatibility: reverse the order of `images` and `text` inputs if they are swapped. - This method should only be called for processors where `images` and `text` have been swapped for uniformization purposes. - Note that this method assumes that two `None` inputs are valid inputs. If this is not the case, it should be handled - in the processor's `__call__` method before calling this method. - """ - - def _is_valid_text_input_for_processor(t): - if isinstance(t, str): - # Strings are fine - return True - elif isinstance(t, (list, tuple)): - # List are fine as long as they are... - if len(t) == 0: - # ... not empty - return False - elif isinstance(t[0], (str, int)): - # ... list of strings or int (for encoded inputs) - return True - elif isinstance(t[0], (list, tuple)): - # ... list of list of strings or int (for encoded inputs) - return isinstance(t[0][0], (str, int)) - else: - return False - else: - return False - - def _is_valid_or_convertible(input, validator, converter): - is_valid = validator(input) or input is None - is_convertible = converter(input) if not is_valid else False - return is_valid, is_convertible - - images_is_valid, images_is_text = _is_valid_or_convertible( - images, valid_images, _is_valid_text_input_for_processor - ) - text_is_valid, text_is_images = _is_valid_or_convertible( - text, _is_valid_text_input_for_processor, valid_images - ) - - # Handle cases where both inputs are valid - if images_is_valid and text_is_valid: - return images, text - - # Handle cases where inputs need to and can be swapped - if ( - (images is None and text_is_images) - or (text is None and images_is_text) - or (images_is_text and text_is_images) - ): - logger.warning_once( - "You may have used the wrong order for inputs. `images` should be passed before `text`. " - "The `images` and `text` inputs will be swapped. This behavior will be deprecated in transformers v4.47." - ) - return text, images - - raise ValueError("Invalid input type. Check that `images` and/or `text` are valid inputs.") - @classmethod def from_pretrained( cls, From 3d8ec3d1e26332f230a467f57a33cc532ca650d3 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 23 Aug 2024 14:52:34 +0000 Subject: [PATCH 7/9] Refactor _validate_images_text_input_order --- src/transformers/models/llava/processing_llava.py | 13 ++++++++++--- tests/models/llava/test_processor_llava.py | 5 +++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 59ad219fca059c..4600d1063b2f52 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -21,7 +21,7 @@ from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, get_image_size, to_numpy_array -from ...processing_utils import ProcessingKwargs, ProcessorMixin, _check_reversed_images_text_for_vlms +from ...processing_utils import ProcessingKwargs, ProcessorMixin, _validate_images_text_input_order from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import logging @@ -109,6 +109,12 @@ def __call__( The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: @@ -120,10 +126,11 @@ def __call__( - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ if images is None and text is None: - raise ValueError("You have to specify at least images or text.") + raise ValueError("You have to specify at least one of `images` or `text`.") # check if images and text inputs are reversed for BC - images, text = _check_reversed_images_text_for_vlms(images, text) + text, images = images, text + images, text = _validate_images_text_input_order(images, text) output_kwargs = self._merge_kwargs( LlavaProcessorKwargs, diff --git a/tests/models/llava/test_processor_llava.py b/tests/models/llava/test_processor_llava.py index ab9904a6c59f27..5b05a8b92ea513 100644 --- a/tests/models/llava/test_processor_llava.py +++ b/tests/models/llava/test_processor_llava.py @@ -15,6 +15,7 @@ import tempfile import unittest +from transformers import AutoProcessor, AutoTokenizer, LlamaTokenizerFast, LlavaProcessor from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_vision_available @@ -22,7 +23,7 @@ if is_vision_available(): - from transformers import AutoProcessor, AutoTokenizer, CLIPImageProcessor, LlamaTokenizerFast, LlavaProcessor + from transformers import CLIPImageProcessor @require_vision @@ -34,7 +35,7 @@ def setUp(self): image_processor = CLIPImageProcessor(do_center_crop=False) tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b") - processor = LlavaProcessor(image_processor, tokenizer) + processor = LlavaProcessor(image_processor=image_processor, tokenizer=tokenizer) processor.save_pretrained(self.tmpdirname) From 14af1c970cbacf6178a715e6b55909fb402163be Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 28 Aug 2024 05:08:00 +0000 Subject: [PATCH 8/9] Add ProcessingUtilTester --- src/transformers/models/llava/processing_llava.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 4600d1063b2f52..28a9410e6cbf0b 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -129,7 +129,6 @@ def __call__( raise ValueError("You have to specify at least one of `images` or `text`.") # check if images and text inputs are reversed for BC - text, images = images, text images, text = _validate_images_text_input_order(images, text) output_kwargs = self._merge_kwargs( From 4dc7adae7c978ddf70ffdbc1eb33fb2166931f55 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Wed, 4 Sep 2024 17:44:12 +0000 Subject: [PATCH 9/9] fix processing and test_processing --- src/transformers/processing_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index ef7911217727fd..ee28c01189b439 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -102,6 +102,8 @@ class TextKwargs(TypedDict, total=False): Whether or not to return the lengths of the encoded inputs. verbose (`bool`, *optional*): Whether or not to print more information and warnings. + padding_side (`str`, *optional*): + The side on which padding will be applied. """ add_special_tokens: Optional[bool] @@ -118,6 +120,7 @@ class TextKwargs(TypedDict, total=False): return_offsets_mapping: Optional[bool] return_length: Optional[bool] verbose: Optional[bool] + padding_side: Optional[str] class ImagesKwargs(TypedDict, total=False):