Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions gemma/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from kauldron import typing
import numpy as np
from PIL import Image
import tensorflow as tf
import warnings
# Removed TensorFlow dependency: use PIL + NumPy + JAX for decoding/resizing.
# Note: inputs are expected to be image arrays (H,W,C) in uint8 or floats.
# Keep a small fallback warning if the input array isn't a standard type.

_IMAGE_MEAN = (127.5,) * 3
_IMAGE_STD = (127.5,) * 3
Expand Down Expand Up @@ -69,17 +72,28 @@ def pre_process_image(
Returns:
The pre-processed image.
"""
# all inputs are expected to have been jpeg compressed.
# TODO(eyvinec): we should remove tf dependency.
image = jnp.asarray(
tf.image.decode_jpeg(tf.io.encode_jpeg(image), channels=3)
)
image = jax.image.resize(
image,
shape=(image_height, image_width, 3),
method="bilinear",
antialias=True,
)
# Accept numpy / jax arrays or PIL images. Convert to uint8 ndarray for PIL.
arr = np.asarray(image)

# If floats in [0, 1], convert to 0-255 uint8
if np.issubdtype(arr.dtype, np.floating):
if arr.max() <= 1.0:
arr = (arr * 255.0).round().astype(np.uint8)
else:
arr = np.clip(arr, 0, 255).round().astype(np.uint8)
else:
arr = arr.astype(np.uint8)

# PIL expects shape (W, H) ordering for resize tuple; Image.fromarray handles H,W,C.
pil = Image.fromarray(arr)
# Use bilinear resizing; PIL's LANCZOS is a high-quality downsample filter but
# bilinear better matches previous `jax.image.resize(..., method='bilinear')`.
pil = pil.resize((image_width, image_height), resample=Image.BILINEAR)

# Back to numpy -> jax
resized = np.asarray(pil).astype(np.float32)
image = jnp.asarray(resized)

image = normalize_images(image)
image = jnp.clip(image, -1, 1)
return image
Expand Down