-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[InstCombine] Fold X udiv Y
to X lshr cttz(Y)
if Y is a power of 2
#121386
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Veera (veera-sivarajan) ChangesFixes #115767 This PR folds Proof: https://alive2.llvm.org/ce/z/qHmLta Full diff: https://github.com/llvm/llvm-project/pull/121386.diff 3 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f85a3c93651353..00779fe5fa2ee1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1632,6 +1632,16 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
}
+ // Op0 udiv Op1 -> Op0 lshr cttz(Op1), if Op1 is a power of 2.
+ if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, /*Depth*/ 0, &I)) {
+ // This will increase instruction count but it's okay
+ // since bitwise operations are substantially faster than
+ // division.
+ auto *Cttz =
+ Builder.CreateBinaryIntrinsic(Intrinsic::cttz, Op1, Builder.getTrue());
+ return BinaryOperator::CreateLShr(Op0, Cttz);
+ }
+
return nullptr;
}
diff --git a/llvm/test/Transforms/IndVarSimplify/rewrite-loop-exit-value.ll b/llvm/test/Transforms/IndVarSimplify/rewrite-loop-exit-value.ll
index 1956f454a52bbf..fa47d06d859e97 100644
--- a/llvm/test/Transforms/IndVarSimplify/rewrite-loop-exit-value.ll
+++ b/llvm/test/Transforms/IndVarSimplify/rewrite-loop-exit-value.ll
@@ -218,7 +218,8 @@ define i32 @vscale_slt_with_vp_umin(ptr nocapture %A, i32 %n) mustprogress vscal
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END:%.*]]
; CHECK: for.end:
; CHECK-NEXT: [[TMP0:%.*]] = add nsw i32 [[N]], -1
-; CHECK-NEXT: [[TMP1:%.*]] = udiv i32 [[TMP0]], [[VF]]
+; CHECK-NEXT: [[TMP5:%.*]] = call range(i32 2, 33) i32 @llvm.cttz.i32(i32 [[VF]], i1 true)
+; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], [[TMP5]]
; CHECK-NEXT: [[TMP2:%.*]] = mul i32 [[TMP1]], [[VSCALE]]
; CHECK-NEXT: [[TMP3:%.*]] = shl i32 [[TMP2]], 2
; CHECK-NEXT: [[TMP4:%.*]] = sub i32 [[N]], [[TMP3]]
@@ -270,7 +271,8 @@ define i32 @vscale_slt_with_vp_umin2(ptr nocapture %A, i32 %n) mustprogress vsca
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END:%.*]]
; CHECK: for.end:
; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[N]], -1
-; CHECK-NEXT: [[TMP1:%.*]] = udiv i32 [[TMP0]], [[VF]]
+; CHECK-NEXT: [[TMP5:%.*]] = call range(i32 2, 33) i32 @llvm.cttz.i32(i32 [[VF]], i1 true)
+; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], [[TMP5]]
; CHECK-NEXT: [[TMP2:%.*]] = mul i32 [[TMP1]], [[VSCALE]]
; CHECK-NEXT: [[TMP3:%.*]] = shl i32 [[TMP2]], 2
; CHECK-NEXT: [[TMP4:%.*]] = sub i32 [[N]], [[TMP3]]
diff --git a/llvm/test/Transforms/InstCombine/div-shift.ll b/llvm/test/Transforms/InstCombine/div-shift.ll
index 8dd6d4a2e83712..005daed087c169 100644
--- a/llvm/test/Transforms/InstCombine/div-shift.ll
+++ b/llvm/test/Transforms/InstCombine/div-shift.ll
@@ -148,7 +148,8 @@ define i8 @udiv_umin_extra_use(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[Z2:%.*]] = shl nuw i8 1, [[Z:%.*]]
; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.umin.i8(i8 [[Y2]], i8 [[Z2]])
; CHECK-NEXT: call void @use(i8 [[M]])
-; CHECK-NEXT: [[D:%.*]] = udiv i8 [[X:%.*]], [[M]]
+; CHECK-NEXT: [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[M]], i1 true)
+; CHECK-NEXT: [[D:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
; CHECK-NEXT: ret i8 [[D]]
;
%y2 = shl i8 1, %y
@@ -165,7 +166,8 @@ define i8 @udiv_smin(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[Y2:%.*]] = shl nuw i8 1, [[Y:%.*]]
; CHECK-NEXT: [[Z2:%.*]] = shl nuw i8 1, [[Z:%.*]]
; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.smin.i8(i8 [[Y2]], i8 [[Z2]])
-; CHECK-NEXT: [[D:%.*]] = udiv i8 [[X:%.*]], [[M]]
+; CHECK-NEXT: [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[M]], i1 true)
+; CHECK-NEXT: [[D:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
; CHECK-NEXT: ret i8 [[D]]
;
%y2 = shl i8 1, %y
@@ -181,7 +183,8 @@ define i8 @udiv_smax(i8 %x, i8 %y, i8 %z) {
; CHECK-NEXT: [[Y2:%.*]] = shl nuw i8 1, [[Y:%.*]]
; CHECK-NEXT: [[Z2:%.*]] = shl nuw i8 1, [[Z:%.*]]
; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.smax.i8(i8 [[Y2]], i8 [[Z2]])
-; CHECK-NEXT: [[D:%.*]] = udiv i8 [[X:%.*]], [[M]]
+; CHECK-NEXT: [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[M]], i1 true)
+; CHECK-NEXT: [[D:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
; CHECK-NEXT: ret i8 [[D]]
;
%y2 = shl i8 1, %y
@@ -1006,7 +1009,8 @@ define i8 @udiv_fail_shl_overflow(i8 %x, i8 %y) {
; CHECK-LABEL: @udiv_fail_shl_overflow(
; CHECK-NEXT: [[SHL:%.*]] = shl i8 2, [[Y:%.*]]
; CHECK-NEXT: [[MIN:%.*]] = call i8 @llvm.umax.i8(i8 [[SHL]], i8 1)
-; CHECK-NEXT: [[MUL:%.*]] = udiv i8 [[X:%.*]], [[MIN]]
+; CHECK-NEXT: [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[MIN]], i1 true)
+; CHECK-NEXT: [[MUL:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
; CHECK-NEXT: ret i8 [[MUL]]
;
%shl = shl i8 2, %y
@@ -1294,3 +1298,100 @@ entry:
%div = sdiv i32 %add, %add2
ret i32 %div
}
+
+define i8 @udiv_if_power_of_two(i8 %x, i8 %y) {
+; CHECK-LABEL: @udiv_if_power_of_two(
+; CHECK-NEXT: start:
+; CHECK-NEXT: [[TMP0:%.*]] = tail call range(i8 0, 9) i8 @llvm.ctpop.i8(i8 [[Y:%.*]])
+; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 [[TMP0]], 1
+; CHECK-NEXT: br i1 [[TMP1]], label [[BB1:%.*]], label [[BB3:%.*]]
+; CHECK: bb1:
+; CHECK-NEXT: [[TMP2:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[Y]], i1 true)
+; CHECK-NEXT: [[TMP3:%.*]] = lshr i8 [[X:%.*]], [[TMP2]]
+; CHECK-NEXT: br label [[BB3]]
+; CHECK: bb3:
+; CHECK-NEXT: [[_0_SROA_0_0:%.*]] = phi i8 [ [[TMP3]], [[BB1]] ], [ 0, [[START:%.*]] ]
+; CHECK-NEXT: ret i8 [[_0_SROA_0_0]]
+;
+start:
+ %0 = tail call i8 @llvm.ctpop.i8(i8 %y)
+ %1 = icmp eq i8 %0, 1
+ br i1 %1, label %bb1, label %bb3
+
+bb1:
+ %2 = udiv i8 %x, %y
+ br label %bb3
+
+bb3:
+ %_0.sroa.0.0 = phi i8 [ %2, %bb1 ], [ 0, %start ]
+ ret i8 %_0.sroa.0.0
+}
+
+define i8 @udiv_exact_assume_power_of_two(i8 %x, i8 %y) {
+; CHECK-LABEL: @udiv_exact_assume_power_of_two(
+; CHECK-NEXT: start:
+; CHECK-NEXT: [[TMP0:%.*]] = tail call range(i8 1, 9) i8 @llvm.ctpop.i8(i8 [[Y:%.*]])
+; CHECK-NEXT: [[COND:%.*]] = icmp eq i8 [[TMP0]], 1
+; CHECK-NEXT: tail call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT: [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[Y]], i1 true)
+; CHECK-NEXT: [[_0:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
+; CHECK-NEXT: ret i8 [[_0]]
+;
+start:
+ %0 = tail call i8 @llvm.ctpop.i8(i8 %y)
+ %cond = icmp eq i8 %0, 1
+ tail call void @llvm.assume(i1 %cond)
+ %_0 = udiv exact i8 %x, %y
+ ret i8 %_0
+}
+
+define i7 @udiv_assume_power_of_two_illegal_type(i7 %x, i7 %y) {
+; CHECK-LABEL: @udiv_assume_power_of_two_illegal_type(
+; CHECK-NEXT: start:
+; CHECK-NEXT: [[TMP0:%.*]] = tail call range(i7 1, 8) i7 @llvm.ctpop.i7(i7 [[Y:%.*]])
+; CHECK-NEXT: [[COND:%.*]] = icmp eq i7 [[TMP0]], 1
+; CHECK-NEXT: tail call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT: [[TMP1:%.*]] = call range(i7 0, 8) i7 @llvm.cttz.i7(i7 [[Y]], i1 true)
+; CHECK-NEXT: [[_0:%.*]] = lshr i7 [[X:%.*]], [[TMP1]]
+; CHECK-NEXT: ret i7 [[_0]]
+;
+start:
+ %0 = tail call i7 @llvm.ctpop.i7(i7 %y)
+ %cond = icmp eq i7 %0, 1
+ tail call void @llvm.assume(i1 %cond)
+ %_0 = udiv i7 %x, %y
+ ret i7 %_0
+}
+
+define i8 @udiv_assume_power_of_two_multiuse(i8 %x, i8 %y) {
+; CHECK-LABEL: @udiv_assume_power_of_two_multiuse(
+; CHECK-NEXT: start:
+; CHECK-NEXT: [[TMP0:%.*]] = tail call range(i8 1, 9) i8 @llvm.ctpop.i8(i8 [[Y:%.*]])
+; CHECK-NEXT: [[COND:%.*]] = icmp eq i8 [[TMP0]], 1
+; CHECK-NEXT: tail call void @llvm.assume(i1 [[COND]])
+; CHECK-NEXT: [[TMP1:%.*]] = call range(i8 0, 9) i8 @llvm.cttz.i8(i8 [[Y]], i1 true)
+; CHECK-NEXT: [[_0:%.*]] = lshr i8 [[X:%.*]], [[TMP1]]
+; CHECK-NEXT: call void @use(i8 [[_0]])
+; CHECK-NEXT: ret i8 [[_0]]
+;
+start:
+ %0 = tail call i8 @llvm.ctpop.i8(i8 %y)
+ %cond = icmp eq i8 %0, 1
+ tail call void @llvm.assume(i1 %cond)
+ %_0 = udiv i8 %x, %y
+ call void @use(i8 %_0)
+ ret i8 %_0
+}
+
+define i8 @udiv_power_of_two_negative(i8 %x, i8 %y) {
+; CHECK-LABEL: @udiv_power_of_two_negative(
+; CHECK-NEXT: start:
+; CHECK-NEXT: [[_0:%.*]] = udiv i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: ret i8 [[_0]]
+;
+start:
+ %0 = tail call i8 @llvm.ctpop.i8(i8 %y)
+ %cond = icmp eq i8 %0, 1
+ %_0 = udiv i8 %x, %y
+ ret i8 %_0
+}
|
X udiv Y
to 'X lshr cttz(Y)` if Y is a power of 2X udiv Y
to X lshr cttz(Y)
if Y is a power of 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The IR diff looks good. It also eliminates some redundant llvm.umul.with.overflow
calls :)
; CHECK-NEXT: ret i8 [[_0_SROA_0_0]] | ||
; | ||
start: | ||
%0 = tail call i8 @llvm.ctpop.i8(i8 %y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use named values.
@@ -1632,6 +1632,16 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { | |||
I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact())); | |||
} | |||
|
|||
// Op0 udiv Op1 -> Op0 lshr cttz(Op1), if Op1 is a power of 2. | |||
if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, /*Depth*/ 0, &I)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, /*Depth*/ 0, &I)) { | |
if (isKnownToBeAPowerOfTwo(Op1, /*OrZero=*/true, /*Depth=*/0, &I)) { |
if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, /*Depth*/ 0, &I)) { | ||
// This will increase instruction count but it's okay | ||
// since bitwise operations are substantially faster than | ||
// division. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC udiv with a power of 2 divisor only takes 4 cycles on cortex-a53. Therefore this transformation is not always a win. Can you double-check this behavior, if possible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I confirmed this behavior exists on cortex-a72 with asmjit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But it seems that cttz
+lshr
takes only 3 cycles? https://godbolt.org/z/n9sjzb47o.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 extra instructions might not be worth it. Maybe this should be in CodegenPrepare?
auto *Cttz = | ||
Builder.CreateBinaryIntrinsic(Intrinsic::cttz, Op1, Builder.getTrue()); | ||
return BinaryOperator::CreateLShr(Op0, Cttz); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can preserve exact: https://alive2.llvm.org/ce/z/n8jnNn
Also it might be cleaner to merge this with the above fold using takeLog2
. Maybe create a lambda like:
auto GetLog2Denom = [&](Value * Denom) {
if (auto * R = takeLog2(Denom, ...) {
return R;
}
if (IsKnownPowerOf2(Denom, ....) {
return Builder.CreateBinaryIntrinsic(cttz, ....);
}
return nullptr;
}
Then:
if(auto * Res = GetLog2Denom(Op1)) {
/// Above fold using `takeLog2`
}
Fixes #115767
This PR folds
X udiv Y
toX lshr cttz(Y)
if Y is a power of two sincebitwise operations are faster than division.
Proof: https://alive2.llvm.org/ce/z/qHmLta