Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Standardize inputs and outputs for existing image-text-to-text models #32059

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/auto.md
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,7 @@ The following auto classes are available for the following multimodal tasks.
### FlaxAutoModelForVision2Seq

[[autodoc]] FlaxAutoModelForVision2Seq

### AutoModelForImageTextToText

[[autodoc]] AutoModelForImageTextToText
4 changes: 4 additions & 0 deletions docs/source/ja/model_doc/auto.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,7 @@ AutoModel.register(NewModelConfig, NewModel)
### FlaxAutoModelForVision2Seq

[[autodoc]] FlaxAutoModelForVision2Seq

### AutoModelForImageTextToText

[[autodoc]] AutoModelForImageTextToText
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,7 @@
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"MODEL_FOR_IMAGE_MAPPING",
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
"MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING",
"MODEL_FOR_IMAGE_TO_IMAGE_MAPPING",
"MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING",
"MODEL_FOR_KEYPOINT_DETECTION_MAPPING",
Expand Down Expand Up @@ -1394,6 +1395,7 @@
"AutoModelForDocumentQuestionAnswering",
"AutoModelForImageClassification",
"AutoModelForImageSegmentation",
"AutoModelForImageTextToText",
"AutoModelForImageToImage",
"AutoModelForInstanceSegmentation",
"AutoModelForKeypointDetection",
Expand Down Expand Up @@ -6056,6 +6058,7 @@
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_IMAGE_MAPPING,
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING,
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_KEYPOINT_DETECTION_MAPPING,
Expand Down Expand Up @@ -6097,6 +6100,7 @@
AutoModelForDocumentQuestionAnswering,
AutoModelForImageClassification,
AutoModelForImageSegmentation,
AutoModelForImageTextToText,
AutoModelForImageToImage,
AutoModelForInstanceSegmentation,
AutoModelForKeypointDetection,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
"MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING",
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
"MODEL_MAPPING",
"MODEL_WITH_LM_HEAD_MAPPING",
Expand Down Expand Up @@ -119,6 +120,7 @@
"AutoModelWithLMHead",
"AutoModelForZeroShotImageClassification",
"AutoModelForZeroShotObjectDetection",
"AutoModelForImageTextToText",
]

try:
Expand Down Expand Up @@ -238,6 +240,7 @@
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_IMAGE_MAPPING,
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING,
MODEL_FOR_IMAGE_TO_IMAGE_MAPPING,
MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING,
MODEL_FOR_KEYPOINT_DETECTION_MAPPING,
Expand Down Expand Up @@ -279,6 +282,7 @@
AutoModelForDocumentQuestionAnswering,
AutoModelForImageClassification,
AutoModelForImageSegmentation,
AutoModelForImageTextToText,
AutoModelForImageToImage,
AutoModelForInstanceSegmentation,
AutoModelForKeypointDetection,
Expand Down
30 changes: 30 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,26 @@
]
)

MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
[
("blip", "BlipForConditionalGeneration"),
("blip-2", "Blip2ForConditionalGeneration"),
("fuyu", "FuyuForCausalLM"),
("git", "GitForCausalLM"),
("idefics", "IdeficsForVisionText2Text"),
("idefics2", "Idefics2ForConditionalGeneration"),
("instructblip", "InstructBlipForConditionalGeneration"),
("kosmos-2", "Kosmos2ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
("paligemma", "PaliGemmaForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"),
("udop", "UdopForConditionalGeneration"),
("vipllava", "VipLlavaForConditionalGeneration"),
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
]
)

MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
Expand Down Expand Up @@ -1371,6 +1391,9 @@
CONFIG_MAPPING_NAMES, MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES
)
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
)
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
)
Expand Down Expand Up @@ -1665,6 +1688,13 @@ class AutoModelForVision2Seq(_BaseAutoModelClass):
AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling")


