Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for post-processing kwargs in image-text-to-text pipeline #35374

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/transformers/models/fuyu/processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ def tokens_to_points(tokens, original_size):

return results

def post_process_image_text_to_text(self, generated_outputs):
def post_process_image_text_to_text(self, generated_outputs, **kwargs):
"""
Post-processes the output of `FuyuForConditionalGeneration` to only return the text output.

Expand All @@ -706,7 +706,9 @@ def post_process_image_text_to_text(self, generated_outputs):
for i, seq in enumerate(unpadded_output_sequences):
padded_output_sequences[i, : len(seq)] = torch.tensor(seq)

return self.batch_decode(padded_output_sequences, skip_special_tokens=True)
skip_special_tokens = kwargs.pop("skip_special_tokens", True)

return self.batch_decode(padded_output_sequences, skip_special_tokens=skip_special_tokens, **kwargs)

def batch_decode(self, *args, **kwargs):
"""
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/kosmos2/processing_kosmos2.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def post_process_generation(self, text, cleanup_and_extract=True):
return clean_text_and_extract_entities_with_bboxes(caption)
return caption

def post_process_image_text_to_text(self, generated_outputs):
def post_process_image_text_to_text(self, generated_outputs, **kwargs):
"""
Post-process the output of the model to decode the text.

Expand All @@ -440,7 +440,8 @@ def post_process_image_text_to_text(self, generated_outputs):
Returns:
`List[str]`: The decoded text.
"""
generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=True)
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs)
return [self.post_process_generation(text, cleanup_and_extract=False) for text in generated_texts]

@property
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/mllama/processing_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
def post_process_image_text_to_text(self, generated_outputs, **kwargs):
"""
Post-process the output of the model to decode the text.

Expand All @@ -359,8 +359,13 @@ def post_process_image_text_to_text(self, generated_outputs):
Returns:
`List[str]`: The decoded text.
"""
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", False)
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)

@property
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/qwen2_vl/processing_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
def post_process_image_text_to_text(self, generated_outputs, **kwargs):
"""
Post-process the output of the model to decode the text.

Expand All @@ -182,8 +182,13 @@ def post_process_image_text_to_text(self, generated_outputs):
Returns:
`List[str]`: The decoded text.
"""
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", False)
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)

@property
Expand Down
33 changes: 25 additions & 8 deletions src/transformers/pipelines/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import enum
from collections.abc import Iterable # pylint: disable=g-importing-member
from typing import Dict, List, Optional, Union

