You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When donating a buffer and the donation target has under-specified sharding, one can get an input-output aliasing mismatch error if XLA infers output-sharding for the donation target that is different from the donation source.
Here is an example of the error:
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Expected aliased input 0 at index {} and output at index {0} to have the same size. Input sub-shape is s32[16]{0} with size 64, output sub-shape is s32[8]{0} with size 32
Either the message should be better (recommend providing concrete sharding for everything), or JAX should ensure that the mismatch never happens (by passing the same fixed sharding for aliased input-output pairs to XLA, or perhaps XLA's sharding inference should respect aliasing in SPMD partitioning).
Description
When donating a buffer and the donation target has under-specified sharding, one can get an input-output aliasing mismatch error if XLA infers output-sharding for the donation target that is different from the donation source.
Here is an example of the error:
Either the message should be better (recommend providing concrete sharding for everything), or JAX should ensure that the mismatch never happens (by passing the same fixed sharding for aliased input-output pairs to XLA, or perhaps XLA's sharding inference should respect aliasing in SPMD partitioning).
Repro:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: