Skip to content

Commit

Permalink
[LLVMGPU] Embed mma_intrinsic in to_layout and infer contraction's in…
Browse files Browse the repository at this point in the history
…trinsic from it. (#18842)

To enable faster flash attention, we'd like to be able to force
different vector widths => we'd like different contraction to
potentially have different intrinsics. This PR introduces a way to set
intrinsic information for individual contraction, and have it preserved
until vector distribution.

---------

Signed-off-by: Stanley Winata <[email protected]>
Co-authored-by: Kunwar Grover <[email protected]>
  • Loading branch information
raikonenfnu and Groverkss authored Oct 21, 2024
1 parent 66342ab commit 114a142
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,31 @@ def IREEVectorExt_ToLayoutOp : IREEVectorExt_PureOp<"to_layout", [
let arguments = (ins
AnyShaped:$input,
VectorLayoutInterface:$layout,
DefaultValuedAttr<UnitAttr, "false">:$shared_memory_conversion
DefaultValuedAttr<UnitAttr, "false">:$shared_memory_conversion,
// TODO: Solve cmake IREEGPU and VectorExt cyclic dependency to
// change mma_Kind type to be of MMAInterfaceAttr.
OptionalAttr<AnyAttr>:$mma_kind
);
let results = (outs
AnyShaped:$output
);
let builders = [
OpBuilder<(ins "Value":$input,
"VectorLayoutInterface":$layout,
"Attribute":$mma_kind_attr,
CArg<"bool", "false">:$shared_memory_conversion), [{
UnitAttr defaultSharedMemoryConversion;
if (shared_memory_conversion) {
build($_builder, $_state, input.getType(), input, layout, UnitAttr::get(input.getContext()));
} else{
build($_builder, $_state, input.getType(), input, layout);
defaultSharedMemoryConversion = UnitAttr::get(input.getContext());
}
}]>
build($_builder, $_state, input.getType(), input, layout, defaultSharedMemoryConversion, mma_kind_attr);
}]>,
OpBuilder<(ins "Value":$input,
"VectorLayoutInterface":$layout), [{
UnitAttr defaultSharedMemoryConversion;
Attribute emptyIntrinsic;
build($_builder, $_state, input.getType(), input, layout, defaultSharedMemoryConversion, emptyIntrinsic);
}]>,
];
let extraClassDeclaration = [{
bool hasTensorSemantics() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ struct DropToLayoutUnitDims final
Value rankReducedValue = rankReducingExtract.value();
auto newToLayoutOp = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, rankReducedValue.getType(), rankReducedValue, newLayout,
toLayoutOp.getSharedMemoryConversion());
newToLayoutOp->setDiscardableAttrs(
toLayoutOp->getDiscardableAttrDictionary());
toLayoutOp.getSharedMemoryConversion(), toLayoutOp.getMmaKindAttr());

// Expand to preserve output shape using insert_slice.
// Here, since the shape comes from the result of a to_layout op, it will
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct VectorizeToLayoutOpPattern final

// Create the toLayout operation but with vector types instead.
auto newLayoutOp = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, newInput, toLayoutOp.getLayout(),
loc, newInput, toLayoutOp.getLayout(), toLayoutOp.getMmaKindAttr(),
toLayoutOp.getSharedMemoryConversion());

// Create the write back to a tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -81,6 +83,34 @@ struct UpcastContractOutput final : OpRewritePattern<vector::ContractionOp> {
}
};

static void inferMmaKind(vector::ContractionOp contract) {
SetVector<Operation *> slice;
getForwardSlice(contract.getResult(), &slice);

// Operations in slice are ordered in topological order, so the first
// to_layout operation we encounter is setting the layout.
IREE::VectorExt::ToLayoutOp toLayout;
for (Operation *op : slice) {
auto candidate = dyn_cast<IREE::VectorExt::ToLayoutOp>(op);
if (candidate) {
toLayout = candidate;
break;
}
}

if (!toLayout) {
return;
}

auto intrinsic =
dyn_cast_or_null<IREE::GPU::MmaInterfaceAttr>(toLayout.getMmaKindAttr());
if (!intrinsic) {
return;
}

contract->setAttr("iree.amdgpu.mma", intrinsic);
}

struct LLVMGPUCastTypeToFitMMAPass final
: impl::LLVMGPUCastTypeToFitMMAPassBase<LLVMGPUCastTypeToFitMMAPass> {
void getDependentDialects(DialectRegistry &registry) const override {
Expand All @@ -91,26 +121,15 @@ struct LLVMGPUCastTypeToFitMMAPass final
void runOnOperation() override {
auto func = getOperation();

llvm::StringLiteral scheduleAttrName =
IREE::GPU::MMAScheduleAttr::getMnemonic();
auto scheduleAttr =
func->getAttrOfType<IREE::GPU::MMAScheduleAttr>(scheduleAttrName);
if (!scheduleAttr) {
DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
if (configDict) {
scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
configDict.get(scheduleAttrName));
// Set MMA type from config embedded in toLayoutOp of contraction.
func.walk([&](vector::ContractionOp contract) {
inferMmaKind(contract);
if (!contract->hasAttr("iree.amdgpu.mma")) {
func.emitOpError("Failed to detect valid to_layout consumer of "
"vector.contract to infer MMA kind.");
return signalPassFailure();
}
}

// Import mma type from dispatch schedule attribute if present.
if (scheduleAttr) {
func.walk([&](vector::ContractionOp contract) {
if (!contract->hasAttr("iree.amdgpu.mma")) {
contract->setAttr("iree.amdgpu.mma", scheduleAttr.getIntrinsic());
}
});
}
});

MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,

// Set layouts for lhs, rhs and acc.
rewriter.setInsertionPoint(contract);
auto layoutedLhs =
rewriter.create<IREE::VectorExt::ToLayoutOp>(loc, lhs, aLayout);
auto layoutedRhs =
rewriter.create<IREE::VectorExt::ToLayoutOp>(loc, rhs, bLayout);
auto layoutedAcc =
rewriter.create<IREE::VectorExt::ToLayoutOp>(loc, acc, cLayout);
auto layoutedLhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, lhs, aLayout, schedule.getIntrinsic());
auto layoutedRhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, rhs, bLayout, schedule.getIntrinsic());
auto layoutedAcc = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, acc, cLayout, schedule.getIntrinsic());

// Promote matmul lhs and rhs.
// TODO: We should read this from the lowering_config on the operation.
Expand All @@ -82,7 +82,7 @@ static LogicalResult setContractionAnchor(IREE::GPU::MMAScheduleAttr schedule,
// Set layout for result.
rewriter.setInsertionPointAfter(contract);
auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, contract->getResult(0), cLayout);
loc, contract->getResult(0), cLayout, schedule.getIntrinsic());
rewriter.replaceAllUsesExcept(contract->getResult(0), toLayout.getResult(),
toLayout);

Expand Down Expand Up @@ -140,11 +140,11 @@ static LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule,
// Set layouts for lhs, rhs and acc.
rewriter.setInsertionPoint(conv);
auto layoutedLhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, lhs.getType(), lhs, aLayout);
loc, lhs, aLayout, schedule.getIntrinsic());
auto layoutedRhs = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, rhs.getType(), rhs, bLayout);
loc, rhs, bLayout, schedule.getIntrinsic());
auto layoutedAcc = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, acc.getType(), acc, cLayout);
loc, acc, cLayout, schedule.getIntrinsic());

// Promote matmul lhs and rhs.
// TODO: We should read this from the lowering_config on the operation.
Expand All @@ -160,7 +160,7 @@ static LogicalResult setConvolutionAnchor(IREE::GPU::MMAScheduleAttr schedule,
// Set layout for result.
rewriter.setInsertionPointAfter(conv);
auto toLayout = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, conv->getResult(0).getType(), conv->getResult(0), cLayout);
loc, conv->getResult(0), cLayout, schedule.getIntrinsic());
rewriter.replaceAllUsesExcept(conv->getResult(0), toLayout.getResult(),
toLayout);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ func.func @mfma_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
%lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf16>
return %0 : vector<96x64xf16>
%1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [3, 2],
outer_tile = [4, 1], thread_tile = [2, 32], element_tile = [4, 1],
subgroup_strides = [0, 0], thread_strides = [32, 1]>)
{mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>} : vector<96x64xf16>
return %1 : vector<96x64xf16>
}

// CHECK-LABEL: func.func @mfma_matmul_96x64x16_mm
Expand All @@ -21,7 +25,6 @@ func.func @mfma_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf
// CHECK-SAME: %[[A]], %[[B]], %[[EXT]]
// CHECK-SAME: vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32>
// CHECK: %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<96x64xf32> to vector<96x64xf16>
// CHECK: return %[[TRUNC]] : vector<96x64xf16>

// -----

Expand All @@ -34,7 +37,11 @@ func.func @mfma_matmul_96x64x16_mmt(%lhs: vector<96x16xf16>, %rhs: vector<64x16x
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
%lhs, %rhs, %init : vector<96x16xf16>, vector<64x16xf16> into vector<96x64xf16>
return %0 : vector<96x64xf16>
%1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [3, 2],
outer_tile = [4, 1], thread_tile = [2, 32], element_tile = [4, 1],
subgroup_strides = [0, 0], thread_strides = [32, 1]>)
{mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>} : vector<96x64xf16>
return %1 : vector<96x64xf16>
}

// CHECK-LABEL: func.func @mfma_matmul_96x64x16_mmt
Expand All @@ -55,7 +62,11 @@ func.func @mfma_matmul_96x64x16_mm_cannot_downcast(%lhs: vector<96x16xf16>, %rhs
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
%lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf64>
return %0 : vector<96x64xf64>
%1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [3, 2],
outer_tile = [4, 1], thread_tile = [2, 32], element_tile = [4, 1],
subgroup_strides = [0, 0], thread_strides = [32, 1]>)
{mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>} : vector<96x64xf64>
return %1 : vector<96x64xf64>
}

// CHECK-LABEL: func.func @mfma_matmul_96x64x16_mm_cannot_downcast
Expand All @@ -75,7 +86,11 @@ func.func @wmma_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
%lhs, %rhs, %init : vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf16>
return %0 : vector<48x32xf16>
%1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [3, 2],
outer_tile = [8, 1], thread_tile = [2, 16], element_tile = [1, 1],
subgroup_strides = [0, 0], thread_strides = [16, 1]>)
{mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>} : vector<48x32xf16>
return %1 : vector<48x32xf16>
}

// CHECK-LABEL: func.func @wmma_matmul_48x32x32_mm
Expand All @@ -87,7 +102,36 @@ func.func @wmma_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf
// CHECK-SAME: %[[A]], %[[B]], %[[EXT]]
// CHECK-SAME: vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf32>
// CHECK: %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<48x32xf32> to vector<48x32xf16>
// CHECK: return %[[TRUNC]] : vector<48x32xf16>

// -----

// This tests cast_type_to_fit_mma works on contract where intrinsic is set by to_layout.
// "iree.amdgpu.mma" will be generated from the "intrinsic" attribute of to_layout.
// this also shows that we can overwrite default intrinsics if explicitly set.

func.func @to_layout_config_matmul_96x64x16_mm(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {
mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>,
subgroup_m_count = 1, subgroup_n_count = 1>,
workgroup_size = [64, 1, 1]} {
%0 = vector.contract {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
%lhs, %rhs, %init : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf16>
%1 = iree_vector_ext.to_layout %0 to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [6, 4],
outer_tile = [1, 1], thread_tile = [16, 4], element_tile = [1, 4],
subgroup_strides = [0, 0], thread_strides = [1, 16]>)
{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>} : vector<96x64xf16>
return %1 : vector<96x64xf16>
}

// CHECK-LABEL: func.func @to_layout_config_matmul_96x64x16_mm
// CHECK-SAME: (%[[A:.+]]: vector<96x16xf16>, %[[B:.+]]: vector<16x64xf16>, %[[INIT:.+]]: vector<96x64xf16>)
// CHECK: arith.extf
// CHECK: vector.contract
// CHECK-SAME: {iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
// CHECK-SAME: : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32>
// CHECK: arith.truncf

// -----

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ func.func @matmul_96x64x16_mfma(%lhs: tensor<96x16xf16>,

// CHECK-LABEL: func.func @matmul_96x64x16_mfma

// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion}
// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion}
// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]])
// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, shared_memory_conversion}
// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, shared_memory_conversion}
// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>}
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
// CHECK-SAME: outs(%[[ACC]]
Expand Down Expand Up @@ -93,9 +93,9 @@ func.func @matmul_96x64x16_wmma(%lhs: tensor<96x16xf16>,

// CHECK-LABEL: func.func @matmul_96x64x16_wmma

// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion}
// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion}
// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]])
// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>, shared_memory_conversion}
// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>, shared_memory_conversion}
// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<WMMA_F32_16x16x16_F16>}
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
// CHECK-SAME: outs(%[[ACC]]
Expand Down Expand Up @@ -144,9 +144,9 @@ func.func @matmul_128x64x16_multi_subgroup(%lhs: tensor<128x16xf16>,

// CHECK-LABEL: func.func @matmul_128x64x16_multi_subgroup

// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion}
// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion}
// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]])
// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
// CHECK-SAME: outs(%[[ACC]]
Expand Down Expand Up @@ -195,9 +195,9 @@ func.func @packed_matmul_128x128x128(%lhs: tensor<8x16x16xf16>,
// CHECK-DAG: #[[$NESTED2:.+]] = #iree_vector_ext.nested_layout<subgroup_tile = [2, 1, 2, 1], batch_tile = [4, 1, 4, 1], outer_tile = [1, 1, 1, 1], thread_tile = [1, 4, 1, 16], element_tile = [1, 4, 1, 1], subgroup_strides = [2, 0, 1, 0], thread_strides = [0, 16, 0, 1]>
// CHECK-LABEL: func.func @packed_matmul_128x128x128

// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {shared_memory_conversion}
// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {shared_memory_conversion}
// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]])
// CHECK-DAG: %[[LHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
// CHECK-DAG: %[[RHS:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED1]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, shared_memory_conversion}
// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_layout %{{.*}} to layout(#[[$NESTED2]]) {mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>}
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[LHS]], %[[RHS]]
// CHECK-SAME: outs(%[[ACC]]

0 comments on commit 114a142

Please sign in to comment.