Skip to content

Commit

Permalink
Add support for multiple images per prompt using <image> token, stand…
Browse files Browse the repository at this point in the history
…ardize not returning input prompt in the generated text.
  • Loading branch information
yonigozlan committed Jul 26, 2024
1 parent 0d9b671 commit 320cd7a
Showing 1 changed file with 84 additions and 24 deletions.
108 changes: 84 additions & 24 deletions src/transformers/pipelines/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@

if is_torch_available():
from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
from .pt_utils import KeyDataset

logger = logging.get_logger(__name__)

IMAGE_TOKEN = "<image>"


class Chat:
"""This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
Expand All @@ -57,7 +60,7 @@ class ImageText:
"""This class is intended to just be used internally in this pipeline and not exposed to users. We used this class
as the base pipeline does not support multiple inputs, so we need to convert multiple inputs to a single input."""

def __init__(self, images: List, text: Union[str, List[str]]):
def __init__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], text: Union[str, List[str]]):
self.images = images
self.text = text

Expand All @@ -72,7 +75,7 @@ def count_images_in_chat(chat):
@add_end_docstrings(build_pipeline_init_args(has_processor=True))
class ImageTextToTextPipeline(Pipeline):
"""
Image-text-to-text pipeline using an `AutoModelForImageTextToText`. This pipeline generates text given an image and text.
Image-text-to-text pipeline. This pipeline generates text given an image and text.
Example:
Expand All @@ -98,16 +101,35 @@ def __init__(self, *args, **kwargs):
requires_backends(self, "vision")
self.check_model_type(MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES)

def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, text=None, timeout=None):
def _sanitize_parameters(
self,
max_new_tokens=None,
generate_kwargs=None,
text=None,
truncation=None,
padding=None,
max_length=None,
timeout=None,
):
forward_kwargs = {}
preprocess_params = {}
post_process_params = {}

if timeout is not None:
preprocess_params["timeout"] = timeout

if truncation is not None:
preprocess_params["truncation"] = truncation

if padding is not None:
preprocess_params["padding"] = padding

if max_length is not None:
preprocess_params["max_length"] = max_length

if generate_kwargs is not None:
forward_kwargs["generate_kwargs"] = generate_kwargs

if max_new_tokens is not None:
if "generate_kwargs" not in forward_kwargs:
forward_kwargs["generate_kwargs"] = {}
Expand All @@ -125,7 +147,7 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag
Generate a text given text and the image(s) passed as inputs.
Args:
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
images (`str`, `List[str]`, `PIL.Image or `List[PIL.Image]`):
The pipeline handles three types of images:
- A string containing a HTTP(s) link pointing to an image
Expand All @@ -146,16 +168,18 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag
A list or a list of list of `dict`: Each result comes as a dictionary with the following key:
- **generated_text** (`str`) -- The generated text.
- **input_text** (`str`) -- The input text.
"""
text = kwargs.pop("text")
batch_size = kwargs.get("batch_size", 1)

if images is None or text is None:
raise ValueError("You have to specify both `images` and `text`")

if not isinstance(images, (list, tuple)):
images = [images]

if isinstance(text, (list, tuple, text) if is_torch_available() else (list, tuple)) and isinstance(
if isinstance(text, (list, tuple, KeyDataset) if is_torch_available() else (list, tuple)) and isinstance(
text[0], (list, tuple, dict)
):
# We have one or more prompts in list-of-dicts format, so this is chat mode
Expand All @@ -167,14 +191,51 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag

if isinstance(text, str):
text = [text] * len(images)
if not isinstance(text[0], str):
raise ValueError("The pipeline does not support nested lists of prompts.")

# Check number of IMAGE_TOKEN token in each text
num_images_in_text = [text_single.count(IMAGE_TOKEN) for text_single in text]
if sum(num_images_in_text) > 0:
if any(num > 1 for num in num_images_in_text) and batch_size > 1:
raise ValueError(
"The pipeline does not support multiple images for a single prompt with batch_size > 1."
)
# Check if already nested images and consistency
if isinstance(images[0], (list, tuple)):
if len(images) != len(text):
raise ValueError("The number of nested image groups and prompts should be the same.")
num_images_in_images = [len(image) for image in images]
if num_images_in_text != num_images_in_images:
raise ValueError(
f"The number of images in each nested image group should be the same as the number of {IMAGE_TOKEN} tokens in the corresponding prompt."
)
elif sum(num_images_in_text) != len(images):
raise ValueError(
f"The total number of {IMAGE_TOKEN} tokens in the prompts should be the same as the number of images passed."
)
else:
# Reorganize the images to match the prompts
images_reorganized = []
for num_images in num_images_in_text:
images_reorganized.append(images[:num_images])
images = images[num_images:]
images = images_reorganized
# After reorganizing, these should be the same
if len(images) != len(text):
raise ValueError("The number of images and text should be the same.")

return super().__call__([ImageText(image, text_single) for image, text_single in zip(images, text)], **kwargs)

def preprocess(self, inputs=None, timeout=None):
kwargs = {"legacy": False}
def preprocess(self, inputs=None, truncation=None, padding="longest", max_length=None, timeout=None):
kwargs = {
"legacy": False,
"truncation": truncation,
"padding": padding,
"max_length": max_length,
}
images = inputs.images

if isinstance(inputs, Chat):
kwargs["chats"] = inputs.messages
text = self.processor.apply_chat_template(
Expand All @@ -192,11 +253,9 @@ def preprocess(self, inputs=None, timeout=None):
images = [load_image(image, timeout=timeout) for image in images]

try:
kwargs["padding"] = True
model_inputs = self.processor(images=images, text=text, return_tensors=self.framework, **kwargs)
except TypeError:
kwargs = {}
kwargs["padding"] = True
kwargs.pop("legacy", None)
model_inputs = self.processor(images=images, text=text, return_tensors=self.framework, **kwargs)

model_inputs["text"] = text
Expand All @@ -206,28 +265,29 @@ def preprocess(self, inputs=None, timeout=None):
def _forward(self, model_inputs, generate_kwargs=None):
if generate_kwargs is None:
generate_kwargs = {}

input_text = model_inputs.pop("text")
model_outputs = self.model.generate(**model_inputs, **generate_kwargs)
return {"outputs": model_outputs, "input_text": input_text}
return {"outputs": model_outputs, "input_text": input_text, "input_ids": model_inputs["input_ids"]}

def postprocess(self, model_outputs):
records = []
input_text = model_outputs["input_text"]
input_text = [input_text] if isinstance(input_text, str) else input_text
outputs = model_outputs["outputs"]
inputs_id = model_outputs["input_ids"]

# Decode inputs and outputs the same way to remove input text from generated text if present
generated_texts = self.processor.post_process_image_text_to_text(outputs)
# cleanup the generated text
decoded_inputs = self.processor.post_process_image_text_to_text(inputs_id)
generated_texts = [text.strip() for text in generated_texts]
if isinstance(input_text, str):
input_text = [input_text]
if input_text is not None:
# remove the input text from the generated text if the generated text starts with the input text
generated_texts = [
text_generated[len(input_text[i]) :].strip()
if text_generated.startswith(input_text[i])
else text_generated
for i, text_generated in enumerate(generated_texts)
]
decoded_inputs = [text.strip() for text in decoded_inputs]
# Remove the input text from the generated text if the generated text starts with the input text
generated_texts = [
text_generated[len(decoded_inputs[i]) :].strip()
if text_generated.startswith(decoded_inputs[i])
else text_generated
for i, text_generated in enumerate(generated_texts)
]

records = [
{"input_text": input_text[i], "generated_text": generated_text}
for i, generated_text in enumerate(generated_texts)
Expand Down

0 comments on commit 320cd7a

Please sign in to comment.