Skip to content

WIP: AMDGPU: Always select the VGPR version of MFMAs #145025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: users/arsenm/amdgpu-add-pass-rewrite-vgpr-mfma-to-agpr
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4865,31 +4865,29 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
// for srcA/srcB?
//
// vdst, srcA, srcB, srcC
const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
OpdsMapping[0] =
Info->mayNeedAGPRs()
!Subtarget.hasGFX90AInsts()
? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI)
: getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI);
OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI);
OpdsMapping[4] =
Info->mayNeedAGPRs()
!Subtarget.hasGFX90AInsts()
? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI)
: getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
break;
}
case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
OpdsMapping[0] =
Info->mayNeedAGPRs()
!Subtarget.hasGFX90AInsts()
? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI)
: getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);

OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI);
OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI);
OpdsMapping[4] =
Info->mayNeedAGPRs()
!Subtarget.hasGFX90AInsts()
? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI)
: getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);

Expand Down
20 changes: 1 addition & 19 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16076,7 +16076,6 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,

MachineFunction *MF = MI.getParent()->getParent();
MachineRegisterInfo &MRI = MF->getRegInfo();
SIMachineFunctionInfo *Info = MF->getInfo<SIMachineFunctionInfo>();

if (TII->isVOP3(MI.getOpcode())) {
// Make sure constant bus requirements are respected.
Expand All @@ -16087,15 +16086,14 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
// use between vgpr and agpr as agpr tuples tend to be big.
if (!MI.getDesc().operands().empty()) {
unsigned Opc = MI.getOpcode();
bool HasAGPRs = Info->mayNeedAGPRs();
const SIRegisterInfo *TRI = Subtarget->getRegisterInfo();
int16_t Src2Idx = AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src2);
for (auto I :
{AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src0),
AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::src1), Src2Idx}) {
if (I == -1)
break;
if ((I == Src2Idx) && (HasAGPRs))
if (I == Src2Idx)
break;
MachineOperand &Op = MI.getOperand(I);
if (!Op.isReg() || !Op.getReg().isVirtual())
Expand Down Expand Up @@ -16129,22 +16127,6 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
TII->legalizeOpWithMove(MI, Src1Idx);
}
}

if (!HasAGPRs)
return;

// Resolve the rest of AV operands to AGPRs.
if (auto *Src2 = TII->getNamedOperand(MI, AMDGPU::OpName::src2)) {
if (Src2->isReg() && Src2->getReg().isVirtual()) {
auto *RC = TRI->getRegClassForReg(MRI, Src2->getReg());
if (TRI->isVectorSuperClass(RC)) {
auto *NewRC = TRI->getEquivalentAGPRClass(RC);
MRI.setRegClass(Src2->getReg(), NewRC);
if (Src2->isTied())
MRI.setRegClass(MI.getOperand(0).getReg(), NewRC);
}
}
}
}

return;
Expand Down
6 changes: 0 additions & 6 deletions llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,6 @@ SIMachineFunctionInfo::SIMachineFunctionInfo(const Function &F,
PSInputAddr = AMDGPU::getInitialPSInputAddr(F);
}

MayNeedAGPRs = ST.hasMAIInsts();
if (ST.hasGFX90AInsts() &&
ST.getMaxNumVGPRs(F) <= AMDGPU::VGPR_32RegClass.getNumRegs() &&
!mayUseAGPRs(F))
MayNeedAGPRs = false; // We will select all MAI with VGPR operands.

