Skip to content

Add ops.image.scale_and_translate. #21577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 14, 2025

Conversation

james77777778
Copy link
Contributor

This op is useful if we want to perform the same functionality as torch.nn.functional.interpolate with align_corners=True.

Some depth estimation models require this functionality.

Here is an example:

import numpy as np
import torch

from keras.src import ops

torch_y = torch.nn.functional.interpolate(
    torch.from_numpy(x).unsqueeze(0).unsqueeze(0),
    size=(5, 5),
    mode="bilinear",
    align_corners=True,
)
torch_y = torch_y.detach().cpu().numpy()[0, 0]


def keras_interpolate(images, size, interpolation, antialias=False):
    shape = tuple([1] + list(size) + [1])
    spatial_dims = (1, 2)
    scale = ops.convert_to_tensor(
        [(shape[i] - 1.0) / (images.shape[i] - 1.0) for i in spatial_dims]
    )
    translation = -(scale / 2.0 - 0.5)
    return ops.image.scale_and_translate(
        images,
        shape,
        spatial_dims,
        scale,
        translation,
        interpolation,
        antialias,
    )


keras_y = keras_interpolate(
    ops.expand_dims(ops.convert_to_tensor(x, dtype="float32"), (0, -1)),
    size=(5, 5),
    interpolation="bilinear",
    antialias=False,
)
keras_y = ops.convert_to_numpy(keras_y)[0, ..., 0]
np.testing.assert_allclose(torch_y, keras_y, rtol=1e-5, atol=1e-5)

The benchmark:

  • (1, 256, 256, 64) to (1, 128, 128, 64)
  • Jitted / Compiled
Backend Time
jax 0.0114s
tensorflow 0.0668s
torch 0.0274s
benchmark.py
import time

import numpy as np

from keras.src import backend
from keras.src import ops

images = np.random.uniform(size=(2, 256, 256, 64)).astype("float32")
shape = (2, 128, 128, 64)
spatial_dims = (1, 2)
scale = np.array(
    [(shape[i] - 1.0) / (images.shape[i] - 1.0) for i in spatial_dims],
    dtype="float32",
)
translation = -(scale / 2.0 - 0.5)
method = "linear"

if backend.backend() == "jax":
    import jax

    jit_scale_and_translate = jax.jit(
        ops.image.scale_and_translate, static_argnums=(1, 2, 5)
    )

elif backend.backend() == "tensorflow":
    import tensorflow as tf

    @tf.function(jit_compile=True)
    def jit_scale_and_translate(*args, **kwargs):
        return ops.image.scale_and_translate(*args, **kwargs)

elif backend.backend() == "torch":
    import torch

    @torch.compile
    def jit_scale_and_translate(*args, **kwargs):
        return ops.image.scale_and_translate(*args, **kwargs)
else:

    def jit_scale_and_translate(*args, **kwargs):
        return ops.image.scale_and_translate(*args, **kwargs)


# Warmup
for _ in range(3):
    jit_scale_and_translate(
        images, shape, spatial_dims, scale, translation, method
    )
print("Warmup done.")

st = time.time()
for i in range(5):
    resized_images = jit_scale_and_translate(
        images, shape, spatial_dims, scale, translation, method
    )
ed = time.time()

print(resized_images.shape)
print(f"Time taken: {ed - st:.4f} seconds")

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @james77777778, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new ops.image.scale_and_translate operation to the Keras library. This operation is designed to provide robust image scaling and translation capabilities, particularly useful for tasks such as depth estimation where precise interpolation with align_corners=True behavior is required. The implementation is fully integrated across Keras's multi-backend system, ensuring consistent behavior and performance across JAX, NumPy, TensorFlow, and Torch. This addition enhances Keras's image processing toolkit, offering a flexible and performant solution for common image manipulation needs in deep learning workflows.

Highlights

  • New Image Operation: Introduces a new image operation, ops.image.scale_and_translate, which provides functionality similar to torch.nn.functional.interpolate with align_corners=True, crucial for applications like depth estimation models.
  • Multi-Backend Implementation: The new operation is implemented across JAX, NumPy, TensorFlow, and Torch backends, ensuring broad compatibility and leveraging backend-specific optimizations.
  • Code Refactoring and Internalization: Refactored several internal helper functions (e.g., _compute_weight_mat, _compute_homography_matrix, and kernel functions) by prefixing them with an underscore, clearly indicating their private nature within the backend implementations.
  • Improved Test Coverage: Comprehensive unit tests have been added for the scale_and_translate operation, covering symbolic and static shape inference, as well as behavioral correctness across various interpolation methods and antialiasing settings.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new image operation, ops.image.scale_and_translate, which is a useful addition, especially for functionalities similar to torch.nn.functional.interpolate with align_corners=True. The implementations for JAX, TensorFlow, and PyTorch backends are well-structured and follow the patterns of existing operations. The tests are comprehensive, covering symbolic execution and correctness against a reference implementation.

I've identified a bug in the NumPy backend's _resize function that was introduced during refactoring, which would break existing resize functionality. I've also suggested a clarification in the docstring for the new scale_and_translate operation to improve clarity for users. Please see the detailed comments for suggestions.

@codecov-commenter
Copy link

codecov-commenter commented Aug 13, 2025

Codecov Report

❌ Patch coverage is 85.27919% with 29 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.76%. Comparing base (ce0d278) to head (80b26c2).

Files with missing lines Patch % Lines
keras/src/backend/tensorflow/image.py 81.53% 7 Missing and 5 partials ⚠️
keras/src/backend/torch/image.py 84.61% 6 Missing and 4 partials ⚠️
keras/src/backend/numpy/image.py 91.17% 1 Missing and 2 partials ⚠️
keras/src/backend/jax/image.py 83.33% 1 Missing and 1 partial ⚠️
keras/api/_tf_keras/keras/ops/image/__init__.py 0.00% 1 Missing ⚠️
keras/src/ops/image.py 93.33% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff            @@
##           master   #21577    +/-   ##
========================================
  Coverage   82.75%   82.76%            
========================================
  Files         567      567            
  Lines       56531    56694   +163     
  Branches     8824     8841    +17     
========================================
+ Hits        46785    46920   +135     
- Misses       7585     7602    +17     
- Partials     2161     2172    +11     
Flag Coverage Δ
keras 82.56% <85.27%> (+<0.01%) ⬆️
keras-jax 63.67% <28.42%> (-0.11%) ⬇️
keras-numpy 58.20% <39.59%> (-0.09%) ⬇️
keras-openvino 34.60% <14.72%> (-0.06%) ⬇️
keras-tensorflow 64.17% <46.19%> (-0.04%) ⬇️
keras-torch 63.78% <46.70%> (-0.04%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@james77777778
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new image operation, ops.image.scale_and_translate, which is useful for tasks like depth estimation. The implementation is added for JAX, TensorFlow, and NumPy backends, with a NotImplementedError for OpenVINO. The changes also include some nice refactoring, such as moving constants to the top of backend files and improving helper functions. The new operation is well-tested.

I have one minor suggestion to improve the clarity of the docstring for the new function. Otherwise, the changes look great.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the contribution!

@@ -14,6 +14,34 @@
"lanczos5",
"bicubic",
)
AFFINE_TRANSFORM_INTERPOLATIONS = { # map to order
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this is clearer.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Aug 14, 2025
@fchollet fchollet merged commit 45c98ec into keras-team:master Aug 14, 2025
11 checks passed
@james77777778 james77777778 deleted the add-scale-and-translate branch August 15, 2025 00:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready to pull Ready to be merged into the codebase size:L
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants