Skip to content

Commit 2a450e7

Browse files
arsenmPravin Jagtap
authored andcommitted
AMDGPU: Define v_mfma_f32_{16x16x128|32x32x64}_f8f6f4 instructions (llvm#116723)
These use a new VOP3PX encoding for the v_mfma_scale_* instructions, which bundles the pre-scale v_mfma_ld_scale_b32. None of the modifiers are supported yet (op_sel, neg or clamp). I'm not sure the intrinsic should really expose op_sel (or any of the others). If I'm reading the documentation correctly, we should be able to just have the raw scale operands and auto-match op_sel to byte extract patterns. The op_sel syntax also seems extra horrible in this usage, especially with the usual assumed op_sel_hi=-1 behavior.
1 parent 2f4d0ac commit 2a450e7

33 files changed

+9930
-33
lines changed

clang/include/clang/Basic/BuiltinsAMDGPU.def

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,11 @@ TARGET_BUILTIN(__builtin_amdgcn_cvt_sr_fp8_f32, "ifiiIi", "nc", "fp8-conversion-
434434
//===----------------------------------------------------------------------===//
435435
// GFX950 only builtins.
436436
//===----------------------------------------------------------------------===//
437+
TARGET_BUILTIN(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4, "V4fV8ZiV8ZiV4fIiIiIiiIii", "nc", "gfx950-insts")
438+
TARGET_BUILTIN(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4, "V16fV8ZiV8ZiV16fIiIiIiiIii", "nc", "gfx950-insts")
439+
437440
TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_16x16x32_f16, "V4fV8hV8hV4fIiIiIi", "nc", "gfx950-insts")
438441
TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_32x32x16_f16, "V16fV8hV8hV16fIiIiIi", "nc", "gfx950-insts")
439-
440442
TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_32x32x16_bf16, "V16fV8yV8yV16fIiIiIi", "nc", "gfx950-insts")
441443

442444
//===----------------------------------------------------------------------===//

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18909,7 +18909,20 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
1890918909
(uint64_t)0);
1891018910
return Builder.CreateInsertElement(I0, A, 1);
1891118911
}
18912-
18912+
case AMDGPU::BI__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
18913+
case AMDGPU::BI__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
18914+
llvm::FixedVectorType *VT = FixedVectorType::get(Builder.getInt32Ty(), 8);
18915+
Function *F = CGM.getIntrinsic(
18916+
BuiltinID == AMDGPU::BI__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
18917+
? Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4
18918+
: Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4,
18919+
{VT, VT});
18920+
18921+
SmallVector<Value *, 9> Args;
18922+
for (unsigned I = 0, N = E->getNumArgs(); I != N; ++I)
18923+
Args.push_back(EmitScalarExpr(E->getArg(I)));
18924+
return Builder.CreateCall(F, Args);
18925+
}
1891318926
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32:
1891418927
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w32:
1891518928
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64:

clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ typedef half v16h __attribute__((ext_vector_type(16)));
1616
typedef half v32h __attribute__((ext_vector_type(32)));
1717
typedef int v2i __attribute__((ext_vector_type(2)));
1818
typedef int v4i __attribute__((ext_vector_type(4)));
19+
typedef int v8i __attribute__((ext_vector_type(8)));
1920
typedef int v16i __attribute__((ext_vector_type(16)));
2021
typedef int v32i __attribute__((ext_vector_type(32)));
2122
typedef short v2s __attribute__((ext_vector_type(2)));
@@ -431,4 +432,18 @@ v16f test_mfma_f32_32x32x16_bf16(v8bf16 a, v8bf16 b, v16f c) {
431432
return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 1, 2, 3);
432433
}
433434

435+
// CHECK-GFX950-LABEL: @test_mfma_scale_f32_16x16x128_f8f6f4
436+
// CHECK-GFX950: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %a, <8 x i32> %b, <4 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
437+
void test_mfma_scale_f32_16x16x128_f8f6f4(global v4f* out, v8i a, v8i b, v4f c, int scale_a, int scale_b)
438+
{
439+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 3, 1, 2, scale_a, 3, scale_b);
440+
}
441+
442+
// CHECK-GFX950-LABEL: @test_mfma_scale_f32_32x32x64_f8f6f4
443+
// CHECK-GFX950: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %a, <8 x i32> %b, <16 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
444+
void test_mfma_scale_f32_32x32x64_f8f6f4(global v16f* out, v8i a, v8i b, v16f c, int scale_a, int scale_b)
445+
{
446+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 3, 1, 2, scale_a, 3, scale_b);
447+
}
448+
434449
#endif

clang/test/SemaOpenCL/builtins-amdgcn-error-gfx950-param.cl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ typedef float float4 __attribute__((ext_vector_type(4)));
55
typedef float float16 __attribute__((ext_vector_type(16)));
66
typedef half half8 __attribute__((ext_vector_type(8)));
77
typedef __bf16 bfloat8 __attribute__((ext_vector_type(8)));
8+
typedef int int8 __attribute__((ext_vector_type(8)));
89

910

