Skip to content

Commit

Permalink
remove optional args and udop uniformization from this PR
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Sep 20, 2024
1 parent 41b5d4c commit 7420705
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 154 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/udop/modeling_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1790,7 +1790,7 @@ def forward(
>>> # one can use the various task prefixes (prompts) used during pre-training
>>> # e.g. the task prefix for DocVQA is "Question answering. "
>>> question = "Question answering. What is the date on the form?"
>>> encoding = processor(image, question, text_pair=words, boxes=boxes, return_tensors="pt")
>>> encoding = processor(image, question, words, boxes=boxes, return_tensors="pt")
>>> # autoregressive generation
>>> predicted_ids = model.generate(**encoding)
Expand Down
163 changes: 68 additions & 95 deletions src/transformers/models/udop/processing_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,12 @@
Processor class for UDOP.
"""

import sys
from typing import List, Optional, Union

from transformers import logging

from ...image_processing_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
from ...tokenization_utils_base import PreTokenizedInput, TextInput


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack


logger = logging.get_logger(__name__)


class UdopTextKwargs(TextKwargs, total=False):
word_labels: Optional[Union[List[int], List[List[int]]]]
boxes: Union[List[List[int]], List[List[List[int]]]]


class UdopProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: UdopTextKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"truncation": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_length": False,
"verbose": True,
},
"images_kwargs": {},
}
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType


class UdopProcessor(ProcessorMixin):
Expand Down Expand Up @@ -84,8 +49,6 @@ class UdopProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "LayoutLMv3ImageProcessor"
tokenizer_class = ("UdopTokenizer", "UdopTokenizerFast")
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
optional_call_args = ["text_pair"]

def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)
Expand All @@ -94,14 +57,28 @@ def __call__(
self,
images: Optional[ImageInput] = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
# The following is to capture `text_pair` argument that may be passed as a positional argument.
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
# This behavior is only needed for backward compatibility and will be removed in future versions.
*args,
audio=None,
videos=None,
**kwargs: Unpack[UdopProcessorKwargs],
) -> BatchFeature:
text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None,
boxes: Union[List[List[int]], List[List[List[int]]]] = None,
word_labels: Optional[Union[List[int], List[List[int]]]] = None,
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair_target: Optional[
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchEncoding:
"""
This method first forwards the `images` argument to [`~UdopImageProcessor.__call__`]. In case
[`UdopImageProcessor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and
Expand All @@ -116,19 +93,6 @@ def __call__(
Please refer to the docstring of the above two methods for more information.
"""
# verify input
output_kwargs = self._merge_kwargs(
UdopProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
**self.prepare_and_validate_optional_call_args(*args),
)

boxes = output_kwargs["text_kwargs"].pop("boxes", None)
word_labels = output_kwargs["text_kwargs"].pop("word_labels", None)
text_pair = output_kwargs["text_kwargs"].pop("text_pair", None)
return_overflowing_tokens = output_kwargs["text_kwargs"].get("return_overflowing_tokens", False)
return_offsets_mapping = output_kwargs["text_kwargs"].get("return_offsets_mapping", False)

if self.image_processor.apply_ocr and (boxes is not None):
raise ValueError(
"You cannot provide bounding boxes if you initialized the image processor with apply_ocr set to True."
Expand All @@ -142,44 +106,66 @@ def __call__(
if return_overflowing_tokens is True and return_offsets_mapping is False:
raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.")

if output_kwargs["text_kwargs"].get("text_target", None) is not None:
if text_target is not None:
# use the processor to prepare the targets of UDOP
return self.tokenizer(
**output_kwargs["text_kwargs"],
text_target=text_target,
text_pair_target=text_pair_target,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
)

else:
# use the processor to prepare the inputs of UDOP
# first, apply the image processor
features = self.image_processor(images=images, **output_kwargs["images_kwargs"])
features_words = features.pop("words", None)
features_boxes = features.pop("boxes", None)

_ = output_kwargs["text_kwargs"].pop("text_target", None)
_ = output_kwargs["text_kwargs"].pop("text_pair_target", None)
output_kwargs["text_kwargs"]["text_pair"] = text_pair
output_kwargs["text_kwargs"]["boxes"] = boxes if boxes is not None else features_boxes
output_kwargs["text_kwargs"]["word_labels"] = word_labels
features = self.image_processor(images=images, return_tensors=return_tensors)

# second, apply the tokenizer
if text is not None and self.image_processor.apply_ocr and text_pair is None:
if isinstance(text, str):
text = [text] # add batch dimension (as the image processor always adds a batch dimension)
output_kwargs["text_kwargs"]["text_pair"] = features_words
text_pair = features["words"]

encoded_inputs = self.tokenizer(
text=text if text is not None else features_words,
**output_kwargs["text_kwargs"],
text=text if text is not None else features["words"],
text_pair=text_pair if text_pair is not None else None,
boxes=boxes if boxes is not None else features["boxes"],
word_labels=word_labels,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
)

# add pixel values
pixel_values = features.pop("pixel_values")
if return_overflowing_tokens is True:
features["pixel_values"] = self.get_overflowing_images(
features["pixel_values"], encoded_inputs["overflow_to_sample_mapping"]
)
features.update(encoded_inputs)
pixel_values = self.get_overflowing_images(pixel_values, encoded_inputs["overflow_to_sample_mapping"])
encoded_inputs["pixel_values"] = pixel_values

return features
return encoded_inputs

# Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.get_overflowing_images
def get_overflowing_images(self, images, overflow_to_sample_mapping):
Expand Down Expand Up @@ -212,20 +198,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)

@property
# Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.model_input_names
def model_input_names(self):
return ["pixel_values", "input_ids", "bbox", "attention_mask"]
return ["input_ids", "bbox", "attention_mask", "pixel_values"]
6 changes: 0 additions & 6 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@

from .tokenization_utils_base import (
PaddingStrategy,
PreTokenizedInput,
PreTrainedTokenizerBase,
TextInput,
TruncationStrategy,
)
from .utils import (
Expand Down Expand Up @@ -116,9 +114,6 @@ class TextKwargs(TypedDict, total=False):
The side on which padding will be applied.
"""

text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
text_pair_target: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]
add_special_tokens: Optional[bool]
padding: Union[bool, str, PaddingStrategy]
truncation: Union[bool, str, TruncationStrategy]
Expand Down Expand Up @@ -333,7 +328,6 @@ class ProcessorMixin(PushToHubMixin):

attributes = ["feature_extractor", "tokenizer"]
optional_attributes = ["chat_template"]
optional_call_args: List[str] = []
# Names need to be attr_class for attr in attributes
feature_extractor_class = None
tokenizer_class = None
Expand Down
63 changes: 19 additions & 44 deletions tests/models/udop/test_processor_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import shutil
import tempfile
import unittest
Expand All @@ -30,10 +32,9 @@
require_sentencepiece,
require_tokenizers,
require_torch,
require_vision,
slow,
)
from transformers.utils import cached_property, is_pytesseract_available, is_torch_available
from transformers.utils import FEATURE_EXTRACTOR_NAME, cached_property, is_pytesseract_available, is_torch_available

from ...test_processing_common import ProcessorTesterMixin

Expand All @@ -54,19 +55,20 @@
class UdopProcessorTest(ProcessorTesterMixin, unittest.TestCase):
tokenizer_class = UdopTokenizer
rust_tokenizer_class = UdopTokenizerFast
processor_class = UdopProcessor
maxDiff = None
processor_class = UdopProcessor

def setUp(self):
image_processor_map = {
"do_resize": True,
"size": 224,
"apply_ocr": True,
}

self.tmpdirname = tempfile.mkdtemp()
image_processor = LayoutLMv3ImageProcessor(
do_resize=True,
size=224,
apply_ocr=True,
)
tokenizer = UdopTokenizer.from_pretrained("microsoft/udop-large")
processor = UdopProcessor(image_processor=image_processor, tokenizer=tokenizer)
processor.save_pretrained(self.tmpdirname)
self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME)
with open(self.feature_extraction_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(image_processor_map) + "\n")

self.tokenizer_pretrained_name = "microsoft/udop-large"

Expand All @@ -78,15 +80,15 @@ def setUp(self):
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
return self.tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs)

def get_image_processor(self, **kwargs):
return LayoutLMv3ImageProcessor.from_pretrained(self.tmpdirname, **kwargs)

def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
return self.rust_tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs)

def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]:
return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)]

def get_image_processor(self, **kwargs):
return LayoutLMv3ImageProcessor.from_pretrained(self.tmpdirname, **kwargs)

def tearDown(self):
shutil.rmtree(self.tmpdirname)

Expand Down Expand Up @@ -151,7 +153,7 @@ def test_model_input_names(self):
input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(images=image_input, text=input_str)
inputs = processor(text=input_str, images=image_input)

self.assertListEqual(list(inputs.keys()), processor.model_input_names)

Expand Down Expand Up @@ -206,31 +208,6 @@ def preprocess_data(examples):

self.assertEqual(len(train_data["pixel_values"]), len(train_data["input_ids"]))

@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
images=image_input,
text=input_str,
return_tensors="pt",
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 5)


# different use cases tests
@require_sentencepiece
Expand Down Expand Up @@ -495,7 +472,7 @@ def test_processor_case_5(self):
question = "What's his name?"
words = ["hello", "world"]
boxes = [[1, 2, 3, 4], [5, 6, 7, 8]]
input_processor = processor(images[0], question, text_pair=words, boxes=boxes, return_tensors="pt")
input_processor = processor(images[0], question, words, boxes, return_tensors="pt")

# verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
Expand All @@ -511,9 +488,7 @@ def test_processor_case_5(self):
questions = ["How old is he?", "what's the time"]
words = [["hello", "world"], ["my", "name", "is", "niels"]]
boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]]
input_processor = processor(
images, questions, text_pair=words, boxes=boxes, padding=True, return_tensors="pt"
)
input_processor = processor(images, questions, words, boxes, padding=True, return_tensors="pt")

# verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
Expand Down
Loading

0 comments on commit 7420705

Please sign in to comment.