Skip to content

Commit

Permalink
change text pair handling when positional arg
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Sep 13, 2024
1 parent f1f360c commit 061710c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
18 changes: 8 additions & 10 deletions src/transformers/models/udop/processing_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@


class UdopTextKwargs(TextKwargs, total=False):
text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]]
word_labels: Optional[Union[List[int], List[List[int]]]]
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
text_pair_target: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]]
boxes: Union[List[List[int]], List[List[List[int]]]]


Expand Down Expand Up @@ -87,6 +84,8 @@ class UdopProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "LayoutLMv3ImageProcessor"
tokenizer_class = ("UdopTokenizer", "UdopTokenizerFast")
# For backward compatibility. See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
optional_call_args = ["text_pair"]

def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)
Expand All @@ -95,6 +94,10 @@ def __call__(
self,
images: Optional[ImageInput] = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
# The following is to capture `text_pair` argument that may be passed as a positional argument.
# See transformers.processing_utils.ProcessorMixin.prepare_and_validate_optional_call_args for more details.
# This behavior is only needed for backward compatibility and will be removed in future versions.
*args,
audio=None,
videos=None,
**kwargs: Unpack[UdopProcessorKwargs],
Expand All @@ -117,13 +120,8 @@ def __call__(
UdopProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
**self.prepare_and_validate_optional_call_args(*args),
)
# for BC
if "text_pair " not in output_kwargs["text_kwargs"] and audio is not None:
logger.warning_once(
"The use of `text_pair` as an argument without specifying it explicitely as `text_pair=` will be deprecated in future versions."
)
output_kwargs["text_kwargs"]["text_pair"] = audio

boxes = output_kwargs["text_kwargs"].pop("boxes", None)
word_labels = output_kwargs["text_kwargs"].pop("word_labels", None)
Expand Down Expand Up @@ -230,4 +228,4 @@ def post_process_image_text_to_text(self, generated_outputs):

@property
def model_input_names(self):
return ["pixel_values", "input_ids", "attention_mask", "bbox"]
return ["pixel_values", "input_ids", "bbox", "attention_mask"]
7 changes: 3 additions & 4 deletions tests/models/udop/test_processor_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
class UdopProcessorTest(ProcessorTesterMixin, unittest.TestCase):
tokenizer_class = UdopTokenizer
rust_tokenizer_class = UdopTokenizerFast
image_processor_class = LayoutLMv3ImageProcessor
processor_class = UdopProcessor
maxDiff = None

Expand All @@ -80,7 +79,7 @@ def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
return self.tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs)

def get_image_processor(self, **kwargs):
return self.image_processor_class.from_pretrained(self.tmpdirname, **kwargs).image_processor
return LayoutLMv3ImageProcessor.from_pretrained(self.tmpdirname, **kwargs)

def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast:
return self.rust_tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs)
Expand Down Expand Up @@ -152,7 +151,7 @@ def test_model_input_names(self):
input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input)
inputs = processor(images=image_input, text=input_str)

self.assertListEqual(list(inputs.keys()), processor.model_input_names)

Expand Down Expand Up @@ -221,8 +220,8 @@ def test_unstructured_kwargs_batched(self):
input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
text=input_str,
return_tensors="pt",
size={"height": 214, "width": 214},
padding="longest",
Expand Down

0 comments on commit 061710c

Please sign in to comment.