Skip to content

Commit

Permalink
add tests and more docs
Browse files Browse the repository at this point in the history
  • Loading branch information
leloykun committed Aug 16, 2024
1 parent e246704 commit 9169bca
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 19 deletions.
2 changes: 2 additions & 0 deletions src/transformers/models/donut/image_processing_donut.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def pad_image(
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
size = get_size_dict(size)
output_height, output_width = size["height"], size["width"]
input_height, input_width = get_image_size(image, channel_dim=input_data_format)

Expand Down Expand Up @@ -232,6 +233,7 @@ def thumbnail(
The channel dimension format of the input image. If not provided, it will be inferred.
"""
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
size = get_size_dict(size)
output_height, output_width = size["height"], size["width"]

# We always resize to the smallest of either the input or output size.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/fuyu/image_processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
pad,
resize,
Expand Down Expand Up @@ -344,6 +344,7 @@ def pad_image(
The channel dimension format of the input image. If not provided, it will be inferred.
"""
image_height, image_width = get_image_size(image, input_data_format)
size = get_size_dict(size)
target_height, target_width = size["height"], size["width"]
padding_top = 0
padding_left = 0
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/nougat/image_processing_nougat.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def pad_image(
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
size = get_size_dict(size)
output_height, output_width = size["height"], size["width"]
input_height, input_width = get_image_size(image, channel_dim=input_data_format)

Expand Down Expand Up @@ -292,6 +293,7 @@ def thumbnail(
The channel dimension format of the input image. If not provided, it will be inferred.
"""
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
size = get_size_dict(size)
output_height, output_width = size["height"], size["width"]

# We always resize to the smallest of either the input or output size.
Expand Down
14 changes: 8 additions & 6 deletions src/transformers/models/nougat/processing_nougat.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def __call__(
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# Temporary fix for "paddding_side" in init_kwargs
_ = output_kwargs["text_kwargs"].pop("padding_side", None)

# For backwards compatibility, we reuse `audio` as `text_pair`
# in case downstream users passed it as a positional argument
if output_kwargs["text_kwargs"].get("text_pair") is not None and audio is not None:
raise ValueError(
"You cannot provide `text_pair` as a positional argument and as a keyword argument at the same time."
Expand All @@ -113,11 +117,11 @@ def __call__(
warnings.warn(
"No `text_pair` kwarg was detected. The use of `text_pair` as an argument without specifying it explicitely as `text_pair=` will be deprecated in future versions."
)
# For backwards compatibility, we reuse `audio` as `text_pair` in case
# downstream users passed it as a positional argument
if audio is not None:
output_kwargs["text_kwargs"]["text_pair"] = audio

# For backwards compatibility, we reuse `videos` as `text_target`
# in case downstream users passed it as a positional argument
if output_kwargs["text_kwargs"].get("text_target") is not None and videos is not None:
raise ValueError(
"You cannot provide `text_target` as a positional argument and as a keyword argument at the same time."
Expand All @@ -127,11 +131,11 @@ def __call__(
warnings.warn(
"No `text_target` kwarg was detected. The use of `text_target` as an argument without specifying it explicitely as `text_target=` will be deprecated in future versions."
)
# For backwards compatibility, we reuse `videos` as `text_target` in case
# downstream users passed it as a positional argument
if videos is not None:
output_kwargs["text_kwargs"]["text_target"] = videos

# For backwards compatibility, we reuse `backwards_compatibility_placeholder_arg` as `text_pair_target`
# in case downstream users passed it as a positional argument
if (
output_kwargs["text_kwargs"].get("text_pair_target") is not None
and backwards_compatibility_placeholder_arg is not None
Expand All @@ -144,8 +148,6 @@ def __call__(
warnings.warn(
"No `text_pair_target` kwarg was detected. The use of `text_pair_target` as an argument without specifying it explicitely as `text_pair_target=` will be deprecated in future versions."
)
# For backwards compatibility, we reuse `backwards_compatibility_placeholder_arg` as `text_pair_target` in case
# downstream users passed it as a positional argument
if backwards_compatibility_placeholder_arg is not None:
output_kwargs["text_kwargs"]["text_pair_target"] = backwards_compatibility_placeholder_arg

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/sam/image_processing_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def pad_image(
input_data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
pad_size = get_size_dict(pad_size)
output_height, output_width = pad_size["height"], pad_size["width"]
input_height, input_width = get_image_size(image, channel_dim=input_data_format)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/tvp/image_processing_tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def pad_image(
The channel dimension format of the input image. If not provided, it will be inferred.
"""
height, width = get_image_size(image, channel_dim=input_data_format)
pad_size = get_size_dict(pad_size)
max_height = pad_size.get("height", height)
max_width = pad_size.get("width", width)

Expand Down
17 changes: 17 additions & 0 deletions tests/models/nougat/test_processor_nougat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import tempfile
import unittest

from transformers import NougatProcessor

from ...test_processing_common import ProcessorTesterMixin


class NougatProcessorTest(ProcessorTesterMixin, unittest.TestCase):
from_pretrained_id = "facebook/nougat-base"
text_data_arg_name = "labels"
processor_class = NougatProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
processor = self.processor_class.from_pretrained(self.from_pretrained_id)
processor.save_pretrained(self.tmpdirname)
26 changes: 14 additions & 12 deletions tests/test_processing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
@require_vision
@require_torch
class ProcessorTesterMixin:
image_data_arg_name = "pixel_values"
text_data_arg_name = "input_ids"
processor_class = None

def prepare_processor_dict(self):
Expand Down Expand Up @@ -136,7 +138,7 @@ def test_tokenizer_defaults_preserved_by_kwargs(self):
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input, return_tensors="pt")
self.assertEqual(len(inputs["input_ids"][0]), 117)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 117)

@require_torch
@require_vision
Expand All @@ -153,7 +155,7 @@ def test_image_processor_defaults_preserved_by_image_kwargs(self):
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)
self.assertEqual(len(inputs[self.image_data_arg_name][0][0]), 234)

@require_vision
@require_torch
Expand All @@ -171,7 +173,7 @@ def test_kwargs_overrides_default_tokenizer_kwargs(self):
inputs = processor(
text=input_str, images=image_input, return_tensors="pt", max_length=112, padding="max_length"
)
self.assertEqual(len(inputs["input_ids"][0]), 112)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 112)

@require_torch
@require_vision
Expand All @@ -188,7 +190,7 @@ def test_kwargs_overrides_default_image_processor_kwargs(self):
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input, size=[224, 224])
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)
self.assertEqual(len(inputs[self.image_data_arg_name][0][0]), 224)

@require_torch
@require_vision
Expand All @@ -212,8 +214,8 @@ def test_unstructured_kwargs(self):
max_length=76,
)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)
self.assertEqual(inputs[self.image_data_arg_name].shape[2], 214)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 76)

@require_torch
@require_vision
Expand All @@ -237,9 +239,9 @@ def test_unstructured_kwargs_batched(self):
max_length=76,
)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(inputs[self.image_data_arg_name].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 6)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 6)

@require_torch
@require_vision
Expand Down Expand Up @@ -286,9 +288,9 @@ def test_structured_kwargs_nested(self):
inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(inputs[self.image_data_arg_name].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 76)

@require_torch
@require_vision
Expand All @@ -312,9 +314,9 @@ def test_structured_kwargs_nested_from_dict(self):
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(inputs[self.image_data_arg_name].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)
self.assertEqual(len(inputs[self.text_data_arg_name][0]), 76)


class MyProcessor(ProcessorMixin):
Expand Down

0 comments on commit 9169bca

Please sign in to comment.