@@ -1043,7 +1043,7 @@ struct GridwiseAttentionAccelRewritePattern
1043
1043
gemm0OutExpTrs, gemm0OutTrs},
1044
1044
/* bounds=*/ ArrayRef<int64_t >{g0Mpt, g0Npt},
1045
1045
/* strides=*/ ArrayRef<int64_t >{1 , 1 },
1046
- /* useIndexDiffs =*/ true , /* forceUnroll =*/ true );
1046
+ /* forceUnroll =*/ true , /* useIndexDiffs =*/ true );
1047
1047
{
1048
1048
OpBuilder::InsertionGuard guard (rewriter);
1049
1049
rewriter.setInsertionPointToStart (loop.getBody ());
@@ -1105,7 +1105,7 @@ struct GridwiseAttentionAccelRewritePattern
1105
1105
gemm0OutBufferMaxTrs},
1106
1106
/* bounds=*/ ArrayRef<int64_t >{g0Mpt, 1 },
1107
1107
/* strides=*/ ArrayRef<int64_t >{1 , 1 },
1108
- /* useIndexDiffs =*/ true , /* forceUnroll =*/ true );
1108
+ /* forceUnroll =*/ true , /* useIndexDiffs =*/ true );
1109
1109
{
1110
1110
OpBuilder::InsertionGuard guard (rewriter);
1111
1111
rewriter.setInsertionPointToStart (loop.getBody ());
@@ -1167,7 +1167,7 @@ struct GridwiseAttentionAccelRewritePattern
1167
1167
ArrayRef<Attribute>{rewriter.getArrayAttr ({}), attentionOutAccTrs},
1168
1168
/* bounds=*/ ArrayRef<int64_t >{g1Mpt, g1Npt},
1169
1169
/* strides=*/ ArrayRef<int64_t >{1 , 1 },
1170
- /* useIndexDiffs =*/ true , /* forceUnroll =*/ true );
1170
+ /* forceUnroll =*/ true , /* useIndexDiffs =*/ true );
1171
1171
{
1172
1172
OpBuilder::InsertionGuard guard (rewriter);
1173
1173
rewriter.setInsertionPointToStart (loop.getBody ());
@@ -1230,7 +1230,7 @@ struct GridwiseAttentionAccelRewritePattern
1230
1230
attentionOutAccBufferTrs},
1231
1231
/* bounds=*/ ArrayRef<int64_t >{g1Mpt, g1Npt},
1232
1232
/* strides=*/ ArrayRef<int64_t >{1 , 1 },
1233
- /* useIndexDiffs =*/ true , /* forceUnroll =*/ true );
1233
+ /* forceUnroll =*/ true , /* useIndexDiffs =*/ true );
1234
1234
{
1235
1235
OpBuilder::InsertionGuard guard (rewriter);
1236
1236
rewriter.setInsertionPointToStart (loop.getBody ());
@@ -1333,10 +1333,11 @@ struct GridwiseAttentionAccelRewritePattern
1333
1333
// post normalization. Therefore, this function creates a trasnforming
1334
1334
// for loop that overwrites out of bounds values of first gemm output
1335
1335
// to be negative infinity.
1336
- void createFirstGemmNegInfPadding (
1337
- PatternRewriter &rewriter, Location loc,
1338
- layout::GridCoordinates gridCoords, Value gemm0OutBuffer,
1339
- RegsAsMatrixSubTiles gemm0OutSubTileViews) const {
1336
+ void createFirstGemmNegInfPadding (PatternRewriter &rewriter, Location loc,
1337
+ layout::GridCoordinates gridCoords,
1338
+ Value gemm0OutBuffer,
1339
+ RegsAsMatrixSubTiles gemm0OutSubTileViews,
1340
+ bool isGfx11) const {
1340
1341
MemRefType gemm0OutBufferType = cast<MemRefType>(gemm0OutBuffer.getType ());
1341
1342
auto negInfTyped = createConstantFloatOp (
1342
1343
rewriter, loc, gemm0OutBufferType.getElementType (),
@@ -1346,6 +1347,9 @@ struct GridwiseAttentionAccelRewritePattern
1346
1347
auto tid = rewriter.create <WorkitemIdOp>(loc, rewriter.getIndexType ());
1347
1348
int64_t elementsInThreadBuffer = gemm0OutBufferType.getNumElements ();
1348
1349
Value zero = rewriter.createOrFold <ConstantIndexOp>(loc, 0 );
1350
+
1351
+ // TODO: fix forceUnroll=false for gfx1100
1352
+ // (https://github.com/ROCm/rocMLIR-internal/issues/1661)
1349
1353
auto loop = rewriter.create <TransformingForOp>(
1350
1354
loc,
1351
1355
ArrayRef<ValueRange>{{gridCoords.g_block , gridCoords.m_block ,
@@ -1355,7 +1359,7 @@ struct GridwiseAttentionAccelRewritePattern
1355
1359
rewriter.getArrayAttr ({})},
1356
1360
/* bounds=*/ ArrayRef<int64_t >{1 , 1 , 1 , 1 , elementsInThreadBuffer},
1357
1361
/* strides=*/ ArrayRef<int64_t >{1 , 1 , 1 , 1 , 1 },
1358
- /* useIndexDiffs =*/ true , /* forceUnroll =*/ true );
1362
+ /* forceUnroll =*/ !isGfx11 , /* useIndexDiffs =*/ true );
1359
1363
{
1360
1364
OpBuilder::InsertionGuard guard (rewriter);
1361
1365
rewriter.setInsertionPointToStart (loop.getBody ());
@@ -2090,16 +2094,17 @@ struct GridwiseAttentionAccelRewritePattern
2090
2094
postProcessFirstGemmSplat<ElementwiseMultOp>(
2091
2095
rewriter, loc, gridCoordsGemm0, gemm0OutBuffer, gemm0OutSubTileViews,
2092
2096
ln2Recip.getDefiningOp <arith::ConstantOp>().getValue ());
2093
- #endif
2094
2097
2095
2098
// Handle padding
2096
2099
bool hasPadding =
2097
2100
op.getPrePadG0M ().has_value () || op.getPrePadG0N ().has_value ();
2098
2101
if (hasPadding) {
2102
+ bool isGfx11 = arch.contains (" gfx11" );
2099
2103
createFirstGemmNegInfPadding (rewriter, loc, gridCoordsGemm0,
2100
2104
gemm0OutBuffer,
2101
- gemm0OutSubTileViewsTrUnPadded);
2105
+ gemm0OutSubTileViewsTrUnPadded, isGfx11 );
2102
2106
}
2107
+ #endif
2103
2108
2104
2109
APInt reductionAxis = APInt (64 , 1 );
2105
2110
APInt nrDimPerThread = APInt (64 , gemm0MPerBlock / gemm0MPerThread);
0 commit comments