Skip to content

Commit f2596d2

Browse files
yonigozlanBernardZach
authored andcommitted
Uniformize kwargs for LLaVa processor and update docs (huggingface#32858)
* Uniformize kwargs for LlaVa and update docs * Change order of processor inputs in docstring * Improve BC support for reversed images and text inputs * cleanup llava processor call docstring * Add encoded inputs as valid text inputs in reverse input check, add deprecation version in warning * Put function check reversed images text outside base processor class * Refactor _validate_images_text_input_order * Add ProcessingUtilTester * fix processing and test_processing
1 parent b24580b commit f2596d2

File tree

4 files changed

+104
-48
lines changed

4 files changed

+104
-48
lines changed

src/transformers/models/llava/modeling_llava.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def forward(
405405
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
406406
>>> image = Image.open(requests.get(url, stream=True).raw)
407407
408-
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
408+
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
409409
410410
>>> # Generate
411411
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)

src/transformers/models/llava/processing_llava.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,33 @@
1616
Processor class for Llava.
1717
"""
1818

19-
from typing import List, Optional, Union
19+
import sys
20+
from typing import List, Union
2021

2122
from ...feature_extraction_utils import BatchFeature
2223
from ...image_utils import ImageInput, get_image_size, to_numpy_array
23-
from ...processing_utils import ProcessorMixin
24-
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
25-
from ...utils import TensorType, logging
24+
from ...processing_utils import ProcessingKwargs, ProcessorMixin, _validate_images_text_input_order
25+
from ...tokenization_utils_base import PreTokenizedInput, TextInput
26+
from ...utils import logging
2627

2728

29+
if sys.version_info >= (3, 11):
30+
from typing import Unpack
31+
else:
32+
from typing_extensions import Unpack
33+
2834
logger = logging.get_logger(__name__)
2935

3036

37+
class LlavaProcessorKwargs(ProcessingKwargs, total=False):
38+
_defaults = {
39+
"text_kwargs": {
40+
"padding": False,
41+
},
42+
"images_kwargs": {},
43+
}
44+
45+
3146
class LlavaProcessor(ProcessorMixin):
3247
r"""
3348
Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
@@ -73,12 +88,11 @@ def __init__(
7388

7489
def __call__(
7590
self,
76-
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
7791
images: ImageInput = None,
78-
padding: Union[bool, str, PaddingStrategy] = False,
79-
truncation: Union[bool, str, TruncationStrategy] = None,
80-
max_length=None,
81-
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
92+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
93+
audio=None,
94+
videos=None,
95+
**kwargs: Unpack[LlavaProcessorKwargs],
8296
) -> BatchFeature:
8397
"""
8498
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
@@ -88,29 +102,15 @@ def __call__(
88102
of the above two methods for more information.
89103
90104
Args:
105+
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
106+
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
107+
tensor. Both channels-first and channels-last formats are supported.
91108
text (`str`, `List[str]`, `List[List[str]]`):
92109
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
93110
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
94111
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
95-
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
96-
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
97-
tensor. Both channels-first and channels-last formats are supported.
98-
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
99-
Select a strategy to pad the returned sequences (according to the model's padding side and padding
100-
index) among:
101-
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
102-
sequence if provided).
103-
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
104-
acceptable input length for the model if that argument is not provided.
105-
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
106-
lengths).
107-
max_length (`int`, *optional*):
108-
Maximum length of the returned list and optionally padding length (see above).
109-
truncation (`bool`, *optional*):
110-
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
111112
return_tensors (`str` or [`~utils.TensorType`], *optional*):
112113
If set, will return tensors of a particular framework. Acceptable values are:
113-
114114
- `'tf'`: Return TensorFlow `tf.constant` objects.
115115
- `'pt'`: Return PyTorch `torch.Tensor` objects.
116116
- `'np'`: Return NumPy `np.ndarray` objects.
@@ -125,8 +125,19 @@ def __call__(
125125
`None`).
126126
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
127127
"""
128+
if images is None and text is None:
129+
raise ValueError("You have to specify at least one of `images` or `text`.")
130+
131+
# check if images and text inputs are reversed for BC
132+
images, text = _validate_images_text_input_order(images, text)
133+
134+
output_kwargs = self._merge_kwargs(
135+
LlavaProcessorKwargs,
136+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
137+
**kwargs,
138+
)
128139
if images is not None:
129-
image_inputs = self.image_processor(images, return_tensors=return_tensors)
140+
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
130141
else:
131142
image_inputs = {}
132143

@@ -158,13 +169,7 @@ def __call__(
158169
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
159170
)
160171

161-
text_inputs = self.tokenizer(
162-
prompt_strings,
163-
return_tensors=return_tensors,
164-
padding=padding,
165-
truncation=truncation,
166-
max_length=max_length,
167-
)
172+
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
168173
return BatchFeature(data={**text_inputs, **image_inputs})
169174

170175
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama

tests/models/llava/test_modeling_llava.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def test_small_model_integration_test(self):
274274
prompt = "<image>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:"
275275
image_file = "https://llava-vl.github.io/static/images/view.jpg"
276276
raw_image = Image.open(requests.get(image_file, stream=True).raw)
277-
inputs = self.processor(prompt, raw_image, return_tensors="pt")
277+
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt")
278278

279279
EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
280280
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
@@ -299,7 +299,7 @@ def test_small_model_integration_test_llama_single(self):
299299
prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place? ASSISTANT:"
300300
image_file = "https://llava-vl.github.io/static/images/view.jpg"
301301
raw_image = Image.open(requests.get(image_file, stream=True).raw)
302-
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
302+
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
303303

304304
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
305305
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip
@@ -325,7 +325,7 @@ def test_small_model_integration_test_llama_batched(self):
325325
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
326326
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
327327

328-
inputs = processor(prompts, images=[image1, image2], return_tensors="pt", padding=True)
328+
inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
329329

330330
output = model.generate(**inputs, max_new_tokens=20)
331331

@@ -349,7 +349,7 @@ def test_small_model_integration_test_batch(self):
349349
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
350350
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
351351

352-
inputs = self.processor(prompts, images=[image1, image2], return_tensors="pt", padding=True)
352+
inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)
353353

354354
output = model.generate(**inputs, max_new_tokens=20)
355355

@@ -381,7 +381,7 @@ def test_small_model_integration_test_llama_batched_regression(self):
381381
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
382382
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
383383

384-
inputs = processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True)
384+
inputs = processor(images=[image1, image2, image1], text=prompts, return_tensors="pt", padding=True)
385385

