-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Added AutoSharding Distribution API #21583
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces the foundational AutoShardDistribution
API to Keras 3.0, aiming to simplify distributed training by automating model and data parallelism. The initial implementation focuses on providing full functionality for the JAX backend. This feature allows users to enable powerful distributed execution with minimal code changes, as Keras will now handle the complexities of sharding. The PR establishes key backend-agnostic interfaces and data structures, paving the way for future support across other backends like TensorFlow and PyTorch.
Highlights
- New AutoShardDistribution API: The core
AutoShardDistribution
class has been added, providing an API for automated sharding of Keras models. This class handles the orchestration of sharding by leveraging backend-specific components. - JAX Backend Integration for Auto-Sharding: Initial support for the JAX backend has been implemented, including a
JaxGraph
for tracing Keras models into JAX computation graphs, aJaxShardingPlanner
to determine optimal sharding layouts, and aJaxShardApplier
to apply these plans to model variables. - Graph Analysis Utility for Sharding: A new
MergeableGraph
utility has been introduced, which is crucial for the graph analysis performed by the sharding planner to identify and group related axes for efficient sharding. - New Test Coverage: Comprehensive unit tests have been added for the
AutoShardDistribution
on the JAX backend, ensuring the functionality works as expected for simple model architectures.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the AutoSharding Distribution API, a significant new feature for simplifying distributed training in Keras. The implementation for the JAX backend is comprehensive, including graph parsing, a sharding planner, and a shard applier. My review has identified a few critical issues in the implementation that need to be addressed. Specifically, there are bugs related to handling model arguments with keywords in the JAX backend's sharding logic and missing attribute initializations in the AutoShardDistribution
class. I've also noted a performance concern regarding redundant tracing. Addressing these points will improve the robustness and efficiency of this new API.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #21583 +/- ##
========================================
Coverage 82.73% 82.73%
========================================
Files 567 568 +1
Lines 56440 56779 +339
Branches 8818 8869 +51
========================================
+ Hits 46696 46977 +281
- Misses 7581 7632 +51
- Partials 2163 2170 +7
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This PR introduces the initial implementation of the AutoShard API, a new feature designed to simplify distributed training for Keras 3.0 users. The primary goal of this feature is to provide automated sharding capabilities that will eventually work across JAX, TensorFlow, and PyTorch backends, enabling users to leverage model and data parallelism without needing deep knowledge of backend-specific APIs.
This initial implementation focuses on providing a fully functional AutoShardDistribution strategy for the JAX backend.
The core idea is to allow users to focus on model development, while Keras handles the complexities of distributed execution. By simply wrapping their model definition and compilation in the distribution.scope(), users can enable powerful data and model parallelism with minimal code changes.
This PR establishes the foundational, backend-agnostic interfaces (KerasGraph, ShardingPlanner, ShardApplier) and core data structures (DeviceMesh, ShardingPlan) that will be used for future backend implementations as outlined in the design document.
For more details, please refer to the
Keras Auto-Sharding Design Document.