Skip to content

Commit

Permalink
uniformize the kwargs for the rest of the processors in the list
Browse files Browse the repository at this point in the history
  • Loading branch information
leloykun committed Aug 16, 2024
1 parent f3c8b18 commit 45507e0
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 180 deletions.
72 changes: 57 additions & 15 deletions src/transformers/models/clipseg/processing_clipseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,28 @@
Image/Text processor class for CLIPSeg
"""

import sys
import warnings
from typing import List, Optional, Union

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...image_utils import ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack


class CLIPSegImagesKwargs(ImagesKwargs, total=False):
visual_prompt: Optional[ImageInput]


class CLIPSegProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: CLIPSegImagesKwargs
_defaults = {}


class CLIPSegProcessor(ProcessorMixin):
Expand Down Expand Up @@ -58,7 +76,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):

super().__init__(image_processor, tokenizer)

def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=None, **kwargs):
def __call__(
self,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
images: Optional[ImageInput] = None,
audio=None,
videos=None,
**kwargs: Unpack[CLIPSegProcessorKwargs],
):
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to CLIPTokenizerFast's [`~CLIPTokenizerFast.__call__`] if `text` is not `None` to encode
Expand All @@ -79,14 +104,6 @@ def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=No
NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape
(C, H, W), where C is a number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
Expand All @@ -96,20 +113,43 @@ def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=No
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""

output_kwargs = self._merge_kwargs(
CLIPSegProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if output_kwargs["text_kwargs"].get("visual_prompt") is not None and audio is not None:
raise ValueError(
"You cannot provide `visual_prompt` as a positional argument and as a keyword argument at the same time."
"Please provide it only as a keyword argument (i.e. `visual_prompt=...`)."
)
if "visual_prompt" not in output_kwargs["text_kwargs"]:
warnings.warn(
"No `visual_prompt` kwarg was detected. The use of `visual_prompt` as an argument without specifying it explicitely as `visual_prompt=` will be deprecated in future versions."
)
# For backwards compatibility, we reuse `audio` as `visual_prompt` in case
# downstream users passed it as a positional argument
if audio is not None:
output_kwargs["text_kwargs"]["visual_prompt"] = audio

visual_prompt = output_kwargs["text_kwargs"].pop("visual_prompt", None)

if text is None and visual_prompt is None and images is None:
raise ValueError("You have to specify either text, visual prompt or images.")

if text is not None and visual_prompt is not None:
raise ValueError("You have to specify exactly one type of prompt. Either text or visual prompt.")

if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])

if visual_prompt is not None:
prompt_features = self.image_processor(visual_prompt, return_tensors=return_tensors, **kwargs)
prompt_features = self.image_processor(visual_prompt, **output_kwargs["images_kwargs"])

if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])

if visual_prompt is not None and images is not None:
encoding = {
Expand All @@ -128,7 +168,9 @@ def __call__(self, text=None, images=None, visual_prompt=None, return_tensors=No
}
return encoding
else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
return BatchEncoding(
data=dict(**image_features), tensor_type=output_kwargs["common_kwargs"]["return_tensors"]
)

def batch_decode(self, *args, **kwargs):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,48 @@
"""

import os
import sys
from typing import List, Optional, Union

from ...image_processing_utils import BatchFeature
from ...image_utils import VideoInput
from ...processing_utils import ProcessorMixin
from ...processing_utils import ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import (
AddedToken,
BatchEncoding,
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
from ...utils import TensorType, logging
from ...utils import logging
from ..auto import AutoTokenizer


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack


logger = logging.get_logger(__name__)


class InstructBlipVideoProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"truncation": None,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": False,
"return_length": False,
"verbose": True,
},
}


class InstructBlipVideoProcessor(ProcessorMixin):
r"""
Constructs an InstructBLIPVideo processor which wraps a InstructBLIP image processor and a LLaMa/T5 tokenizer into a single
Expand Down Expand Up @@ -71,30 +93,24 @@ def __init__(self, image_processor, tokenizer, qformer_tokenizer=None, num_query

def __call__(
self,
images: VideoInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
images: Optional[VideoInput] = None,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio=None,
videos=None,
**kwargs: Unpack[InstructBlipVideoProcessorKwargs],
) -> BatchFeature:
"""
This method uses [`InstructBlipVideoImageProcessor.__call__`] method to prepare image(s) or video(s) for the model, and
[`BertTokenizerFast.__call__`] to prepare text for the model.
Please refer to the docstring of the above two methods for more information.
"""
output_kwargs = self._merge_kwargs(
InstructBlipVideoProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

encoding = BatchFeature()

if text is not None:
Expand All @@ -105,21 +121,10 @@ def __call__(

_text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=None, # required to concatenate below
**kwargs,
**{
**output_kwargs["text_kwargs"],
"return_tensors": None, # required to concatenate below
},
)

# if we know how many query tokens, expand text inside processor. We need this hacky manipulation
Expand All @@ -145,31 +150,14 @@ def __call__(
)

# cast to desired return tensors type after concatenating
text_encoding = BatchEncoding(text_encoding, tensor_type=return_tensors)
text_encoding = BatchEncoding(text_encoding, tensor_type=output_kwargs["common_kwargs"]["return_tensors"])
encoding.update(text_encoding)
qformer_text_encoding = self.qformer_tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
)
qformer_text_encoding = self.qformer_tokenizer(text=text, **output_kwargs["text_kwargs"])
encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")

if images is not None:
image_encoding = self.image_processor(images, return_tensors=return_tensors)
image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
encoding.update(image_encoding)

return encoding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@
Processor class for LLaVa-NeXT-Video.
"""

import sys
from typing import TYPE_CHECKING, List, Optional, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, VideoInput, get_image_size, to_numpy_array
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType, logging
from ...processing_utils import ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import logging


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack


if TYPE_CHECKING:
Expand All @@ -31,6 +38,17 @@
logger = logging.get_logger(__name__)


class LlavaNextVideoProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"common_kwargs": {
"return_tensors": "pt",
},
}


class LlavaNextVideoProcessor(ProcessorMixin):
r"""
Constructs a LLaVa-NeXT-Video processor which wraps a LLaVa-NeXT image processor, LLaVa-NeXT-Video video processor and
Expand Down Expand Up @@ -88,12 +106,10 @@ def __init__(
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
images: ImageInput = None,
videos: VideoInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: int = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
images: Optional[ImageInput] = None,
videos: Optional[VideoInput] = None,
audio=None,
**kwargs: Unpack[LlavaNextVideoProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand All @@ -115,26 +131,6 @@ def __call__(
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
Expand All @@ -145,13 +141,19 @@ def __call__(
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
LlavaNextVideoProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if images is not None:
image_inputs = self.image_processor(images, return_tensors=return_tensors)
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
else:
image_inputs = {}

if videos is not None:
videos_inputs = self.video_processor(videos, return_tensors=return_tensors)
videos_inputs = self.video_processor(videos, **output_kwargs["videos_kwargs"])
else:
videos_inputs = {}

Expand Down Expand Up @@ -203,14 +205,7 @@ def __call__(
sample = sample.replace(self.video_token, self.video_token * num_video_tokens)
prompt_strings.append(sample)

text_inputs = self.tokenizer(
prompt_strings,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
print(text_inputs.keys())
text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])

return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})

Expand Down
Loading

0 comments on commit 45507e0

Please sign in to comment.