From 114a1427810f3da0234f98c22f58390773b0489a Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Sun, 20 Oct 2024 18:16:23 -0700 Subject: [PATCH] [LLVMGPU] Embed mma_intrinsic in to_layout and infer contraction's intrinsic from it. (#18842) To enable faster flash attention, we'd like to be able to force different vector widths => we'd like different contraction to potentially have different intrinsics. This PR introduces a way to set intrinsic information for individual contraction, and have it preserved until vector distribution. --------- Signed-off-by: Stanley Winata Co-authored-by: Kunwar Grover --- .../Dialect/VectorExt/IR/VectorExtOps.td | 20 +++++-- .../VectorExtFoldUnitExtentDims.cpp | 4 +- .../Transforms/VectorizeIREEVectorExtOps.cpp | 2 +- .../LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp | 57 ++++++++++++------- .../LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp | 22 +++---- .../LLVMGPU/test/cast_type_to_fit_mma.mlir | 56 ++++++++++++++++-- .../LLVMGPU/test/configure_tensor_layout.mlir | 24 ++++---- 7 files changed, 128 insertions(+), 57 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td index 496ee0f25403..4e40cd87ca5b 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td @@ -40,7 +40,10 @@ def IREEVectorExt_ToLayoutOp : IREEVectorExt_PureOp<"to_layout", [ let arguments = (ins AnyShaped:$input, VectorLayoutInterface:$layout, - DefaultValuedAttr:$shared_memory_conversion + DefaultValuedAttr:$shared_memory_conversion, + // TODO: Solve cmake IREEGPU and VectorExt cyclic dependency to + // change mma_Kind type to be of MMAInterfaceAttr. + OptionalAttr:$mma_kind ); let results = (outs AnyShaped:$output @@ -48,13 +51,20 @@ def IREEVectorExt_ToLayoutOp : IREEVectorExt_PureOp<"to_layout", [ let builders = [ OpBuilder<(ins "Value":$input, "VectorLayoutInterface":$layout, + "Attribute":$mma_kind_attr, CArg<"bool", "false">:$shared_memory_conversion), [{ + UnitAttr defaultSharedMemoryConversion; if (shared_memory_conversion) { - build($_builder, $_state, input.getType(), input, layout, UnitAttr::get(input.getContext())); - } else{ - build($_builder, $_state, input.getType(), input, layout); + defaultSharedMemoryConversion = UnitAttr::get(input.getContext()); } - }]> + build($_builder, $_state, input.getType(), input, layout, defaultSharedMemoryConversion, mma_kind_attr); + }]>, + OpBuilder<(ins "Value":$input, + "VectorLayoutInterface":$layout), [{ + UnitAttr defaultSharedMemoryConversion; + Attribute emptyIntrinsic; + build($_builder, $_state, input.getType(), input, layout, defaultSharedMemoryConversion, emptyIntrinsic); + }]>, ]; let extraClassDeclaration = [{ bool hasTensorSemantics() { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorExtFoldUnitExtentDims.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorExtFoldUnitExtentDims.cpp index 64edf0003f76..fe1f425a6d79 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorExtFoldUnitExtentDims.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorExtFoldUnitExtentDims.cpp @@ -61,9 +61,7 @@ struct DropToLayoutUnitDims final Value rankReducedValue = rankReducingExtract.value(); auto newToLayoutOp = rewriter.create( loc, rankReducedValue.getType(), rankReducedValue, newLayout, - toLayoutOp.getSharedMemoryConversion()); - newToLayoutOp->setDiscardableAttrs( - toLayoutOp->getDiscardableAttrDictionary()); + toLayoutOp.getSharedMemoryConversion(), toLayoutOp.getMmaKindAttr()); // Expand to preserve output shape using insert_slice. // Here, since the shape comes from the result of a to_layout op, it will diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp index e2c5c0c47bd5..7a1696e919af 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp @@ -48,7 +48,7 @@ struct VectorizeToLayoutOpPattern final // Create the toLayout operation but with vector types instead. auto newLayoutOp = rewriter.create( - loc, newInput, toLayoutOp.getLayout(), + loc, newInput, toLayoutOp.getLayout(), toLayoutOp.getMmaKindAttr(), toLayoutOp.getSharedMemoryConversion()); // Create the write back to a tensor. diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp index 359c6ffa0fcd..26fb949b61b2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp @@ -6,8 +6,10 @@ #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h" +#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h" #include "iree/compiler/Codegen/LLVMGPU/Passes.h" #include "iree/compiler/Codegen/Utils/VectorOpUtils.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -81,6 +83,34 @@ struct UpcastContractOutput final : OpRewritePattern { } }; +static void inferMmaKind(vector::ContractionOp contract) { + SetVector slice; + getForwardSlice(contract.getResult(), &slice); + + // Operations in slice are ordered in topological order, so the first + // to_layout operation we encounter is setting the layout. + IREE::VectorExt::ToLayoutOp toLayout; + for (Operation *op : slice) { + auto candidate = dyn_cast(op); + if (candidate) { + toLayout = candidate; + break; + } + } + + if (!toLayout) { + return; + } + + auto intrinsic = + dyn_cast_or_null(toLayout.getMmaKindAttr()); + if (!intrinsic) { + return; + } + + contract->setAttr("iree.amdgpu.mma", intrinsic); +} + struct LLVMGPUCastTypeToFitMMAPass final : impl::LLVMGPUCastTypeToFitMMAPassBase { void getDependentDialects(DialectRegistry ®istry) const override { @@ -91,26 +121,15 @@ struct LLVMGPUCastTypeToFitMMAPass final void runOnOperation() override { auto func = getOperation(); - llvm::StringLiteral scheduleAttrName = - IREE::GPU::MMAScheduleAttr::getMnemonic(); - auto scheduleAttr = - func->getAttrOfType(scheduleAttrName); - if (!scheduleAttr) { - DictionaryAttr configDict = getTranslationInfo(func).getConfiguration(); - if (configDict) { - scheduleAttr = dyn_cast_or_null( - configDict.get(scheduleAttrName)); + // Set MMA type from config embedded in toLayoutOp of contraction. + func.walk([&](vector::ContractionOp contract) { + inferMmaKind(contract); + if (!contract->hasAttr("iree.amdgpu.mma")) { + func.emitOpError("Failed to detect valid to_layout consumer of " + "vector.contract to infer MMA kind."); + return signalPassFailure(); } - } - - // Import mma type from dispatch schedule attribute if present. - if (scheduleAttr) { - func.walk([&](vector::ContractionOp contract) { - if (!contract->hasAttr("iree.amdgpu.mma")) { - contract->setAttr("iree.amdgpu.mma", scheduleAttr.getIntrinsic()); - } - }); - } + }); MLIRContext *context = &getContext(); RewritePatternSet patterns(context); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp index 3f84454268dc..22b570b9795b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUConfigureTensorLayouts.cpp @@ -56,12 +56,12 @@ static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule, // Set layouts for lhs, rhs and acc. rewriter.setInsertionPoint(contract); - auto layoutedLhs = - rewriter.create(loc, lhs, aLayout); - auto layoutedRhs = - rewriter.create(loc, rhs, bLayout); - auto layoutedAcc = - rewriter.create(loc, acc, cLayout); + auto layoutedLhs = rewriter.create( + loc, lhs, aLayout, schedule.getIntrinsic()); + auto layoutedRhs = rewriter.create( + loc, rhs, bLayout, schedule.getIntrinsic()); + auto layoutedAcc = rewriter.create( + loc, acc, cLayout, schedule.getIntrinsic()); // Promote matmul lhs and rhs. // TODO: We should read this from the lowering_config on the operation. @@ -82,7 +82,7 @@ static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule, // Set layout for result. rewriter.setInsertionPointAfter(contract); auto toLayout = rewriter.create( - loc, contract->getResult(0), cLayout); + loc, contract->getResult(0), cLayout, schedule.getIntrinsic()); rewriter.replaceAllUsesExcept(contract->getResult(0), toLayout.getResult(), toLayout); @@ -140,11 +140,11 @@ static LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule, // Set layouts for lhs, rhs and acc. rewriter.setInsertionPoint(conv); auto layoutedLhs = rewriter.create( - loc, lhs.getType(), lhs, aLayout); + loc, lhs, aLayout, schedule.getIntrinsic()); auto layoutedRhs = rewriter.create( - loc, rhs.getType(), rhs, bLayout); + loc, rhs, bLayout, schedule.getIntrinsic()); auto layoutedAcc = rewriter.create( - loc, acc.getType(), acc, cLayout); + loc, acc, cLayout, schedule.getIntrinsic()); // Promote matmul lhs and rhs. // TODO: We should read this from the lowering_config on the operation. @@ -160,7 +160,7 @@ static LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule, // Set layout for result. rewriter.setInsertionPointAfter(conv); auto toLayout = rewriter.create( - loc, conv->getResult(0).getType(), conv->getResult(0), cLayout); + loc, conv->getResult(0), cLayout, schedule.getIntrinsic()); rewriter.replaceAllUsesExcept(conv->getResult(0), toLayout.getResult(), toLayout); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir index 4f97eb6d4163..7da3e141d9a9 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir @@ -9,7 +9,11 @@ func.func @mfma_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf16> - return %0 : vector<96x64xf16> + %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout) + {mma_kind = #iree_gpu.mma_layout} : vector<96x64xf16> + return %1 : vector<96x64xf16> } // CHECK-LABEL: func.func @mfma_matmul_96x64x16_mm @@ -21,7 +25,6 @@ func.func @mfma_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf // CHECK-SAME: %[[A]], %[[B]], %[[EXT]] // CHECK-SAME: vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32> // CHECK: %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<96x64xf32> to vector<96x64xf16> -// CHECK: return %[[TRUNC]] : vector<96x64xf16> // ----- @@ -34,7 +37,11 @@ func.func @mfma_matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16x indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs, %rhs, %init : vector<96x16xf16>, vector<64x16xf16> into vector<96x64xf16> - return %0 : vector<96x64xf16> + %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout) + {mma_kind = #iree_gpu.mma_layout} : vector<96x64xf16> + return %1 : vector<96x64xf16> } // CHECK-LABEL: func.func @mfma_matmul_96x64x16_mmt @@ -55,7 +62,11 @@ func.func @mfma_matmul_96x64x16_mm_cannot_downcast(%lhs: vector<96x16xf16>, %rhs indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf64> - return %0 : vector<96x64xf64> + %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout) + {mma_kind = #iree_gpu.mma_layout} : vector<96x64xf64> + return %1 : vector<96x64xf64> } // CHECK-LABEL: func.func @mfma_matmul_96x64x16_mm_cannot_downcast @@ -75,7 +86,11 @@ func.func @wmma_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %lhs, %rhs, %init : vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf16> - return %0 : vector<48x32xf16> + %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout) + {mma_kind = #iree_gpu.mma_layout} : vector<48x32xf16> + return %1 : vector<48x32xf16> } // CHECK-LABEL: func.func @wmma_matmul_48x32x32_mm @@ -87,7 +102,36 @@ func.func @wmma_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf // CHECK-SAME: %[[A]], %[[B]], %[[EXT]] // CHECK-SAME: vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf32> // CHECK: %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<48x32xf32> to vector<48x32xf16> -// CHECK: return %[[TRUNC]] : vector<48x32xf16> + +// ----- + +// This tests cast_type_to_fit_mma works on contract where intrinsic is set by to_layout. +// "iree.amdgpu.mma" will be generated from the "intrinsic" attribute of to_layout. +// this also shows that we can overwrite default intrinsics if explicitly set. + +func.func @to_layout_config_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes { + mma_schedule = #iree_gpu.mma_schedule< + intrinsic = #iree_gpu.mma_layout, + subgroup_m_count = 1, subgroup_n_count = 1>, + workgroup_size = [64, 1, 1]} { + %0 = vector.contract { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + %lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf16> + %1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout) + {mma_kind = #iree_gpu.mma_layout} : vector<96x64xf16> + return %1 : vector<96x64xf16> +} + +// CHECK-LABEL: func.func @to_layout_config_matmul_96x64x16_mm +// CHECK-SAME: (%[[A:.+]]: vector<96x16xf16>, %[[B:.+]]: vector<16x64xf16>, %[[INIT:.+]]: vector<96x64xf16>) +// CHECK: arith.extf +// CHECK: vector.contract +// CHECK-SAME: {iree.amdgpu.mma = #iree_gpu.mma_layout} +// CHECK-SAME: : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32> +// CHECK: arith.truncf // ----- diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir index b9ecc4a2583a..8b761f8fac50 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/configure_tensor_layout.mlir @@ -42,9 +42,9 @@ func.func @matmul_96x64x16_mfma(%lhs: tensor<96x16xf16>, // CHECK-LABEL: func.func @matmul_96x64x16_mfma -// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion} -// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion} -// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) +// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout, shared_memory_conversion} +// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout, shared_memory_conversion} +// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout} // CHECK: linalg.generic // CHECK-SAME: ins(%[[LHS]], %[[RHS]] // CHECK-SAME: outs(%[[ACC]] @@ -93,9 +93,9 @@ func.func @matmul_96x64x16_wmma(%lhs: tensor<96x16xf16>, // CHECK-LABEL: func.func @matmul_96x64x16_wmma -// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion} -// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion} -// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) +// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout, shared_memory_conversion} +// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout, shared_memory_conversion} +// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout} // CHECK: linalg.generic // CHECK-SAME: ins(%[[LHS]], %[[RHS]] // CHECK-SAME: outs(%[[ACC]] @@ -144,9 +144,9 @@ func.func @matmul_128x64x16_multi_subgroup(%lhs: tensor<128x16xf16>, // CHECK-LABEL: func.func @matmul_128x64x16_multi_subgroup -// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion} -// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion} -// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) +// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout, shared_memory_conversion} +// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout, shared_memory_conversion} +// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout} // CHECK: linalg.generic // CHECK-SAME: ins(%[[LHS]], %[[RHS]] // CHECK-SAME: outs(%[[ACC]] @@ -195,9 +195,9 @@ func.func @packed_matmul_128x128x128(%lhs: tensor<8x16x16xf16>, // CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout // CHECK-LABEL: func.func @packed_matmul_128x128x128 -// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion} -// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion} -// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) +// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout, shared_memory_conversion} +// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout, shared_memory_conversion} +// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout} // CHECK: linalg.generic // CHECK-SAME: ins(%[[LHS]], %[[RHS]] // CHECK-SAME: outs(%[[ACC]]