Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniformize kwargs for Pixtral processor #33521

Merged
merged 4 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/en/model_doc/pixtral.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ The Pixtral model was released by the Mistral AI team on [Vllm](https://github.c
Tips:

- Pixtral is a multimodal model, the main contribution is the 2d ROPE on the images, and support for arbitrary image size (the images are not padded together nor are they resized)
- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders.
- This model follows the `Llava` familiy, meaning image embeddings are placed instead of the `[IMG]` token placeholders.
- The format for one or mulitple prompts is the following:
```
"<s>[INST][IMG]\nWhat are the things I should be cautious about when I visit this place?[/INST]"
Expand All @@ -35,7 +35,7 @@ This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts)

Here is an example of how to run it:

```python
```python
from transformers import LlavaForConditionalGeneration, AutoProcessor
from PIL import Image

Expand All @@ -51,7 +51,7 @@ IMG_URLS = [
]
PROMPT = "<s>[INST]Describe the images.\n[IMG][IMG][IMG][IMG][/INST]"

inputs = processor(text=PROMPT, images=IMG_URLS, return_tensors="pt").to("cuda")
inputs = processor(images=IMG_URLS, text=PROMPT, return_tensors="pt").to("cuda")
generate_ids = model.generate(**inputs, max_new_tokens=500)
ouptut = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/pixtral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@


if TYPE_CHECKING:
from .configuration_pixtral import PixtralProcessor, PixtralVisionConfig
from .configuration_pixtral import PixtralVisionConfig
from .processing_pixtral import PixtralProcessor

try:
if not is_torch_available():
Expand Down
74 changes: 40 additions & 34 deletions src/transformers/models/pixtral/processing_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,36 @@
Processor class for Pixtral.
"""

from typing import List, Optional, Union
import sys
from typing import List, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType, is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends
from ...processing_utils import ProcessingKwargs, ProcessorMixin, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import is_torch_device, is_torch_dtype, is_torch_tensor, logging, requires_backends


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack

logger = logging.get_logger(__name__)


class PixtralProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {},
"common_kwargs": {
"return_tensors": "pt",
},
}


# Copied from transformers.models.idefics2.processing_idefics2.is_url
def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http")
Expand Down Expand Up @@ -143,12 +161,11 @@ def __init__(

def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[PixtralProcessorKwargs],
) -> BatchMixFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand All @@ -158,26 +175,13 @@ def __call__(
of the above two methods for more information.

Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:

Expand All @@ -195,6 +199,15 @@ def __call__(
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)

output_kwargs = self._merge_kwargs(
PixtralProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if images is not None:
if is_image_or_image_url(images):
images = [[images]]
Expand All @@ -209,7 +222,7 @@ def __call__(
"Invalid input images. Please provide a single image or a list of images or a list of list of images."
)
images = [[load_image(im) for im in sample] for sample in images]
image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors=return_tensors)
image_inputs = self.image_processor(images, patch_size=self.patch_size, **output_kwargs["images_kwargs"])
else:
image_inputs = {}

Expand Down Expand Up @@ -246,16 +259,9 @@ def __call__(
while "<placeholder>" in sample:
replace_str = replace_strings.pop(0)
sample = sample.replace("<placeholder>", replace_str, 1)

prompt_strings.append(sample)

text_inputs = self.tokenizer(
prompt_strings,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])
return BatchMixFeature(data={**text_inputs, **image_inputs})

# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
Expand Down
22 changes: 18 additions & 4 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np

from .dynamic_module_utils import custom_object_save
from .image_utils import ChannelDimension, is_vision_available, valid_images
from .image_utils import ChannelDimension, is_valid_image, is_vision_available


if is_vision_available():
Expand Down Expand Up @@ -1003,6 +1003,20 @@ def _validate_images_text_input_order(images, text):
in the processor's `__call__` method before calling this method.
"""

def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http")

def _is_valid_images_input_for_processor(imgs):
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):
for img in imgs:
if not _is_valid_images_input_for_processor(img):
return False
# If not a list or tuple, we have been given a single image or batched tensor of images
elif not (is_valid_image(imgs) or is_url(imgs)):
return False
return True

def _is_valid_text_input_for_processor(t):
if isinstance(t, str):
# Strings are fine
Expand All @@ -1019,11 +1033,11 @@ def _is_valid_text_input_for_processor(t):
def _is_valid(input, validator):
return validator(input) or input is None

images_is_valid = _is_valid(images, valid_images)
images_is_text = _is_valid_text_input_for_processor(images) if not images_is_valid else False
images_is_valid = _is_valid(images, _is_valid_images_input_for_processor)
images_is_text = _is_valid_text_input_for_processor(images)

text_is_valid = _is_valid(text, _is_valid_text_input_for_processor)
text_is_images = valid_images(text) if not text_is_valid else False
text_is_images = _is_valid_images_input_for_processor(text)
# Handle cases where both inputs are valid
if images_is_valid and text_is_valid:
return images, text
Expand Down
Loading
Loading