diff --git a/src/transformers/models/idefics3/image_processing_idefics3.py b/src/transformers/models/idefics3/image_processing_idefics3.py index 7806d22947359a..2d54c7c47d00ba 100644 --- a/src/transformers/models/idefics3/image_processing_idefics3.py +++ b/src/transformers/models/idefics3/image_processing_idefics3.py @@ -29,6 +29,9 @@ get_image_size, infer_channel_dimension_format, is_pil_image, + is_torch_tensor, + is_jax_tensor, + is_tf_tensor, is_scaled_image, is_valid_image, to_numpy_array, @@ -161,79 +164,6 @@ def get_resize_output_image_size( return height, width -def split_image( - image: np.ndarray, - max_image_size: Dict[str, int], - resample: PILImageResampling = PILImageResampling.LANCZOS, - input_data_format: Optional[Union[str, ChannelDimension]] = None, - data_format: Optional[Union[str, ChannelDimension]] = None, -): - """ - Image splitting strategy. - 1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio. - 2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)` - sub-images of approximately the same size each (up to the fact that `vision_encoder_max_image_size` does not divide `height` or - `width`). - 3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width. - """ - if isinstance(image, Image.Image): - width, height = image.size - else: - height, width = get_image_size(image, channel_dim=input_data_format) - max_height = max_width = max_image_size["longest_edge"] - - frames = [] - if height > max_height or width > max_width: - # Calculate the number of splits - num_splits_h = math.ceil(height / max_height) - num_splits_w = math.ceil(width / max_width) - # Calculate the optimal width and height for the sub-images - optimal_height = math.ceil(height / num_splits_h) - optimal_width = math.ceil(width / num_splits_w) - - # Iterate through each row and column - for r in range(num_splits_h): - for c in range(num_splits_w): - # Calculate the starting point of the crop - start_x = c * optimal_width - start_y = r * optimal_height - - # Calculate the ending point of the crop - end_x = min(start_x + optimal_width, width) - end_y = min(start_y + optimal_height, height) - - # Crop the image - if isinstance(image, Image.Image): - cropped_image = image.crop((start_x, start_y, end_x, end_y)) - else: - cropped_image = _crop( - image, start_x, start_y, end_x, end_y, input_data_format=input_data_format, data_format=data_format - ) - frames.append(cropped_image) - - # For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency - global_image_height, global_image_width = max_height, max_width - if height != global_image_height or width != global_image_width: - if isinstance(image, Image.Image): - image = image.resize((global_image_width, global_image_height), resample=resample) - else: - image = resize( - image, - (global_image_height, global_image_width), - resample=resample, - input_data_format=input_data_format, - ) - else: - num_splits_h, num_splits_w = 0, 0 - - if data_format is not None and not isinstance(image, Image.Image): - image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) - - frames.append(image) - - return frames, num_splits_h, num_splits_w - - # Copied from transformers.models.idefics2.image_processing_idefics2.make_list_of_images def make_list_of_images(images: ImageInput) -> List[List[np.ndarray]]: """ @@ -312,28 +242,67 @@ def make_pixel_mask( return mask -# Copied from transformers.models.idefics2.image_processing_idefics2.convert_to_rgb -def convert_to_rgb(image: ImageInput) -> ImageInput: +# Copied from transformers.image_transforms.to_pil_image +def to_pil_image( + image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"], + do_rescale: Optional[bool] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + image_mode: str = "RGB", +) -> "PIL.Image.Image": + """ + Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if + needed. + + Args: + image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor` or `tf.Tensor`): + The image to convert to the `PIL.Image` format. + do_rescale (`bool`, *optional*): + Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default + to `True` if the image type is a floating type and casting to `int` would result in a loss of precision, + and `False` otherwise. + input_data_format (`ChannelDimension`, *optional*): + The channel dimension format of the input image. If unset, will use the inferred format from the input. + + Returns: + `PIL.Image.Image`: The converted image. + """ + if isinstance(image, PIL.Image.Image): + return image + + # Convert all tensors to numpy arrays before converting to PIL image + if is_torch_tensor(image) or is_tf_tensor(image): + image = image.numpy() + elif is_jax_tensor(image): + image = np.array(image) + elif not isinstance(image, np.ndarray): + raise ValueError("Input image type not supported: {}".format(type(image))) + + # If there is a single channel, we squeeze it, as otherwise PIL can't handle it. + image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image + image = image.astype(np.uint8) + return PIL.Image.fromarray(image, mode=image_mode) + +def convert_to_rgb(image: ImageInput, palette: Optional[PIL.ImagePalette.ImagePalette]=None) -> ImageInput: """ Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image as is. Args: image (Image): The image to convert. + palette (List[int], *optional*): + The palette to use if given. """ if not isinstance(image, PIL.Image.Image): - return image - - # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background - # for transparent images. The call to `alpha_composite` handles this case - if image.mode == "RGB": - return image + mode = "P" if palette is not None else None + image = to_pil_image(image, image_mode=mode) + if image.mode=='P' and palette is not None: + image.putpalette(palette) image_rgba = image.convert("RGBA") background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) alpha_composite = Image.alpha_composite(background, image_rgba) alpha_composite = alpha_composite.convert("RGB") - return alpha_composite + return np.array(alpha_composite) # FIXME Amy: make a more general crop function that isn't just centre crop @@ -475,9 +444,18 @@ def resize( raise ValueError("size must be a dictionary with key 'longest_edge' or 'height' and 'width'.") if isinstance(image, Image.Image): return image.resize((size[1], size[0]), resample=resample) - return resize( - image, size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs - ) + else: + image_mode = None + if image.ndim == 2 or image.shape[-1] == 1: + image_mode = 'P' + image = to_pil_image(image, input_data_format=input_data_format, image_mode=image_mode) + + resized_image = image.resize((size[1], size[0]), resample=resample) + resized_array = np.array(resized_image) + if resized_array.ndim == 2: + resized_array = np.expand_dims(resized_array, axis=-1) + return resized_array + def split_image( self, @@ -485,11 +463,16 @@ def split_image( max_image_size: Dict[str, int], input_data_format: Optional[Union[str, ChannelDimension]] = None, data_format: Optional[Union[str, ChannelDimension]] = None, + resample: PILImageResampling = PILImageResampling.LANCZOS, ): """ Split an image into squares of side max_image_size and the original image resized to max_image_size. That means that a single image becomes a sequence of images. This is a "trick" to spend more compute on each image with no changes in the vision encoder. + 1) If one side of the original image is larger than `max_image_size`, resize it to `max_image_size` while preserving the aspect ratio. + 2) Divide the resulting image into `ceil(height / max_image_size)` x `ceil(width / max_image_size)` + sub-images of the same size each (image_size, image_size). Typically, 364x364. + 3) Returns the list of the crops and the original image, in addition to the number of splits for the height and the width. Args: image (`np.ndarray`): @@ -499,9 +482,67 @@ def split_image( patches of this size, and the original image will be concatenated with the patches, resized to max_size. input_data_format (`ChannelDimension` or `str`, *optional*): The channel dimension format of the input image. If not provided, it will be inferred. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the output image. If not provided, it will be the same as the input image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.LANCZOS`): + Resampling filter to use when resizing the image. """ - return split_image(image, max_image_size, input_data_format=input_data_format, data_format=data_format) + if isinstance(image, Image.Image): + width, height = image.size + else: + height, width = get_image_size(image, channel_dim=input_data_format) + max_height = max_width = max_image_size["longest_edge"] + + frames = [] + if height > max_height or width > max_width: + # Calculate the number of splits + num_splits_h = math.ceil(height / max_height) + num_splits_w = math.ceil(width / max_width) + # Calculate the optimal width and height for the sub-images + optimal_height = math.ceil(height / num_splits_h) + optimal_width = math.ceil(width / num_splits_w) + + # Iterate through each row and column + for r in range(num_splits_h): + for c in range(num_splits_w): + # Calculate the starting point of the crop + start_x = c * optimal_width + start_y = r * optimal_height + + # Calculate the ending point of the crop + end_x = min(start_x + optimal_width, width) + end_y = min(start_y + optimal_height, height) + + # Crop the image + if isinstance(image, Image.Image): + cropped_image = image.crop((start_x, start_y, end_x, end_y)) + else: + cropped_image = _crop( + image, start_x, start_y, end_x, end_y, input_data_format=input_data_format, data_format=data_format + ) + frames.append(cropped_image) + + # For the global image at the end, we resize it to match the max_image_size, for cpu memory efficiency + global_image_height, global_image_width = max_height, max_width + if height != global_image_height or width != global_image_width: + if isinstance(image, Image.Image): + image = image.resize((global_image_width, global_image_height), resample=resample) + else: + image = self.resize( + image, + {'height': global_image_height, 'width': global_image_width}, + resample=resample, + input_data_format=input_data_format, + ) + else: + num_splits_h, num_splits_w = 0, 0 + + if data_format is not None and not isinstance(image, Image.Image): + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + + frames.append(image) + return frames, num_splits_h, num_splits_w def _pad_image( self, image: np.ndarray, @@ -708,15 +749,25 @@ def preprocess( "torch.Tensor, tf.Tensor or jax.ndarray." ) - for img in images_list[0]: - if not is_pil_image(img): - logger.warning_once( - "Idefics3's image processing pipeline is optimized to process PIL images, but you passed a different type of image. " - "This might lead to inconsistent results" - ) - if input_data_format is None: - # We assume that all images have the same channel dimension format. - input_data_format = infer_channel_dimension_format(images_list[0][0]) + + palettes = [image[0].getpalette() if isinstance(image[0], Image.Image) and image[0].mode == 'P' else None for image in images_list] # save the palettes for conversion to RGB + # All transformations expect numpy arrays. + images_list = [[to_numpy_array(image) for image in images] for images in images_list] + + new_images_list = [] + for images in images_list: + new_images = [] + for img in images: + if img.ndim == 2: + img = np.expand_dims(img, axis=-1) + new_images.append(img) + new_images_list.append(new_images) + images_list = new_images_list + del new_images_list + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images_list[0][0], num_channels=(1, 3, 4)) validate_preprocess_arguments( @@ -794,10 +845,8 @@ def preprocess( images_list_cols = [[0] * len(images) for images in images_list] if do_convert_rgb: - images_list = [[convert_to_rgb(image) for image in images] for images in images_list] + images_list = [[convert_to_rgb(image, palette) for image in images] for images, palette in zip(images_list, palettes)] - # All transformations expect numpy arrays. - images_list = [[to_numpy_array(image) for image in images] for images in images_list] if is_scaled_image(images_list[0][0]) and do_rescale: logger.warning_once(