Skip to content

Add support for post-processing kwargs in image-text-to-text pipeline #35374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/transformers/models/fuyu/processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,14 +682,18 @@ def tokens_to_points(tokens, original_size):

return results

def post_process_image_text_to_text(self, generated_outputs):
def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
"""
Post-processes the output of `FuyuForConditionalGeneration` to only return the text output.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
containing the token ids of the generated sequences.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
Comment on lines +693 to +694
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this one not a kwarg as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set it as an arg as its default value is hardcoded to True

**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.

Returns:
`List[str]`: The decoded text output.
Expand All @@ -706,7 +710,7 @@ def post_process_image_text_to_text(self, generated_outputs):
for i, seq in enumerate(unpadded_output_sequences):
padded_output_sequences[i, : len(seq)] = torch.tensor(seq)

return self.batch_decode(padded_output_sequences, skip_special_tokens=True)
return self.batch_decode(padded_output_sequences, skip_special_tokens=skip_special_tokens, **kwargs)

def batch_decode(self, *args, **kwargs):
"""
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/kosmos2/processing_kosmos2.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,19 +428,23 @@ def post_process_generation(self, text, cleanup_and_extract=True):
return clean_text_and_extract_entities_with_bboxes(caption)
return caption

def post_process_image_text_to_text(self, generated_outputs):
def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
"""
Post-process the output of the model to decode the text.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.

Returns:
`List[str]`: The decoded text.
"""
generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=True)
generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs)
return [self.post_process_generation(text, cleanup_and_extract=False) for text in generated_texts]

@property
Expand Down
15 changes: 13 additions & 2 deletions src/transformers/models/mllama/processing_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,20 +346,31 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
"""
Post-process the output of the model to decode the text.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.

Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)

@property
Expand Down
15 changes: 13 additions & 2 deletions src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,20 +192,31 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
"""
Post-process the output of the model to decode the text.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.

Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)

@property
Expand Down
15 changes: 13 additions & 2 deletions src/transformers/models/qwen2_vl/processing_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,31 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
def post_process_image_text_to_text(
self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
):
"""
Post-process the output of the model to decode the text.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.

Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
generated_outputs,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)

@property
Expand Down
34 changes: 26 additions & 8 deletions src/transformers/pipelines/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import enum
from collections.abc import Iterable # pylint: disable=g-importing-member
from typing import Dict, List, Optional, Union

from ..processing_utils import ProcessingKwargs, Unpack
Expand Down Expand Up @@ -71,6 +72,8 @@ def retrieve_images_in_messages(
"""
if images is None:
images = []
elif not isinstance(images, Iterable):
images = [images]
idx_images = 0
retrieved_images = []
for message in messages:
Expand Down Expand Up @@ -188,14 +191,15 @@ def _sanitize_parameters(
return_full_text=None,
return_tensors=None,
return_type=None,
clean_up_tokenization_spaces=None,
stop_sequence=None,
continue_final_message=None,
**kwargs: Unpack[ProcessingKwargs],
):
forward_kwargs = {}
preprocess_params = {}
postprocess_params = {}

preprocess_params["processing_kwargs"] = kwargs
preprocess_params.update(kwargs)