1011
void test_mfma_f32_16x16x32_f16(__global float4* out, half8 a, half8 b, float4 c, int X) {
@@ -26,3 +27,17 @@ void test_mfma_f32_32x32x16_bf16(__global float16* out, bfloat8 a, bfloat8 b, fl
2627
*out = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 0, X, 0); // expected-error{{argument to '__builtin_amdgcn_mfma_f32_32x32x16_bf16' must be a constant integer}}
2728
*out = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 0, 0, X); // expected-error{{argument to '__builtin_amdgcn_mfma_f32_32x32x16_bf16' must be a constant integer}}
2829
}
30+
31+
void test_mfma_scale_f32_16x16x128_f8f6f4(__global float4* out, int8 a, int8 b, float4 c, int X, int Y) {
32+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, X, 0, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
33+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 0, X, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
34+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 0, 0, X, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
35+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 0, 0, 0, Y, X, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
36+
}
37+
38+
void test_mfma_scale_f32_32x32x64_f8f6f4(__global float16* out, int8 a, int8 b, float16 c, int X, int Y) {
39+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, X, 0, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
40+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 0, X, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
41+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 0, 0, X, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
42+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 0, 0, 0, Y, X, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
43+
}

clang/test/SemaOpenCL/builtins-amdgcn-error-gfx950.cl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,33 @@
44
typedef float float4 __attribute__((ext_vector_type(4)));
55
typedef float float16 __attribute__((ext_vector_type(16)));
66
typedef half half8 __attribute__((ext_vector_type(8)));
7+
typedef half half16 __attribute__((ext_vector_type(16)));
78
typedef __bf16 bfloat8 __attribute__((ext_vector_type(8)));
9+
typedef __bf16 bfloat16 __attribute__((ext_vector_type(16)));
10+
typedef unsigned int uint2 __attribute__((ext_vector_type(2)));
11+
typedef int int4 __attribute__((ext_vector_type(4)));
12+
typedef int int8 __attribute__((ext_vector_type(8)));
13+
typedef int int16 __attribute__((ext_vector_type(16)));
814

915
void test(__global float4* out0, half8 a0, half8 b0, float4 c0,
1016
__global float16* out1, half8 a1, half8 b1, float16 c1,
11-
__global float16* out2, bfloat8 a2, bfloat8 b2, float16 c2) {
17+
__global float16* out2, bfloat8 a2, bfloat8 b2, float16 c2,
18+
__global int4* out3, int4 a3, int4 b3, int4 c3,
19+
__global int16* out4, int4 a4, int4 b4, int16 c4,
20+
__global float4* out5, bfloat8 a5, bfloat8 b5, float4 c5,
21+
__global float4* out6, half8 a6, half16 b6, float4 c6,
22+
__global float16* out7, half8 a7, half16 b7, float16 c7,
23+
__global float4* out8, bfloat8 a8, bfloat16 b8, float4 c8,
24+
__global float16* out9, bfloat8 a9, bfloat16 b9, float16 c9,
25+
__global int4* out10, int4 a10, int8 b10, int4 c10,
26+
__global int16* out11, int4 a11, int8 b11, int16 c11,
27+
__global float4* out12, int4 a12, int8 b12, float4 c12,
28+
__global float16* out13, int4 a13, int8 b13, float16 c13,
29+
__global float4* out14, int8 a14, int8 b14, float4 c14, int d14, int e14,
30+
__global float16* out15, int8 a15, int8 b15, float16 c15, int d15, int e15) {
1231
*out0 = __builtin_amdgcn_mfma_f32_16x16x32_f16(a0, b0, c0, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_16x16x32_f16' needs target feature gfx950-insts}}
1332
*out1 = __builtin_amdgcn_mfma_f32_32x32x16_f16(a1, b1, c1, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_32x32x16_f16' needs target feature gfx950-insts}}
1433
*out2 = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a2, b2, c2, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_32x32x16_bf16' needs target feature gfx950-insts}}
34+
*out14 = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a14, b14, c14, 0, 0, 0, d14, 0, e14); // expected-error{{'__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' needs target feature gfx950-insts}}
35+
*out15 = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a15, b15, c15, 0, 0, 0, d15, 0, e15); // expected-error{{'__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' needs target feature gfx950-insts}}
1536
}

llvm/docs/AMDGPUUsage.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,6 +1391,16 @@ The AMDGPU backend implements the following LLVM IR intrinsics.
13911391
sign-extended from the width of the underlying PC hardware register even on
13921392
processors where the s_getpc_b64 instruction returns a zero-extended value.
13931393

1394+
llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4 Emit `v_mfma_scale_f32_16x16x128_f8f6f4` to set the scale factor. The
1395+
last 4 operands correspond to the scale inputs.
1396+
1397+
- 2-bit byte index to use for each lane for matrix A
1398+
- Matrix A scale values
1399+
- 2-bit byte index to use for each lane for matrix B
1400+
- Matrix B scale values
1401+
1402+
llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4 Emit `v_mfma_scale_f32_32x32x64_f8f6f4`
1403+
13941404
============================================== ==========================================================
13951405

13961406
.. TODO::

llvm/include/llvm/IR/IntrinsicsAMDGPU.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2987,6 +2987,35 @@ class AMDGPUMfmaIntrinsic<LLVMType DestTy, LLVMType SrcABTy> :
29872987
[IntrConvergent, IntrNoMem,
29882988
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>]>;
29892989

2990+
2991+
// srcA's format is determined by cbsz. srcB's format is determined by
2992+
// blgp.
2993+
//
2994+
// These should be <8 x i32> for f8 formats, <6 x i32> for f6 formats,
2995+
// and <4 x i32> for f4 formats. If the format control bits imply a
2996+
// smaller type than used, the high elements will be truncated.
2997+
//
2998+
// If the format control bits imply a larger type than used, the high
2999+
// elements are padded with undef.
3000+
3001+
class AMDGPUMfmaScaleIntrinsic<LLVMType DestTy> :
3002+
DefaultAttrsIntrinsic<[DestTy],
3003+
[llvm_anyvector_ty, llvm_anyvector_ty, DestTy,
3004+
llvm_i32_ty, // cbsz
3005+
llvm_i32_ty, // blgp
3006+
// llvm_i1_ty, // TODO: neg_src2
3007+
// llvm_i1_ty, // TODO: abs_src2
3008+
// llvm_i1_ty, // TODO: clamp
3009+
llvm_i32_ty, // op_sel (A matrix scale, 2-bits) // TODO: Make i2?
3010+
llvm_i32_ty, // v_mfma_ld_scale_b32 src0 (A matrix scale)
3011+
llvm_i32_ty, // op_sel (B matrix scale, 2-bits) // TODO: Make i2?
3012+
llvm_i32_ty // v_mfma_ld_scale_b32 src1 (B matrix scale)
3013+
],
3014+
[IntrConvergent, IntrNoMem,
3015+
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>,
3016+
ImmArg<ArgIndex<5>>, ImmArg<ArgIndex<7>>
3017+
]>;
3018+
29903019
defset list<Intrinsic> AMDGPUMFMAIntrinsics908 = {
29913020
def int_amdgcn_mfma_f32_32x32x1f32 : AMDGPUMfmaIntrinsic<llvm_v32f32_ty, llvm_float_ty>;
29923021
def int_amdgcn_mfma_f32_16x16x1f32 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_float_ty>;
@@ -3148,6 +3177,8 @@ def int_amdgcn_mfma_f32_16x16x32_f16 : AMDGPUMfmaIntrinsic<llvm_v4f32_ty, llvm_v
31483177
def int_amdgcn_mfma_f32_32x32x16_f16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8f16_ty>;
31493178

31503179
def int_amdgcn_mfma_f32_32x32x16_bf16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8bf16_ty>;
3180+
def int_amdgcn_mfma_scale_f32_16x16x128_f8f6f4 : AMDGPUMfmaScaleIntrinsic<llvm_v4f32_ty>;
3181+
def int_amdgcn_mfma_scale_f32_32x32x64_f8f6f4 : AMDGPUMfmaScaleIntrinsic<llvm_v16f32_ty>;
31513182
}
31523183

31533184
//===----------------------------------------------------------------------===//

llvm/lib/Target/AMDGPU/AMDGPUGISel.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,6 @@ def gi_fp_pow2_to_exponent : GICustomOperandRenderer<"renderFPPow2ToExponent">,
418418

419419
def gi_as_hw_round_mode : GICustomOperandRenderer<"renderRoundMode">,
420420
GISDNodeXFormEquiv<as_hw_round_mode>;
421+
422+
def gi_MFMALdScaleModifierOp : GICustomOperandRenderer<"renderScaledMAIIntrinsicOperand">,
423+
GISDNodeXFormEquiv<MFMALdScaleXForm>;

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,6 +1281,7 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
12811281
if (isa<UndefValue>(Src)) {
12821282
return IC.replaceInstUsesWith(II, Src);
12831283
}
1284+
return std::nullopt;
12841285
}
12851286
}
12861287
if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =

llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5711,6 +5711,18 @@ void AMDGPUInstructionSelector::renderRoundMode(MachineInstrBuilder &MIB,
57115711
MIB.addImm((MI.getOperand(OpIdx).getImm() + 3) % 4);
57125712
}
57135713

5714+
/// Convert from 2-bit value to enum values used for op_sel* source modifiers.
5715+
void AMDGPUInstructionSelector::renderScaledMAIIntrinsicOperand(
5716+
MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const {
5717+
unsigned Val = MI.getOperand(OpIdx).getImm();
5718+
unsigned New = 0;
5719+
if (Val & 0x1)
5720+
New |= SISrcMods::OP_SEL_0;
5721+
if (Val & 0x2)
5722+
New |= SISrcMods::OP_SEL_1;
5723+
MIB.addImm(New);
5724+
}
5725+
57145726
bool AMDGPUInstructionSelector::isInlineImmediate(const APInt &Imm) const {
57155727
return TII.isInlineConstant(Imm);
57165728
}

0 commit comments

Comments
 (0)