[shmap/partial-auto] Fixes lowering for jax.lax.axis_index in shard_map for degenerated shmaps. #25699
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Please take a look at the test I added for the example.
We found that when the manual axis has size 1, the axis_index will be definitely a replicated constant 0 across all the devices. There is no need to add iota and sharding and go full to shard for that, which actually invokes a sharding normalization inside XLA.
Let's say if we do iota and adding the sharding custom ops, then the sharding will be something like,
devices=[1,8,1]<=[8] {replicated, manual} for the iota result.
Then this sharding is normalized to {replicated} in XLA.
FullToShard
custom call cannot takereplicated
sharding which will invoke a crash in GSPMD.