Skip to content

Added AutoSharding Distribution API #21583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

buildwithsuhana
Copy link
Contributor

This PR introduces the initial implementation of the AutoShard API, a new feature designed to simplify distributed training for Keras 3.0 users. The primary goal of this feature is to provide automated sharding capabilities that will eventually work across JAX, TensorFlow, and PyTorch backends, enabling users to leverage model and data parallelism without needing deep knowledge of backend-specific APIs.

This initial implementation focuses on providing a fully functional AutoShardDistribution strategy for the JAX backend.

The core idea is to allow users to focus on model development, while Keras handles the complexities of distributed execution. By simply wrapping their model definition and compilation in the distribution.scope(), users can enable powerful data and model parallelism with minimal code changes.

This PR establishes the foundational, backend-agnostic interfaces (KerasGraph, ShardingPlanner, ShardApplier) and core data structures (DeviceMesh, ShardingPlan) that will be used for future backend implementations as outlined in the design document.

For more details, please refer to the

Keras Auto-Sharding Design Document.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the foundational AutoShardDistribution API to Keras 3.0, aiming to simplify distributed training by automating model and data parallelism. The initial implementation focuses on providing full functionality for the JAX backend. This feature allows users to enable powerful distributed execution with minimal code changes, as Keras will now handle the complexities of sharding. The PR establishes key backend-agnostic interfaces and data structures, paving the way for future support across other backends like TensorFlow and PyTorch.

Highlights

  • New AutoShardDistribution API: The core AutoShardDistribution class has been added, providing an API for automated sharding of Keras models. This class handles the orchestration of sharding by leveraging backend-specific components.
  • JAX Backend Integration for Auto-Sharding: Initial support for the JAX backend has been implemented, including a JaxGraph for tracing Keras models into JAX computation graphs, a JaxShardingPlanner to determine optimal sharding layouts, and a JaxShardApplier to apply these plans to model variables.
  • Graph Analysis Utility for Sharding: A new MergeableGraph utility has been introduced, which is crucial for the graph analysis performed by the sharding planner to identify and group related axes for efficient sharding.
  • New Test Coverage: Comprehensive unit tests have been added for the AutoShardDistribution on the JAX backend, ensuring the functionality works as expected for simple model architectures.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the AutoSharding Distribution API, a significant new feature for simplifying distributed training in Keras. The implementation for the JAX backend is comprehensive, including graph parsing, a sharding planner, and a shard applier. My review has identified a few critical issues in the implementation that need to be addressed. Specifically, there are bugs related to handling model arguments with keywords in the JAX backend's sharding logic and missing attribute initializations in the AutoShardDistribution class. I've also noted a performance concern regarding redundant tracing. Addressing these points will improve the robustness and efficiency of this new API.

@codecov-commenter
Copy link

codecov-commenter commented Aug 14, 2025

Codecov Report

❌ Patch coverage is 76.98413% with 58 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.73%. Comparing base (0c6c363) to head (5f30daa).
⚠️ Report is 16 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/jax/distribution_lib.py 80.79% 28 Missing and 6 partials ⚠️
keras/src/distribution/distribution_lib.py 59.18% 18 Missing and 2 partials ⚠️
keras/src/distribution/autoshard_utils.py 87.50% 1 Missing and 2 partials ⚠️
keras/api/_tf_keras/keras/distribution/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff            @@
##           master   #21583    +/-   ##
========================================
  Coverage   82.73%   82.73%            
========================================
  Files         567      568     +1     
  Lines       56440    56779   +339     
  Branches     8818     8869    +51     
========================================
+ Hits        46696    46977   +281     
- Misses       7581     7632    +51     
- Partials     2163     2170     +7     
Flag Coverage Δ
keras 82.54% <76.98%> (+<0.01%) ⬆️
keras-jax 63.83% <76.98%> (+0.02%) ⬆️
keras-numpy 58.16% <7.93%> (-0.14%) ⬇️
keras-openvino 34.63% <7.93%> (-0.02%) ⬇️
keras-tensorflow 64.06% <7.93%> (-0.19%) ⬇️
keras-torch 63.66% <7.93%> (-0.21%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@buildwithsuhana buildwithsuhana marked this pull request as draft August 14, 2025 17:02
@buildwithsuhana buildwithsuhana marked this pull request as ready for review August 14, 2025 18:52
@buildwithsuhana buildwithsuhana marked this pull request as draft August 18, 2025 03:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants