-
Notifications
You must be signed in to change notification settings - Fork 791
Add sharding propagation support in nnx.eval_shape (clone of #5111) #5247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
samanklesaria
wants to merge
2
commits into
google:main
Choose a base branch
from
samanklesaria:eval_shape_sharding
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+42
−0
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -359,6 +359,36 @@ def test_sharding_axis_types(self, mode): | |
| self.assertEqual(v.sharding.mesh, mesh) | ||
| self.assertEqual(v.sharding.spec, P('row', 'col')) | ||
|
|
||
| def test_eval_shape_with_sharding0(self): | ||
| # based on https://github.com/google/flax/issues/5110 | ||
| mesh1 = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto)) | ||
| mesh2 = jax.make_mesh((1, 4), ("c", "d"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto)) | ||
|
|
||
| class Model(nnx.Module): | ||
| def __init__(self): | ||
| self.p1 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"out_sharding": ("a", "b"), "mesh": mesh1}) | ||
| self.p2 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"out_sharding": ("c", "d"), "mesh": mesh2}) | ||
|
|
||
| abs_model = nnx.eval_shape(lambda: Model()) | ||
| assert isinstance(abs_model.p1.kernel.sharding, jax.sharding.NamedSharding) | ||
| assert abs_model.p1.kernel.sharding.mesh is mesh1 | ||
| assert abs_model.p1.kernel.sharding.spec == jax.P("a", "b") | ||
| assert isinstance(abs_model.p2.kernel.sharding, jax.sharding.NamedSharding) | ||
| assert abs_model.p2.kernel.sharding.mesh is mesh2 | ||
| assert abs_model.p2.kernel.sharding.spec == jax.P("c", "d") | ||
|
Comment on lines
+362
to
+378
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def test_eval_shape_with_sharding1(self): | ||
| class Model(nnx.Module): | ||
| def __init__(self): | ||
| self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0), kernel_metadata={"out_sharding": ("a", "b")}) | ||
|
|
||
| mesh = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto)) | ||
| with jax.set_mesh(mesh): | ||
| abs_model = nnx.eval_shape(lambda: Model()) | ||
| assert isinstance(abs_model.linear.kernel.sharding, jax.sharding.NamedSharding) | ||
| assert abs_model.linear.kernel.sharding.mesh is mesh | ||
| assert abs_model.linear.kernel.sharding.spec == jax.P("a", "b") | ||
|
Comment on lines
+380
to
+390
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| def has_sharding_spec(array): | ||
| sharding = array.sharding | ||
| if hasattr(sharding, 'spec'): | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for applying sharding within
_to_variableseems to correctly handle cases where a mesh is explicitly provided in metadata or a global mesh is available. This ensures thatjax.ShapeDtypeStructis created with the appropriate sharding information.