Replies: 5 comments 3 replies
-
Thanks for the questions!
Could you share a reproducer we could study? It's hard to guess at what's happening here.
Well, that JEP is meant to explain that
It's hard to say what's going on with the automatic version, but in general I would suggest adding Are there any downsides for your work to just sticking with the manual approach? (We're also working on another mode, where you still get to write code as if programming a single device, with a 'global view' of arrays and no explicit collectives, but rather than having sharding and partitioning decisions happen opaquely in the compiler, they're transparent and explicit at trace time, so that e.g. you can reflect on sharding just like you can on shapes.)
Yes, very possible. Actually, I wouldn't consider the automatic version to be "more optimized" in general; there's nothing it can do that you can't express with
If you're losing significant performance, I suggest opening a bug (like this one!) with some kind of reproducer we can go on. And add sharding constraints with What do you think? |
Beta Was this translation helpful? Give feedback.
-
Hi Matt, Thanks for the response! A minimal reproducer will take a bit of time and effort -- let me get back to you on that. It's worth noting, though, that I did submit the aforementioned performance issue with a minimal reproducer (at least as minimal as I could get it).
At the moment, the
Do you think you could give a bit more detail on where these annotations might be necessary? For simple data parallelism, I would imagine that the sharding propagation is fairly trivial. I've used |
Beta Was this translation helpful? Give feedback.
-
Oops, I see! Sorry we haven't yet followed up on that to its conclusion.
Is FSDP hard to write with
I actually don't know; I'm not enough of a practitioner. But if you have some code that is misbehaving, e.g. with extra allreduces, I'd try to constrain the sharding around that misbehavior.
I think people usually just add them in the primal computation (ie the computation being differentiated), and those have an effect on the backward pass code in that a |
Beta Was this translation helpful? Give feedback.
-
Does Jax have pipeline/model/tensor parallelism options? That may be the clue to what is happening under the hood. What is generated is not as trivial as forward+backward/*+collective. Thanks, |
Beta Was this translation helpful? Give feedback.
-
Hi @kvablack , Just wondering if you are still waiting for solution? I did some experiment with your code on a 8xA100 GPU BM and found all-reduce was 100%, was able to make JAX to cast FT32 during all-reduce to avoid numeric instability while other necesaary operations were in bfloat16 in a mixed precession training. |
Beta Was this translation helpful? Give feedback.
-
Hi all,
I'm currently running data-parallel training, and I would like to better understand how the JIT compiler inserts collective communication operations. My current setup, which I think is the "canonical" setup, is like this:
My understanding is that the JIT compiler will automatically partition the computation across devices (due to the data parallel sharding) as well as insert the necessary all-reduce operations to keep the parameters replicated. My mental model for the sequence of operations is that each device does its own individual forward + backward pass, then the gradients are all-reduced, then each device does a gradient update.
However, when I look at the optimized HLO dump and Perfetto trace, I see all-reduce operations sprinkled throughout the entire computation. Furthermore, when I add up the number of all-reduced elements, it is only ~60% of the total number of parameters. Lastly, when I enable mixed-precision training (bfloat16 activations), I see some of the all-reduces performed in bfloat16 (even though the loss and gradients are definitely all float32 at tracing time). Here are my concrete questions:
To try and understand things better, I implemented a
shard_map
version to manually take control of cross-device communication, like so:Looking at the dumped HLO, it matches up much more with what I expect -- there is a big block of all-reduce operations at the end, they are all float32, and the total number of elements roughly adds up to the size of the network (although it's still missing 5% somehow -- maybe I miscounted). This gave me some more questions:
jax.grad
outside ofshard_map
, whereas I did it inside. However, this JEP specifically points out that there is an efficiency issue taking a grad ofshard_map
. What's wrong with keeping the grad inside, like I did?shard_map
version?shard_map
version? I ran into this awhile ago, which I documented in this issue. I use automatic partitioning for more advanced parallelism strategies (such as FSDP) but this makes me worry that I'm losing a ton of performance by not manually writing every matmul usingshard_map
.Beta Was this translation helpful? Give feedback.
All reactions