Skip to content

Commit

Permalink
Fix Udop to return BatchFeature instead of BatchEncoding and uniformi…
Browse files Browse the repository at this point in the history
…ze kwargs
  • Loading branch information
yonigozlan committed Aug 9, 2024
1 parent 016d443 commit 22b5295
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 69 deletions.
208 changes: 141 additions & 67 deletions src/transformers/models/udop/processing_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,47 @@
Processor class for UDOP.
"""

import sys
from typing import List, Optional, Union

from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack


class UdopTextKwargs(TextKwargs, total=False):
text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]]
word_labels: Optional[Union[List[int], List[List[int]]]]
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
text_pair_target: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]
boxes: Union[List[List[int]], List[List[List[int]]]]


class UdopProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: UdopTextKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"truncation": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_length": False,
"verbose": True,
},
"images_kwargs": {
"num_image_tokens": 64,
},
}


class UdopProcessor(ProcessorMixin):
Expand Down Expand Up @@ -57,28 +92,29 @@ def __call__(
self,
images: Optional[ImageInput] = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
boxes: Union[List[List[int]], List[List[List[int]]]] = None,
word_labels: Optional[Union[List[int], List[List[int]]]] = None,
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair_target: Optional[
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchEncoding:
# text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
# boxes: Union[List[List[int]], List[List[List[int]]]] = None,
# word_labels: Optional[Union[List[int], List[List[int]]]] = None,
# text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
# text_pair_target: Optional[
# Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
# ] = None,
# add_special_tokens: bool = True,
# padding: Union[bool, str, PaddingStrategy] = False,
# truncation: Union[bool, str, TruncationStrategy] = False,
# max_length: Optional[int] = None,
# stride: int = 0,
# pad_to_multiple_of: Optional[int] = None,
# return_token_type_ids: Optional[bool] = None,
# return_attention_mask: Optional[bool] = None,
# return_overflowing_tokens: bool = False,
# return_special_tokens_mask: bool = False,
# return_offsets_mapping: bool = False,
# return_length: bool = False,
# verbose: bool = True,
# return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs: Unpack[UdopProcessorKwargs],
) -> BatchFeature:
"""
This method first forwards the `images` argument to [`~UdopImageProcessor.__call__`]. In case
[`UdopImageProcessor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and
Expand All @@ -93,6 +129,19 @@ def __call__(
Please refer to the docstring of the above two methods for more information.
"""
# verify input

output_kwargs = self._merge_kwargs(
UdopProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

boxes = output_kwargs["text_kwargs"].pop("boxes", None)
word_labels = output_kwargs["text_kwargs"].pop("word_labels", None)
text_pair = output_kwargs["text_kwargs"].pop("text_pair", None)
return_overflowing_tokens = output_kwargs["text_kwargs"].get("return_overflowing_tokens", False)
return_offsets_mapping = output_kwargs["text_kwargs"].get("return_offsets_mapping", False)

if self.image_processor.apply_ocr and (boxes is not None):
raise ValueError(
"You cannot provide bounding boxes if you initialized the image processor with apply_ocr set to True."
Expand All @@ -106,66 +155,77 @@ def __call__(
if return_overflowing_tokens is True and return_offsets_mapping is False:
raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.")

if text_target is not None:
if output_kwargs["text_kwargs"].get("text_target", None) is not None:
# use the processor to prepare the targets of UDOP
return self.tokenizer(
text_target=text_target,
text_pair_target=text_pair_target,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
# text_target=text_target,
# text_pair_target=text_pair_target,
# add_special_tokens=add_special_tokens,
# padding=padding,
# truncation=truncation,
# max_length=max_length,
# stride=stride,
# pad_to_multiple_of=pad_to_multiple_of,
# return_token_type_ids=return_token_type_ids,
# return_attention_mask=return_attention_mask,
# return_overflowing_tokens=return_overflowing_tokens,
# return_special_tokens_mask=return_special_tokens_mask,
# return_offsets_mapping=return_offsets_mapping,
# return_length=return_length,
# verbose=verbose,
# return_tensors=return_tensors,
**output_kwargs["text_kwargs"],
)

else:
# use the processor to prepare the inputs of UDOP
# first, apply the image processor
features = self.image_processor(images=images, return_tensors=return_tensors)
features = self.image_processor(images=images, **output_kwargs["images_kwargs"])
features_words = features.pop("words", None)
features_boxes = features.pop("boxes", None)

_ = output_kwargs["text_kwargs"].pop("text_target", None)
_ = output_kwargs["text_kwargs"].pop("text_pair_target", None)
output_kwargs["text_kwargs"]["text_pair"] = text_pair
output_kwargs["text_kwargs"]["boxes"] = boxes if boxes is not None else features_boxes
output_kwargs["text_kwargs"]["word_labels"] = word_labels

# second, apply the tokenizer
if text is not None and self.image_processor.apply_ocr and text_pair is None:
if isinstance(text, str):
text = [text] # add batch dimension (as the image processor always adds a batch dimension)
text_pair = features["words"]
output_kwargs["text_kwargs"]["text_pair"] = features_words

encoded_inputs = self.tokenizer(
text=text if text is not None else features["words"],
text_pair=text_pair if text_pair is not None else None,
boxes=boxes if boxes is not None else features["boxes"],
word_labels=word_labels,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
text=text if text is not None else features_words,
# text_pair=text_pair,
# boxes=boxes if boxes is not None else features_boxes,
# word_labels=word_labels,
# add_special_tokens=add_special_tokens,
# padding=padding,
# truncation=truncation,
# max_length=max_length,
# stride=stride,
# pad_to_multiple_of=pad_to_multiple_of,
# return_token_type_ids=return_token_type_ids,
# return_attention_mask=return_attention_mask,
# return_overflowing_tokens=return_overflowing_tokens,
# return_special_tokens_mask=return_special_tokens_mask,
# return_offsets_mapping=return_offsets_mapping,
# return_length=return_length,
# verbose=verbose,
# return_tensors=return_tensors,
**output_kwargs["text_kwargs"],
)

# add pixel values
pixel_values = features.pop("pixel_values")
if return_overflowing_tokens is True:
pixel_values = self.get_overflowing_images(pixel_values, encoded_inputs["overflow_to_sample_mapping"])
encoded_inputs["pixel_values"] = pixel_values
features["pixel_values"] = self.get_overflowing_images(
features["pixel_values"], encoded_inputs["overflow_to_sample_mapping"]
)
features.update(encoded_inputs)

return encoded_inputs
return features

# Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.get_overflowing_images
def get_overflowing_images(self, images, overflow_to_sample_mapping):
Expand Down Expand Up @@ -198,7 +258,21 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)

@property
# Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.model_input_names
def model_input_names(self):
return ["input_ids", "bbox", "attention_mask", "pixel_values"]
return ["pixel_values", "input_ids", "bbox", "attention_mask"]
6 changes: 4 additions & 2 deletions tests/models/udop/test_processor_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def test_processor_case_5(self):
question = "What's his name?"
words = ["hello", "world"]
boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
input_processor = processor(images[0], question, words, boxes, return_tensors="pt")
input_processor = processor(images[0], question, text_pair=words, boxes=boxes, return_tensors="pt")

# verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
Expand All @@ -492,7 +492,9 @@ def test_processor_case_5(self):
questions = ["How old is he?", "what's the time"]
words = [["hello", "world"], ["my", "name", "is", "niels"]]
boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
input_processor = processor(images, questions, words, boxes, padding=True, return_tensors="pt")
input_processor = processor(
images, questions, text_pair=words, boxes=boxes, padding=True, return_tensors="pt"
)

# verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
Expand Down

0 comments on commit 22b5295

Please sign in to comment.