Skip to content

Commit 29c451b

Browse files
authored
Yet more IREEGPUAttrs cleanup: drop get{A,B,C}SingleSubgroupLayout methods (#19169)
These methods existed before we added the unified `getSingleSubgroupLayout` taking a `MMAFragment` argument. Now they can go away. Actually polymorphic callers, which motivated this being an interface method, are taken care of by a new overload of `getSingleSubgroupLayout` taking a `MMAInterfaceAttr`. --------- Signed-off-by: Benoit Jacob <[email protected]>
1 parent e10342d commit 29c451b

File tree

6 files changed

+45
-114
lines changed

6 files changed

+45
-114
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/AMDGPUDistributeContract.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,19 @@ static LogicalResult isIntrinsicLayoutCompatible(
5757
auto [lhsM, rhsN] = opInfo.getOperandMNIndex();
5858
auto [lhsK, rhsK] = opInfo.getOperandKIndex();
5959
auto [accM, accN] = opInfo.getResultMNIndex();
60-
if (failed(isSubgroupLayoutCompatible(getASingleSubgroupLayout(intrinsic),
61-
lhsLayout, lhsM, lhsK))) {
60+
if (failed(isSubgroupLayoutCompatible(
61+
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Lhs),
62+
lhsLayout, lhsM, lhsK))) {
6263
return failure();
6364
}
64-
if (failed(isSubgroupLayoutCompatible(getBSingleSubgroupLayout(intrinsic),
65-
rhsLayout, rhsK, rhsN))) {
65+
if (failed(isSubgroupLayoutCompatible(
66+
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Rhs),
67+
rhsLayout, rhsK, rhsN))) {
6668
return failure();
6769
}
68-
if (failed(isSubgroupLayoutCompatible(getCSingleSubgroupLayout(intrinsic),
69-
accLayout, accM, accN))) {
70+
if (failed(isSubgroupLayoutCompatible(
71+
getSingleSubgroupLayout(intrinsic, IREE::GPU::MMAFragment::Acc),
72+
accLayout, accM, accN))) {
7073
return failure();
7174
}
7275
return success();

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Lines changed: 6 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ static bool is_AMD_WMMA(MMAIntrinsic intrinsic) {
6969

7070
static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) {
7171
// Not using Wave64 at all at the moment, so the only place where the
72-
// subgroup size is CDNA* architectures.
72+
// subgroup size is 64 is on CDNA* architectures.
7373
return is_AMD_MFMA(intrinsic) ? 64 : 32;
7474
}
7575

@@ -292,38 +292,14 @@ OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
292292
return getOpaqueMMALayout<IREE::GPU::MMAIntrinsic>(context, intrinsic);
293293
}
294294

295-
//===----------------------------------------------------------------------===//
296-
// MmaInterface Attribute Helper Functions
297-
//===----------------------------------------------------------------------===//
298-
299-
MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
300-
if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
301-
return mmaAttr.getASingleSubgroupLayout();
302-
}
303-
if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
304-
return vmmaAttr.getASingleSubgroupLayout();
305-
}
306-
assert(false && "unhandled MMA Interface type.");
307-
return {};
308-
}
309-
310-
MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
311-
if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
312-
return mmaAttr.getBSingleSubgroupLayout();
313-
}
314-
if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
315-
return vmmaAttr.getBSingleSubgroupLayout();
316-
}
317-
assert(false && "unhandled MMA Interface type.");
318-
return {};
319-
}
320-
321-
MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind) {
295+
MMASingleSubgroupLayout getSingleSubgroupLayout(MmaInterfaceAttr mmaKind,
296+
MMAFragment fragment) {
322297
if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
323-
return mmaAttr.getCSingleSubgroupLayout();
298+
return getSingleSubgroupLayout(mmaAttr.getIntrinsic().getValue(), fragment);
324299
}
325300
if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
326-
return vmmaAttr.getCSingleSubgroupLayout();
301+
return getSingleSubgroupLayout(vmmaAttr.getIntrinsic().getValue(),
302+
fragment);
327303
}
328304
assert(false && "unhandled MMA Interface type.");
329305
return {};
@@ -407,18 +383,6 @@ FailureOr<IREE::GPU::MMAScope> MMAAttr::getMmaScope() const {
407383
return IREE::GPU::MMAScope::Subgroup;
408384
}
409385