if timeout is not None:
preprocess_params["timeout"] = timeout
Expand Down Expand Up @@ -226,7 +230,16 @@ def _sanitize_parameters(
postprocess_params["return_type"] = return_type
if continue_final_message is not None:
postprocess_params["continue_final_message"] = continue_final_message

if clean_up_tokenization_spaces is not None:
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
if stop_sequence is not None:
stop_sequence_ids = self.processor.tokenizer.encode(stop_sequence, add_special_tokens=False)
if len(stop_sequence_ids) > 1:
logger.warning_once(
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
" the stop sequence will be used as the stop sequence string in the interim."
)
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
return preprocess_params, forward_kwargs, postprocess_params

def __call__(
Expand Down Expand Up @@ -264,6 +277,8 @@ def __call__(
return_full_text (`bool`, *optional*, defaults to `True`):
If set to `False` only added text is returned, otherwise the full text is returned. Cannot be
specified at the same time as `return_text`.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the potential extra spaces in the text output.
continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the
last message in the input chat rather than starting a new one, allowing you to "prefill" its response.
By default this is `True` when the final message in the input chat has the `assistant` role and
Expand Down Expand Up @@ -315,7 +330,7 @@ def __call__(

return super().__call__({"images": images, "text": text}, **kwargs)

def preprocess(self, inputs=None, timeout=None, continue_final_message=None, processing_kwargs=None):
def preprocess(self, inputs=None, timeout=None, continue_final_message=None, **processing_kwargs):
# In case we only have text inputs
if isinstance(inputs, (list, tuple, str)):
images = None
Expand All @@ -332,6 +347,7 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, pro
add_generation_prompt=not continue_final_message,
continue_final_message=continue_final_message,
return_tensors=self.framework,
**processing_kwargs,
)
inputs_text = inputs
images = inputs.images
Expand All @@ -340,7 +356,7 @@ def preprocess(self, inputs=None, timeout=None, continue_final_message=None, pro
inputs_text = inputs["text"]
images = inputs["images"]

images = load_images(images)
images = load_images(images, timeout=timeout)

# if batched text inputs, we set padding to True unless specified otherwise
if isinstance(text, (list, tuple)) and len(text) > 1:
Expand All @@ -363,7 +379,9 @@ def _forward(self, model_inputs, generate_kwargs=None):

return {"generated_sequence": generated_sequence, "prompt_text": prompt_text, "input_ids": input_ids}

def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_final_message=None):
def postprocess(
self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_final_message=None, **postprocess_kwargs
):
input_texts = model_outputs["prompt_text"]
input_texts = [input_texts] if isinstance(input_texts, (str, Chat)) else input_texts
generated_sequence = model_outputs["generated_sequence"]
Expand All @@ -375,8 +393,8 @@ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, continue_
]

# Decode inputs and outputs the same way to remove input text from generated text if present
generated_texts = self.processor.post_process_image_text_to_text(generated_sequence)
decoded_inputs = self.processor.post_process_image_text_to_text(input_ids)
generated_texts = self.processor.post_process_image_text_to_text(generated_sequence, **postprocess_kwargs)
decoded_inputs = self.processor.post_process_image_text_to_text(input_ids, **postprocess_kwargs)

# Force consistent behavior for including the input text in the output
if return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,19 +1392,23 @@ def apply_chat_template(
return out["input_ids"]
return prompt

def post_process_image_text_to_text(self, generated_outputs):
def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
"""
Post-process the output of a vlm to decode the text.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
skip_special_tokens (`bool`, *optional*, defaults to `True`):
Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
**kwargs:
Additional arguments to be passed to the tokenizer's `batch_decode method`.

Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)
return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs)


def _validate_images_text_input_order(images, text):
Expand Down
29 changes: 8 additions & 21 deletions tests/pipelines/test_pipelines_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_model_pt_chat_template(self):
],
}
]
outputs = pipe([image_ny, image_chicago], text=messages)
outputs = pipe([image_ny, image_chicago], text=messages, return_full_text=False, max_new_tokens=10)
self.assertEqual(
outputs,
[
Expand All @@ -139,20 +139,7 @@ def test_model_pt_chat_template(self):
],
}
],
"generated_text": [
{
"role": "user",
"content": [
{"type": "text", "text": "What’s the difference between these two images?"},
{"type": "image"},
{"type": "image"},
],
},
{
"role": "assistant",
"content": "The first image shows a statue of the Statue of Liberty in the foreground, while the second image shows",
},
],
"generated_text": "The first image shows a statue of Liberty in the",
}
],
)
Expand All @@ -179,7 +166,7 @@ def test_model_pt_chat_template_continue_final_message(self):
],
},
]
outputs = pipe(text=messages)
outputs = pipe(text=messages, max_new_tokens=10)
self.assertEqual(
outputs,
[
Expand Down Expand Up @@ -213,7 +200,7 @@ def test_model_pt_chat_template_continue_final_message(self):
"content": [
{
"type": "text",
"text": "There is a dog and a person in the image. The dog is sitting on the sand, and the person is sitting on",
"text": "There is a dog and a person in the image. The dog is sitting",
}
],
},
Expand All @@ -238,7 +225,7 @@ def test_model_pt_chat_template_new_text(self):
],
}
]
outputs = pipe(text=messages, return_full_text=False)
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)
self.assertEqual(
outputs,
[
Expand All @@ -255,15 +242,15 @@ def test_model_pt_chat_template_new_text(self):
],
}
],
"generated_text": "In the image, a woman is sitting on the sandy beach, her legs crossed in a relaxed manner",
"generated_text": "In the image, a woman is sitting on the",
}
],
)

@slow
@require_torch
def test_model_pt_chat_template_image_url(self):
pipe = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf")
pipe = pipeline("image-text-to-text", model="llava-hf/llava-interleave-qwen-0.5b-hf")
messages = [
{
"role": "user",
Expand All @@ -279,7 +266,7 @@ def test_model_pt_chat_template_image_url(self):
}
]
outputs = pipe(text=messages, return_full_text=False, max_new_tokens=10)[0]["generated_text"]
self.assertEqual(outputs, "The image captures the iconic Statue of Liberty, a")
self.assertEqual(outputs, "A statue of liberty in the foreground of a city")

@slow
@require_torch
Expand Down