-
Notifications
You must be signed in to change notification settings - Fork 16
Add TPU tests for JAX and Tensorflow. #160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces the foundational infrastructure required to enable comprehensive testing of Keras with both JAX and TensorFlow backends on Tensor Processing Units (TPUs). It establishes distinct build environments, including a specialized Dockerfile for JAX/TPU and dedicated dependency lists for both frameworks, ensuring all necessary TPU-related libraries are correctly installed for robust test execution. Highlights
Ignored Files
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds Docker and requirements files for running JAX and TensorFlow tests on TPUs. My review has identified several critical issues that will cause the build and dependency installation to fail. In the Dockerfile
, there's a syntax error. In the requirements files, non-existent versions of TensorFlow are specified. I have also provided some medium-severity suggestions for the Dockerfile
to improve build performance and maintainability.
39cb0a8
to
5c86f87
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Left some comments
@@ -0,0 +1,10 @@ | |||
# Tensorflow cpu-only version. | |||
tensorflow-cpu>=2.20.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we keeping this for data processing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, right now the unit tests use tf.data
on TPU. I can try to remove that dependency (but in a separate PR).
resolver = tf.distribute.cluster_resolver.TPUClusterResolver( | ||
tpu=_TPU.value | ||
) | ||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver("") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, this automatically picks up the TPU_NAME
env var when we pass an empty string? Should this be None instead of empty string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, this automatically picks up the
TPU_NAME
env var when we pass an empty string?
Exactly.
Should this be None instead of empty string?
Changed.
import jax | ||
import jax.experimental.sparse as jax_sparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming that JAX will not be present on the TensorFlow backend, can we do something like this on the top?
try:
import jax
import jax.experimental.sparse as jax_sparse
except ImportError as e:
jax = None
jax_sparse = None
BTW, in the future, when we move to 3.14, we can use lazy imports: https://peps.python.org/pep-0810/ :P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming that JAX will not be present on the TensorFlow backend, can we do something like this on the top?
try: import jax import jax.experimental.sparse as jax_sparse except ImportError as e: jax = None jax_sparse = None
Done
BTW, in the future, when we move to 3.14, we can use lazy imports: https://peps.python.org/pep-0810/ :P
I can't wait for this. However, we can only use these once we drop support for 3.10, 3.11, 3.12 and 3.13... so maybe in 4 years :/
This is using the self-hosted TPU runners. Only the multi-backend `keras_rs/src/layers/embedding/distributed_embedding_test.py` test is run for now. Many other tests have failures. They will be addressed in subsequent PRs. - Modified `distributed_embedding_test.py` to replace the `TPU` flag with a `TPU_NAME` environment variable as plumbing flag through `pytest` is needlessly complicated. - Modified `distributed_embedding_test.py` to not require JAX to be installed when running against the TensorFlow backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hertschuh - oh, actually, do you want to do what KerasHub does, i.e., run TPU tests only when we add a GitHub label to the PR?
On second thoughts, maybe we can skip it because we don't have large tests to run right now
This is using the self-hosted TPU runners.
Only the multi-backend
keras_rs/src/layers/embedding/distributed_embedding_test.py
test is run for now. Many other tests have failures. They will be addressed in subsequent PRs.distributed_embedding_test.py
to replace theTPU
flag with aTPU_NAME
environment variable as plumbing the flag throughpytest
is needlessly complicated.distributed_embedding_test.py
to not require JAX to be installed when running against the TensorFlow backend.