Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for args to ProcessorMixin for backward compatibility #33479

Merged
merged 8 commits into from
Sep 20, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ class LlavaOnevisionProcessor(ProcessorMixin):
r"""
Constructs a LLaVa-Onevision processor which wraps a LLaVa-Onevision video processor, LLaVa-NeXT image processor and a LLaMa tokenizer into a single processor.

[`LlavaNextProcessor`] offers all the functionalities of [`LlavaOnevisionVideoProcessor`], [`LlavaNextImageProcessor`] and [`LlamaTokenizerFast`]. See the
[`LlavaNextProcessor`] offers all the functionalities of [`LlavaOnevisionVideoProcessor`], [`LlavaOnevisionImageProcessor`] and [`LlamaTokenizerFast`]. See the
[`~LlavaOnevisionVideoProcessor.__call__`], [`~LlavaNextProcessor.__call__`] and [`~LlavaNextProcessor.decode`] for more information.

Args:
image_processor ([`LlavaNextImageProcessor`], *optional*):
image_processor ([`LlavaOnevisionImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
Expand Down Expand Up @@ -114,6 +114,7 @@ def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos: VideoInput = None,
**kwargs: Unpack[LlavaOnevisionProcessorKwargs],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp It seems that the Unpack was in the init instead of the call function of the processor. Do you think it's worth a separate PR or can we include it here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, yes, can be here, thanks! Regarding adding audio=None before video, while audio won't be used for this model: seems to be counterintuitive to me.

@molbap welcome back! We had questions about changing order of main input args in your absence, and this is also distantnly related to that. Should we be adding unused args with this order?

Copy link
Contributor

@molbap molbap Sep 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey! glad to be back, github notifications are hefty 😁
keeping audio=None is a bit strange, but it's the price for having always the same signature of constant size. I'd be pro-keeping always the same order as well. This is because in terms of integration of transformers in other tools, it makes things (a bit) easier, and would allow more future impact for processors. But if you think of something smart along the lines of informing a user of which modality/types are accepted by the model, I'm open to it!
edit: pressed enter too soon

) -> BatchFeature:
Expand Down
64 changes: 64 additions & 0 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@

from .tokenization_utils_base import (
PaddingStrategy,
PreTokenizedInput,
PreTrainedTokenizerBase,
TextInput,
TruncationStrategy,
)
from .utils import (
Expand Down Expand Up @@ -114,6 +116,9 @@ class TextKwargs(TypedDict, total=False):
The side on which padding will be applied.
"""

text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
text_pair_target: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]
add_special_tokens: Optional[bool]
padding: Union[bool, str, PaddingStrategy]
truncation: Union[bool, str, TruncationStrategy]
Expand Down Expand Up @@ -328,6 +333,7 @@ class ProcessorMixin(PushToHubMixin):

attributes = ["feature_extractor", "tokenizer"]
optional_attributes = ["chat_template"]
optional_call_args: List[str] = []
# Names need to be attr_class for attr in attributes
feature_extractor_class = None
tokenizer_class = None
Expand Down Expand Up @@ -973,6 +979,64 @@ def validate_init_kwargs(processor_config, valid_kwargs):
unused_kwargs = {k: processor_config[k] for k in unused_keys}
return unused_kwargs

def prepare_and_validate_optional_call_args(self, *args):
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice docstring :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's from @leloykun :)

Matches optional positional arguments to their corresponding names in `optional_call_args`
in the processor class in the order they are passed to the processor call.

Note that this should only be used in the `__call__` method of the processors with special
arguments. Special arguments are arguments that aren't `text`, `images`, `audio`, nor `videos`
but also aren't passed to the tokenizer, image processor, etc. Examples of such processors are:
- `CLIPSegProcessor`
- `LayoutLMv2Processor`
- `OwlViTProcessor`

Also note that passing by position to the processor call is now deprecated and will be disallowed
in future versions. We only have this for backward compatibility.

Example:
Suppose that the processor class has `optional_call_args = ["arg_name_1", "arg_name_2"]`.
And we define the call method as:
```python
def __call__(
self,
text: str,
images: Optional[ImageInput] = None,
*arg,
audio=None,
videos=None,
)
```

Then, if we call the processor as:
```python
images = [...]
processor("What is common in these images?", images, arg_value_1, arg_value_2)
```

Then, this method will return:
```python
{
"arg_name_1": arg_value_1,
"arg_name_2": arg_value_2,
}
```
which we could then pass as kwargs to `self._merge_kwargs`
"""
if len(args):
warnings.warn(
"Passing positional arguments to the processor call is now deprecated and will be disallowed in v4.47. "
"Please pass all arguments as keyword arguments."
)
if len(args) > len(self.optional_call_args):
raise ValueError(
f"Expected *at most* {len(self.optional_call_args)} optional positional arguments in processor call"
f"which will be matched with {' '.join(self.optional_call_args)} in the order they are passed."
f"However, got {len(args)} positional arguments instead."
"Please pass all arguments as keyword arguments instead (e.g. `processor(arg_name_1=..., arg_name_2=...))`."
)
return {arg_name: arg_value for arg_value, arg_name in zip(args, self.optional_call_args)}

def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]]],
Expand Down
115 changes: 1 addition & 114 deletions tests/models/altclip/test_processor_altclip.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for cleaning these up!

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import unittest

from transformers import XLMRobertaTokenizer, XLMRobertaTokenizerFast
from transformers.testing_utils import require_torch, require_vision
from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available

from ...test_processing_common import ProcessorTesterMixin
Expand Down Expand Up @@ -50,116 +50,3 @@ def get_rust_tokenizer(self, **kwargs):

def get_image_processor(self, **kwargs):
return CLIPImageProcessor.from_pretrained(self.model_id, **kwargs)

@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 7)

def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"crop_size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)

self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)

def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")

image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()

# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"crop_size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)

def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
padding="max_length",
max_length=76,
)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)

def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", crop_size=(234, 234))
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
126 changes: 0 additions & 126 deletions tests/models/chinese_clip/test_processor_chinese_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,129 +206,3 @@ def test_model_input_names(self):
inputs = processor(text=input_str, images=image_input)

self.assertListEqual(list(inputs.keys()), processor.model_input_names)

def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 6)

def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"crop_size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)

self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)

def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")

image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()

# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"crop_size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)

def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
padding="max_length",
max_length=76,
)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)

def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", crop_size=(234, 234))
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)

def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", crop_size=(234, 234))
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input, crop_size=[224, 224])
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)
Loading
Loading