if (AMDGPU::isChainCC(CC)) {
// Chain functions don't receive an SP from their caller, but are free to
// set one up. For now, we can use s32 to match what amdgpu_gfx functions
Expand Down
6 changes: 0 additions & 6 deletions llvm/lib/Target/AMDGPU/SIMachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,6 @@ class SIMachineFunctionInfo final : public AMDGPUMachineFunction,
// user arguments. This is an offset from the KernargSegmentPtr.
bool ImplicitArgPtr : 1;

bool MayNeedAGPRs : 1;

// The hard-wired high half of the address of the global information table
// for AMDPAL OS type. 0xffffffff represents no hard-wired high half, since
// current hardware only allows a 16 bit value.
Expand Down Expand Up @@ -1165,10 +1163,6 @@ class SIMachineFunctionInfo final : public AMDGPUMachineFunction,

unsigned getMaxMemoryClusterDWords() const { return MaxMemoryClusterDWords; }

bool mayNeedAGPRs() const {
return MayNeedAGPRs;
}

// \returns true if a function has a use of AGPRs via inline asm or
// has a call which may use it.
bool mayUseAGPRs(const Function &F) const;
Expand Down
55 changes: 30 additions & 25 deletions llvm/lib/Target/AMDGPU/VOP3PInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -856,17 +856,11 @@ defvar MayNotNeedAGPRs_gisel = [{
return !MF.getInfo<SIMachineFunctionInfo>()->mayNeedAGPRs();
}];

class AgprMAIFrag<SDPatternOperator Op, bit HasAbid = true,
bit Scaled = false> :
MAIFrag<Op, MayNeedAGPRs, HasAbid, Scaled> {
let GISelPredicateCode = MayNeedAGPRs_gisel;
}
class AgprMAIFrag<SDPatternOperator Op, bit HasAbid = true, bit Scaled = false>
: MAIFrag<Op, [{}], HasAbid, Scaled> {}

class VgprMAIFrag<SDPatternOperator Op, bit HasAbid = true,
bit Scaled = false> :
MAIFrag<Op, MayNotNeedAGPRs, HasAbid, Scaled> {
let GISelPredicateCode = MayNotNeedAGPRs_gisel;
}
class VgprMAIFrag<SDPatternOperator Op, bit HasAbid = true, bit Scaled = false>
: MAIFrag<Op, [{}], HasAbid, Scaled> {}

let isAsCheapAsAMove = 1, isReMaterializable = 1 in {
defm V_ACCVGPR_READ_B32 : VOP3Inst<"v_accvgpr_read_b32", VOPProfileAccRead>;
Expand Down Expand Up @@ -917,10 +911,14 @@ multiclass MAIInst<string OpName, string P, SDPatternOperator node = null_frag,
!if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag, AgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
MFMATable<0, "AGPR", NAME # "_e64">;

let OtherPredicates = [isGFX90APlus], Mnemonic = OpName in
def _vgprcd_e64 : MAIInst<OpName # "_vgprcd", !cast<VOPProfileMAI>("VOPProfileMAI_" # P # "_VCD"),
!if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag, VgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
MFMATable<0, "VGPR", NAME # "_vgprcd_e64", NAME # "_e64">;
let OtherPredicates = [isGFX90APlus], Mnemonic = OpName,
AddedComplexity = 10 in def _vgprcd_e64
: MAIInst<OpName#"_vgprcd",
!cast<VOPProfileMAI>("VOPProfileMAI_"#P#"_VCD"),
!if(!or(NoDstOverlap, !eq(node, null_frag)), null_frag,
VgprMAIFrag<node, HasAbid, Scaled>),
Scaled>,
MFMATable<0, "VGPR", NAME#"_vgprcd_e64", NAME#"_e64">;
}

if NoDstOverlap then {
Expand All @@ -931,16 +929,22 @@ multiclass MAIInst<string OpName, string P, SDPatternOperator node = null_frag,
!if(!eq(node, null_frag), null_frag, AgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
MFMATable<1, "AGPR", NAME # "_e64", NAME # "_mac_e64">;

let OtherPredicates = [isGFX90APlus] in
def _mac_vgprcd_e64 : MAIInst<OpName # "_mac_vgprcd", !cast<VOPProfileMAI>("VOPProfileMAI_" # P # "_VCD"),
!if(!eq(node, null_frag), null_frag, VgprMAIFrag<node, HasAbid, Scaled>), Scaled>,
MFMATable<1, "VGPR", NAME # "_vgprcd_e64", NAME # "_mac_e64">;
let OtherPredicates = [isGFX90APlus],
AddedComplexity = 10 in def _mac_vgprcd_e64
: MAIInst<OpName#"_mac_vgprcd",
!cast<VOPProfileMAI>("VOPProfileMAI_"#P#"_VCD"),
!if(!eq(node, null_frag), null_frag,
VgprMAIFrag<node, HasAbid, Scaled>),
Scaled>,
MFMATable<1, "VGPR", NAME#"_vgprcd_e64", NAME#"_mac_e64">;
}
}
} // End isConvergent = 1, mayRaiseFPException = 0, ReadsModeReg = 1
}

// Provide a wrapper around MAIInst that provides the appended operands from V_MFMA_LD_SCALE_B32
// Provide a wrapper around MAIInst that provides the appended operands from
// V_MFMA_LD_SCALE_B32 AGPR variants are never selected; VGPR is selected and
// may later be rewritten to AGPR.
multiclass ScaledMAIInst_mc<string OpName, string UnscaledOpName_, SDPatternOperator node> {
defvar VariantSuffix = !subst(!toupper(OpName), "", NAME); // Drop the main opcode name prefix to get the "_fN_fM" suffix.
defvar UnscaledOpName = UnscaledOpName_#VariantSuffix;
Expand All @@ -949,9 +953,9 @@ multiclass ScaledMAIInst_mc<string OpName, string UnscaledOpName_, SDPatternOper

defvar NoDstOverlap = !cast<VOPProfileMAI>(!cast<MAIInst>(UnscaledOpName#"_e64").Pfl).NoDstOverlap;

def _e64 : ScaledMAIInst<OpName,
!cast<MAIInst>(UnscaledOpName#"_e64"), !if(NoDstOverlap, null_frag, AgprMAIFrag<node, HasAbid, true>)>,
MFMATable<0, "AGPR", NAME # "_e64">;
def _e64
: ScaledMAIInst<OpName, !cast<MAIInst>(UnscaledOpName#"_e64"), null_frag>,
MFMATable<0, "AGPR", NAME#"_e64">;

def _vgprcd_e64 : ScaledMAIInst<OpName # "_vgprcd",
!cast<MAIInst>(UnscaledOpName#"_vgprcd_e64"), !if(NoDstOverlap, null_frag, VgprMAIFrag<node, HasAbid, true>)>,
Expand All @@ -961,9 +965,10 @@ multiclass ScaledMAIInst_mc<string OpName, string UnscaledOpName_, SDPatternOper
let Constraints = !if(NoDstOverlap, "$vdst = $src2", ""),
isConvertibleToThreeAddress = NoDstOverlap,
Mnemonic = UnscaledOpName_ in {
def _mac_e64 : ScaledMAIInst<OpName # "_mac",
!cast<MAIInst>(UnscaledOpName # "_mac_e64"), AgprMAIFrag<node, HasAbid, true>>,
MFMATable<1, "AGPR", NAME # "_e64">;
def _mac_e64
: ScaledMAIInst<OpName#"_mac",
!cast<MAIInst>(UnscaledOpName#"_mac_e64"), null_frag>,
MFMATable<1, "AGPR", NAME#"_e64">;

def _mac_vgprcd_e64 : ScaledMAIInst<OpName # " _mac_vgprcd",
!cast<MAIInst>(UnscaledOpName # "_mac_vgprcd_e64"), VgprMAIFrag<node, HasAbid, true>>,
Expand Down
Loading