-
Notifications
You must be signed in to change notification settings - Fork 19.6k
[JAX] Fix breakage due to default dtype handling in ctc_loss. #21588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Fix breakage due to default dtype handling in ctc_loss. #21588
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @divyashreepathihalli, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request addresses a critical breakage in Keras's ctc_loss
functionality when used with JAX. The issue arose because Keras was implicitly relying on a deprecated JAX API for default dtype handling, which was affected by a recent JAX update. The core change involves explicitly setting the data type for one-hot encoding operations within the ctc_loss
implementation in Keras's JAX backend. By providing the dtype
argument to jax.nn.one_hot
, the code no longer depends on the problematic deprecated behavior, ensuring compatibility and stability with newer JAX versions.
Highlights
- Explicit Dtype Specification: The
jax.nn.one_hot
function calls within thectc_loss
implementation inkeras/src/backend/jax/nn.py
now explicitly pass thedtype
argument. - Addressing Deprecated JAX API Usage: This change ensures that the one-hot encoded tensors have the correct data type, preventing implicit reliance on JAX's default dtype handling, which was causing issues due to a deprecated API.
- Fixing Keras Test Breakage: The fix directly resolves a test breakage in Keras's
ctc_loss
that occurred after a recent JAX update related to default dtypes.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a breakage in ctc_loss
for the JAX backend, caused by a change in JAX's default dtype handling. The fix involves explicitly specifying the dtype
in two jax.nn.one_hot
calls within the ctc_loss
function. This change is correct and makes the implementation more robust by removing the dependency on the deprecated default dtype behavior. The chosen dtypes match the other tensors used in the subsequent operations, ensuring correctness. The changes are well-targeted and I approve them.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #21588 +/- ##
=======================================
Coverage 82.73% 82.73%
=======================================
Files 567 567
Lines 56678 56678
Branches 8839 8839
=======================================
Hits 46895 46895
Misses 7609 7609
Partials 2174 2174
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
from cl/794615827
"_Keras is the only user of a deprecated JAX api (jax_default_dtype_bits) for electing float32 default dtypes. This API will be removed in an upcoming JAX release.
A recent JAX change to default dtype handling if that flag is set broke a Keras test (cl/794553125). Since the behavior that Keras is relying on is deprecated it is probably best to update Keras not to depend on that default dtype handling in the first place in the test that broke (ctc_loss)._"