-
Notifications
You must be signed in to change notification settings - Fork 21
Update the xla commit to experiment potential fix for OOM #1307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This is failing to fetch the GB-25 repository, without much information: https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/16985145489/job/48166877204#step:11:63. Sigh. |
also the sed for xla replacement is now wrong (per a lot of bazel reshuffling done over the past week to fix things more stably). The new way to do so is in https://github.com/EnzymeAD/Reactant.jl/blob/07d4fcacb935f3e915ca40c1ab7c98a210a93efc/deps/ReactantExtra/WORKSPACE#L198:
becomes
|
480fc95
to
7b06885
Compare
https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/16999045099/job/48196660878?pr=1307#step:18:720
same as #1243 (comment) 😢 XLA dump uploaded to https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/16999045099/artifacts/3777223905 |
Thanks @giordano for the quick verification, just trying to understand effect the above XLA commit to help us better prioritize the optimization direction. can I assume, without the commit, link, after rematerialization, memory usage is
after the commit, the link, after remateriliazation, the memory usage is
Question:
cc: @wsmoses |
7b06885
to
a36cc1b
Compare
I believe I linked two different lines, i.e. after two different compilation stages (we compile two different kernels), this should be a more direct comparison (always after the first kernel):
In any case I changed this PR to always run "vanilla XLA" vs your PR, for a quicker comparison. |
x/ref openxla/xla#30307 |
e068323
to
aa9dbfb
Compare
aa9dbfb
to
cb3b0c4
Compare
@felixwqp I warmly recommend using |
Ah, thank you for this suggestion! Will use it going forward. apologies for overriding your commits, it's new to github review process, any suggestions are helpful and welcome! |
I want to only update the xla commit to new commit failed with
Should I adjust
|
That was supposed to be addressed by #1297 🤔 |
7026eb7
to
a76fba6
Compare
Hopefully fixed by #1316. I rebased on |
a76fba6
to
4160e61
Compare
…large number of zero-broadcasts
I don't understand why the commit doesn't work now when it worked in the other PR a few days ago, but the branch name should be good.
Update xla commit
4160e61
to
f887a3d
Compare
The simulation step is quite a bit faster than on main currently:
I think most of the improvement is in compile time, we don't get the warning about long XLA compilation anymore. |
so we're now down to 56.6GB from 58.50GB. So in short, definite compile time improvement confirmed, and a slight memory reduction -- though more memory reduction to go |
Somewhat good news, only calling
The XLA dump is quite a bit smaller. |
I further reduced the code by removing some kernels, and this is still using more memory than necessary:
The job didn't crash because the 12.4 GiB still fit within the total memory, but you can see at the end of the job that the peak memory usage is larger than the memory used after the model creation
The XLA dump is even smaller |
Thank you Mose! There are primarily two modules, |
No, I think we're a bit clueless about where memory is going 🫠 |
Based on dump, I performed a memory profile analysis on HLO input ( Top Temporary Ops SummaryHere is a breakdown of memory usage by operation type, focusing on temporary allocations:
![]() Questions
|
Billy please do correct me, but I believe that's indeed expected. That's the memory we have right after the model generation. In #1243 (comment) I had anticipated about 20 GB only based on scaling the input parameters from previous runs, so I think these 20 GB are what we expect. |
I need to stare more at the minimization, but the basic jist here is that the parameters will have some amount of memory usage, but in principle we should be able to use no additional memory. Specifically the original code just had that original allocation, and updated it in place. In practice when we do the full, loop-based version I expect we'll have a factor of two for the induction variables or something, but here I would expect that these allocations should be eliminable. |
With EnzymeAD/Reactant.jl#1619 we get a slightly lower peak memory:
but the warning we get during compilation mentions a much larger memory buffer:
vs PR
|
I'm curious if #1363 creates similar improvements [or more], or causes more chaos. |
also note that the absence of the dus_to_pad comm op in that PR causes all-gathers to return, which is bad |
This comment was marked as outdated.
This comment was marked as outdated.
@felixwqp: @glwagner further reduced the kernels, I hope now we got something even more useful. In this this run we launched the program three times separately, using 3 different kernels: one filling the halo regions only in the east-west direction (which is also the direction in which we do the sharding, so we expect device-device communication here), one filling halo regions only in the north-south direction (no sharding), and another one filling all the halo regions (device-device communication also here). We see
Here is the XLA dump from the run: % wc -l simulation-xla-dump-1.11--main-main/xla_dump_*/module_00*.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
468 simulation-xla-dump-1.11--main-main/xla_dump_all/module_0049.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
111 simulation-xla-dump-1.11--main-main/xla_dump_east_west/module_0051.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
197 simulation-xla-dump-1.11--main-main/xla_dump_north_south/module_0051.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
776 total I presume the more interesting to look at are the east-west and north-south ones, perhaps especially the former which does the device-device communication (and is also the shorter one), but comparing it with the north-south one may also be useful. Hope this helps! Greg please do correct me if I said anything wrong above! |
Try to fix the large number of zero-broadcast generated from big pad in spmd partitioner.
for sample HLO,
before the fix, the result looks like:
after the fix, the result looks like: https://screenshot.googleplex.com/
