Skip to content

Commit

Permalink
address @zucchini-nlp's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
leloykun committed Aug 20, 2024
1 parent 5fd2c32 commit 9e00f68
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 41 deletions.
25 changes: 16 additions & 9 deletions src/transformers/models/mgp_str/processing_mgp_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,35 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")
if "char_tokenizer" in kwargs:
warnings.warn(
"The `char_tokenizer` argument is deprecated and will be removed in future versions, use `tokenizer`"
" instead.",
FutureWarning,
)
char_tokenizer = kwargs.pop("char_tokenizer")

image_processor = image_processor if image_processor is not None else feature_extractor
tokenizer = tokenizer if tokenizer is not None else char_tokenizer
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")

self.tokenizer = tokenizer
self.char_tokenizer = tokenizer # For backwards compatibility
self.bpe_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
self.wp_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")

super().__init__(image_processor, tokenizer)

@property
def char_tokenizer(self):
warnings.warn(
"The `char_tokenizer` attribute is deprecated and will be removed in future versions, use `tokenizer` instead.",
FutureWarning,
)
return self.tokenizer

@char_tokenizer.setter
def char_tokenizer(self, value):
warnings.warn(
"The `char_tokenizer` attribute is deprecated and will be removed in future versions, use `tokenizer` instead.",
FutureWarning,
)
self.tokenizer = value

def __call__(
self,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
Expand Down
24 changes: 7 additions & 17 deletions src/transformers/models/tvp/image_processing_tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@

# Copied from transformers.models.vivit.image_processing_vivit.make_batched
def make_batched(videos) -> List[List[ImageInput]]:
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
if isinstance(videos, np.ndarray) and videos.ndim == 5:
return videos

elif isinstance(videos, np.ndarray) and videos.ndim == 4:
return [videos]

elif isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
return videos

elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
Expand Down Expand Up @@ -305,30 +311,20 @@ def _preprocess_image(
# All transformations expect numpy arrays.
image = to_numpy_array(image)

print(f"{image.shape = }")

if do_resize:
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)

print(f"{image.shape = }")

if do_center_crop:
image = self.center_crop(image, size=crop_size, input_data_format=input_data_format)

print(f"{image.shape = }")

if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

print(f"{image.shape = }")

if do_normalize:
image = self.normalize(
image=image.astype(np.float32), mean=image_mean, std=image_std, input_data_format=input_data_format
)

print(f"{image.shape = }")

if do_pad:
image = self.pad_image(
image=image,
Expand All @@ -338,18 +334,12 @@ def _preprocess_image(
input_data_format=input_data_format,
)

print(f"{image.shape = }")

# the pretrained checkpoints assume images are BGR, not RGB
if do_flip_channel_order:
image = flip_channel_order(image=image, input_data_format=input_data_format)

print(f"{image.shape = }")

image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)

print(f"{image.shape = }")

return image

@filter_out_non_signature_kwargs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,15 @@
logger = logging.get_logger(__name__)


# Copied from transformers.models.vivit.image_processing_vivit.make_batched
def make_batched(videos) -> List[List[ImageInput]]:
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
if isinstance(videos, np.ndarray) and videos.ndim == 5:
return videos

elif isinstance(videos, np.ndarray) and videos.ndim == 4:
return [videos]

elif isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
return videos

elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/vivit/image_processing_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@


def make_batched(videos) -> List[List[ImageInput]]:
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
if isinstance(videos, np.ndarray) and videos.ndim == 5:
return videos

elif isinstance(videos, np.ndarray) and videos.ndim == 4:
return [videos]

elif isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
return videos

elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
Expand Down
6 changes: 0 additions & 6 deletions tests/models/tvp/test_processor_tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import tempfile
import unittest

import numpy as np

from transformers import TvpProcessor
from transformers.testing_utils import require_torch, require_vision

Expand All @@ -20,10 +18,6 @@ def setUp(self):
processor = self.processor_class.from_pretrained(self.from_pretrained_id)
processor.save_pretrained(self.tmpdirname)

@require_vision
def prepare_video_inputs(self):
return [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]

@require_torch
@require_vision
def test_video_processor_defaults_preserved_by_kwargs(self):
Expand Down
6 changes: 0 additions & 6 deletions tests/models/x_clip/test_processor_x_clip.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import tempfile
import unittest

import numpy as np

from transformers import XCLIPProcessor
from transformers.testing_utils import require_torch, require_vision

Expand All @@ -19,10 +17,6 @@ def setUp(self):
processor = self.processor_class.from_pretrained(self.from_pretrained_id)
processor.save_pretrained(self.tmpdirname)

@require_vision
def prepare_video_inputs(self):
return [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]

@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_processing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def prepare_image_inputs(self):

@require_vision
def prepare_video_inputs(self):
return [np.random.randint(255, size=(4, 3, 30, 400), dtype=np.uint8)]
return np.random.randint(255, size=(1, 4, 3, 30, 400), dtype=np.uint8)

def test_processor_to_json_string(self):
processor = self.get_processor()
Expand Down

0 comments on commit 9e00f68

Please sign in to comment.