Skip to content

Commit 369af41

Browse files
momchil-velikovrlavaee
authored andcommitted
[MLIR][ArmSVE] Add an ArmSVE dialect operation mapping to bfmmla (llvm#145064)
1 parent bdeabea commit 369af41

File tree

5 files changed

+113
-1
lines changed

5 files changed

+113
-1
lines changed

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,35 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
293293
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
294294
}
295295

296+
def BfmmlaOp : ArmSVE_IntrOp<"bfmmla", [Pure,
297+
AllTypesMatch<["src1", "src2"]>,
298+
AllTypesMatch<["acc", "res"]>,
299+
]> {
300+
let summary = "BFloat16 matrix multiply-accumulate";
301+
let description = [{
302+
BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices";
303+
304+
This operation multiplies the 2x4 BFloat16 matrix held in each 128-bit
305+
segment of the first source vector by the 4x2 BFloat16 matrix in the
306+
corresponding segment of the second source vector, then accumulates
307+
this intermediate result with the 2x2 Float32 matrix in the corresponding
308+
segment of the accumulator vector, yielding the final 2x2 Float32
309+
segment of the result.
310+
311+
Source:
312+
https://developer.arm.com/documentation/100987/0000
313+
}];
314+
// Supports (vector<[8]xbf16>, vector<[8]xbf16>) -> (vector<[4]xf32>)
315+
let arguments = (ins
316+
ScalableVectorOfLengthAndType<[4], [F32]>:$acc,
317+
ScalableVectorOfLengthAndType<[8], [BF16]>:$src1,
318+
ScalableVectorOfLengthAndType<[8], [BF16]>:$src2
319+
);
320+
let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$res);
321+
let assemblyFormat =
322+
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
323+
}
324+
296325
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
297326
"expected corresponding svbool type widened to [16]xi1",
298327
lhsArg, rhsArg,

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
220220
void mlir::configureArmSVELegalizeForExportTarget(
221221
LLVMConversionTarget &target) {
222222
// clang-format off
223-
target.addLegalOp<ConvertFromSvboolIntrOp,
223+
target.addLegalOp<BfmmlaOp,
224+
ConvertFromSvboolIntrOp,
224225
ConvertToSvboolIntrOp,
225226
DupQLaneIntrOp,
226227
PselIntrOp,

mlir/test/Dialect/ArmSVE/invalid.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,63 @@ func.func @arm_sve_psel_bad_vector_type(%a : vector<[7]xi1>, %index: index) {
7272
arm_sve.psel %a, %a[%index] : vector<[7]xi1>, vector<[7]xi1>
7373
return
7474
}
75+
76+
// -----
77+
78+
func.func @bfmmla_invalid_element_type_lhs_rhs(%acc: vector<[4]xf32>,
79+
%lhs: vector<[8]xf16>,
80+
%rhs: vector<[8]xf16>) -> vector<[4]xf32> {
81+
// expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<[8]xf16>'}}
82+
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xf16> to vector<[4]xf32>
83+
return %0 : vector<[4]xf32>
84+
}
85+
86+
// -----
87+
88+
func.func @bfmmla_invalid_dimension_lhs_rhs(%acc: vector<[4]xf32>,
89+
%lhs: vector<[4]xbf16>,
90+
%rhs: vector<[4]xbf16>) -> vector<[4]xf32> {
91+
// expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<[4]xbf16>}}
92+
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[4]xbf16> to vector<[4]xf32>
93+
return %0 : vector<[4]xf32>
94+
}
95+
96+
// -----
97+
98+
func.func @bfmmla_fixed_dimension_lhs_rhs(%acc: vector<[4]xf32>,
99+
%lhs: vector<8xbf16>,
100+
%rhs: vector<8xbf16>) -> vector<[4]xf32> {
101+
// expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<8xbf16>}}
102+
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<8xbf16> to vector<[4]xf32>
103+
return %0 : vector<[4]xf32>
104+
}
105+
106+
// -----
107+
108+
func.func @bfmmla_invalid_element_type_acc(%acc: vector<[4]xi32>,
109+
%lhs: vector<[8]xbf16>,
110+
%rhs: vector<[8]xbf16>) -> vector<[4]xi32> {
111+
// expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<[4]xi32>'}}
112+
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<[4]xi32>
113+
return %0 : vector<[4]xi32>
114+
}
115+
116+
// -----
117+
118+
func.func @bfmmla_invalid_dimension_acc(%acc: vector<[8]xf32>,
119+
%lhs: vector<[8]xbf16>,
120+
%rhs: vector<[8]xbf16>) -> vector<[8]xf32> {
121+
// expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<[8]xf32>'}}
122+
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<[8]xf32>
123+
return %0 : vector<[8]xf32>
124+
}
125+
126+
// -----
127+
128+
func.func @bfmmla_fixed_dimension_acc(%acc: vector<4xf32>,
129+
%lhs: vector<[8]xbf16>,
130+
%rhs: vector<[8]xbf16>) -> vector<4xf32> {
131+
// expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<4xf32>'}}
132+
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<4xf32>
133+
return %0 : vector<4xf32>
134+
}

mlir/test/Dialect/ArmSVE/roundtrip.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
5555

5656
// -----
5757

58+
func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
59+
%b: vector<[8]xbf16>,
60+
%c: vector<[4]xf32>) -> vector<[4]xf32> {
61+
// CHECK: arm_sve.intr.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
62+
%0 = arm_sve.intr.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
63+
return %0 : vector<[4]xf32>
64+
}
65+
66+
// -----
67+
5868
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
5969
%b: vector<[4]xi32>,
6070
%c: vector<[4]xi32>,

mlir/test/Target/LLVMIR/arm-sve.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,18 @@ llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
6060
llvm.return %0 : vector<[4]xi32>
6161
}
6262

63+
// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_bfmmla
64+
llvm.func @arm_sve_bfmmla(%arg0: vector<[8]xbf16>,
65+
%arg1: vector<[8]xbf16>,
66+
%arg2: vector<[4]xf32>)
67+
-> vector<[4]xf32> {
68+
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.bfmmla(<vscale x 4 x float>
69+
%0 = "arm_sve.intr.bfmmla"(%arg2, %arg0, %arg1) :
70+
(vector<[4]xf32>, vector<[8]xbf16>, vector<[8]xbf16>)
71+
-> vector<[4]xf32>
72+
llvm.return %0 : vector<[4]xf32>
73+
}
74+
6375
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
6476
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
6577
%arg1: vector<[4]xi32>,

0 commit comments

Comments
 (0)