Skip to content

[SPIRV] Support G_IS_FPCLASS #148637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 232 additions & 3 deletions llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,8 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
}

getActionDefinitionsBuilder(G_IS_FPCLASS).custom();

getLegacyLegalizerInfo().computeTables();
verify(*ST.getInstrInfo());
}
Expand All @@ -355,9 +357,14 @@ static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpvType,
bool SPIRVLegalizerInfo::legalizeCustom(
LegalizerHelper &Helper, MachineInstr &MI,
LostDebugLocObserver &LocObserver) const {
auto Opc = MI.getOpcode();
MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
if (Opc == TargetOpcode::G_ICMP) {
switch (MI.getOpcode()) {
default:
// TODO: implement legalization for other opcodes.
return true;
case TargetOpcode::G_IS_FPCLASS:
return legalizeIsFPClass(Helper, MI, LocObserver);
case TargetOpcode::G_ICMP: {
assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
auto &Op0 = MI.getOperand(2);
auto &Op1 = MI.getOperand(3);
Expand All @@ -378,6 +385,228 @@ bool SPIRVLegalizerInfo::legalizeCustom(
}
return true;
}
// TODO: implement legalization for other opcodes.
}
}

// Note this code was copied from LegalizerHelper::lowerISFPCLASS and adjusted
// to ensure that all instructions created during the lowering have SPIR-V types
// assigned to them.
bool SPIRVLegalizerInfo::legalizeIsFPClass(
LegalizerHelper &Helper, MachineInstr &MI,
LostDebugLocObserver &LocObserver) const {
auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs();
FPClassTest Mask = static_cast<FPClassTest>(MI.getOperand(2).getImm());

auto &MIRBuilder = Helper.MIRBuilder;
auto &MF = MIRBuilder.getMF();
MachineRegisterInfo &MRI = MF.getRegInfo();

if (Mask == fcNone) {
MIRBuilder.buildConstant(DstReg, 0);
MI.eraseFromParent();
return true;
}
if (Mask == fcAllFlags) {
MIRBuilder.buildConstant(DstReg, 1);
MI.eraseFromParent();
return true;
}

Type *LLVMDstTy =
IntegerType::get(MIRBuilder.getContext(), DstTy.getScalarSizeInBits());
if (DstTy.isVector())
LLVMDstTy = VectorType::get(LLVMDstTy, DstTy.getElementCount());
SPIRVType *SPIRVDstTy = GR->getOrCreateSPIRVType(
LLVMDstTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
/*EmitIR*/ true);

unsigned BitSize = SrcTy.getScalarSizeInBits();
const fltSemantics &Semantics = getFltSemanticForLLT(SrcTy.getScalarType());

LLT IntTy = LLT::scalar(BitSize);
Type *LLVMIntTy = IntegerType::get(MIRBuilder.getContext(), BitSize);
if (SrcTy.isVector()) {
IntTy = LLT::vector(SrcTy.getElementCount(), IntTy);
LLVMIntTy = VectorType::get(LLVMIntTy, SrcTy.getElementCount());
}
SPIRVType *SPIRVIntTy = GR->getOrCreateSPIRVType(
LLVMIntTy, MIRBuilder, SPIRV::AccessQualifier::ReadWrite,
/*EmitIR*/ true);

// Clang doesn't support capture of structured bindings:
LLT DstTyCopy = DstTy;
const auto assignSPIRVTy = [&](MachineInstrBuilder &&MI) {
// Assign this MI's (assumed only) destination to one of the two types we
// expect: either the G_IS_FPCLASS's destination type, or the integer type
// bitcast from the source type.
LLT MITy = MRI.getType(MI.getReg(0));
assert((MITy == IntTy || MITy == DstTyCopy) &&
"Unexpected LLT type while lowering G_IS_FPCLASS");
auto *SPVTy = MITy == IntTy ? SPIRVIntTy : SPIRVDstTy;
GR->assignSPIRVTypeToVReg(SPVTy, MI.getReg(0), MF);
return MI;
};

// Helper to build and assign a constant in one go
const auto buildSPIRVConstant = [&](LLT Ty, auto &&C) {
return assignSPIRVTy(MIRBuilder.buildConstant(Ty, C));
};

// Note that rather than creating a COPY here (between a floating-point and
// integer type of the same size) we create a SPIR-V bitcast immediately. We
// can't create a G_BITCAST because the LLTs are the same, and we can't seem
// to correctly lower COPYs to SPIR-V bitcasts at this moment.
Register ResVReg = MRI.createGenericVirtualRegister(IntTy);
MRI.setRegClass(ResVReg, GR->getRegClass(SPIRVIntTy));
GR->assignSPIRVTypeToVReg(SPIRVIntTy, ResVReg, Helper.MIRBuilder.getMF());
auto AsInt = MIRBuilder.buildInstr(SPIRV::OpBitcast)
.addDef(ResVReg)
.addUse(GR->getSPIRVTypeID(SPIRVIntTy))
.addUse(SrcReg);
AsInt = assignSPIRVTy(std::move(AsInt));

// Various masks.
APInt SignBit = APInt::getSignMask(BitSize);
APInt ValueMask = APInt::getSignedMaxValue(BitSize); // All bits but sign.
APInt Inf = APFloat::getInf(Semantics).bitcastToAPInt(); // Exp and int bit.
APInt ExpMask = Inf;
APInt AllOneMantissa = APFloat::getLargest(Semantics).bitcastToAPInt() & ~Inf;
APInt QNaNBitMask =
APInt::getOneBitSet(BitSize, AllOneMantissa.getActiveBits() - 1);
APInt InversionMask = APInt::getAllOnes(DstTy.getScalarSizeInBits());

auto SignBitC = buildSPIRVConstant(IntTy, SignBit);
auto ValueMaskC = buildSPIRVConstant(IntTy, ValueMask);
auto InfC = buildSPIRVConstant(IntTy, Inf);
auto ExpMaskC = buildSPIRVConstant(IntTy, ExpMask);
auto ZeroC = buildSPIRVConstant(IntTy, 0);

auto Abs = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ValueMaskC));
auto Sign = assignSPIRVTy(
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_NE, DstTy, AsInt, Abs));

