Skip to content

Commit

Permalink
Uniformize kwargs for LLaVa processor and update docs (#32858)
Browse files Browse the repository at this point in the history
* Uniformize kwargs for LlaVa and update docs

* Change order of processor inputs in docstring

* Improve BC support for reversed images and text inputs

* cleanup llava processor call docstring

* Add encoded inputs as valid text inputs in reverse input check, add deprecation version in warning

* Put function check reversed images text outside base processor class

* Refactor _validate_images_text_input_order

* Add ProcessingUtilTester

* fix processing and test_processing
  • Loading branch information
yonigozlan committed Sep 16, 2024
1 parent ce62a41 commit 2f62146
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def forward(
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
Expand Down
73 changes: 39 additions & 34 deletions src/transformers/models/llava/processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,33 @@
Processor class for Llava.
"""

from typing import List, Optional, Union
import sys
from typing import List, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType, logging
from ...processing_utils import ProcessingKwargs, ProcessorMixin, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack

logger = logging.get_logger(__name__)


class LlavaProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {},
}


class LlavaProcessor(ProcessorMixin):
r"""
Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
Expand Down Expand Up @@ -73,12 +88,11 @@ def __init__(

def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[LlavaProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand All @@ -88,29 +102,15 @@ def __call__(
of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
Expand All @@ -125,8 +125,19 @@ def __call__(
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if images is None and text is None:
raise ValueError("You have to specify at least one of `images` or `text`.")

# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)

output_kwargs = self._merge_kwargs(
LlavaProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(images, return_tensors=return_tensors)
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
else:
image_inputs = {}

Expand Down Expand Up @@ -158,13 +169,7 @@ def __call__(
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)

text_inputs = self.tokenizer(
prompt_strings,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs})

# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
Expand Down
20 changes: 10 additions & 10 deletions tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def test_small_model_integration_test(self):
prompt = "<image>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:"
image_file = "https://llava-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(prompt, raw_image, return_tensors="pt")
inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt")

EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
Expand All @@ -299,7 +299,7 @@ def test_small_model_integration_test_llama_single(self):
prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place? ASSISTANT:"
image_file = "https://llava-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)

output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place? ASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the tides and currents, as they can change rapidly and pose a risk to swimmers or those who venture too close to the edge of the pier. Finally, be respectful of the environment and other visitors, and follow any posted rules or guidelines for the area." # fmt: skip
Expand All @@ -325,7 +325,7 @@ def test_small_model_integration_test_llama_batched(self):
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)

inputs = processor(prompts, images=[image1, image2], return_tensors="pt", padding=True)
inputs = processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)

output = model.generate(**inputs, max_new_tokens=20)

Expand All @@ -349,7 +349,7 @@ def test_small_model_integration_test_batch(self):
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)

inputs = self.processor(prompts, images=[image1, image2], return_tensors="pt", padding=True)
inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True)

output = model.generate(**inputs, max_new_tokens=20)

Expand Down Expand Up @@ -381,7 +381,7 @@ def test_small_model_integration_test_llama_batched_regression(self):
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)

inputs = processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True)
inputs = processor(images=[image1, image2, image1], text=prompts, return_tensors="pt", padding=True)

output = model.generate(**inputs, max_new_tokens=20)

Expand Down Expand Up @@ -409,8 +409,8 @@ def test_batched_generation(self):
image2 = Image.open(requests.get(url2, stream=True).raw)

inputs = processor(
text=[prompt1, prompt2, prompt3],
images=[image1, image2, image1, image2],
text=[prompt1, prompt2, prompt3],
return_tensors="pt",
padding=True,
).to(torch_device)
Expand Down Expand Up @@ -444,7 +444,7 @@ def test_llava_index_error_bug(self):
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"

raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)

# Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20)
Expand Down Expand Up @@ -510,7 +510,7 @@ def test_generation_no_images(self):
processor = AutoProcessor.from_pretrained(model_id)

# Prepare inputs with no images
inputs = processor("Hello, I am", return_tensors="pt").to(torch_device)
inputs = processor(text="Hello, I am", return_tensors="pt").to(torch_device)

# Make sure that `generate` works
_ = model.generate(**inputs, max_new_tokens=20)
Expand Down Expand Up @@ -554,13 +554,13 @@ def test_expansion_in_processing(self):
# check processing with expansion of inputs
processor.vision_feature_select_strategy = "default"
processor.patch_size = 14
inputs_expanded = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
inputs_expanded = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
self.assertTrue(inputs_expanded.input_ids.shape[-1] == 593)

# check processing without expansion of inputs (legacy behavior)
processor.vision_feature_select_strategy = None
processor.patch_size = None
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device, torch.float16)
self.assertTrue(inputs.input_ids.shape[-1] == 18)

# generate exactly 20 tokens
Expand Down
57 changes: 54 additions & 3 deletions tests/models/llava/test_processor_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest

from transformers.testing_utils import require_vision
from transformers import AutoProcessor, AutoTokenizer, LlamaTokenizerFast, LlavaProcessor
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from transformers import AutoTokenizer, LlavaProcessor
from transformers import CLIPImageProcessor


@require_vision
class LlavaProcessorTest(unittest.TestCase):
class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = LlavaProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = CLIPImageProcessor(do_center_crop=False)
tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b")

processor = LlavaProcessor(image_processor=image_processor, tokenizer=tokenizer)

processor.save_pretrained(self.tmpdirname)

def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer

def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor

def tearDown(self):
shutil.rmtree(self.tmpdirname)

def test_can_load_various_tokenizers(self):
for checkpoint in ["Intel/llava-gemma-2b", "llava-hf/llava-1.5-7b-hf"]:
processor = LlavaProcessor.from_pretrained(checkpoint)
Expand All @@ -45,3 +70,29 @@ def test_chat_template(self):

formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)

@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
images=image_input,
text=input_str,
return_tensors="pt",
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)

self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 5)

0 comments on commit 2f62146

Please sign in to comment.