Skip to content

Commit

Permalink
remove optional args and udop uniformization from this PR
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Sep 16, 2024
1 parent 7c504d1 commit f5d8507
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 317 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/udop/modeling_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1790,7 +1790,7 @@ def forward(
>>> # one can use the various task prefixes (prompts) used during pre-training
>>> # e.g. the task prefix for DocVQA is "Question answering. "
>>> question = "Question answering. What is the date on the form?"
>>> encoding = processor(image, question, text_pair=words, boxes=boxes, return_tensors="pt")
>>> encoding = processor(image, question, words, boxes=boxes, return_tensors="pt")
>>> # autoregressive generation
>>> predicted_ids = model.generate(**encoding)
Expand Down
163 changes: 68 additions & 95 deletions src/transformers/models/udop/processing_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,12 @@
Processor class for UDOP.
"""

import sys
from typing import List, Optional, Union

from transformers import logging

from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack


logger = logging.get_logger(__name__)


class UdopTextKwargs(TextKwargs, total=False):
word_labels: Optional[Union[List[int], List[List[int]]]]
boxes: Union[List[List[int]], List[List[List[int]]]]


class UdopProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: UdopTextKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"truncation": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_length": False,
"verbose": True,
},
"images_kwargs": {},
}
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType


class UdopProcessor(ProcessorMixin):
Expand Down Expand Up @@ -84,8 +49,6 @@ class UdopProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "LayoutLMv3ImageProcessor"
tokenizer_class = ("UdopTokenizer", "UdopTokenizerFast")
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
optional_call_args = ["text_pair"]

def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)
Expand All @@ -94,14 +57,28 @@ def __call__(
self,
images: Optional[ImageInput] = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
# The following is to capture `text_pair` argument that may be passed as a positional argument.
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
# This behavior is only needed for backward compatibility and will be removed in future versions.
*args,
audio=None,
videos=None,
**kwargs: Unpack[UdopProcessorKwargs],
) -> BatchFeature:
text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
boxes: Union[List[List[int]], List[List[List[int]]]] = None,
word_labels: Optional[Union[List[int], List[List[int]]]] = None,
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair_target: Optional[
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchEncoding:
"""
This method first forwards the `images` argument to [`~UdopImageProcessor.__call__`]. In case
[`UdopImageProcessor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and
Expand All @@ -116,19 +93,6 @@ def __call__(
Please refer to the docstring of the above two methods for more information.
"""
# verify input
output_kwargs = self._merge_kwargs(
UdopProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
**self.prepare_and_validate_optional_call_args(*args),
)

boxes = output_kwargs["text_kwargs"].pop("boxes", None)
word_labels = output_kwargs["text_kwargs"].pop("word_labels", None)
text_pair = output_kwargs["text_kwargs"].pop("text_pair", None)
return_overflowing_tokens = output_kwargs["text_kwargs"].get("return_overflowing_tokens", False)
return_offsets_mapping = output_kwargs["text_kwargs"].get("return_offsets_mapping", False)

if self.image_processor.apply_ocr and (boxes is not None):
raise ValueError(
"You cannot provide bounding boxes if you initialized the image processor with apply_ocr set to True."
Expand All @@ -142,44 +106,66 @@ def __call__(
if return_overflowing_tokens is True and return_offsets_mapping is False:
raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.")

if output_kwargs["text_kwargs"].get("text_target", None) is not None:
if text_target is not None:
# use the processor to prepare the targets of UDOP
return self.tokenizer(
**output_kwargs["text_kwargs"],
text_target=text_target,
text_pair_target=text_pair_target,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
)

else:
# use the processor to prepare the inputs of UDOP
# first, apply the image processor
features = self.image_processor(images=images, **output_kwargs["images_kwargs"])
features_words = features.pop("words", None)
features_boxes = features.pop("boxes", None)

_ = output_kwargs["text_kwargs"].pop("text_target", None)
_ = output_kwargs["text_kwargs"].pop("text_pair_target", None)
output_kwargs["text_kwargs"]["text_pair"] = text_pair
output_kwargs["text_kwargs"]["boxes"] = boxes if boxes is not None else features_boxes
output_kwargs["text_kwargs"]["word_labels"] = word_labels
features = self.image_processor(images=images, return_tensors=return_tensors)

# second, apply the tokenizer
if text is not None and self.image_processor.apply_ocr and text_pair is None:
if isinstance(text, str):
text = [text] # add batch dimension (as the image processor always adds a batch dimension)
output_kwargs["text_kwargs"]["text_pair"] = features_words
text_pair = features["words"]

encoded_inputs = self.tokenizer(
text=text if text is not None else features_words,
**output_kwargs["text_kwargs"],
text=text if text is not None else features["words"],
text_pair=text_pair if text_pair is not None else None,
boxes=boxes if boxes is not None else features["boxes"],
word_labels=word_labels,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
)

# add pixel values
pixel_values = features.pop("pixel_values")
if return_overflowing_tokens is True:
features["pixel_values"] = self.get_overflowing_images(
features["pixel_values"], encoded_inputs["overflow_to_sample_mapping"]
)
features.update(encoded_inputs)
pixel_values = self.get_overflowing_images(pixel_values, encoded_inputs["overflow_to_sample_mapping"])
encoded_inputs["pixel_values"] = pixel_values

return features
return encoded_inputs

# Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.get_overflowing_images
def get_overflowing_images(self, images, overflow_to_sample_mapping):
Expand Down Expand Up @@ -212,20 +198,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)

@property
# Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.model_input_names
def model_input_names(self):
return ["pixel_values", "input_ids", "bbox", "attention_mask"]
return ["input_ids", "bbox", "attention_mask", "pixel_values"]
64 changes: 0 additions & 64 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@

from .tokenization_utils_base import (
PaddingStrategy,
PreTokenizedInput,
PreTrainedTokenizerBase,
TextInput,
TruncationStrategy,
)
from .utils import (
Expand Down Expand Up @@ -108,9 +106,6 @@ class TextKwargs(TypedDict, total=False):
The side on which padding will be applied.
"""

text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
text_pair_target: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]
add_special_tokens: Optional[bool]
padding: Union[bool, str, PaddingStrategy]
truncation: Union[bool, str, TruncationStrategy]
Expand Down Expand Up @@ -322,7 +317,6 @@ class ProcessorMixin(PushToHubMixin):

attributes = ["feature_extractor", "tokenizer"]
optional_attributes = ["chat_template"]
optional_call_args: List[str] = []
# Names need to be attr_class for attr in attributes
feature_extractor_class = None
tokenizer_class = None
Expand Down Expand Up @@ -964,64 +958,6 @@ def validate_init_kwargs(processor_config, valid_kwargs):
unused_kwargs = {k: processor_config[k] for k in unused_keys}
return unused_kwargs

def prepare_and_validate_optional_call_args(self, *args):
"""
Matches optional positional arguments to their corresponding names in `optional_call_args`
in the processor class in the order they are passed to the processor call.
Note that this should only be used in the `__call__` method of the processors with special
arguments. Special arguments are arguments that aren't `text`, `images`, `audio`, nor `videos`
but also aren't passed to the tokenizer, image processor, etc. Examples of such processors are:
- `CLIPSegProcessor`
- `LayoutLMv2Processor`
- `OwlViTProcessor`
Also note that passing by position to the processor call is now deprecated and will be disallowed
in future versions. We only have this for backward compatibility.
Example:
Suppose that the processor class has `optional_call_args = ["arg_name_1", "arg_name_2"]`.
And we define the call method as:
```python
def __call__(
self,
text: str,
images: Optional[ImageInput] = None,
*arg,
audio=None,
videos=None,
)
```
Then, if we call the processor as:
```python
images = [...]
processor("What is common in these images?", images, "arg_value_1", "arg_value_2")
```
Then, this method will return:
```python
{
"arg_name_1": "arg_value_1",
"arg_name_2": "arg_value_2",
}
```
which we could then pass as kwargs to `self._merge_kwargs`
"""
if len(args):
warnings.warn(
"Passing positional arguments to the processor call is now deprecated and will be disallowed in future versions. "
"Please pass all arguments as keyword arguments."
)
if len(args) > len(self.optional_call_args):
raise ValueError(
f"Expected *at most* {len(self.optional_call_args)} optional positional arguments in processor call"
f"which will be matched with {' '.join(self.optional_call_args)} in the order they are passed."
f"However, got {len(args)} positional arguments instead."
"Please pass all arguments as keyword arguments instead (e.g. `processor(arg_name_1=..., arg_name_2=...))`."
)
return {arg_name: arg_value for arg_value, arg_name in zip(args, self.optional_call_args)}

def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]]],
Expand Down
Loading

0 comments on commit f5d8507

Please sign in to comment.