class AutoModelForImageTextToText(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING


AutoModelForImageTextToText = auto_class_update(AutoModelForImageTextToText, head_doc="image-text-to-text modeling")


class AutoModelForAudioClassification(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,13 @@
("trocr", "TrOCRProcessor"),
("tvlt", "TvltProcessor"),
("tvp", "TvpProcessor"),
("udop", "UdopProcessor"),
("unispeech", "Wav2Vec2Processor"),
("unispeech-sat", "Wav2Vec2Processor"),
("video_llava", "VideoLlavaProcessor"),
("vilt", "ViltProcessor"),
("vipllava", "LlavaProcessor"),
("vision-encoder-decoder", "DonutProcessor"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
("wav2vec2", "Wav2Vec2Processor"),
("wav2vec2-bert", "Wav2Vec2Processor"),
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/blip/processing_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,20 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.

Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)

@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/blip_2/processing_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,20 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.

Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)

@property
# Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
def model_input_names(self):
Expand Down
31 changes: 29 additions & 2 deletions src/transformers/models/donut/processing_donut.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ def __call__(self, *args, **kwargs):
[`~DonutTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
"""
# For backward compatibility
legacy = kwargs.pop("legacy", True)
if legacy:
warnings.warn(
"The use of legacy will be deprecated in the future. Please use the new processing behavior by setting legacy=False."
)

if self._in_target_context_manager:
return self.current_processor(*args, **kwargs)

Expand All @@ -85,15 +91,22 @@ def __call__(self, *args, **kwargs):

if images is not None:
inputs = self.image_processor(images, *args, **kwargs)
if text is not None:
if text is not None and images is None:
encodings = self.tokenizer(text, **kwargs)
elif text is not None:
if not legacy:
kwargs.update({"add_special_tokens": False})
encodings = self.tokenizer(text, **kwargs)

if text is None:
return inputs
elif images is None:
return encodings
else:
inputs["labels"] = encodings["input_ids"]
if not legacy:
inputs["decoder_input_ids"] = encodings["input_ids"]
else:
inputs["labels"] = encodings["input_ids"]
return inputs

def batch_decode(self, *args, **kwargs):
Expand Down Expand Up @@ -180,6 +193,20 @@ def token2json(self, tokens, is_inner_value=False, added_vocab=None):
else:
return [] if is_inner_value else {"text_sequence": tokens}

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.

Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)

@property
def feature_extractor_class(self):
warnings.warn(
Expand Down
24 changes: 24 additions & 0 deletions src/transformers/models/fuyu/processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,30 @@ def tokens_to_points(tokens, original_size):

return results

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-processes the output of `FuyuForConditionalGeneration` to only return the text output.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
containing the token ids of the generated sequences.

Returns:
`List[str]`: The decoded text output.
"""
boa = self.tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
# get boa index for each outputted sequence tensor
# start all generated sequences from the beginning of the answer token, pad to have consistent length
unpadded_output_sequences = [seq[(seq == boa).nonzero(as_tuple=True)[0] + 1 :] for seq in generated_outputs]
max_len = max(len(seq) for seq in unpadded_output_sequences)
# convert to torch and pad sequences
padded_output_sequences = torch.full((len(unpadded_output_sequences), max_len), self.pad_token_id)
for i, seq in enumerate(unpadded_output_sequences):
padded_output_sequences[i, : len(seq)] = torch.tensor(seq)

return self.batch_decode(padded_output_sequences, skip_special_tokens=True)

def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
Expand Down
25 changes: 25 additions & 0 deletions src/transformers/models/git/processing_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
Image/Text processor class for GIT
"""

import warnings

from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding

Expand Down Expand Up @@ -76,6 +78,12 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
legacy = kwargs.pop("legacy", True)
if legacy:
warnings.warn(
"The use of legacy will be deprecated in the future. Please use the new processing behavior by setting legacy=False."
)

tokenizer_kwargs, image_processor_kwargs = {}, {}
if kwargs:
tokenizer_kwargs = {k: v for k, v in kwargs.items() if k not in self.image_processor._valid_processor_keys}
Expand All @@ -94,6 +102,9 @@ def __call__(self, text=None, images=None, return_tensors=None, **kwargs):

if text is not None and images is not None:
encoding["pixel_values"] = image_features.pixel_values
if not legacy:
encoding["input_ids"] = encoding["input_ids"][:, :-1]
encoding["attention_mask"] = encoding["attention_mask"][:, :-1]
return encoding
elif text is not None:
return encoding
Expand All @@ -114,6 +125,20 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.

Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)

@property
def model_input_names(self):
return ["input_ids", "attention_mask", "pixel_values"]
Loading
Loading