Skip to content

Commit

Permalink
Uniformize kwargs for Pixtral processor (#33521)
Browse files Browse the repository at this point in the history
* add uniformized pixtral and kwargs

* update doc

* fix _validate_images_text_input_order

* nit
  • Loading branch information
yonigozlan committed Sep 17, 2024
1 parent c29a869 commit d8500cd
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 62 deletions.
6 changes: 3 additions & 3 deletions docs/source/en/model_doc/pixtral.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ The Pixtral model was released by the Mistral AI team on [Vllm](https://github.c
Tips:

- Pixtral is a multimodal model, the main contribution is the 2d ROPE on the images, and support for arbitrary image size (the images are not padded together nor are they resized)
- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders.
- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders.
- The format for one or mulitple prompts is the following:
```
"<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
Expand All @@ -35,7 +35,7 @@ This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts)

Here is an example of how to run it:

```python
```python
from transformers import LlavaForConditionalGeneration, AutoProcessor
from PIL import Image

Expand All @@ -51,7 +51,7 @@ IMG_URLS = [
]
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"

inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda")
inputs = processor(images=IMG_URLS, text=PROMPT, return_tensors="pt").to("cuda")
generate_ids = model.generate(**inputs, max_new_tokens=500)
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/pixtral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@


if TYPE_CHECKING:
from .configuration_pixtral import PixtralProcessor, PixtralVisionConfig
from .configuration_pixtral import PixtralVisionConfig
from .processing_pixtral import PixtralProcessor

try:
if not is_torch_available():
Expand Down
74 changes: 40 additions & 34 deletions src/transformers/models/pixtral/processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,36 @@
Processor class for Pixtral.
"""

from typing import List, Optional, Union
import sys
from typing import List, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
from ...processing_utils import ProcessingKwargs, ProcessorMixin, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack

logger = logging.get_logger(__name__)


class PixtralProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {},
"common_kwargs": {
"return_tensors": "pt",
},
}


# Copied from transformers.models.idefics2.processing_idefics2.is_url
def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http")
Expand Down Expand Up @@ -143,12 +161,11 @@ def __init__(

def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[PixtralProcessorKwargs],
) -> BatchMixFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand All @@ -158,26 +175,13 @@ def __call__(
of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
Expand All @@ -195,6 +199,15 @@ def __call__(
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)

output_kwargs = self._merge_kwargs(
PixtralProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if images is not None:
if is_image_or_image_url(images):
images = [[images]]
Expand All @@ -209,7 +222,7 @@ def __call__(
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
)
images = [[load_image(im) for im in sample] for sample in images]
image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors=return_tensors)
image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"])
else:
image_inputs = {}

Expand Down Expand Up @@ -246,16 +259,9 @@ def __call__(
while "<placeholder>" in sample:
replace_str = replace_strings.pop(0)
sample = sample.replace("<placeholder>", replace_str, 1)

prompt_strings.append(sample)

text_inputs = self.tokenizer(
prompt_strings,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return BatchMixFeature(data={**text_inputs, **image_inputs})

# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
Expand Down
22 changes: 18 additions & 4 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np

from .dynamic_module_utils import custom_object_save
from .image_utils import ChannelDimension, is_vision_available, valid_images
from .image_utils import ChannelDimension, is_valid_image, is_vision_available


if is_vision_available():
Expand Down Expand Up @@ -1003,6 +1003,20 @@ def _validate_images_text_input_order(images, text):
in the processor's `__call__` method before calling this method.
"""

def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http")

def _is_valid_images_input_for_processor(imgs):
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):
for img in imgs:
if not _is_valid_images_input_for_processor(img):
return False
# If not a list or tuple, we have been given a single image or batched tensor of images
elif not (is_valid_image(imgs) or is_url(imgs)):
return False
return True

def _is_valid_text_input_for_processor(t):
if isinstance(t, str):
# Strings are fine
Expand All @@ -1019,11 +1033,11 @@ def _is_valid_text_input_for_processor(t):
def _is_valid(input, validator):
return validator(input) or input is None

images_is_valid = _is_valid(images, valid_images)
images_is_text = _is_valid_text_input_for_processor(images) if not images_is_valid else False
images_is_valid = _is_valid(images, _is_valid_images_input_for_processor)
images_is_text = _is_valid_text_input_for_processor(images)

text_is_valid = _is_valid(text, _is_valid_text_input_for_processor)
text_is_images = valid_images(text) if not text_is_valid else False
text_is_images = _is_valid_images_input_for_processor(text)
# Handle cases where both inputs are valid
if images_is_valid and text_is_valid:
return images, text
Expand Down
Loading

0 comments on commit d8500cd

Please sign in to comment.