auto Res = buildSPIRVConstant(DstTy, 0);

const auto appendToRes = [&](MachineInstrBuilder &&ToAppend) {
Res = assignSPIRVTy(
MIRBuilder.buildOr(DstTyCopy, Res, assignSPIRVTy(std::move(ToAppend))));
};

// Tests that involve more than one class should be processed first.
if ((Mask & fcFinite) == fcFinite) {
// finite(V) ==> abs(V) u< exp_mask
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, Abs,
ExpMaskC));
Mask &= ~fcFinite;
} else if ((Mask & fcFinite) == fcPosFinite) {
// finite(V) && V > 0 ==> V u< exp_mask
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, AsInt,
ExpMaskC));
Mask &= ~fcPosFinite;
} else if ((Mask & fcFinite) == fcNegFinite) {
// finite(V) && V < 0 ==> abs(V) u< exp_mask && signbit == 1
auto Cmp = assignSPIRVTy(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT,
DstTy, Abs, ExpMaskC));
appendToRes(MIRBuilder.buildAnd(DstTy, Cmp, Sign));
Mask &= ~fcNegFinite;
}

if (FPClassTest PartialCheck = Mask & (fcZero | fcSubnormal)) {
// fcZero | fcSubnormal => test all exponent bits are 0
// TODO: Handle sign bit specific cases
// TODO: Handle inverted case
if (PartialCheck == (fcZero | fcSubnormal)) {
auto ExpBits = assignSPIRVTy(MIRBuilder.buildAnd(IntTy, AsInt, ExpMaskC));
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
ExpBits, ZeroC));
Mask &= ~PartialCheck;
}
}

// Check for individual classes.
if (FPClassTest PartialCheck = Mask & fcZero) {
if (PartialCheck == fcPosZero)
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
AsInt, ZeroC));
else if (PartialCheck == fcZero)
appendToRes(
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, ZeroC));
else // fcNegZero
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
AsInt, SignBitC));
}