from ..processing_utils import ProcessingKwargs, Unpack
Expand Down Expand Up @@ -71,6 +72,8 @@ def retrieve_images_in_messages(
"""
if images is None:
images = []
elif not isinstance(images, Iterable):
images = [images]
idx_images = 0
retrieved_images = []
for message in messages:
Expand Down Expand Up @@ -188,14 +191,15 @@ def _sanitize_parameters(
return_full_text=None,
return_tensors=None,
return_type=None,
clean_up_tokenization_spaces=None,
stop_sequence=None,
continue_final_message=None,
**kwargs: Unpack[ProcessingKwargs],
):
forward_kwargs = {}
preprocess_params = {}
postprocess_params = {}

preprocess_params["processing_kwargs"] = kwargs
preprocess_params.update(kwargs)

if timeout is not None:
preprocess_params["timeout"] = timeout
Expand Down Expand Up @@ -226,7 +230,16 @@ def _sanitize_parameters(
postprocess_params["return_type"] = return_type
if continue_final_message is not None:
postprocess_params["continue_final_message"] = continue_final_message

if clean_up_tokenization_spaces is not None:
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
if stop_sequence is not None:
stop_sequence_ids = self.processor.tokenizer.encode(stop_sequence, add_special_tokens=False)
if len(stop_sequence_ids) > 1:
logger.warning_once(
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
" the stop sequence will be used as the stop sequence string in the interim."
)
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
return preprocess_params, forward_kwargs, postprocess_params

def __call__(
Expand Down Expand Up @@ -264,6 +277,8 @@ def __call__(
return_full_text (`bool`, *optional*, defaults to `True`):
If set to `False` only added text is returned, otherwise the full text is returned. Cannot be
specified at the same time as `return_text`.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the potential extra spaces in the text output.
continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the
last message in the input chat rather than starting a new one, allowing you to "prefill" its response.
By default this is `True` when the final message in the input chat has the `assistant` role and
Expand Down Expand Up @@ -315,7 +330,7 @@ def __call__(

return super().__call__({"images": images, "text": text}, **kwargs)

def preprocess(self, inputs=None, timeout=None, continue_final_message=None, processing_kwargs=None):
def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **processing_kwargs):
# In case we only have text inputs
if isinstance(inputs, (list, tuple, str)):
images = None
Expand All @@ -340,7 +355,7 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, pro
inputs_text = inputs["text"]
images = inputs["images"]

images = load_images(images)
images = load_images(images, timeout=timeout)

# if batched text inputs, we set padding to True unless specified otherwise
if isinstance(text, (list, tuple)) and len(text) > 1:
Expand All @@ -363,7 +378,9 @@ def _forward(self, model_inputs, generate_kwargs=None):

return {"generated_sequence": generated_sequence, "prompt_text": prompt_text, "input_ids": input_ids}

def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_final_message=None):
def postprocess(
self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_final_message=None, **postprocess_kwargs
):
input_texts = model_outputs["prompt_text"]
input_texts = [input_texts] if isinstance(input_texts, (str, Chat)) else input_texts
generated_sequence = model_outputs["generated_sequence"]
Expand All @@ -375,8 +392,8 @@ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_
]

# Decode inputs and outputs the same way to remove input text from generated text if present
generated_texts = self.processor.post_process_image_text_to_text(generated_sequence)
decoded_inputs = self.processor.post_process_image_text_to_text(input_ids)
generated_texts = self.processor.post_process_image_text_to_text(generated_sequence, **postprocess_kwargs)
decoded_inputs = self.processor.post_process_image_text_to_text(input_ids, **postprocess_kwargs)

# Force consistent behavior for including the input text in the output
if return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@ def apply_chat_template(
conversation, chat_template=chat_template, tokenize=tokenize, **kwargs
)

def post_process_image_text_to_text(self, generated_outputs):
def post_process_image_text_to_text(self, generated_outputs, **kwargs):
"""
Post-process the output of a vlm to decode the text.

Expand All @@ -1144,7 +1144,8 @@ def post_process_image_text_to_text(self, generated_outputs):
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)
skip_special_tokens = kwargs.pop("skip_special_tokens", True)
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs)


def _validate_images_text_input_order(images, text):
Expand Down
25 changes: 6 additions & 19 deletions tests/pipelines/test_pipelines_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_model_pt_chat_template(self):
],
}
]
outputs = pipe([image_ny, image_chicago], text=messages)
outputs = pipe([image_ny, image_chicago], text=messages, return_full_text=False, max_new_tokens=10)
self.assertEqual(
outputs,
[
Expand All @@ -139,20 +139,7 @@ def test_model_pt_chat_template(self):
],
}
],
"generated_text": [
{
"role": "user",
"content": [
{"type": "text", "text": "What’s the difference between these two images?"},
{"type": "image"},
{"type": "image"},
],
},
{
"role": "assistant",
"content": "The first image shows a statue of the Statue of Liberty in the foreground, while the second image shows",
},
],
"generated_text": "The first image shows a statue of Liberty in the",
}
],
)
Expand All @@ -179,7 +166,7 @@ def test_model_pt_chat_template_continue_final_message(self):
],
},
]
outputs = pipe(text=messages)
outputs = pipe(text=messages, max_new_tokens=10)
self.assertEqual(
outputs,
[
Expand Down Expand Up @@ -213,7 +200,7 @@ def test_model_pt_chat_template_continue_final_message(self):
"content": [
{
"type": "text",
"text": "There is a dog and a person in the image. The dog is sitting on the sand, and the person is sitting on",
"text": "There is a dog and a person in the image. The dog is sitting",
}
],
},
Expand All @@ -238,7 +225,7 @@ def test_model_pt_chat_template_new_text(self):
],
}
]
outputs = pipe(text=messages, return_full_text=False)
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)
self.assertEqual(
outputs,
[
Expand All @@ -255,7 +242,7 @@ def test_model_pt_chat_template_new_text(self):
],
}
],
"generated_text": "In the image, a woman is sitting on the sandy beach, her legs crossed in a relaxed manner",
"generated_text": "In the image, a woman is sitting on the",
}
],
)
Expand Down