-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[InstCombine] Fold fcmp pred sqrt(X), 0.0 -> fcmp pred2 X, 0.0
#101626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-llvm-transforms Author: Yingwei Zheng (dtcxzyw) ChangesProof (Please run alive-tv with larger smt-to): https://alive2.llvm.org/ce/z/-aqixk In most cases, Full diff: https://github.com/llvm/llvm-project/pull/101626.diff 3 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 3b6df2760ecc2..622e7a420dd95 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -7980,6 +7980,63 @@ static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
}
}
+/// Optimize sqrt(X) compared with zero.
+static Instruction *foldSqrtWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
+ Value *X;
+ if (!match(I.getOperand(0), m_Sqrt(m_Value(X))))
+ return nullptr;
+
+ if (!match(I.getOperand(1), m_PosZeroFP()))
+ return nullptr;
+
+ auto ReplacePredAndOp0 = [&](FCmpInst::Predicate P) {
+ I.setPredicate(P);
+ return IC.replaceOperand(I, 0, X);
+ };
+
+ switch (I.getPredicate()) {
+ case FCmpInst::FCMP_OLT:
+ case FCmpInst::FCMP_UGE:
+ // sqrt(X) < 0.0 --> false
+ // sqrt(X) u>= 0.0 --> true
+ llvm_unreachable("fcmp should have simplified");
+ case FCmpInst::FCMP_ULT:
+ case FCmpInst::FCMP_ULE:
+ case FCmpInst::FCMP_OGT:
+ case FCmpInst::FCMP_OGE:
+ case FCmpInst::FCMP_OEQ:
+ case FCmpInst::FCMP_UNE:
+ // sqrt(X) u< 0.0 --> X u< 0.0
+ // sqrt(X) u<= 0.0 --> X u<= 0.0
+ // sqrt(X) > 0.0 --> X > 0.0
+ // sqrt(X) >= 0.0 --> X >= 0.0
+ // sqrt(X) == 0.0 --> X == 0.0
+ // sqrt(X) u!= 0.0 --> X u!= 0.0
+ return IC.replaceOperand(I, 0, X);
+
+ case FCmpInst::FCMP_OLE:
+ // sqrt(X) <= 0.0 --> X == 0.0
+ return ReplacePredAndOp0(FCmpInst::FCMP_OEQ);
+ case FCmpInst::FCMP_UGT:
+ // sqrt(X) u> 0.0 --> X u!= 0.0
+ return ReplacePredAndOp0(FCmpInst::FCMP_UNE);
+ case FCmpInst::FCMP_UEQ:
+ // sqrt(X) u== 0.0 --> X u<= 0.0
+ return ReplacePredAndOp0(FCmpInst::FCMP_ULE);
+ case FCmpInst::FCMP_ONE:
+ // sqrt(X) != 0.0 --> X > 0.0
+ return ReplacePredAndOp0(FCmpInst::FCMP_OGT);
+ case FCmpInst::FCMP_ORD:
+ // !isnan(sqrt(X)) --> X >= 0.0
+ return ReplacePredAndOp0(FCmpInst::FCMP_OGE);
+ case FCmpInst::FCMP_UNO:
+ // isnan(sqrt(X)) --> X u< 0.0
+ return ReplacePredAndOp0(FCmpInst::FCMP_ULT);
+ default:
+ llvm_unreachable("Unexpected predicate!");
+ }
+}
+
static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) {
CmpInst::Predicate Pred = I.getPredicate();
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@@ -8247,6 +8304,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
if (Instruction *R = foldFabsWithFcmpZero(I, *this))
return R;
+ if (Instruction *R = foldSqrtWithFcmpZero(I, *this))
+ return R;
+
if (match(Op0, m_FNeg(m_Value(X)))) {
// fcmp pred (fneg X), C --> fcmp swap(pred) X, -C
Constant *C;
diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll
index 656b3d2c49206..3ea93149094c8 100644
--- a/llvm/test/Transforms/InstCombine/fcmp.ll
+++ b/llvm/test/Transforms/InstCombine/fcmp.ll
@@ -2117,3 +2117,175 @@ define <8 x i1> @fcmp_ogt_fsub_const_vec_denormal_preserve-sign(<8 x float> %x,
%cmp = fcmp ogt <8 x float> %fs, zeroinitializer
ret <8 x i1> %cmp
}
+
+define i1 @fcmp_sqrt_zero_olt(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_olt(
+; CHECK-NEXT: ret i1 false
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp olt half %sqrt, 0.0
+ ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ult(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult(
+; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp ult half %sqrt, 0.0
+ ret i1 %cmp
+}
+
+define <2 x i1> @fcmp_sqrt_zero_ult_vec(<2 x half> %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult_vec(
+; CHECK-NEXT: [[CMP:%.*]] = fcmp ult <2 x half> [[X:%.*]], zeroinitializer
+; CHECK-NEXT: ret <2 x i1> [[CMP]]
+;
+ %sqrt = call <2 x half> @llvm.sqrt(<2 x half> %x)
+ %cmp = fcmp ult <2 x half> %sqrt, zeroinitializer
+ ret <2 x i1> %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ole(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ole(
+; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp ole half %sqrt, 0.0
+ ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ule(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ule(
+; CHECK-NEXT: [[CMP:%.*]] = fcmp ule half [[X:%.*]], 0xH0000
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp ule half %sqrt, 0.0
+ ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ogt(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ogt(
+; CHECK-NEXT: [[CMP:%.*]] = fcmp ogt half [[X:%.*]], 0xH0000
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp ogt half %sqrt, 0.0
+ ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ugt(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ugt(
+; CHECK-NEXT: [[CMP:%.*]] = fcmp une half [[X:%.*]], 0xH0000
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp ugt half %sqrt, 0.0
+ ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_oge(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_oge(
+; CHECK-NEXT: [[CMP:%.*]] = fcmp oge half [[X:%.*]], 0xH0000
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp oge half %sqrt, 0.0
+ ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_uge(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_uge(
+; CHECK-NEXT: ret i1 true
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp uge half %sqrt, 0.0
+ ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_oeq(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_oeq(
+; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp oeq half %sqrt, 0.0
+ ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ueq(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ueq(
+; CHECK-NEXT: [[CMP:%.*]] = fcmp ule half [[X:%.*]], 0xH0000
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp ueq half %sqrt, 0.0
+ ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_one(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_one(
+; CHECK-NEXT: [[CMP:%.*]] = fcmp ogt half [[X:%.*]], 0xH0000
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp one half %sqrt, 0.0
+ ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_une(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_une(
+; CHECK-NEXT: [[CMP:%.*]] = fcmp une half [[X:%.*]], 0xH0000
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp une half %sqrt, 0.0
+ ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ord(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ord(
+; CHECK-NEXT: [[CMP:%.*]] = fcmp oge half [[X:%.*]], 0xH0000
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp ord half %sqrt, 0.0
+ ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_uno(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_uno(
+; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp uno half %sqrt, 0.0
+ ret i1 %cmp
+}
+
+; negative tests
+
+define i1 @fcmp_sqrt_zero_ult_var(half %x, half %y) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult_var(
+; CHECK-NEXT: [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[SQRT]], [[Y:%.*]]
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp ult half %sqrt, %y
+ ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ult_nonzero(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult_nonzero(
+; CHECK-NEXT: [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[SQRT]], 0xH3C00
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %sqrt = call half @llvm.sqrt(half %x)
+ %cmp = fcmp ult half %sqrt, 1.000000e+00
+ ret i1 %cmp
+}
diff --git a/llvm/test/Transforms/InstCombine/known-never-nan.ll b/llvm/test/Transforms/InstCombine/known-never-nan.ll
index a1cabc29682b4..82075b37b4361 100644
--- a/llvm/test/Transforms/InstCombine/known-never-nan.ll
+++ b/llvm/test/Transforms/InstCombine/known-never-nan.ll
@@ -9,9 +9,7 @@
define i1 @fabs_sqrt_src_maybe_nan(double %arg0, double %arg1) {
; CHECK-LABEL: @fabs_sqrt_src_maybe_nan(
-; CHECK-NEXT: [[FABS:%.*]] = call double @llvm.fabs.f64(double [[ARG0:%.*]])
-; CHECK-NEXT: [[OP:%.*]] = call double @llvm.sqrt.f64(double [[FABS]])
-; CHECK-NEXT: [[TMP:%.*]] = fcmp ord double [[OP]], 0.000000e+00
+; CHECK-NEXT: [[TMP:%.*]] = fcmp ord double [[ARG0:%.*]], 0.000000e+00
; CHECK-NEXT: ret i1 [[TMP]]
;
%fabs = call double @llvm.fabs.f64(double %arg0)
|
Proof (Please run alive-tv with larger smt-to): https://alive2.llvm.org/ce/z/-aqixk
FMF propagation: https://alive2.llvm.org/ce/z/zyKK_p
In most cases,
sqrtcannot be eliminated since it has multiple uses. But this patch will break data dependencies and allow optimizer to sink expensivesqrtcalls into successor blocks.