Skip to content

[DAG] Refactor X86 combineVSelectWithAllOnesOrZeros fold into a generic DAG Combine #145298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12987,6 +12987,85 @@ SDValue DAGCombiner::visitVP_SELECT(SDNode *N) {
return SDValue();
}

static SDValue combineVSelectWithAllOnesOrZeros(SDValue Cond, SDValue TVal,
SDValue FVal,
const TargetLowering &TLI,
SelectionDAG &DAG,
const SDLoc &DL) {
if (!TLI.isTypeLegal(TVal.getValueType()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we avoid this if we're not going to fold? Creating unused nodes mid-fold like this can cause problems with oneuse folds later on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we get here which mean IsTAllZero or IsFAllOne is ture, so after Creating(swap), IsTAllOne or IsFAllZero must be a true, so we will definitely enter the logic of fold

Copy link
Collaborator

@RKSimon RKSimon Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK - so it just needs to be moved below the ComputeNumSignBits checks somehow I guess?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I will modify it.

return SDValue();

EVT VT = TVal.getValueType();
EVT CondVT = Cond.getValueType();

assert(CondVT.isVector() && "Vector select expects a vector selector!");

// Classify TVal/FVal content
bool IsTAllZero = ISD::isBuildVectorAllZeros(TVal.getNode());
bool IsTAllOne = ISD::isBuildVectorAllOnes(TVal.getNode());
bool IsFAllZero = ISD::isBuildVectorAllZeros(FVal.getNode());
bool IsFAllOne = ISD::isBuildVectorAllOnes(FVal.getNode());

// no vselect(cond, 0/-1, X) or vselect(cond, X, 0/-1), return
if (!(IsTAllZero || IsTAllOne || IsFAllZero || IsFAllOne))
return SDValue();

// select Cond, 0, 0 → 0
if (IsTAllZero && IsFAllZero) {
return VT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, VT)
: DAG.getConstant(0, DL, VT);
}

// To use the condition operand as a bitwise mask, it must have elements that
// are the same size as the select elements. Ie, the condition operand must
// have already been promoted from the IR select condition type <N x i1>.
// Don't check if the types themselves are equal because that excludes
// vector floating-point selects.
if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
return SDValue();

// Try inverting Cond and swapping T/F if it gives all-ones/all-zeros form
if (!IsTAllOne && !IsFAllZero && Cond.hasOneUse() &&
Cond.getOpcode() == ISD::SETCC &&
TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) ==
CondVT) {
if (IsTAllZero || IsFAllOne) {
SDValue CC = Cond.getOperand(2);
ISD::CondCode InverseCC = ISD::getSetCCInverse(
cast<CondCodeSDNode>(CC)->get(), Cond.getOperand(0).getValueType());
Cond = DAG.getSetCC(DL, CondVT, Cond.getOperand(0), Cond.getOperand(1),
InverseCC);
std::swap(TVal, FVal);
std::swap(IsTAllOne, IsFAllOne);
std::swap(IsTAllZero, IsFAllZero);
}
}

// Cond value must be 'sign splat' to be converted to a logical op.
if (DAG.ComputeNumSignBits(Cond) != CondVT.getScalarSizeInBits())
return SDValue();

// select Cond, -1, 0 → bitcast Cond
if (IsTAllOne && IsFAllZero)
return DAG.getBitcast(VT, Cond);

// select Cond, -1, x → or Cond, x
if (IsTAllOne) {
SDValue X = DAG.getBitcast(CondVT, FVal);
SDValue Or = DAG.getNode(ISD::OR, DL, CondVT, Cond, X);
return DAG.getBitcast(VT, Or);
}

// select Cond, x, 0 → and Cond, x
if (IsFAllZero) {
SDValue X = DAG.getBitcast(CondVT, TVal);
SDValue And = DAG.getNode(ISD::AND, DL, CondVT, Cond, X);
return DAG.getBitcast(VT, And);
}

return SDValue();
}

