Skip to content

Commit

Permalink
address @zucchini-nlp's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
leloykun committed Aug 20, 2024
1 parent 5fd2c32 commit a44a76a
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 64 deletions.
25 changes: 16 additions & 9 deletions src/transformers/models/mgp_str/processing_mgp_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,28 +78,35 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")
if "char_tokenizer" in kwargs:
warnings.warn(
"The `char_tokenizer` argument is deprecated and will be removed in future versions, use `tokenizer`"
" instead.",
FutureWarning,
)
char_tokenizer = kwargs.pop("char_tokenizer")

image_processor = image_processor if image_processor is not None else feature_extractor
tokenizer = tokenizer if tokenizer is not None else char_tokenizer
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")

self.tokenizer = tokenizer
self.char_tokenizer = tokenizer # For backwards compatibility
self.bpe_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
self.wp_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")

super().__init__(image_processor, tokenizer)

@property
def char_tokenizer(self):
warnings.warn(
"The `char_tokenizer` attribute is deprecated and will be removed in future versions, use `tokenizer` instead.",
FutureWarning,
)
return self.tokenizer

@char_tokenizer.setter
def char_tokenizer(self, value):
warnings.warn(
"The `char_tokenizer` attribute is deprecated and will be removed in future versions, use `tokenizer` instead.",
FutureWarning,
)
self.tokenizer = value

def __call__(
self,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
Expand Down
24 changes: 7 additions & 17 deletions src/transformers/models/tvp/image_processing_tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@

# Copied from transformers.models.vivit.image_processing_vivit.make_batched
def make_batched(videos) -> List[List[ImageInput]]:
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
if isinstance(videos, np.ndarray) and videos.ndim == 5:
return videos

elif isinstance(videos, np.ndarray) and videos.ndim == 4:
return [videos]

elif isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
return videos

elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
Expand Down Expand Up @@ -305,30 +311,20 @@ def _preprocess_image(
# All transformations expect numpy arrays.
image = to_numpy_array(image)

print(f"{image.shape = }")

if do_resize:
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)

print(f"{image.shape = }")

if do_center_crop:
image = self.center_crop(image, size=crop_size, input_data_format=input_data_format)

print(f"{image.shape = }")

if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

print(f"{image.shape = }")

if do_normalize:
image = self.normalize(
image=image.astype(np.float32), mean=image_mean, std=image_std, input_data_format=input_data_format
)

print(f"{image.shape = }")

if do_pad:
image = self.pad_image(
image=image,
Expand All @@ -338,18 +334,12 @@ def _preprocess_image(
input_data_format=input_data_format,
)

print(f"{image.shape = }")

# the pretrained checkpoints assume images are BGR, not RGB
if do_flip_channel_order:
image = flip_channel_order(image=image, input_data_format=input_data_format)

print(f"{image.shape = }")

image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)

print(f"{image.shape = }")

return image

@filter_out_non_signature_kwargs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,15 @@
logger = logging.get_logger(__name__)


# Copied from transformers.models.vivit.image_processing_vivit.make_batched
def make_batched(videos) -> List[List[ImageInput]]:
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
if isinstance(videos, np.ndarray) and videos.ndim == 5:
return videos

elif isinstance(videos, np.ndarray) and videos.ndim == 4:
return [videos]

elif isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
return videos

elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,24 @@
Processor class for VisionTextDualEncoder
"""

import sys
import warnings
from typing import List, Optional, Union

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import PreTokenizedInput, TextInput


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack


class VisionTextDualEncoderProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}


class VisionTextDualEncoderProcessor(ProcessorMixin):
Expand Down Expand Up @@ -61,7 +75,14 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor

def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
def __call__(
self,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
images: Optional[ImageInput] = None,
audio=None,
videos=None,
**kwargs: Unpack[VisionTextDualEncoderProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to VisionTextDualEncoderTokenizer's [`~PreTrainedTokenizer.__call__`] if `text` is not
Expand All @@ -70,24 +91,16 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
images (`ImageInput`, *optional*):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
Expand All @@ -99,19 +112,25 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")

output_kwargs = self._merge_kwargs(
VisionTextDualEncoderProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])

if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors, **kwargs)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])

return_tensors = output_kwargs["common_kwargs"].get("return_tensors")
if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
return encoding
return BatchFeature(data=dict(**encodings, **image_features), tensor_type=return_tensors)
elif text is not None:
return encoding
return BatchFeature(data=dict(**encodings), tensor_type=return_tensors)
else:
return BatchEncoding(data=dict(**image_features), tensor_type=return_tensors)
return BatchFeature(data=dict(**image_features), tensor_type=return_tensors)

def batch_decode(self, *args, **kwargs):
"""
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/vivit/image_processing_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@


def make_batched(videos) -> List[List[ImageInput]]:
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
if isinstance(videos, np.ndarray) and videos.ndim == 5:
return videos

elif isinstance(videos, np.ndarray) and videos.ndim == 4:
return [videos]

elif isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
return videos

elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
Expand Down
6 changes: 0 additions & 6 deletions tests/models/tvp/test_processor_tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import tempfile
import unittest

import numpy as np

from transformers import TvpProcessor
from transformers.testing_utils import require_torch, require_vision

Expand All @@ -20,10 +18,6 @@ def setUp(self):
processor = self.processor_class.from_pretrained(self.from_pretrained_id)
processor.save_pretrained(self.tmpdirname)

@require_vision
def prepare_video_inputs(self):
return [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]

@require_torch
@require_vision
def test_video_processor_defaults_preserved_by_kwargs(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,25 @@

import numpy as np

from transformers import BertTokenizerFast
from transformers import BertTokenizerFast, VisionTextDualEncoderProcessor
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES, BertTokenizer
from transformers.testing_utils import require_tokenizers, require_vision
from transformers.utils import IMAGE_PROCESSOR_NAME, is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from PIL import Image

from transformers import VisionTextDualEncoderProcessor, ViTImageProcessor
from transformers import ViTImageProcessor


@require_tokenizers
@require_vision
class VisionTextDualEncoderProcessorTest(unittest.TestCase):
class VisionTextDualEncoderProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = VisionTextDualEncoderProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()

Expand All @@ -54,6 +58,9 @@ def setUp(self):
with open(self.image_processor_file, "w", encoding="utf-8") as fp:
json.dump(image_processor_map, fp)

processor = VisionTextDualEncoderProcessor.from_pretrained("clip-italian/clip-italian")
processor.save_pretrained(self.tmpdirname)

def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)

Expand Down
6 changes: 0 additions & 6 deletions tests/models/x_clip/test_processor_x_clip.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import tempfile
import unittest

import numpy as np

from transformers import XCLIPProcessor
from transformers.testing_utils import require_torch, require_vision

Expand All @@ -19,10 +17,6 @@ def setUp(self):
processor = self.processor_class.from_pretrained(self.from_pretrained_id)
processor.save_pretrained(self.tmpdirname)

@require_vision
def prepare_video_inputs(self):
return [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]

@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_processing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def prepare_image_inputs(self):

@require_vision
def prepare_video_inputs(self):
return [np.random.randint(255, size=(4, 3, 30, 400), dtype=np.uint8)]
return np.random.randint(255, size=(1, 4, 3, 30, 400), dtype=np.uint8)

def test_processor_to_json_string(self):
processor = self.get_processor()
Expand Down

0 comments on commit a44a76a

Please sign in to comment.