diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 9ad19ccee72228..eb1c55341b0784 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -405,7 +405,7 @@ def forward( >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(**inputs, max_new_tokens=15) diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 678724ae95be41..28a9410e6cbf0b 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -16,18 +16,33 @@ Processor class for Llava. """ -from typing import List, Optional, Union +import sys +from typing import List, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, get_image_size, to_numpy_array -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType, logging +from ...processing_utils import ProcessingKwargs, ProcessorMixin, _validate_images_text_input_order +from ...tokenization_utils_base import PreTokenizedInput, TextInput +from ...utils import logging +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + logger = logging.get_logger(__name__) +class LlavaProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": {}, + } + + class LlavaProcessor(ProcessorMixin): r""" Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor. @@ -73,12 +88,11 @@ def __init__( def __call__( self, - text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, images: ImageInput = None, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = None, - max_length=None, - return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + audio=None, + videos=None, + **kwargs: Unpack[LlavaProcessorKwargs], ) -> BatchFeature: """ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` @@ -88,29 +102,15 @@ def __call__( of the above two methods for more information. Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. text (`str`, `List[str]`, `List[List[str]]`): The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). - images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): - The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch - tensor. Both channels-first and channels-last formats are supported. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. return_tensors (`str` or [`~utils.TensorType`], *optional*): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. @@ -125,8 +125,19 @@ def __call__( `None`). - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. """ + if images is None and text is None: + raise ValueError("You have to specify at least one of `images` or `text`.") + + # check if images and text inputs are reversed for BC + images, text = _validate_images_text_input_order(images, text) + + output_kwargs = self._merge_kwargs( + LlavaProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) if images is not None: - image_inputs = self.image_processor(images, return_tensors=return_tensors) + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) else: image_inputs = {} @@ -158,13 +169,7 @@ def __call__( "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." ) - text_inputs = self.tokenizer( - prompt_strings, - return_tensors=return_tensors, - padding=padding, - truncation=truncation, - max_length=max_length, - ) + text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"]) return BatchFeature(data={**text_inputs, **image_inputs}) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 5c05480ffa6dbb..305fc9e9a84cdb 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -274,7 +274,7 @@ def test_small_model_integration_test(self): prompt = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:" image_file = "https://llava-vl.github.io/static/images/view.jpg" raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = self.processor(prompt, raw_image, return_tensors="pt") + inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt") EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) @@ -299,7 +299,7 @@ def test_small_model_integration_test_llama_single(self): prompt = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT:" image_file = "https://llava-vl.github.io/static/images/view.jpg" raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) output = model.generate(**inputs, max_new_tokens=900, do_sample=False) EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip @@ -325,7 +325,7 @@ def test_small_model_integration_test_llama_batched(self): image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) - inputs = processor(prompts, images=[image1, image2], return_tensors="pt", padding=True) + inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True) output = model.generate(**inputs, max_new_tokens=20) @@ -349,7 +349,7 @@ def test_small_model_integration_test_batch(self): image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) - inputs = self.processor(prompts, images=[image1, image2], return_tensors="pt", padding=True) + inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True) output = model.generate(**inputs, max_new_tokens=20) @@ -381,7 +381,7 @@ def test_small_model_integration_test_llama_batched_regression(self): image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) - inputs = processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True) + inputs = processor(images=[image1, image2, image1], text=prompts, return_tensors="pt", padding=True) output = model.generate(**inputs, max_new_tokens=20) @@ -409,8 +409,8 @@ def test_batched_generation(self): image2 = Image.open(requests.get(url2, stream=True).raw) inputs = processor( - text=[prompt1, prompt2, prompt3], images=[image1, image2, image1, image2], + text=[prompt1, prompt2, prompt3], return_tensors="pt", padding=True, ).to(torch_device) @@ -444,7 +444,7 @@ def test_llava_index_error_bug(self): image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) # Make sure that `generate` works _ = model.generate(**inputs, max_new_tokens=20) @@ -510,7 +510,7 @@ def test_generation_no_images(self): processor = AutoProcessor.from_pretrained(model_id) # Prepare inputs with no images - inputs = processor("Hello, I am", return_tensors="pt").to(torch_device) + inputs = processor(text="Hello, I am", return_tensors="pt").to(torch_device) # Make sure that `generate` works _ = model.generate(**inputs, max_new_tokens=20) @@ -554,13 +554,13 @@ def test_expansion_in_processing(self): # check processing with expansion of inputs processor.vision_feature_select_strategy = "default" processor.patch_size = 14 - inputs_expanded = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) self.assertTrue(inputs_expanded.input_ids.shape[-1] == 593) # check processing without expansion of inputs (legacy behavior) processor.vision_feature_select_strategy = None processor.patch_size = None - inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16) + inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16) self.assertTrue(inputs.input_ids.shape[-1] == 18) # generate exactly 20 tokens diff --git a/tests/models/llava/test_processor_llava.py b/tests/models/llava/test_processor_llava.py index 54c1b4674cbcef..5b05a8b92ea513 100644 --- a/tests/models/llava/test_processor_llava.py +++ b/tests/models/llava/test_processor_llava.py @@ -11,18 +11,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import shutil +import tempfile import unittest -from transformers.testing_utils import require_vision +from transformers import AutoProcessor, AutoTokenizer, LlamaTokenizerFast, LlavaProcessor +from transformers.testing_utils import require_torch, require_vision from transformers.utils import is_vision_available +from ...test_processing_common import ProcessorTesterMixin + if is_vision_available(): - from transformers import AutoTokenizer, LlavaProcessor + from transformers import CLIPImageProcessor @require_vision -class LlavaProcessorTest(unittest.TestCase): +class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = LlavaProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + image_processor = CLIPImageProcessor(do_center_crop=False) + tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b") + + processor = LlavaProcessor(image_processor=image_processor, tokenizer=tokenizer) + + processor.save_pretrained(self.tmpdirname) + + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + def test_can_load_various_tokenizers(self): for checkpoint in ["Intel/llava-gemma-2b", "llava-hf/llava-1.5-7b-hf"]: processor = LlavaProcessor.from_pretrained(checkpoint) @@ -45,3 +70,29 @@ def test_chat_template(self): formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) self.assertEqual(expected_prompt, formatted_prompt) + + @require_torch + @require_vision + def test_unstructured_kwargs_batched(self): + if "image_processor" not in self.processor_class.attributes: + self.skipTest(f"image_processor attribute not present in {self.processor_class}") + image_processor = self.get_component("image_processor") + tokenizer = self.get_component("tokenizer") + + processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor) + self.skip_processor_without_typed_kwargs(processor) + + input_str = ["lower newer", "upper older longer string"] + image_input = self.prepare_image_inputs() * 2 + inputs = processor( + images=image_input, + text=input_str, + return_tensors="pt", + size={"height": 214, "width": 214}, + padding="longest", + max_length=76, + ) + + self.assertEqual(inputs["pixel_values"].shape[2], 214) + + self.assertEqual(len(inputs["input_ids"][0]), 5)