386386
output = model.generate(**inputs, max_new_tokens=20)
387387

@@ -409,8 +409,8 @@ def test_batched_generation(self):
409409
image2 = Image.open(requests.get(url2, stream=True).raw)
410410

411411
inputs = processor(
412-
text=[prompt1, prompt2, prompt3],
413412
images=[image1, image2, image1, image2],
413+
text=[prompt1, prompt2, prompt3],
414414
return_tensors="pt",
415415
padding=True,
416416
).to(torch_device)
@@ -444,7 +444,7 @@ def test_llava_index_error_bug(self):
444444
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
445445

446446
raw_image = Image.open(requests.get(image_file, stream=True).raw)
447-
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
447+
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
448448

449449
# Make sure that `generate` works
450450
_ = model.generate(**inputs, max_new_tokens=20)
@@ -510,7 +510,7 @@ def test_generation_no_images(self):
510510
processor = AutoProcessor.from_pretrained(model_id)
511511

512512
# Prepare inputs with no images
513-
inputs = processor("Hello, I am", return_tensors="pt").to(torch_device)
513+
inputs = processor(text="Hello, I am", return_tensors="pt").to(torch_device)
514514

515515
# Make sure that `generate` works
516516
_ = model.generate(**inputs, max_new_tokens=20)
@@ -554,13 +554,13 @@ def test_expansion_in_processing(self):
554554
# check processing with expansion of inputs
555555
processor.vision_feature_select_strategy = "default"
556556
processor.patch_size = 14
557-
inputs_expanded = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
557+
inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
558558
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 593)
559559

560560
# check processing without expansion of inputs (legacy behavior)
561561
processor.vision_feature_select_strategy = None
562562
processor.patch_size = None
563-
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
563+
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
564564
self.assertTrue(inputs.input_ids.shape[-1] == 18)
565565

566566
# generate exactly 20 tokens

tests/models/llava/test_processor_llava.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,43 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import shutil
15+
import tempfile
1416
import unittest
1517

16-
from transformers.testing_utils import require_vision
18+
from transformers import AutoProcessor, AutoTokenizer, LlamaTokenizerFast, LlavaProcessor
19+
from transformers.testing_utils import require_torch, require_vision
1720
from transformers.utils import is_vision_available
1821

22+
from ...test_processing_common import ProcessorTesterMixin
23+
1924

2025
if is_vision_available():
21-
from transformers import AutoTokenizer, LlavaProcessor
26+
from transformers import CLIPImageProcessor
2227

2328

2429
@require_vision
25-
class LlavaProcessorTest(unittest.TestCase):
30+
class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
31+
processor_class = LlavaProcessor
32+
33+
def setUp(self):
34+
self.tmpdirname = tempfile.mkdtemp()
35+
image_processor = CLIPImageProcessor(do_center_crop=False)
36+
tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b")
37+
38+
processor = LlavaProcessor(image_processor=image_processor, tokenizer=tokenizer)
39+
40+
processor.save_pretrained(self.tmpdirname)
41+
42+
def get_tokenizer(self, **kwargs):
43+
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
44+
45+
def get_image_processor(self, **kwargs):
46+
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
47+
48+
def tearDown(self):
49+
shutil.rmtree(self.tmpdirname)
50+
2651
def test_can_load_various_tokenizers(self):
2752
for checkpoint in ["Intel/llava-gemma-2b", "llava-hf/llava-1.5-7b-hf"]:
2853
processor = LlavaProcessor.from_pretrained(checkpoint)
@@ -45,3 +70,29 @@ def test_chat_template(self):
4570

4671
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
4772
self.assertEqual(expected_prompt, formatted_prompt)
73+
74+
@require_torch
75+
@require_vision
76+
def test_unstructured_kwargs_batched(self):
77+
if "image_processor" not in self.processor_class.attributes:
78+
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
79+
image_processor = self.get_component("image_processor")
80+
tokenizer = self.get_component("tokenizer")
81+
82+
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
83+
self.skip_processor_without_typed_kwargs(processor)
84+
85+
input_str = ["lower newer", "upper older longer string"]
86+
image_input = self.prepare_image_inputs() * 2
87+
inputs = processor(
88+
images=image_input,
89+
text=input_str,
90+
return_tensors="pt",
91+
size={"height": 214, "width": 214},
92+
padding="longest",
93+
max_length=76,
94+
)
95+
96+
self.assertEqual(inputs["pixel_values"].shape[2], 214)
97+
98+
self.assertEqual(len(inputs["input_ids"][0]), 5)

0 commit comments

Comments
 (0)