Skip to content

Conversation

felixwqp
Copy link
Collaborator

Try to fix the large number of zero-broadcast generated from big pad in spmd partitioner.

for sample HLO,

  HloModule module_0076.reactant_loop_after_spmd_partitioner_cp_pattern

    ENTRY main() -> f64[3056,123] {
      %constant.946 = f64[] constant(0)
      %subtract.1055 = f64[3056,12272]{1,0} parameter(0), sharding={devices=[2,2]<=[2,2]T(1,0)}
      %slice.1057 = f64[3056,1]{1,0} slice(%subtract.1055), slice={[0:3056], [12271:12272]}, sharding={devices=[2,2]<=[2,2]T(1,0)}
      ROOT %pad.1059 = f64[3056,123]{1,0} pad(%slice.1057, %constant.946), padding=0_0x0_122, sharding={devices=[2,2]<=[2,2]T(1,0)}
    }

before the fix, the result looks like: 4LjPYjEP2XJGwcn

after the fix, the result looks like: https://screenshot.googleplex.com/
3BZP5MRC5kMpK2L

@giordano
Copy link
Member

I fixed a bug in the workflow where we were using the wrong xla commit (apparently this is the first time we're using a custom xla commit, so we didn't spot it before) and also copied the changes from #1243 to replicate the conditions which trigger the OOM.

For readers, this is related to #671.

@giordano
Copy link
Member

This is failing to fetch the GB-25 repository, without much information: https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/16985145489/job/48166877204#step:11:63. Sigh.

@wsmoses
Copy link
Member

wsmoses commented Aug 15, 2025

also the sed for xla replacement is now wrong (per a lot of bazel reshuffling done over the past week to fix things more stably).

The new way to do so is in https://github.com/EnzymeAD/Reactant.jl/blob/07d4fcacb935f3e915ca40c1ab7c98a210a93efc/deps/ReactantExtra/WORKSPACE#L198:

xla_workspace(NEW_XLA_PATCHES)

becomes

xla_workspace(NEW_XLA_PATCHES, 'abc2342343432432432')

@giordano giordano force-pushed the wfelix_xla_dev branch 2 times, most recently from 480fc95 to 7b06885 Compare August 15, 2025 20:45
@giordano
Copy link
Member

giordano commented Aug 15, 2025

https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/16999045099/job/48196660878?pr=1307#step:18:720

2025-08-15 21:13:25.777440: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3181] Can't reduce memory use below 10.64GiB (11429599132 bytes) by rematerialization; only reduced to 58.50GiB (62817533356 bytes), down from 61.39GiB (65919097928 bytes) originally

same as #1243 (comment) 😢 XLA dump uploaded to https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/16999045099/artifacts/3777223905

@felixwqp
Copy link
Collaborator Author

Thanks @giordano for the quick verification, just trying to understand effect the above XLA commit to help us better prioritize the optimization direction.

can I assume,

without the commit, link, after rematerialization, memory usage is 67.69GiB

2025-08-02 23:46:25.965629: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3423] Can't reduce memory use below 10.64GiB (11429599125 bytes) by rematerialization; only reduced to 67.69GiB (72679996812 bytes), down from 71.05GiB (76287081352 bytes) originally

after the commit, the link, after remateriliazation, the memory usage is 58.50GiB,

2025-08-15 21:13:25.777440: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3181] Can't reduce memory use below 10.64GiB (11429599132 bytes) by rematerialization; only reduced to 58.50GiB (62817533356 bytes), down from 61.39GiB (65919097928 bytes) originally

Question:

  1. is there any Enzyme level difference between these two runs account for these 10GB memory usage reduction? If not it means the HLO optimization did improve the memory usage, we can prioritize the same methodology for further XLA memory optimization.

cc: @wsmoses

@giordano
Copy link
Member

giordano commented Aug 15, 2025

I believe I linked two different lines, i.e. after two different compilation stages (we compile two different kernels), this should be a more direct comparison (always after the first kernel):
#1243:

2025-08-02 23:41:14.191988: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3423] Can't reduce memory use below 10.64GiB (11429599132 bytes) by rematerialization; only reduced to 58.50GiB (62817533356 bytes), down from 61.39GiB (65919097928 bytes) originally

this PR

2025-08-15 21:13:25.777440: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3181] Can't reduce memory use below 10.64GiB (11429599132 bytes) by rematerialization; only reduced to 58.50GiB (62817533356 bytes), down from 61.39GiB (65919097928 bytes) originally

In any case I changed this PR to always run "vanilla XLA" vs your PR, for a quicker comparison.

@wsmoses
Copy link
Member

wsmoses commented Aug 16, 2025

x/ref openxla/xla#30307

@giordano giordano force-pushed the wfelix_xla_dev branch 2 times, most recently from e068323 to aa9dbfb Compare August 16, 2025 22:17
@giordano
Copy link
Member

@felixwqp I warmly recommend using git --force-with-lease instead of --force, so you don't keep undoing my fixes 🙂

@felixwqp
Copy link
Collaborator Author

Ah, thank you for this suggestion! Will use it going forward. apologies for overriding your commits, it's new to github review process, any suggestions are helpful and welcome!

@felixwqp
Copy link
Collaborator Author

I want to only update the xla commit to 1ac176a9b8b4800bc2753d944eec62a39e6189b8 to verify if the hlo dump looks as intended. No need to trigger OOM anymore.

new commit failed with

Error: The artifact name is not valid: julia-environment-1.11-1ac176a9b8b4800bc2753d944eec62a39e6189b8-mg_sharded-factors-ap/persistent_compile_cache. Contains the following character:  Forward slash /

Should I adjust reactant_commit?

   reactant_commit:
          - 'ap/persistent_compile_cache'

@giordano
Copy link
Member

That was supposed to be addressed by #1297 🤔

@giordano
Copy link
Member

giordano commented Aug 18, 2025

Hopefully fixed by #1316. I rebased on main. Please don't override again my changes 😅 Edit: confirmed that fixed the artifact issue.

felixwqp and others added 5 commits August 19, 2025 07:13
I don't understand why the commit doesn't work now when it worked in the other
PR a few days ago, but the branch name should be good.
@giordano
Copy link
Member

The simulation step is quite a bit faster than on main currently:

I think most of the improvement is in compile time, we don't get the warning about long XLA compilation anymore.

@wsmoses
Copy link
Member

wsmoses commented Aug 19, 2025

2025-08-19 09:15:10.544797: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11429599125 bytes) by rematerialization; only reduced to 56.60GiB (60775368484 bytes), down from 66.05GiB (70922254524 bytes) originally

so we're now down to 56.6GB from 58.50GB.

So in short, definite compile time improvement confirmed, and a slight memory reduction -- though more memory reduction to go

@giordano
Copy link
Member

Somewhat good news, only calling initialize! + update_state! (without time_step!) is sufficient to get an OOM on the device (time_step! would cause to use even more memory): https://github.com/EnzymeAD/Enzyme-JAX/actions/runs/17173076863/job/48727161438?pr=1307#step:19:728

2025-08-23 09:38:22.863446: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11428802313 bytes) by rematerialization; only reduced to 46.84GiB (50297252461 bytes), down from 63.27GiB (67935638125 bytes) originally
[...]
E0000 00:00:1755941951.386733    3750 pjrt_stream_executor_client.cc:3081] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 51712165888 bytes. [tf-allocator-allocation-error='']
2025-08-23 09:39:11.389852: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_2_bfc) ran out of memory trying to allocate 48.16GiB (rounded to 51712165888)requested by op 
If the cause is memory fragmentation maybe the environment variable 'TF_GPU_ALLOCATOR=cuda_malloc_async' will improve the situation. 

The XLA dump is quite a bit smaller.

@giordano
Copy link
Member

I further reduced the code by removing some kernels, and this is still using more memory than necessary:

2025-08-23 14:04:03.144019: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11428802359 bytes) by rematerialization; only reduced to 12.40GiB (13313787004 bytes), down from 12.40GiB (13313787004 bytes) originally

The job didn't crash because the 12.4 GiB still fit within the total memory, but you can see at the end of the job that the peak memory usage is larger than the memory used after the model creation

┌ Info: [0] allocations
│   GordonBell25.allocatorstats() =
│    AllocatorStats
│    --------------
│    num_allocs: 41
│    bytes_in_use: 21894366976
│    peak_bytes_in_use: 28943843072
│    largest_alloc_size: 7049475840
│    bytes_limit: 31804391424
│    bytes_reserved: 0
│    peak_bytes_reserved: 0
│    bytes_reservable_limit: nothing
│    largest_free_block_bytes: 0
│    pool_bytes: 31804391424
│    peak_pool_bytes: 31804391424
└    

The XLA dump is even smaller

@felixwqp
Copy link
Collaborator Author

Thank you Mose!

There are primarily two modules, reactant__100 and reactant__first. I will start with reactant__first because reactant__first looks like have more memory allocated, but if you happen to know where the memory bottleneck happens, I will focus on the one with memory bottleneck,

@giordano
Copy link
Member

No, I think we're a bit clueless about where memory is going 🫠

@felixwqp
Copy link
Collaborator Author

Based on dump, I performed a memory profile analysis on HLO input (module_0097.reactant_first_t....before_optimizations.txt), run in an google internal H100 environment.

Top Temporary Ops Summary

Here is a breakdown of memory usage by operation type, focusing on temporary allocations:

