Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Poor performance with WarpReduction #19868

Open
IanWood1 opened this issue Jan 31, 2025 · 4 comments
Open

[GPU] Poor performance with WarpReduction #19868

IanWood1 opened this issue Jan 31, 2025 · 4 comments
Labels
bug 🐞 Something isn't working

Comments

@IanWood1
Copy link
Contributor

IanWood1 commented Jan 31, 2025

What happened?

I made some small changes to dispatch creation and was testing SDXl but found that compilation times were excessively long. I found that there is a dispatch generating a bunch of ops and has a large allocation. Its only slightly different from the previous dispatch, with only the number of loops being different.

I tried to collapse the loops as much as possible but that didn't seem to fix the problem. It seems the pipeline needs a very specific configuration of loops to work well.

Steps to reproduce your issue

  1. Download https://gist.github.com/IanWood1/948a753821f8aa8dc07288c903778fb5
  2. Run iree-compile

What component(s) does this issue relate to?

No response

Version information

36e7593

Additional context

No response

@IanWood1 IanWood1 added the bug 🐞 Something isn't working label Jan 31, 2025
IanWood1 added a commit to IanWood1/iree that referenced this issue Feb 6, 2025
The tweak to collapse dims prevents a compilation timeout, but it has
horrible effects on the runtime performance. When there are multiple
reduction ops and it goes down warp reduction, the dispatch has to be
in a very specific state to have good results. Otherwise, compilation
times out or the compiled dispatch is VERY slow (3x total sdxl runtime).

See: iree-org#19868

I found that there are a few sdxl instances of

1 = op with multiple uses
2 = consumer of "1" (transpose)
3 = consumer of "2" (bit extend)

However, there is a reshape that will get stuck between 1-2 or 2-3
depending on which pass you look at (maybe always 2-3). 1-2 could be fused with
multi-use fusion.

Signed-off-by: Ian Wood <[email protected]>
@MaheshRavishankar
Copy link
Contributor

@pashu123 can you help look into this issue. I know you were trying to deprecate the warp reduction pipeline and use tile and fuse. That is the better overall solution, but if there is a quick fix with the warp reduction pipeline to get past this, it would be great.

@pashu123
Copy link
Contributor

pashu123 commented Feb 7, 2025

sure!

@pashu123
Copy link
Contributor

pashu123 commented Feb 7, 2025

Hi @IanWood1, The problem here is that since the size/rank of the first linalg.generic and the last linalg.generic don't match, it's unable to fuse the last generic, creating large vector sizes. Here's the dump https://gist.github.com/pashu123/b01165555eeb84b9cf404cd5c9f51072 @MaheshRavishankar Do we want to support fusion in this case?

@IanWood1
Copy link
Contributor Author

IanWood1 commented Feb 8, 2025

Following up from the discussion earlier:

Here's a smaller example that produces the same problem:

util.func public @test1(%arg0: tensor<32x102400xf32>, %arg1: tensor<32x10x10240xf32>) -> tensor<32x10x10240xf32> {
  %0 = tensor.empty() : tensor<32xf32>
  %1 = tensor.empty() : tensor<32x10x10240xf32>
  %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<32x102400xf32>) outs(%0 : tensor<32xf32>) {
  ^bb0(%in: f32, %out: f32):
    %4 = arith.addf %in, %out : f32
    linalg.yield %4 : f32
  } -> tensor<32xf32>
  %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg1, %2 : tensor<32x10x10240xf32>, tensor<32xf32>) outs(%1 : tensor<32x10x10240xf32>) {
  ^bb0(%in: f32, %in_0: f32, %out: f32):
    %4 = arith.addf %in, %in_0 : f32
    linalg.yield %4 : f32
  } -> tensor<32x10x10240xf32>
  util.return %3 : tensor<32x10x10240xf32>
}
iree-compile repro.mlir --iree-hal-target-backends=rocm --iree-hip-target=gfx1100 --iree-dispatch-creation-enable-aggressive-fusion -o /dev/null

The issue is that the lowering config is set on the reduction op and then propagated to the consumer. E.g.

%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%3 : tensor<32x102400xf32>) outs(%6 : tensor<32xf32>) attrs = 
 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1], [0, 4096]]>} {
^bb0(%in: f32, %out: f32):
  %9 = arith.addf %in, %out : f32
  linalg.yield %9 : f32
} -> tensor<32xf32>

%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%4, %7 : tensor<32x10x10240xf32>, tensor<32xf32>) outs(%5 : tensor<32x10x10240xf32>) attrs = 
 {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1], [0, 4096]]>} {   
^bb0(%in: f32, %in_0: f32, %out: f32):
  %9 = arith.addf %in, %in_0 : f32 
  linalg.yield %9 : f32
} -> tensor<32x10x10240xf32>

But this leaves the dim of size 10240 untiled. In the case where the consumer is a lower dimensionality than the producer, the vector size is too small to distribute to threads during the ConvertVectorReductionToGPUPass. So the entire parallel generic gets stuck in the gpu.warp_execute_on_lane_0 op (causing performance issues).

@pashu123 do you think it would be possible to set a separate config on the consumer? Otherwise, it seems like we shouldn't be forming these dispatches and it has been luck that the reduction tiling has matched the parallel consumer.


Also, a bit of a tangent but I noticed that it was common for there to be doubly nested scf.for loops (from the tiled consumer) so I tried funcPassManager.addPass(affine::createLoopCoalescingPass()); directly after the WarpReduction pipeline and I saw ~20% boost in performance vs main. However, it only seemed to have ~0.8ms impact on punet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants