Skip to content

Conversation

hertschuh
Copy link
Collaborator

@hertschuh hertschuh commented Oct 8, 2025

This is using the self-hosted TPU runners.

Only the multi-backend keras_rs/src/layers/embedding/distributed_embedding_test.py test is run for now. Many other tests have failures. They will be addressed in subsequent PRs.

  • Modified distributed_embedding_test.py to replace the TPU flag with a TPU_NAME environment variable as plumbing the flag through pytest is needlessly complicated.
  • Modified distributed_embedding_test.py to not require JAX to be installed when running against the TensorFlow backend.

@hertschuh hertschuh marked this pull request as draft October 8, 2025 00:42
Copy link

Summary of Changes

Hello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the foundational infrastructure required to enable comprehensive testing of Keras with both JAX and TensorFlow backends on Tensor Processing Units (TPUs). It establishes distinct build environments, including a specialized Dockerfile for JAX/TPU and dedicated dependency lists for both frameworks, ensuring all necessary TPU-related libraries are correctly installed for robust test execution.

Highlights

  • JAX TPU Test Environment: A new Dockerfile has been added to establish a dedicated testing environment for Keras using the JAX backend with TPU support. This includes Python 3.10, essential system packages, and configuration for the JAX backend.
  • JAX TPU Dependencies: A new requirements-jax-tpu.txt file has been introduced, specifying the necessary dependencies such as tensorflow-cpu, jax[tpu], and jax-tpu-embedding to facilitate JAX-based TPU testing.
  • TensorFlow TPU Dependencies: A new requirements-tensorflow-tpu.txt file has been added, outlining the specific tensorflow-tpu dependency required for TensorFlow-based TPU testing.
Ignored Files
  • Ignored by pattern: .github/workflows/** (1)
    • .github/workflows/actions.yml
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds Docker and requirements files for running JAX and TensorFlow tests on TPUs. My review has identified several critical issues that will cause the build and dependency installation to fail. In the Dockerfile, there's a syntax error. In the requirements files, non-existent versions of TensorFlow are specified. I have also provided some medium-severity suggestions for the Dockerfile to improve build performance and maintainability.

@hertschuh hertschuh force-pushed the tpu_tests branch 26 times, most recently from 39cb0a8 to 5c86f87 Compare October 10, 2025 00:43
@hertschuh hertschuh marked this pull request as ready for review October 10, 2025 00:52
@hertschuh hertschuh requested a review from abheesht17 October 10, 2025 00:52
@hertschuh
Copy link
Collaborator Author

FYI @cantonios @adityagupta1089

Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left some comments

@@ -0,0 +1,10 @@
# Tensorflow cpu-only version.
tensorflow-cpu>=2.20.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we keeping this for data processing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, right now the unit tests use tf.data on TPU. I can try to remove that dependency (but in a separate PR).

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=_TPU.value
)
resolver = tf.distribute.cluster_resolver.TPUClusterResolver("")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this automatically picks up the TPU_NAME env var when we pass an empty string? Should this be None instead of empty string?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this automatically picks up the TPU_NAME env var when we pass an empty string?

Exactly.

Should this be None instead of empty string?

Changed.

Comment on lines 183 to 184
import jax
import jax.experimental.sparse as jax_sparse
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming that JAX will not be present on the TensorFlow backend, can we do something like this on the top?

try:
  import jax
  import jax.experimental.sparse as jax_sparse
except ImportError as e:
  jax = None
  jax_sparse = None

BTW, in the future, when we move to 3.14, we can use lazy imports: https://peps.python.org/pep-0810/ :P

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming that JAX will not be present on the TensorFlow backend, can we do something like this on the top?

try:
  import jax
  import jax.experimental.sparse as jax_sparse
except ImportError as e:
  jax = None
  jax_sparse = None

Done

BTW, in the future, when we move to 3.14, we can use lazy imports: https://peps.python.org/pep-0810/ :P

I can't wait for this. However, we can only use these once we drop support for 3.10, 3.11, 3.12 and 3.13... so maybe in 4 years :/

This is using the self-hosted TPU runners.

Only the multi-backend `keras_rs/src/layers/embedding/distributed_embedding_test.py` test is run for now. Many other tests have failures. They will be addressed in subsequent PRs.

- Modified `distributed_embedding_test.py` to replace the `TPU` flag with a `TPU_NAME` environment variable as plumbing flag through `pytest` is needlessly complicated.
- Modified `distributed_embedding_test.py` to not require JAX to be installed when running against the TensorFlow backend.
Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hertschuh - oh, actually, do you want to do what KerasHub does, i.e., run TPU tests only when we add a GitHub label to the PR?

On second thoughts, maybe we can skip it because we don't have large tests to run right now

@hertschuh hertschuh merged commit 47ab13d into keras-team:main Oct 14, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants