Skip to content

Commit 6b6041a

Browse files
authored
Merge pull request #1701 from ROCm/3rd-backport-6.3
Workaround for issue 1661
2 parents 631ecff + 126296f commit 6b6041a

File tree

5 files changed

+48
-11
lines changed

5 files changed

+48
-11
lines changed

mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,7 +1043,7 @@ struct GridwiseAttentionAccelRewritePattern
10431043
gemm0OutExpTrs, gemm0OutTrs},
10441044
/*bounds=*/ArrayRef<int64_t>{g0Mpt, g0Npt},
10451045
/*strides=*/ArrayRef<int64_t>{1, 1},
1046-
/*useIndexDiffs=*/true, /*forceUnroll=*/true);
1046+
/*forceUnroll=*/true, /*useIndexDiffs=*/true);
10471047
{
10481048
OpBuilder::InsertionGuard guard(rewriter);
10491049
rewriter.setInsertionPointToStart(loop.getBody());
@@ -1105,7 +1105,7 @@ struct GridwiseAttentionAccelRewritePattern
11051105
gemm0OutBufferMaxTrs},
11061106
/*bounds=*/ArrayRef<int64_t>{g0Mpt, 1},
11071107
/*strides=*/ArrayRef<int64_t>{1, 1},
1108-
/*useIndexDiffs=*/true, /*forceUnroll=*/true);
1108+
/*forceUnroll=*/true, /*useIndexDiffs=*/true);
11091109
{
11101110
OpBuilder::InsertionGuard guard(rewriter);
11111111
rewriter.setInsertionPointToStart(loop.getBody());
@@ -1167,7 +1167,7 @@ struct GridwiseAttentionAccelRewritePattern
11671167
ArrayRef<Attribute>{rewriter.getArrayAttr({}), attentionOutAccTrs},
11681168
/*bounds=*/ArrayRef<int64_t>{g1Mpt, g1Npt},
11691169
/*strides=*/ArrayRef<int64_t>{1, 1},
1170-
/*useIndexDiffs=*/true, /*forceUnroll=*/true);
1170+
/*forceUnroll=*/true, /*useIndexDiffs=*/true);
11711171
{
11721172
OpBuilder::InsertionGuard guard(rewriter);
11731173
rewriter.setInsertionPointToStart(loop.getBody());
@@ -1230,7 +1230,7 @@ struct GridwiseAttentionAccelRewritePattern
12301230
attentionOutAccBufferTrs},
12311231
/*bounds=*/ArrayRef<int64_t>{g1Mpt, g1Npt},
12321232
/*strides=*/ArrayRef<int64_t>{1, 1},
1233-
/*useIndexDiffs=*/true, /*forceUnroll=*/true);
1233+
/*forceUnroll=*/true, /*useIndexDiffs=*/true);
12341234
{
12351235
OpBuilder::InsertionGuard guard(rewriter);
12361236
rewriter.setInsertionPointToStart(loop.getBody());
@@ -1333,10 +1333,11 @@ struct GridwiseAttentionAccelRewritePattern
13331333
// post normalization. Therefore, this function creates a trasnforming
13341334
// for loop that overwrites out of bounds values of first gemm output
13351335
// to be negative infinity.
1336-
void createFirstGemmNegInfPadding(
1337-
PatternRewriter &rewriter, Location loc,
1338-
layout::GridCoordinates gridCoords, Value gemm0OutBuffer,
1339-
RegsAsMatrixSubTiles gemm0OutSubTileViews) const {
1336+
void createFirstGemmNegInfPadding(PatternRewriter &rewriter, Location loc,
1337+
layout::GridCoordinates gridCoords,
1338+
Value gemm0OutBuffer,
1339+
RegsAsMatrixSubTiles gemm0OutSubTileViews,
1340+
bool isGfx11) const {
13401341
MemRefType gemm0OutBufferType = cast<MemRefType>(gemm0OutBuffer.getType());
13411342
auto negInfTyped = createConstantFloatOp(
13421343
rewriter, loc, gemm0OutBufferType.getElementType(),
@@ -1346,6 +1347,9 @@ struct GridwiseAttentionAccelRewritePattern
13461347
auto tid = rewriter.create<WorkitemIdOp>(loc, rewriter.getIndexType());
13471348
int64_t elementsInThreadBuffer = gemm0OutBufferType.getNumElements();
13481349
Value zero = rewriter.createOrFold<ConstantIndexOp>(loc, 0);
1350+
1351+
// TODO: fix forceUnroll=false for gfx1100
1352+
// (https://github.com/ROCm/rocMLIR-internal/issues/1661)
13491353
auto loop = rewriter.create<TransformingForOp>(
13501354
loc,
13511355
ArrayRef<ValueRange>{{gridCoords.g_block, gridCoords.m_block,
@@ -1355,7 +1359,7 @@ struct GridwiseAttentionAccelRewritePattern
13551359
rewriter.getArrayAttr({})},
13561360
/*bounds=*/ArrayRef<int64_t>{1, 1, 1, 1, elementsInThreadBuffer},
13571361
/*strides=*/ArrayRef<int64_t>{1, 1, 1, 1, 1},
1358-
/*useIndexDiffs=*/true, /*forceUnroll=*/true);
1362+
/*forceUnroll=*/!isGfx11, /*useIndexDiffs=*/true);
13591363
{
13601364
OpBuilder::InsertionGuard guard(rewriter);
13611365
rewriter.setInsertionPointToStart(loop.getBody());
@@ -2090,16 +2094,17 @@ struct GridwiseAttentionAccelRewritePattern
20902094
postProcessFirstGemmSplat<ElementwiseMultOp>(
20912095
rewriter, loc, gridCoordsGemm0, gemm0OutBuffer, gemm0OutSubTileViews,
20922096
ln2Recip.getDefiningOp<arith::ConstantOp>().getValue());
2093-
#endif
20942097

20952098
// Handle padding
20962099
bool hasPadding =
20972100
op.getPrePadG0M().has_value() || op.getPrePadG0N().has_value();
20982101
if (hasPadding) {
2102+
bool isGfx11 = arch.contains("gfx11");
20992103
createFirstGemmNegInfPadding(rewriter, loc, gridCoordsGemm0,
21002104
gemm0OutBuffer,
2101-
gemm0OutSubTileViewsTrUnPadded);
2105+
gemm0OutSubTileViewsTrUnPadded, isGfx11);
21022106
}
2107+
#endif
21032108

21042109
APInt reductionAxis = APInt(64, 1);
21052110
APInt nrDimPerThread = APInt(64, gemm0MPerBlock / gemm0MPerThread);

mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,25 @@ func.func @gridwise_attn_grid_reversed(%arg0: memref<1x384x64xf32>, %arg1: memre
288288
} : memref<1x64x384xf32>, memref<1x64x384xf32>, memref<1x384x64xf32>, memref<1x384x64xf32>
289289
return
290290
}
291+
292+
// -----
293+
294+
// CHECK: @gridwise_attn_issue_1661_workaround
295+
func.func @gridwise_attn_issue_1661_workaround(%arg0: memref<256xf16>, %arg1: memref<98304xf16>, %arg2: memref<98304xf16>, %arg3: memref<256xf16>) attributes {block_size = 32 : i32, grid_size = 4 : i32, kernel, mhal.arch = "amdgcn-amd-amdhsa:gfx1100"} {
296+
%0 = rock.transform %arg0 by <affine_map<(d0, d1, d2) -> ((d0 + d1) * 64 + d2)> by [<Unmerge{4, 1, 64} ["g", "seq_q", "head_qk"] at [0, 1, 2] -> ["raw"] at [0]>] bounds = [4, 1, 64] -> [256]> : memref<256xf16> to memref<4x1x64xf16>
297+
%1 = rock.transform %arg1 by <affine_map<(d0, d1, d2) -> ((d0 * 64 + d1) * 384 + d2)> by [<Unmerge{4, 64, 384} ["g", "seq_k", "head_qk"] at [0, 1, 2] -> ["raw"] at [0]>] bounds = [4, 64, 384] -> [98304]> : memref<98304xf16> to memref<4x64x384xf16>
298+
%2 = rock.transform %arg2 by <affine_map<(d0, d1, d2) -> ((d0 * 384 + d1) * 64 + d2)> by [<Unmerge{4, 384, 64} ["g", "seq_k", "head_v"] at [0, 1, 2] -> ["raw"] at [0]>] bounds = [4, 384, 64] -> [98304]> : memref<98304xf16> to memref<4x384x64xf16>
299+
%3 = rock.transform %arg3 by <affine_map<(d0, d1, d2) -> ((d0 + d1) * 64 + d2)> by [<Unmerge{4, 1, 64} ["g", "seq_q", "head_v"] at [0, 1, 2] -> ["raw"] at [0]>] bounds = [4, 1, 64] -> [256]> : memref<256xf16> to memref<4x1x64xf16>
300+
%4 = rock.transform %0 by <affine_map<(d0, d1, d2) -> (d0, d2, d1)> by [<PassThrough ["gemmG"] at [0] -> ["gemmG"] at [0]>, <PassThrough ["gemm0K", "gemm0M"] at [1, 2] -> ["gemm0K", "gemm0M"] at [2, 1]>] bounds = [4, 64, 1] -> [4, 1, 64]> : memref<4x1x64xf16> to memref<4x64x1xf16>
301+
%5 = rock.transform %4 by <affine_map<(d0, d1, d2) -> (d0, d1, d2)> by [<PassThrough ["gemmG"] at [0] -> ["gemmG"] at [0]>, <PassThrough ["gemm0K"] at [1] -> ["gemm0K"] at [1]>, <Pad{0, 31} ["gemm0NPad"] at [2] -> ["gemm0N"] at [2]>] bounds = [4, 64, 32] -> [4, 64, 1]> : memref<4x64x1xf16> to memref<4x64x32xf16>
302+
%6 = rock.transform %3 by <affine_map<(d0, d1, d2) -> (d0, d1, d2)> by [<PassThrough ["gemmG"] at [0] -> ["gemmG"] at [0]>, <Pad{0, 31} ["gemm1NPad"] at [1] -> ["gemm1N"] at [1]>, <PassThrough ["gemm1M"] at [2] -> ["gemm1M"] at [2]>] bounds = [4, 32, 64] -> [4, 1, 64]> : memref<4x1x64xf16> to memref<4x32x64xf16>
303+
304+
// CHECK: %[[neginf:.+]] = arith.constant 0xFC00 : f16
305+
// CHECK: rock.transforming_for {useIndexDiffs}
306+
// CHECK: %[[cmpres:.*]] = arith.cmpi eq, %{{.*}}, %false : i1
307+
// CHECK-NEXT: scf.if %[[cmpres]]
308+
// CHECK-NEXT: rock.in_bounds_store %[[neginf]] -> %{{.*}}[%{{.*}}] : f16 -> memref<32xf16, #gpu.address_space<private>>, index
309+
rock.gridwise_attention_accel(%5, %1, %2, %6) features = dot|atomic_add|atomic_fmax_f32|wmma preSoftmaxOps = {
310+
} {arch = "amdgcn-amd-amdhsa:gfx1100", blockSize = 32 : i32, firstGemmIdx = 0 : i32, gridSize = 4 : i32, params0 = #rock.wmma_gemm_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, splitKFactor = 1, forceUnroll = true>, params1 = #rock.wmma_gemm_params<kpackPerBlock = 32, mPerBlock = 32, nPerBlock = 32, kpack = 1, mPerWave = 32, nPerWave = 32, splitKFactor = 1, forceUnroll = true>, prePadG0N = 1 : index} : memref<4x64x32xf16>, memref<4x64x384xf16>, memref<4x384x64xf16>, memref<4x32x64xf16>
311+
return
312+
}

mlir/test/e2e/PrAttentionF16.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,7 @@ config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-at
6565
# cross attention
6666
[[suite.test]]
6767
config = "-seq_len_q 128 -seq_len_k 27 -head_dim_qk 64 -head_dim_v 32 --with-attn-scale --with-attn-bias"
68+
69+
# issue 1661
70+
[[suite.test]]
71+
config = "-seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias"

mlir/test/e2e/PrAttentionF32.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,6 @@ config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-at
4545
[[suite.test]]
4646
config = "-seq_len_q 128 -seq_len_k 27 -head_dim_qk 64 -head_dim_v 32 --with-attn-scale --with-attn-bias"
4747

48+
# issue 1661
49+
[[suite.test]]
50+
config = "-seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias"

mlir/test/e2e/PrAttentionI8.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,6 @@ config = "-seq_len_q 384 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-at
5454
[[suite.test]]
5555
config = "-seq_len_q 128 -seq_len_k 27 -head_dim_qk 64 -head_dim_v 32 --with-attn-scale --with-attn-bias"
5656

57+
# issue 1661
58+
[[suite.test]]
59+
config = "-seq_len_q 1 -seq_len_k 384 -head_dim_qk 64 -head_dim_v 64 --with-attn-scale --with-attn-bias"

0 commit comments

Comments
 (0)