Skip to content

Commit

Permalink
big change to enable processing input images as numpy arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
andimarafioti committed Aug 13, 2024
1 parent 0fc4880 commit 5a0c0f4
Showing 1 changed file with 147 additions and 98 deletions.
245 changes: 147 additions & 98 deletions src/transformers/models/idefics3/image_processing_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
get_image_size,
infer_channel_dimension_format,
is_pil_image,
is_torch_tensor,
is_jax_tensor,
is_tf_tensor,
is_scaled_image,
is_valid_image,
to_numpy_array,
Expand Down Expand Up @@ -161,79 +164,6 @@ def get_resize_output_image_size(
return height, width


def split_image(
image: np.ndarray,
max_image_size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.LANCZOS,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
):
"""
Image splitting strategy.
1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio.
2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)`
sub-images of approximately the same size each (up to the fact that `vision_encoder_max_image_size` does not divide `height` or
`width`).
3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width.
"""
if isinstance(image, Image.Image):
width, height = image.size
else:
height, width = get_image_size(image, channel_dim=input_data_format)
max_height = max_width = max_image_size["longest_edge"]

frames = []
if height > max_height or width > max_width:
# Calculate the number of splits
num_splits_h = math.ceil(height / max_height)
num_splits_w = math.ceil(width / max_width)
# Calculate the optimal width and height for the sub-images
optimal_height = math.ceil(height / num_splits_h)
optimal_width = math.ceil(width / num_splits_w)

# Iterate through each row and column
for r in range(num_splits_h):
for c in range(num_splits_w):
# Calculate the starting point of the crop
start_x = c * optimal_width
start_y = r * optimal_height

# Calculate the ending point of the crop
end_x = min(start_x + optimal_width, width)
end_y = min(start_y + optimal_height, height)

# Crop the image
if isinstance(image, Image.Image):
cropped_image = image.crop((start_x, start_y, end_x, end_y))
else:
cropped_image = _crop(
image, start_x, start_y, end_x, end_y, input_data_format=input_data_format, data_format=data_format
)
frames.append(cropped_image)

# For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency
global_image_height, global_image_width = max_height, max_width
if height != global_image_height or width != global_image_width:
if isinstance(image, Image.Image):
image = image.resize((global_image_width, global_image_height), resample=resample)
else:
image = resize(
image,
(global_image_height, global_image_width),
resample=resample,
input_data_format=input_data_format,
)
else:
num_splits_h, num_splits_w = 0, 0

if data_format is not None and not isinstance(image, Image.Image):
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)

frames.append(image)

return frames, num_splits_h, num_splits_w


# Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images
def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]:
"""
Expand Down Expand Up @@ -312,28 +242,67 @@ def make_pixel_mask(
return mask


# Copied from transformers.models.idefics2.image_processing_idefics2.convert_to_rgb
def convert_to_rgb(image: ImageInput) -> ImageInput:
# Copied from transformers.image_transforms.to_pil_image

This comment has been minimized.

Copy link
@amyeroberts

amyeroberts Aug 13, 2024

Collaborator

If it's copied from then why not just import directly? We can't import from other model modules e.g. idefics3 can't do from transformers.models.idefics2.image_processing_idefics2 import XXX but we can freely import utils from the library like in image transforms

def to_pil_image(
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
do_rescale: Optional[bool] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
image_mode: str = "RGB",
) -> "PIL.Image.Image":
"""
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
needed.
Args:
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor` or `tf.Tensor`):
The image to convert to the `PIL.Image` format.
do_rescale (`bool`, *optional*):
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
and `False` otherwise.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
Returns:
`PIL.Image.Image`: The converted image.
"""
if isinstance(image, PIL.Image.Image):
return image

# Convert all tensors to numpy arrays before converting to PIL image
if is_torch_tensor(image) or is_tf_tensor(image):
image = image.numpy()
elif is_jax_tensor(image):
image = np.array(image)
elif not isinstance(image, np.ndarray):
raise ValueError("Input image type not supported: {}".format(type(image)))

# If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
image = image.astype(np.uint8)
return PIL.Image.fromarray(image, mode=image_mode)

def convert_to_rgb(image: ImageInput, palette: Optional[PIL.ImagePalette.ImagePalette]=None) -> ImageInput:
"""
Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
as is.
Args:
image (Image):
The image to convert.
palette (List[int], *optional*):
The palette to use if given.
"""
if not isinstance(image, PIL.Image.Image):
return image

# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
# for transparent images. The call to `alpha_composite` handles this case
if image.mode == "RGB":
return image
mode = "P" if palette is not None else None
image = to_pil_image(image, image_mode=mode)
if image.mode=='P' and palette is not None:
image.putpalette(palette)

image_rgba = image.convert("RGBA")
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
alpha_composite = Image.alpha_composite(background, image_rgba)
alpha_composite = alpha_composite.convert("RGB")
return alpha_composite
return np.array(alpha_composite)


# FIXME Amy: make a more general crop function that isn't just centre crop
Expand Down Expand Up @@ -475,21 +444,35 @@ def resize(
raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.")
if isinstance(image, Image.Image):
return image.resize((size[1], size[0]), resample=resample)
return resize(
image, size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
)
else:
image_mode = None
if image.ndim == 2 or image.shape[-1] == 1:
image_mode = 'P'
image = to_pil_image(image, input_data_format=input_data_format, image_mode=image_mode)

resized_image = image.resize((size[1], size[0]), resample=resample)
resized_array = np.array(resized_image)
if resized_array.ndim == 2:
resized_array = np.expand_dims(resized_array, axis=-1)
return resized_array


def split_image(
self,
image,
max_image_size: Dict[str, int],
input_data_format: Optional[Union[str, ChannelDimension]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
resample: PILImageResampling = PILImageResampling.LANCZOS,
):
"""
Split an image into squares of side max_image_size and the original image resized to max_image_size.
That means that a single image becomes a sequence of images.
This is a "trick" to spend more compute on each image with no changes in the vision encoder.
1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio.
2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)`
sub-images of the same size each (image_size, image_size). Typically, 364x364.
3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width.
Args:
image (`np.ndarray`):
Expand All @@ -499,9 +482,67 @@ def split_image(
patches of this size, and the original image will be concatenated with the patches, resized to max_size.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the output image. If not provided, it will be the same as the input image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`):
Resampling filter to use when resizing the image.
"""
return split_image(image, max_image_size, input_data_format=input_data_format, data_format=data_format)
if isinstance(image, Image.Image):
width, height = image.size
else:
height, width = get_image_size(image, channel_dim=input_data_format)
max_height = max_width = max_image_size["longest_edge"]

frames = []
if height > max_height or width > max_width:
# Calculate the number of splits
num_splits_h = math.ceil(height / max_height)
num_splits_w = math.ceil(width / max_width)
# Calculate the optimal width and height for the sub-images
optimal_height = math.ceil(height / num_splits_h)
optimal_width = math.ceil(width / num_splits_w)

# Iterate through each row and column
for r in range(num_splits_h):
for c in range(num_splits_w):
# Calculate the starting point of the crop
start_x = c * optimal_width
start_y = r * optimal_height

# Calculate the ending point of the crop
end_x = min(start_x + optimal_width, width)
end_y = min(start_y + optimal_height, height)

# Crop the image
if isinstance(image, Image.Image):
cropped_image = image.crop((start_x, start_y, end_x, end_y))
else:
cropped_image = _crop(
image, start_x, start_y, end_x, end_y, input_data_format=input_data_format, data_format=data_format
)
frames.append(cropped_image)

# For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency
global_image_height, global_image_width = max_height, max_width
if height != global_image_height or width != global_image_width:
if isinstance(image, Image.Image):
image = image.resize((global_image_width, global_image_height), resample=resample)
else:
image = self.resize(
image,
{'height': global_image_height, 'width': global_image_width},
resample=resample,
input_data_format=input_data_format,
)
else:
num_splits_h, num_splits_w = 0, 0

if data_format is not None and not isinstance(image, Image.Image):
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)

frames.append(image)

return frames, num_splits_h, num_splits_w
def _pad_image(
self,
image: np.ndarray,
Expand Down Expand Up @@ -708,15 +749,25 @@ def preprocess(
"torch.Tensor, tf.Tensor or jax.ndarray."
)

for img in images_list[0]:
if not is_pil_image(img):
logger.warning_once(
"Idefics3's image processing pipeline is optimized to process PIL images, but you passed a different type of image. "
"This might lead to inconsistent results"
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images_list[0][0])

palettes = [image[0].getpalette() if isinstance(image[0], Image.Image) and image[0].mode == 'P' else None for image in images_list] # save the palettes for conversion to RGB
# All transformations expect numpy arrays.
images_list = [[to_numpy_array(image) for image in images] for images in images_list]

new_images_list = []
for images in images_list:
new_images = []
for img in images:
if img.ndim == 2:
img = np.expand_dims(img, axis=-1)
new_images.append(img)
new_images_list.append(new_images)
images_list = new_images_list
del new_images_list

if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4))


validate_preprocess_arguments(
Expand Down Expand Up @@ -794,10 +845,8 @@ def preprocess(
images_list_cols = [[0] * len(images) for images in images_list]

if do_convert_rgb:
images_list = [[convert_to_rgb(image) for image in images] for images in images_list]
images_list = [[convert_to_rgb(image, palette) for image in images] for images, palette in zip(images_list, palettes)]

# All transformations expect numpy arrays.
images_list = [[to_numpy_array(image) for image in images] for images in images_list]

if is_scaled_image(images_list[0][0]) and do_rescale:
logger.warning_once(
Expand Down

0 comments on commit 5a0c0f4

Please sign in to comment.