if (FPClassTest PartialCheck = Mask & fcSubnormal) {
// issubnormal(V) ==> unsigned(abs(V) - 1) u< (all mantissa bits set)
// issubnormal(V) && V>0 ==> unsigned(V - 1) u< (all mantissa bits set)
auto V = (PartialCheck == fcPosSubnormal) ? AsInt : Abs;
auto OneC = buildSPIRVConstant(IntTy, 1);
auto VMinusOne = MIRBuilder.buildSub(IntTy, V, OneC);
auto SubnormalRes = assignSPIRVTy(
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, VMinusOne,
buildSPIRVConstant(IntTy, AllOneMantissa)));
if (PartialCheck == fcNegSubnormal)
SubnormalRes = MIRBuilder.buildAnd(DstTy, SubnormalRes, Sign);
appendToRes(std::move(SubnormalRes));
}

if (FPClassTest PartialCheck = Mask & fcInf) {
if (PartialCheck == fcPosInf)
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
AsInt, InfC));
else if (PartialCheck == fcInf)
appendToRes(
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy, Abs, InfC));
else { // fcNegInf
APInt NegInf = APFloat::getInf(Semantics, true).bitcastToAPInt();
auto NegInfC = buildSPIRVConstant(IntTy, NegInf);
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_EQ, DstTy,
AsInt, NegInfC));
}
}

if (FPClassTest PartialCheck = Mask & fcNan) {
auto InfWithQnanBitC = buildSPIRVConstant(IntTy, Inf | QNaNBitMask);
if (PartialCheck == fcNan) {
// isnan(V) ==> abs(V) u> int(inf)
appendToRes(
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
} else if (PartialCheck == fcQNan) {
// isquiet(V) ==> abs(V) u>= (unsigned(Inf) | quiet_bit)
appendToRes(MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGE, DstTy, Abs,
InfWithQnanBitC));
} else { // fcSNan
// issignaling(V) ==> abs(V) u> unsigned(Inf) &&
// abs(V) u< (unsigned(Inf) | quiet_bit)
auto IsNan = assignSPIRVTy(
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_UGT, DstTy, Abs, InfC));
auto IsNotQnan = assignSPIRVTy(MIRBuilder.buildICmp(
CmpInst::Predicate::ICMP_ULT, DstTy, Abs, InfWithQnanBitC));
appendToRes(MIRBuilder.buildAnd(DstTy, IsNan, IsNotQnan));
}
}

if (FPClassTest PartialCheck = Mask & fcNormal) {
// isnormal(V) ==> (0 u< exp u< max_exp) ==> (unsigned(exp-1) u<
// (max_exp-1))
APInt ExpLSB = ExpMask & ~(ExpMask.shl(1));
auto ExpMinusOne = assignSPIRVTy(
MIRBuilder.buildSub(IntTy, Abs, buildSPIRVConstant(IntTy, ExpLSB)));
APInt MaxExpMinusOne = ExpMask - ExpLSB;
auto NormalRes = assignSPIRVTy(
MIRBuilder.buildICmp(CmpInst::Predicate::ICMP_ULT, DstTy, ExpMinusOne,
buildSPIRVConstant(IntTy, MaxExpMinusOne)));
if (PartialCheck == fcNegNormal)
NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, Sign);
else if (PartialCheck == fcPosNormal) {
auto PosSign = assignSPIRVTy(MIRBuilder.buildXor(
DstTy, Sign, buildSPIRVConstant(DstTy, InversionMask)));
NormalRes = MIRBuilder.buildAnd(DstTy, NormalRes, PosSign);
}
appendToRes(std::move(NormalRes));
}

MIRBuilder.buildCopy(DstReg, Res);
MI.eraseFromParent();
return true;
}
4 changes: 4 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class SPIRVLegalizerInfo : public LegalizerInfo {
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI,
LostDebugLocObserver &LocObserver) const override;
SPIRVLegalizerInfo(const SPIRVSubtarget &ST);

private:
bool legalizeIsFPClass(LegalizerHelper &Helper, MachineInstr &MI,
LostDebugLocObserver &LocObserver) const;
};
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H
Loading
Loading