Operation Type Total Memory (MiB) Percentage Number of Operations Op Names Framework Op
param 20880 75.87% *
copy 2880 10.46% 2 copy.29, copy.30 N/A
loop_add_fusion 2880 10.46% 2 loop_add_fusion{0}, loop_add_fusion{2} add.221
loop_select_fusion 862.89 3.14% 3 loop_select_fusion{0}, loop_select_fusion{1}, loop_select_fusion{2} pad.178
rest ops 18.51 0.07%
TOTAL 27521.5 100%
8prNPahc8JZAzpP

Questions

  1. Parameter Memory Usage: The HLO input parameters consume 20880 MiB, which accounts for 75.87% of the total memory allocation. Is this magnitude of memory usage for input parameters considered reasonable or expected from an MLIR perspective for this type of model or operation?
  2. If the parameter memory usage is larger than expected, could this suggest potential inefficiencies or optimization opportunities in the StableHLO to HLO conversion process?

@giordano
Copy link
Member

Is this magnitude of memory usage for input parameters considered reasonable or expected from an MLIR perspective for this type of model or operation?

Billy please do correct me, but I believe that's indeed expected. That's the memory we have right after the model generation. In #1243 (comment) I had anticipated about 20 GB only based on scaling the input parameters from previous runs, so I think these 20 GB are what we expect.

@wsmoses
Copy link
Member

wsmoses commented Aug 26, 2025

I need to stare more at the minimization, but the basic jist here is that the parameters will have some amount of memory usage, but in principle we should be able to use no additional memory. Specifically the original code just had that original allocation, and updated it in place.

In practice when we do the full, loop-based version I expect we'll have a factor of two for the induction variables or something, but here I would expect that these allocations should be eliminable.

@giordano
Copy link
Member

giordano commented Sep 3, 2025

With EnzymeAD/Reactant.jl#1619 we get a slightly lower peak memory:
main:

│    peak_bytes_in_use: 28943842560

vs PR:

│    peak_bytes_in_use: 28534432768

but the warning we get during compilation mentions a much larger memory buffer:
main

W0000 00:00:1756938493.424201    5944 hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11428802359 bytes) by rematerialization; only reduced to 12.40GiB (13313787004 bytes), down from 12.40GiB (13313787004 bytes) originally

vs PR

W0000 00:00:1756941213.091395   59266 hlo_rematerialization.cc:3183] Can't reduce memory use below 10.64GiB (11428802359 bytes) by rematerialization; only reduced to 16.16GiB (17352329320 bytes), down from 16.23GiB (17427728524 bytes) originally

XLA dumps: main vs PR.

@wsmoses
Copy link
Member

wsmoses commented Sep 3, 2025

I'm curious if #1363 creates similar improvements [or more], or causes more chaos.

@wsmoses
Copy link
Member

wsmoses commented Sep 3, 2025

also note that the absence of the dus_to_pad comm op in that PR causes all-gathers to return, which is bad

@giordano

This comment was marked as outdated.

@giordano giordano marked this pull request as draft September 13, 2025 18:25
@giordano
Copy link
Member

@felixwqp: @glwagner further reduced the kernels, I hope now we got something even more useful.

In this this run we launched the program three times separately, using 3 different kernels: one filling the halo regions only in the east-west direction (which is also the direction in which we do the sharding, so we expect device-device communication here), one filling halo regions only in the north-south direction (no sharding), and another one filling all the halo regions (device-device communication also here). We see

│    num_allocs: 41
│    bytes_in_use: 23404316416
│    peak_bytes_in_use: 24918199296
│    largest_alloc_size: 1513882880
│    bytes_limit: 31804391424
│    num_allocs: 41
│    bytes_in_use: 23404316416
│    peak_bytes_in_use: 23443442432
│    largest_alloc_size: 1509949440
│    bytes_limit: 31804391424
│    num_allocs: 40
│    bytes_in_use: 21894366976
│    peak_bytes_in_use: 23409627136
│    largest_alloc_size: 1515260160
│    bytes_limit: 31804391424

Here is the XLA dump from the run: simulation-xla-dump-1.11--main-main.zip. Here are the number of lines of the three modules:

% wc -l simulation-xla-dump-1.11--main-main/xla_dump_*/module_00*.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
     468 simulation-xla-dump-1.11--main-main/xla_dump_all/module_0049.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
     111 simulation-xla-dump-1.11--main-main/xla_dump_east_west/module_0051.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
     197 simulation-xla-dump-1.11--main-main/xla_dump_north_south/module_0051.reactant_fill_ha....sm_8.0_gpu_after_optimizations.txt
     776 total

I presume the more interesting to look at are the east-west and north-south ones, perhaps especially the former which does the device-device communication (and is also the shorter one), but comparing it with the north-south one may also be useful. Hope this helps!

Greg please do correct me if I said anything wrong above!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants