Skip to content

Commit fcae83d

Browse files
[fixup] Skip the two-stage LLVM IR generation, map the op directly to the LLVM IR intrinsic
1 parent feb578d commit fcae83d

File tree

4 files changed

+11
-29
lines changed

4 files changed

+11
-29
lines changed

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,10 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
293293
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
294294
}
295295

296-
297-
def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure,
298-
AllTypesMatch<["src1", "src2"]>,
299-
AllTypesMatch<["acc", "dst"]>]> {
296+
def BfmmlaOp : ArmSVE_IntrOp<"bfmmla", [Pure,
297+
AllTypesMatch<["src1", "src2"]>,
298+
AllTypesMatch<["acc", "res"]>,
299+
]> {
300300
let summary = "BFloat16 matrix multiply-accumulate";
301301
let description = [{
302302
BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices";
@@ -317,9 +317,9 @@ def BfmmlaOp : ArmSVE_Op<"bfmmla", [Pure,
317317
ScalableVectorOfLengthAndType<[8], [BF16]>:$src1,
318318
ScalableVectorOfLengthAndType<[8], [BF16]>:$src2
319319
);
320-
let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$dst);
320+
let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$res);
321321
let assemblyFormat =
322-
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
322+
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
323323
}
324324

325325
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
@@ -619,12 +619,6 @@ def UsmmlaIntrOp :
619619
ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
620620
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
621621

622-
def BfmmlaIntrOp :
623-
ArmSVE_IntrOp<"bfmmla", [Pure, TypeIs<"res", ScalableVectorOfLengthAndType<[4], [F32]>>]>,
624-
Arguments<(ins Arg<ScalableVectorOfLengthAndType<[4], [F32]>, "acc">:$acc,
625-
Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "lhs">:$lhs,
626-
Arg<ScalableVectorOfLengthAndType<[8], [BF16]>, "rhs">:$rhs)>;
627-
628622
def SdotIntrOp :
629623
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
630624
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
2525
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
2626
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
2727
using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
28-
using BfmmlaOpLowering = OneToOneConvertToLLVMPattern<BfmmlaOp, BfmmlaIntrOp>;
2928
using DupQLaneLowering =
3029
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
3130
using ScalableMaskedAddIOpLowering =
@@ -192,8 +191,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
192191
// Populate conversion patterns
193192

194193
// clang-format off
195-
patterns.add<BfmmlaOpLowering,
196-
ConvertFromSvboolOpLowering,
194+
patterns.add<ConvertFromSvboolOpLowering,
197195
ConvertToSvboolOpLowering,
198196
DupQLaneLowering,
199197
PselOpLowering,
@@ -222,7 +220,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
222220
void mlir::configureArmSVELegalizeForExportTarget(
223221
LLVMConversionTarget &target) {
224222
// clang-format off
225-
target.addLegalOp<BfmmlaIntrOp,
223+
target.addLegalOp<BfmmlaOp,
226224
ConvertFromSvboolIntrOp,
227225
ConvertToSvboolIntrOp,
228226
DupQLaneIntrOp,
@@ -244,8 +242,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
244242
ZipX2IntrOp,
245243
ZipX4IntrOp,
246244
SdotIntrOp>();
247-
target.addIllegalOp<BfmmlaOp,
248-
ConvertFromSvboolOp,
245+
target.addIllegalOp<ConvertFromSvboolOp,
249246
ConvertToSvboolOp,
250247
DupQLaneOp,
251248
PselOp,

mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,6 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
6060

6161
// -----
6262

63-
func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
64-
%b: vector<[8]xbf16>,
65-
%c: vector<[4]xf32>) -> vector<[4]xf32> {
66-
// CHECK: arm_sve.intr.bfmmla
67-
%0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
68-
return %0 : vector<[4]xf32>
69-
}
70-
// -----
71-
7263
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
7364
%b: vector<[4]xi32>,
7465
%c: vector<[4]xi32>,

mlir/test/Dialect/ArmSVE/roundtrip.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
5858
func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
5959
%b: vector<[8]xbf16>,
6060
%c: vector<[4]xf32>) -> vector<[4]xf32> {
61-
// CHECK: arm_sve.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
62-
%0 = arm_sve.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
61+
// CHECK: arm_sve.intr.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
62+
%0 = arm_sve.intr.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
6363
return %0 : vector<[4]xf32>
6464
}
6565

0 commit comments

Comments
 (0)