SDValue DAGCombiner::visitVSELECT(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
Expand Down Expand Up @@ -13255,6 +13334,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
if (SimplifyDemandedVectorElts(SDValue(N, 0)))
return SDValue(N, 0);

if (SDValue V = combineVSelectWithAllOnesOrZeros(N0, N1, N2, TLI, DAG, DL))
return V;

return SDValue();
}

Expand Down
72 changes: 9 additions & 63 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47264,13 +47264,14 @@ static SDValue combineToExtendBoolVectorInReg(
DAG.getConstant(EltSizeInBits - 1, DL, VT));
}

/// If a vector select has an operand that is -1 or 0, try to simplify the
/// If a vector select has an left operand that is 0, try to simplify the
/// select to a bitwise logic operation.
/// TODO: Move to DAGCombiner, possibly using TargetLowering::hasAndNot()?
static SDValue
combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
/// TODO: Move to DAGCombiner.combineVSelectWithAllOnesOrZeros, possibly using
/// TargetLowering::hasAndNot()?
static SDValue combineVSelectWithLastZeros(SDNode *N, SelectionDAG &DAG,
const SDLoc &DL,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
SDValue Cond = N->getOperand(0);
SDValue LHS = N->getOperand(1);
SDValue RHS = N->getOperand(2);
Expand All @@ -47283,20 +47284,6 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,

assert(CondVT.isVector() && "Vector select expects a vector selector!");

// TODO: Use isNullOrNullSplat() to distinguish constants with undefs?
// TODO: Can we assert that both operands are not zeros (because that should
// get simplified at node creation time)?
bool TValIsAllZeros = ISD::isBuildVectorAllZeros(LHS.getNode());
bool FValIsAllZeros = ISD::isBuildVectorAllZeros(RHS.getNode());

// If both inputs are 0/undef, create a complete zero vector.
// FIXME: As noted above this should be handled by DAGCombiner/getNode.
if (TValIsAllZeros && FValIsAllZeros) {
if (VT.isFloatingPoint())
return DAG.getConstantFP(0.0, DL, VT);
return DAG.getConstant(0, DL, VT);
}

// To use the condition operand as a bitwise mask, it must have elements that
// are the same size as the select elements. Ie, the condition operand must
// have already been promoted from the IR select condition type <N x i1>.
Expand All @@ -47305,56 +47292,15 @@ combineVSelectWithAllOnesOrZeros(SDNode *N, SelectionDAG &DAG, const SDLoc &DL,
if (CondVT.getScalarSizeInBits() != VT.getScalarSizeInBits())
return SDValue();

// Try to invert the condition if true value is not all 1s and false value is
// not all 0s. Only do this if the condition has one use.
bool TValIsAllOnes = ISD::isBuildVectorAllOnes(LHS.getNode());
if (!TValIsAllOnes && !FValIsAllZeros && Cond.hasOneUse() &&
// Check if the selector will be produced by CMPP*/PCMP*.
Cond.getOpcode() == ISD::SETCC &&
// Check if SETCC has already been promoted.
TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT) ==
CondVT) {
bool FValIsAllOnes = ISD::isBuildVectorAllOnes(RHS.getNode());

if (TValIsAllZeros || FValIsAllOnes) {
SDValue CC = Cond.getOperand(2);
ISD::CondCode NewCC = ISD::getSetCCInverse(
cast<CondCodeSDNode>(CC)->get(), Cond.getOperand(0).getValueType());
Cond = DAG.getSetCC(DL, CondVT, Cond.getOperand(0), Cond.getOperand(1),
NewCC);
std::swap(LHS, RHS);
TValIsAllOnes = FValIsAllOnes;
FValIsAllZeros = TValIsAllZeros;
}
}

// Cond value must be 'sign splat' to be converted to a logical op.
if (DAG.ComputeNumSignBits(Cond) != CondVT.getScalarSizeInBits())
return SDValue();

// vselect Cond, 111..., 000... -> Cond
if (TValIsAllOnes && FValIsAllZeros)
return DAG.getBitcast(VT, Cond);

if (!TLI.isTypeLegal(CondVT))
return SDValue();

// vselect Cond, 111..., X -> or Cond, X
if (TValIsAllOnes) {
SDValue CastRHS = DAG.getBitcast(CondVT, RHS);
SDValue Or = DAG.getNode(ISD::OR, DL, CondVT, Cond, CastRHS);
return DAG.getBitcast(VT, Or);
}

// vselect Cond, X, 000... -> and Cond, X
if (FValIsAllZeros) {
SDValue CastLHS = DAG.getBitcast(CondVT, LHS);
SDValue And = DAG.getNode(ISD::AND, DL, CondVT, Cond, CastLHS);
return DAG.getBitcast(VT, And);
}

// vselect Cond, 000..., X -> andn Cond, X
if (TValIsAllZeros) {
if (ISD::isBuildVectorAllZeros(LHS.getNode())) {
SDValue CastRHS = DAG.getBitcast(CondVT, RHS);
SDValue AndN;
// The canonical form differs for i1 vectors - x86andnp is not used
Expand Down Expand Up @@ -48115,7 +48061,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
if (!TLI.isTypeLegal(VT) || isSoftF16(VT, Subtarget))
return SDValue();

if (SDValue V = combineVSelectWithAllOnesOrZeros(N, DAG, DL, DCI, Subtarget))
if (SDValue V = combineVSelectWithLastZeros(N, DAG, DL, DCI, Subtarget))
return V;

if (SDValue V = combineVSelectToBLENDV(N, DAG, DL, DCI, Subtarget))
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/AArch64/arm64-zip.ll
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ define <4 x float> @shuffle_zip1(<4 x float> %arg) {
; CHECK-NEXT: fmov.4s v1, #1.00000000
; CHECK-NEXT: zip1.4h v0, v0, v0
; CHECK-NEXT: sshll.4s v0, v0, #0
; CHECK-NEXT: and.16b v0, v1, v0
; CHECK-NEXT: and.16b v0, v0, v1
; CHECK-NEXT: ret
bb:
%inst = fcmp olt <4 x float> zeroinitializer, %arg
Expand Down
16 changes: 8 additions & 8 deletions llvm/test/CodeGen/AArch64/cmp-select-sign.ll
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ define i64 @not_sign_i64_4(i64 %a) {
define <7 x i8> @sign_7xi8(<7 x i8> %a) {
; CHECK-LABEL: sign_7xi8:
; CHECK: // %bb.0:
; CHECK-NEXT: movi v1.8b, #1
; CHECK-NEXT: cmlt v0.8b, v0.8b, #0
; CHECK-NEXT: orr v0.8b, v0.8b, v1.8b
; CHECK-NEXT: movi v1.2d, #0xffffffffffffffff
; CHECK-NEXT: movi v2.8b, #1
; CHECK-NEXT: cmge v0.8b, v1.8b, v0.8b
; CHECK-NEXT: orr v0.8b, v0.8b, v2.8b
; CHECK-NEXT: ret
%c = icmp sgt <7 x i8> %a, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
%res = select <7 x i1> %c, <7 x i8> <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>, <7 x i8> <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
Expand Down Expand Up @@ -150,7 +151,8 @@ define <16 x i8> @sign_16xi8(<16 x i8> %a) {
define <3 x i32> @sign_3xi32(<3 x i32> %a) {
; CHECK-LABEL: sign_3xi32:
; CHECK: // %bb.0:
; CHECK-NEXT: cmlt v0.4s, v0.4s, #0
; CHECK-NEXT: movi v1.2d, #0xffffffffffffffff
; CHECK-NEXT: cmge v0.4s, v1.4s, v0.4s
; CHECK-NEXT: orr v0.4s, #1
; CHECK-NEXT: ret
%c = icmp sgt <3 x i32> %a, <i32 -1, i32 -1, i32 -1>
Expand Down Expand Up @@ -197,11 +199,9 @@ define <4 x i32> @not_sign_4xi32(<4 x i32> %a) {
; CHECK-LABEL: not_sign_4xi32:
; CHECK: // %bb.0:
; CHECK-NEXT: adrp x8, .LCPI16_0
; CHECK-NEXT: movi v2.4s, #1
; CHECK-NEXT: ldr q1, [x8, :lo12:.LCPI16_0]
; CHECK-NEXT: cmgt v0.4s, v0.4s, v1.4s
; CHECK-NEXT: and v1.16b, v0.16b, v2.16b
; CHECK-NEXT: orn v0.16b, v1.16b, v0.16b
; CHECK-NEXT: cmge v0.4s, v1.4s, v0.4s
; CHECK-NEXT: orr v0.4s, #1
; CHECK-NEXT: ret
%c = icmp sgt <4 x i32> %a, <i32 1, i32 -1, i32 -1, i32 -1>
%res = select <4 x i1> %c, <4 x i32> <i32 1, i32 1, i32 1, i32 1>, <4 x i32> <i32 -1, i32 -1, i32 -1, i32 -1>
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/AArch64/concatbinop.ll
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ define <16 x i8> @signOf_neon(ptr nocapture noundef readonly %a, ptr nocapture n
; CHECK-NEXT: uzp1 v3.16b, v5.16b, v6.16b
; CHECK-NEXT: uzp1 v1.16b, v1.16b, v2.16b
; CHECK-NEXT: and v0.16b, v3.16b, v0.16b
; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b
; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b
; CHECK-NEXT: ret
entry:
%0 = load <8 x i16>, ptr %a, align 2
Expand Down
8 changes: 4 additions & 4 deletions llvm/test/CodeGen/AArch64/sat-add.ll
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ define <16 x i8> @unsigned_sat_variable_v16i8_using_cmp_notval(<16 x i8> %x, <16
; CHECK-NEXT: mvn v2.16b, v1.16b
; CHECK-NEXT: add v1.16b, v0.16b, v1.16b
; CHECK-NEXT: cmhi v0.16b, v0.16b, v2.16b
; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b
; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b
; CHECK-NEXT: ret
%noty = xor <16 x i8> %y, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
%a = add <16 x i8> %x, %y
Expand Down Expand Up @@ -570,7 +570,7 @@ define <8 x i16> @unsigned_sat_variable_v8i16_using_cmp_notval(<8 x i16> %x, <8
; CHECK-NEXT: mvn v2.16b, v1.16b
; CHECK-NEXT: add v1.8h, v0.8h, v1.8h
; CHECK-NEXT: cmhi v0.8h, v0.8h, v2.8h
; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b
; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b
; CHECK-NEXT: ret
%noty = xor <8 x i16> %y, <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>
%a = add <8 x i16> %x, %y
Expand Down Expand Up @@ -610,7 +610,7 @@ define <4 x i32> @unsigned_sat_variable_v4i32_using_cmp_notval(<4 x i32> %x, <4
; CHECK-NEXT: mvn v2.16b, v1.16b
; CHECK-NEXT: add v1.4s, v0.4s, v1.4s
; CHECK-NEXT: cmhi v0.4s, v0.4s, v2.4s
; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b
; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b
; CHECK-NEXT: ret
%noty = xor <4 x i32> %y, <i32 -1, i32 -1, i32 -1, i32 -1>
%a = add <4 x i32> %x, %y
Expand Down Expand Up @@ -651,7 +651,7 @@ define <2 x i64> @unsigned_sat_variable_v2i64_using_cmp_notval(<2 x i64> %x, <2
; CHECK-NEXT: mvn v2.16b, v1.16b
; CHECK-NEXT: add v1.2d, v0.2d, v1.2d
; CHECK-NEXT: cmhi v0.2d, v0.2d, v2.2d
; CHECK-NEXT: orr v0.16b, v1.16b, v0.16b
; CHECK-NEXT: orr v0.16b, v0.16b, v1.16b
; CHECK-NEXT: ret
%noty = xor <2 x i64> %y, <i64 -1, i64 -1>
%a = add <2 x i64> %x, %y
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/CodeGen/AArch64/select_cc.ll
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ define <2 x double> @select_olt_load_cmp(<2 x double> %a, ptr %src) {
; CHECK-SD-NEXT: ldr d1, [x0]
; CHECK-SD-NEXT: fcmgt v1.2s, v1.2s, #0.0
; CHECK-SD-NEXT: sshll v1.2d, v1.2s, #0
; CHECK-SD-NEXT: and v0.16b, v0.16b, v1.16b
; CHECK-SD-NEXT: and v0.16b, v1.16b, v0.16b
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: select_olt_load_cmp:
Expand Down
3 changes: 0 additions & 3 deletions llvm/test/CodeGen/AArch64/selectcc-to-shiftand.ll
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,6 @@ define <16 x i8> @sel_shift_bool_v16i8(<16 x i1> %t) {
; CHECK-SD-LABEL: sel_shift_bool_v16i8:
; CHECK-SD: // %bb.0:
; CHECK-SD-NEXT: shl v0.16b, v0.16b, #7
; CHECK-SD-NEXT: movi v1.16b, #128
; CHECK-SD-NEXT: cmlt v0.16b, v0.16b, #0
; CHECK-SD-NEXT: and v0.16b, v0.16b, v1.16b
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: sel_shift_bool_v16i8:
Expand Down
12 changes: 6 additions & 6 deletions llvm/test/CodeGen/AArch64/tbl-loops.ll
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ define void @loop1(ptr noalias nocapture noundef writeonly %dst, ptr nocapture n
; CHECK-NEXT: add x13, x13, #32
; CHECK-NEXT: fcmgt v3.4s, v1.4s, v0.4s
; CHECK-NEXT: fcmgt v4.4s, v2.4s, v0.4s
; CHECK-NEXT: fcmlt v5.4s, v1.4s, #0.0
; CHECK-NEXT: fcmlt v6.4s, v2.4s, #0.0
; CHECK-NEXT: bit v1.16b, v0.16b, v3.16b
; CHECK-NEXT: bit v2.16b, v0.16b, v4.16b
; CHECK-NEXT: bic v1.16b, v1.16b, v5.16b
; CHECK-NEXT: bic v2.16b, v2.16b, v6.16b
; CHECK-NEXT: bsl v3.16b, v0.16b, v1.16b
; CHECK-NEXT: bsl v4.16b, v0.16b, v2.16b
; CHECK-NEXT: fcmlt v1.4s, v1.4s, #0.0
; CHECK-NEXT: fcmlt v2.4s, v2.4s, #0.0
; CHECK-NEXT: bic v1.16b, v3.16b, v1.16b
; CHECK-NEXT: bic v2.16b, v4.16b, v2.16b
; CHECK-NEXT: fcvtzs v1.4s, v1.4s
; CHECK-NEXT: fcvtzs v2.4s, v2.4s
; CHECK-NEXT: xtn v1.4h, v1.4s
Expand Down
4 changes: 1 addition & 3 deletions llvm/test/CodeGen/AArch64/vselect-constants.ll
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,8 @@ define <4 x i32> @cmp_sel_0_or_minus1_vec(<4 x i32> %x, <4 x i32> %y) {
define <4 x i32> @sel_1_or_0_vec(<4 x i1> %cond) {
; CHECK-LABEL: sel_1_or_0_vec:
; CHECK: // %bb.0:
; CHECK-NEXT: ushll v0.4s, v0.4h, #0
; CHECK-NEXT: movi v1.4s, #1
; CHECK-NEXT: shl v0.4s, v0.4s, #31
; CHECK-NEXT: cmlt v0.4s, v0.4s, #0
; CHECK-NEXT: ushll v0.4s, v0.4h, #0
; CHECK-NEXT: and v0.16b, v0.16b, v1.16b
; CHECK-NEXT: ret
%add = select <4 x i1> %cond, <4 x i32> <i32 1, i32 1, i32 1, i32 1>, <4 x i32> <i32 0, i32 0, i32 0, i32 0>
Expand Down
Loading
Loading