Skip to content

Commit

Permalink
Idefics3 - resolve data_format for image processing
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed Sep 19, 2024
1 parent 4aad266 commit ab8e6dd
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 110 deletions.
5 changes: 4 additions & 1 deletion src/transformers/image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _rescale_for_pil_conversion(image):
def to_pil_image(
image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
do_rescale: Optional[bool] = None,
image_mode: Optional[str] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> "PIL.Image.Image":
"""
Expand All @@ -175,6 +176,8 @@ def to_pil_image(
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
and `False` otherwise.
image_mode (`str`, *optional*):
The mode to use for the PIL image. If unset, will use the default mode for the input image type.
input_data_format (`ChannelDimension`, *optional*):
The channel dimension format of the input image. If unset, will use the inferred format from the input.
Expand Down Expand Up @@ -207,7 +210,7 @@ def to_pil_image(
image = rescale(image, 255)

image = image.astype(np.uint8)
return PIL.Image.fromarray(image)
return PIL.Image.fromarray(image, mode=image_mode)


# Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366
Expand Down
Loading

0 comments on commit ab8e6dd

Please sign in to comment.