410-
MMASingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
411-
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Lhs);
412-
}
413-
414-
MMASingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
415-
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Rhs);
416-
}
417-
418-
MMASingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
419-
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Acc);
420-
}
421-
422386
// Get virtual intrinsics that is composed/based on queried op.
423387
SmallVector<VirtualMMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
424388
switch (getIntrinsic().getValue()) {
@@ -1098,18 +1062,6 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
10981062
return {};
10991063
}
11001064

1101-
MMASingleSubgroupLayout VirtualMMAAttr::getASingleSubgroupLayout() const {
1102-
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Lhs);
1103-
}
1104-
1105-
MMASingleSubgroupLayout VirtualMMAAttr::getBSingleSubgroupLayout() const {
1106-
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Rhs);
1107-
}
1108-
1109-
MMASingleSubgroupLayout VirtualMMAAttr::getCSingleSubgroupLayout() const {
1110-
return getSingleSubgroupLayout(getIntrinsic().getValue(), MMAFragment::Acc);
1111-
}
1112-
11131065
//===----------------------------------------------------------------------===//
11141066
// Target Attributes
11151067
//===----------------------------------------------------------------------===//

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ namespace mlir::iree_compiler::IREE::GPU {
3333
// semantics in that case are that threads within the subgroup whose thread-ids
3434
// differ by a multiple of `P`, are accessing the same elements.
3535
//
36-
// Example observed in RDNA3 WMMA Wave64 intrinsics:
37-
// If the subgroup size is 64 but the product `P` of `thread` sizes is 32, that
38-
// means that each element is being accessed by 2 threads (2 = 64/32), and the
39-
// threads accessing the same element are those whose tids are exactly 32 apart.
36+
// Example observed in RDNA3 WMMA Wave32 intrinsics:
37+
// If the subgroup size is 32 but the product `P` of `thread` sizes is 16, that
38+
// means that each element is being accessed by 2 threads (2 = 32/16), and the
39+
// threads accessing the same element are those whose tids are exactly 16 apart.
4040
struct MMASingleSubgroupLayout {
4141
// Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are
4242
// outer-most in the layout. This happens when a MMA op, seen on a single
@@ -54,7 +54,7 @@ struct MMASingleSubgroupLayout {
5454
// Internal dimensions (as in TileSwizzle::Dim::Kind::Internal) that are
5555
// inner-most in the layout. This happens when a MMA op, seen on a single
5656
// thread, has an operand that consists of multiple elements, and these elems
57-
// are NOT contiguous.
57+
// are contiguous.
5858
// This is not used by every MMA op; ops which don't use that simply have 1's.
5959
SmallVector<int64_t, 2> element;
6060
};
@@ -65,11 +65,8 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
6565
MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
6666
MMAFragment fragment);
6767

68-
MMASingleSubgroupLayout getASingleSubgroupLayout(MmaInterfaceAttr mmaKind);
69-
70-
MMASingleSubgroupLayout getBSingleSubgroupLayout(MmaInterfaceAttr mmaKind);
71-
72-
MMASingleSubgroupLayout getCSingleSubgroupLayout(MmaInterfaceAttr mmaKind);
68+
MMASingleSubgroupLayout getSingleSubgroupLayout(MmaInterfaceAttr mmaKind,
69+
MMAFragment fragment);
7370

7471
// Struct describing the shape of a MMA operation, but not the detailed layout.
7572
// TODO(bjacob): the only user outside of IREEGPUAttrs.cpp is

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,6 @@ class IREEGPU_MmaVectorLayoutAttr<string attrname, string mmaintrinsic> :
151151
"getMNKShape",
152152
"getSubgroupSize",
153153
"getMmaScope",
154-
"getASingleSubgroupLayout",
155-
"getBSingleSubgroupLayout",
156-
"getCSingleSubgroupLayout",
157154
"buildMmaOperation",
158155
"populateOperandOffsetsSizesStrides",
159156
]>
@@ -225,14 +222,6 @@ def IREEGPU_MMAAttr : IREEGPU_MmaVectorLayoutAttr<"MMA", "MMAIntrinsicAttr"> {
225222
let extraClassDeclaration = [{
226223
int64_t getBlockSize() const;
227224

228-
// Returns the A/B/C matrix's partial nested layout shape inside a single
229-
// subgroup. Shape at each outer/thread/element level is a 2-D value,
230-
// following canonical matmul order--(M, K) for A, (K, N) for B, and
231-
// (M, N) for C.
232-
MMASingleSubgroupLayout getASingleSubgroupLayout() const;
233-
MMASingleSubgroupLayout getBSingleSubgroupLayout() const;
234-
MMASingleSubgroupLayout getCSingleSubgroupLayout() const;
235-
236225
SmallVector<VirtualMMAIntrinsic> getVirtualIntrinsics() const;
237226
}];
238227
}
@@ -287,9 +276,6 @@ def IREEGPU_VirtualMMAAttr :
287276
"getMNKShape",
288277
"getSubgroupSize",
289278
"getMmaScope",
290-
"getASingleSubgroupLayout",
291-
"getBSingleSubgroupLayout",
292-
"getCSingleSubgroupLayout",
293279
"populateOperandOffsetsSizesStrides",
294280
"buildMmaOperation",
295281
]>
@@ -319,14 +305,6 @@ def IREEGPU_VirtualMMAAttr :
319305
let extraClassDeclaration = [{
320306
int64_t getBlockSize() const;
321307

322-
// Returns the A/B/C matrix's partial nested layout shape inside a single
323-
// subgroup. Shape at each outer/thread/element level is a 2-D value,
324-
// following canonical matmul order--(M, K) for A, (K, N) for B, and
325-
// (M, N) for C.
326-
MMASingleSubgroupLayout getASingleSubgroupLayout() const;
327-
MMASingleSubgroupLayout getBSingleSubgroupLayout() const;
328-
MMASingleSubgroupLayout getCSingleSubgroupLayout() const;
329-
330308
// Factor to unroll K from native MMA/intrinsic size to virtual size.
331309
// e.g MFMA_F32_16x16x16 has K of 16, while VMFMA_F32_16x16x32 has K of 32
332310
// in this example, unrollK = 32/16 = 2.

compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/ConcretizeMmaShapes.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,20 @@ LogicalResult materializeOperandConcreteShape(
3030
SmallVector<ReassociationIndices> &reassociations,
3131
RankedTensorType &resultType) {
3232

33-
SmallVector<int64_t, 2> outerSizes;
33+
MMASingleSubgroupLayout layout = getSingleSubgroupLayout(mma, fragment);
34+
SmallVector<int64_t, 2> outerSizes = layout.outer;
3435
SmallVector<int64_t, 2> opaqueSizes;
3536
auto [m, n, k] = mma.getMNKShape();
3637
switch (fragment) {
3738
case IREE::GPU::MMAFragment::Lhs: {
38-
outerSizes = mma.getASingleSubgroupLayout().outer;
3939
opaqueSizes.append({m, k});
4040
break;
4141
}
4242
case IREE::GPU::MMAFragment::Rhs: {
43-
outerSizes = mma.getBSingleSubgroupLayout().outer;
4443
opaqueSizes.append({k, n});
4544
break;
4645
}
4746
case IREE::GPU::MMAFragment::Acc: {
48-
outerSizes = mma.getCSingleSubgroupLayout().outer;
4947
opaqueSizes.append({m, n});
5048
break;
5149
}

compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,12 @@ getContractionLayout(IREE::GPU::MMAScheduleAttr schedule,
311311
cSubgroupStrides[dim] = subgroupNStrides[i];
312312
}
313313

314-
auto cLayout = createNestedLayout(context, cRank, m, n,
315-
/*subgroupCount=*/cSubgroupSizes,
316-
/*subgroupStrides=*/cSubgroupStrides,
317-
/*batchCount=*/cBatchSizes,
318-
getCSingleSubgroupLayout(mmaAttr));
314+
IREE::VectorExt::NestedLayoutAttr cLayout = createNestedLayout(
315+
context, cRank, m, n,
316+
/*subgroupCount=*/cSubgroupSizes,
317+
/*subgroupStrides=*/cSubgroupStrides,
318+
/*batchCount=*/cBatchSizes,
319+
getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Acc));
319320
LLVM_DEBUG({ llvm::dbgs() << "C layout: " << cLayout << "\n"; });
320321

321322
// A matrix layout
@@ -339,11 +340,12 @@ getContractionLayout(IREE::GPU::MMAScheduleAttr schedule,
339340
}
340341
aBatchSizes[afk] = bounds[opInfo.getKDims().back()] / intrinsicK;
341342

342-
auto aLayout = createNestedLayout(context, aRank, afm, afk,
343-
/*subgroupCount=*/aSubgroupSizes,
344-
/*subgroupStrides=*/aSubgroupStrides,
345-
/*batchCount=*/aBatchSizes,
346-
getASingleSubgroupLayout(mmaAttr));
343+
IREE::VectorExt::NestedLayoutAttr aLayout = createNestedLayout(
344+
context, aRank, afm, afk,
345+
/*subgroupCount=*/aSubgroupSizes,
346+
/*subgroupStrides=*/aSubgroupStrides,
347+
/*batchCount=*/aBatchSizes,
348+
getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Lhs));
347349
LLVM_DEBUG({ llvm::dbgs() << "A layout: " << aLayout << "\n"; });
348350

349351
int64_t bRank = opInfo.getBRank();
@@ -363,11 +365,12 @@ getContractionLayout(IREE::GPU::MMAScheduleAttr schedule,
363365
}
364366
bBatchSizes[bfk] = bounds[opInfo.getKDims().back()] / intrinsicK;
365367

366-
auto bLayout = createNestedLayout(context, bRank, bfk, bfn,
367-
/*subgroupCount=*/bSubgroupSizes,
368-
/*subgroupStrides=*/bSubgroupStrides,
369-
/*batchCount=*/bBatchSizes,
370-
getBSingleSubgroupLayout(mmaAttr));
368+
IREE::VectorExt::NestedLayoutAttr bLayout = createNestedLayout(
369+
context, bRank, bfk, bfn,
370+
/*subgroupCount=*/bSubgroupSizes,
371+
/*subgroupStrides=*/bSubgroupStrides,
372+
/*batchCount=*/bBatchSizes,
373+
getSingleSubgroupLayout(mmaAttr, IREE::GPU::MMAFragment::Rhs));
371374
LLVM_DEBUG({ llvm::dbgs() << "B layout: " << bLayout << "\n"; });
372375

373376
std::tuple<VectorLayoutInterface, VectorLayoutInterface,
@@ -618,11 +621,11 @@ static LogicalResult setAttentionMatmulAnchor(RewriterBase &rewriter,
618621
auto pvIntrinsic =
619622
cast<IREE::GPU::MmaInterfaceAttr>(pvSchedule.getIntrinsic());
620623
IREE::GPU::MMASingleSubgroupLayout lhsLayout =
621-
getASingleSubgroupLayout(pvIntrinsic);
624+
getSingleSubgroupLayout(pvIntrinsic, IREE::GPU::MMAFragment::Lhs);
622625
IREE::GPU::MMASingleSubgroupLayout rhsLayout =
623-
getBSingleSubgroupLayout(pvIntrinsic);
626+
getSingleSubgroupLayout(pvIntrinsic, IREE::GPU::MMAFragment::Rhs);
624627
IREE::GPU::MMASingleSubgroupLayout outLayout =
625-
getCSingleSubgroupLayout(qkIntrinsic);
628+
getSingleSubgroupLayout(qkIntrinsic, IREE::GPU::MMAFragment::Acc);
626629

627630
auto matchLayout = [](IREE::GPU::MMASingleSubgroupLayout layoutA,
628631
IREE::GPU::MMASingleSubgroupLayout layoutB) -> bool {

0 commit comments

Comments
 (0)