diff --git a/src/transformers/models/mgp_str/processing_mgp_str.py b/src/transformers/models/mgp_str/processing_mgp_str.py index 7e30a0336b809f..169d8adcec7b8a 100644 --- a/src/transformers/models/mgp_str/processing_mgp_str.py +++ b/src/transformers/models/mgp_str/processing_mgp_str.py @@ -78,28 +78,35 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs): FutureWarning, ) feature_extractor = kwargs.pop("feature_extractor") - if "char_tokenizer" in kwargs: - warnings.warn( - "The `char_tokenizer` argument is deprecated and will be removed in future versions, use `tokenizer`" - " instead.", - FutureWarning, - ) - char_tokenizer = kwargs.pop("char_tokenizer") image_processor = image_processor if image_processor is not None else feature_extractor - tokenizer = tokenizer if tokenizer is not None else char_tokenizer if image_processor is None: raise ValueError("You need to specify an `image_processor`.") if tokenizer is None: raise ValueError("You need to specify a `tokenizer`.") self.tokenizer = tokenizer - self.char_tokenizer = tokenizer # For backwards compatibility self.bpe_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") self.wp_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") super().__init__(image_processor, tokenizer) + @property + def char_tokenizer(self): + warnings.warn( + "The `char_tokenizer` attribute is deprecated and will be removed in future versions, use `tokenizer` instead.", + FutureWarning, + ) + return self.tokenizer + + @char_tokenizer.setter + def char_tokenizer(self, value): + warnings.warn( + "The `char_tokenizer` attribute is deprecated and will be removed in future versions, use `tokenizer` instead.", + FutureWarning, + ) + self.tokenizer = value + def __call__( self, text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, diff --git a/src/transformers/models/tvp/image_processing_tvp.py b/src/transformers/models/tvp/image_processing_tvp.py index 4e9618eef17084..7a4c5db004671e 100644 --- a/src/transformers/models/tvp/image_processing_tvp.py +++ b/src/transformers/models/tvp/image_processing_tvp.py @@ -50,7 +50,13 @@ # Copied from transformers.models.vivit.image_processing_vivit.make_batched def make_batched(videos) -> List[List[ImageInput]]: - if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): + if isinstance(videos, np.ndarray) and videos.ndim == 5: + return videos + + elif isinstance(videos, np.ndarray) and videos.ndim == 4: + return [videos] + + elif isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): return videos elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): @@ -305,30 +311,20 @@ def _preprocess_image( # All transformations expect numpy arrays. image = to_numpy_array(image) - print(f"{image.shape = }") - if do_resize: image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) - print(f"{image.shape = }") - if do_center_crop: image = self.center_crop(image, size=crop_size, input_data_format=input_data_format) - print(f"{image.shape = }") - if do_rescale: image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) - print(f"{image.shape = }") - if do_normalize: image = self.normalize( image=image.astype(np.float32), mean=image_mean, std=image_std, input_data_format=input_data_format ) - print(f"{image.shape = }") - if do_pad: image = self.pad_image( image=image, @@ -338,18 +334,12 @@ def _preprocess_image( input_data_format=input_data_format, ) - print(f"{image.shape = }") - # the pretrained checkpoints assume images are BGR, not RGB if do_flip_channel_order: image = flip_channel_order(image=image, input_data_format=input_data_format) - print(f"{image.shape = }") - image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) - print(f"{image.shape = }") - return image @filter_out_non_signature_kwargs() diff --git a/src/transformers/models/videomae/image_processing_videomae.py b/src/transformers/models/videomae/image_processing_videomae.py index 413589523aa675..c21210faf6670c 100644 --- a/src/transformers/models/videomae/image_processing_videomae.py +++ b/src/transformers/models/videomae/image_processing_videomae.py @@ -47,8 +47,15 @@ logger = logging.get_logger(__name__) +# Copied from transformers.models.vivit.image_processing_vivit.make_batched def make_batched(videos) -> List[List[ImageInput]]: - if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): + if isinstance(videos, np.ndarray) and videos.ndim == 5: + return videos + + elif isinstance(videos, np.ndarray) and videos.ndim == 4: + return [videos] + + elif isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): return videos elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): diff --git a/src/transformers/models/vivit/image_processing_vivit.py b/src/transformers/models/vivit/image_processing_vivit.py index 5f251bbd1b95b9..fb959e9f1eddb2 100644 --- a/src/transformers/models/vivit/image_processing_vivit.py +++ b/src/transformers/models/vivit/image_processing_vivit.py @@ -51,7 +51,13 @@ def make_batched(videos) -> List[List[ImageInput]]: - if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): + if isinstance(videos, np.ndarray) and videos.ndim == 5: + return videos + + elif isinstance(videos, np.ndarray) and videos.ndim == 4: + return [videos] + + elif isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): return videos elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): diff --git a/tests/models/tvp/test_processor_tvp.py b/tests/models/tvp/test_processor_tvp.py index 8f5e0bd6b5d05d..40d700e0beea15 100644 --- a/tests/models/tvp/test_processor_tvp.py +++ b/tests/models/tvp/test_processor_tvp.py @@ -2,8 +2,6 @@ import tempfile import unittest -import numpy as np - from transformers import TvpProcessor from transformers.testing_utils import require_torch, require_vision @@ -20,10 +18,6 @@ def setUp(self): processor = self.processor_class.from_pretrained(self.from_pretrained_id) processor.save_pretrained(self.tmpdirname) - @require_vision - def prepare_video_inputs(self): - return [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] - @require_torch @require_vision def test_video_processor_defaults_preserved_by_kwargs(self): diff --git a/tests/models/x_clip/test_processor_x_clip.py b/tests/models/x_clip/test_processor_x_clip.py index e9d0bf4b2539ee..5b34855a67252a 100644 --- a/tests/models/x_clip/test_processor_x_clip.py +++ b/tests/models/x_clip/test_processor_x_clip.py @@ -1,8 +1,6 @@ import tempfile import unittest -import numpy as np - from transformers import XCLIPProcessor from transformers.testing_utils import require_torch, require_vision @@ -19,10 +17,6 @@ def setUp(self): processor = self.processor_class.from_pretrained(self.from_pretrained_id) processor.save_pretrained(self.tmpdirname) - @require_vision - def prepare_video_inputs(self): - return [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] - @require_torch @require_vision def test_image_processor_defaults_preserved_by_image_kwargs(self): diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index ebb5e6f74f3d07..53cfcf5520c053 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -93,7 +93,7 @@ def prepare_image_inputs(self): @require_vision def prepare_video_inputs(self): - return [np.random.randint(255, size=(4, 3, 30, 400), dtype=np.uint8)] + return np.random.randint(255, size=(1, 4, 3, 30, 400), dtype=np.uint8) def test_processor_to_json_string(self): processor = self.get_processor()