From d179ca7071951666f8d7a8228f3cb27ce15bf632 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Mon, 16 Sep 2024 19:44:08 +0000 Subject: [PATCH 1/3] Add optional kwargs and uniformize udop --- src/transformers/models/udop/modeling_udop.py | 2 +- .../models/udop/processing_udop.py | 163 ++++++++++-------- tests/models/udop/test_processor_udop.py | 37 ++-- 3 files changed, 114 insertions(+), 88 deletions(-) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 972248daaae599..6f7b6cf060495a 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1790,7 +1790,7 @@ def forward( >>> # one can use the various task prefixes (prompts) used during pre-training >>> # e.g. the task prefix for DocVQA is "Question answering. " >>> question = "Question answering. What is the date on the form?" - >>> encoding = processor(image, question, words, boxes=boxes, return_tensors="pt") + >>> encoding = processor(image, question, text_pair=words, boxes=boxes, return_tensors="pt") >>> # autoregressive generation >>> predicted_ids = model.generate(**encoding) diff --git a/src/transformers/models/udop/processing_udop.py b/src/transformers/models/udop/processing_udop.py index 2902541d6f5b46..3d4cfc9ce4334e 100644 --- a/src/transformers/models/udop/processing_udop.py +++ b/src/transformers/models/udop/processing_udop.py @@ -16,12 +16,47 @@ Processor class for UDOP. """ +import sys from typing import List, Optional, Union +from transformers import logging + +from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessorMixin -from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy -from ...utils import TensorType +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs +from ...tokenization_utils_base import PreTokenizedInput, TextInput + + +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack + + +logger = logging.get_logger(__name__) + + +class UdopTextKwargs(TextKwargs, total=False): + word_labels: Optional[Union[List[int], List[List[int]]]] + boxes: Union[List[List[int]], List[List[List[int]]]] + + +class UdopProcessorKwargs(ProcessingKwargs, total=False): + text_kwargs: UdopTextKwargs + _defaults = { + "text_kwargs": { + "add_special_tokens": True, + "padding": False, + "truncation": False, + "stride": 0, + "return_overflowing_tokens": False, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_length": False, + "verbose": True, + }, + "images_kwargs": {}, + } class UdopProcessor(ProcessorMixin): @@ -49,6 +84,8 @@ class UdopProcessor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] image_processor_class = "LayoutLMv3ImageProcessor" tokenizer_class = ("UdopTokenizer", "UdopTokenizerFast") + # For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. + optional_call_args = ["text_pair"] def __init__(self, image_processor, tokenizer): super().__init__(image_processor, tokenizer) @@ -57,28 +94,14 @@ def __call__( self, images: Optional[ImageInput] = None, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, - boxes: Union[List[List[int]], List[List[List[int]]]] = None, - word_labels: Optional[Union[List[int], List[List[int]]]] = None, - text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, - text_pair_target: Optional[ - Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] - ] = None, - add_special_tokens: bool = True, - padding: Union[bool, str, PaddingStrategy] = False, - truncation: Union[bool, str, TruncationStrategy] = False, - max_length: Optional[int] = None, - stride: int = 0, - pad_to_multiple_of: Optional[int] = None, - return_token_type_ids: Optional[bool] = None, - return_attention_mask: Optional[bool] = None, - return_overflowing_tokens: bool = False, - return_special_tokens_mask: bool = False, - return_offsets_mapping: bool = False, - return_length: bool = False, - verbose: bool = True, - return_tensors: Optional[Union[str, TensorType]] = None, - ) -> BatchEncoding: + # The following is to capture `text_pair` argument that may be passed as a positional argument. + # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. + # This behavior is only needed for backward compatibility and will be removed in future versions. + *args, + audio=None, + videos=None, + **kwargs: Unpack[UdopProcessorKwargs], + ) -> BatchFeature: """ This method first forwards the `images` argument to [`~UdopImageProcessor.__call__`]. In case [`UdopImageProcessor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and @@ -93,6 +116,19 @@ def __call__( Please refer to the docstring of the above two methods for more information. """ # verify input + output_kwargs = self._merge_kwargs( + UdopProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + **self.prepare_and_validate_optional_call_args(*args), + ) + + boxes = output_kwargs["text_kwargs"].pop("boxes", None) + word_labels = output_kwargs["text_kwargs"].pop("word_labels", None) + text_pair = output_kwargs["text_kwargs"].pop("text_pair", None) + return_overflowing_tokens = output_kwargs["text_kwargs"].get("return_overflowing_tokens", False) + return_offsets_mapping = output_kwargs["text_kwargs"].get("return_offsets_mapping", False) + if self.image_processor.apply_ocr and (boxes is not None): raise ValueError( "You cannot provide bounding boxes if you initialized the image processor with apply_ocr set to True." @@ -106,66 +142,44 @@ def __call__( if return_overflowing_tokens is True and return_offsets_mapping is False: raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.") - if text_target is not None: + if output_kwargs["text_kwargs"].get("text_target", None) is not None: # use the processor to prepare the targets of UDOP return self.tokenizer( - text_target=text_target, - text_pair_target=text_pair_target, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, + **output_kwargs["text_kwargs"], ) else: # use the processor to prepare the inputs of UDOP # first, apply the image processor - features = self.image_processor(images=images, return_tensors=return_tensors) + features = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + features_words = features.pop("words", None) + features_boxes = features.pop("boxes", None) + + _ = output_kwargs["text_kwargs"].pop("text_target", None) + _ = output_kwargs["text_kwargs"].pop("text_pair_target", None) + output_kwargs["text_kwargs"]["text_pair"] = text_pair + output_kwargs["text_kwargs"]["boxes"] = boxes if boxes is not None else features_boxes + output_kwargs["text_kwargs"]["word_labels"] = word_labels # second, apply the tokenizer if text is not None and self.image_processor.apply_ocr and text_pair is None: if isinstance(text, str): text = [text] # add batch dimension (as the image processor always adds a batch dimension) - text_pair = features["words"] + output_kwargs["text_kwargs"]["text_pair"] = features_words encoded_inputs = self.tokenizer( - text=text if text is not None else features["words"], - text_pair=text_pair if text_pair is not None else None, - boxes=boxes if boxes is not None else features["boxes"], - word_labels=word_labels, - add_special_tokens=add_special_tokens, - padding=padding, - truncation=truncation, - max_length=max_length, - stride=stride, - pad_to_multiple_of=pad_to_multiple_of, - return_token_type_ids=return_token_type_ids, - return_attention_mask=return_attention_mask, - return_overflowing_tokens=return_overflowing_tokens, - return_special_tokens_mask=return_special_tokens_mask, - return_offsets_mapping=return_offsets_mapping, - return_length=return_length, - verbose=verbose, - return_tensors=return_tensors, + text=text if text is not None else features_words, + **output_kwargs["text_kwargs"], ) # add pixel values - pixel_values = features.pop("pixel_values") if return_overflowing_tokens is True: - pixel_values = self.get_overflowing_images(pixel_values, encoded_inputs["overflow_to_sample_mapping"]) - encoded_inputs["pixel_values"] = pixel_values + features["pixel_values"] = self.get_overflowing_images( + features["pixel_values"], encoded_inputs["overflow_to_sample_mapping"] + ) + features.update(encoded_inputs) - return encoded_inputs + return features # Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.get_overflowing_images def get_overflowing_images(self, images, overflow_to_sample_mapping): @@ -198,7 +212,20 @@ def decode(self, *args, **kwargs): """ return self.tokenizer.decode(*args, **kwargs) + def post_process_image_text_to_text(self, generated_outputs): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True) + @property - # Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.model_input_names def model_input_names(self): - return ["input_ids", "bbox", "attention_mask", "pixel_values"] + return ["pixel_values", "input_ids", "bbox", "attention_mask"] diff --git a/tests/models/udop/test_processor_udop.py b/tests/models/udop/test_processor_udop.py index 749ec7c3d6df78..621b761b5f17a1 100644 --- a/tests/models/udop/test_processor_udop.py +++ b/tests/models/udop/test_processor_udop.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -import os import shutil import tempfile import unittest @@ -34,7 +32,7 @@ require_torch, slow, ) -from transformers.utils import FEATURE_EXTRACTOR_NAME, cached_property, is_pytesseract_available, is_torch_available +from transformers.utils import cached_property, is_pytesseract_available, is_torch_available from ...test_processing_common import ProcessorTesterMixin @@ -55,20 +53,19 @@ class UdopProcessorTest(ProcessorTesterMixin, unittest.TestCase): tokenizer_class = UdopTokenizer rust_tokenizer_class = UdopTokenizerFast - maxDiff = None processor_class = UdopProcessor + maxDiff = None def setUp(self): - image_processor_map = { - "do_resize": True, - "size": 224, - "apply_ocr": True, - } - self.tmpdirname = tempfile.mkdtemp() - self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME) - with open(self.feature_extraction_file, "w", encoding="utf-8") as fp: - fp.write(json.dumps(image_processor_map) + "\n") + image_processor = LayoutLMv3ImageProcessor( + do_resize=True, + size=224, + apply_ocr=True, + ) + tokenizer = UdopTokenizer.from_pretrained("microsoft/udop-large") + processor = UdopProcessor(image_processor=image_processor, tokenizer=tokenizer) + processor.save_pretrained(self.tmpdirname) self.tokenizer_pretrained_name = "microsoft/udop-large" @@ -80,15 +77,15 @@ def setUp(self): def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer: return self.tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs) + def get_image_processor(self, **kwargs): + return LayoutLMv3ImageProcessor.from_pretrained(self.tmpdirname, **kwargs) + def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast: return self.rust_tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs) def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]: return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)] - def get_image_processor(self, **kwargs): - return LayoutLMv3ImageProcessor.from_pretrained(self.tmpdirname, **kwargs) - def tearDown(self): shutil.rmtree(self.tmpdirname) @@ -153,7 +150,7 @@ def test_model_input_names(self): input_str = "lower newer" image_input = self.prepare_image_inputs() - inputs = processor(text=input_str, images=image_input) + inputs = processor(images=image_input, text=input_str) self.assertListEqual(list(inputs.keys()), processor.model_input_names) @@ -472,7 +469,7 @@ def test_processor_case_5(self): question = "What's his name?" words = ["hello", "world"] boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] - input_processor = processor(images[0], question, words, boxes, return_tensors="pt") + input_processor = processor(images[0], question, text_pair=words, boxes=boxes, return_tensors="pt") # verify keys expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] @@ -488,7 +485,9 @@ def test_processor_case_5(self): questions = ["How old is he?", "what's the time"] words = [["hello", "world"], ["my", "name", "is", "niels"]] boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]] - input_processor = processor(images, questions, words, boxes, padding=True, return_tensors="pt") + input_processor = processor( + images, questions, text_pair=words, boxes=boxes, padding=True, return_tensors="pt" + ) # verify keys expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] From 7fd3037a7d0e07a7ee0ab2630dfaa639946a2319 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 20 Sep 2024 16:46:43 +0000 Subject: [PATCH 2/3] cleanup Unpack --- src/transformers/models/udop/processing_udop.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/transformers/models/udop/processing_udop.py b/src/transformers/models/udop/processing_udop.py index 3d4cfc9ce4334e..1eae94ea4609ff 100644 --- a/src/transformers/models/udop/processing_udop.py +++ b/src/transformers/models/udop/processing_udop.py @@ -16,23 +16,16 @@ Processor class for UDOP. """ -import sys from typing import List, Optional, Union from transformers import logging from ...image_processing_utils import BatchFeature from ...image_utils import ImageInput -from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs +from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput -if sys.version_info >= (3, 11): - from typing import Unpack -else: - from typing_extensions import Unpack - - logger = logging.get_logger(__name__) From bee3a60a6d5696de3f43feb0771c00e9b9a8223a Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 20 Sep 2024 19:53:32 +0000 Subject: [PATCH 3/3] nit Udop --- src/transformers/models/udop/processing_udop.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/udop/processing_udop.py b/src/transformers/models/udop/processing_udop.py index 1eae94ea4609ff..ddd5d484a98883 100644 --- a/src/transformers/models/udop/processing_udop.py +++ b/src/transformers/models/udop/processing_udop.py @@ -88,8 +88,10 @@ def __call__( images: Optional[ImageInput] = None, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, # The following is to capture `text_pair` argument that may be passed as a positional argument. - # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details. + # See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details, + # or this conversation for more context: https://github.com/huggingface/transformers/pull/32544#discussion_r1720208116 # This behavior is only needed for backward compatibility and will be removed in future versions. + # *args, audio=None, videos=None, @@ -121,6 +123,7 @@ def __call__( text_pair = output_kwargs["text_kwargs"].pop("text_pair", None) return_overflowing_tokens = output_kwargs["text_kwargs"].get("return_overflowing_tokens", False) return_offsets_mapping = output_kwargs["text_kwargs"].get("return_offsets_mapping", False) + text_target = output_kwargs["text_kwargs"].get("text_target", None) if self.image_processor.apply_ocr and (boxes is not None): raise ValueError( @@ -132,10 +135,10 @@ def __call__( "You cannot provide word labels if you initialized the image processor with apply_ocr set to True." ) - if return_overflowing_tokens is True and return_offsets_mapping is False: + if return_overflowing_tokens and not return_offsets_mapping: raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.") - if output_kwargs["text_kwargs"].get("text_target", None) is not None: + if text_target is not None: # use the processor to prepare the targets of UDOP return self.tokenizer( **output_kwargs["text_kwargs"], @@ -148,8 +151,8 @@ def __call__( features_words = features.pop("words", None) features_boxes = features.pop("boxes", None) - _ = output_kwargs["text_kwargs"].pop("text_target", None) - _ = output_kwargs["text_kwargs"].pop("text_pair_target", None) + output_kwargs["text_kwargs"].pop("text_target", None) + output_kwargs["text_kwargs"].pop("text_pair_target", None) output_kwargs["text_kwargs"]["text_pair"] = text_pair output_kwargs["text_kwargs"]["boxes"] = boxes if boxes is not None else features_boxes output_kwargs["text_kwargs"]["word_labels"] = word_labels