Skip to content

Conversation

@dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Aug 2, 2024

Proof (Please run alive-tv with larger smt-to): https://alive2.llvm.org/ce/z/-aqixk
FMF propagation: https://alive2.llvm.org/ce/z/zyKK_p

sqrt(X) < 0.0 --> false
sqrt(X) u>= 0.0 --> true
sqrt(X) u< 0.0 --> X u< 0.0
sqrt(X) u<= 0.0 --> X u<= 0.0
sqrt(X) > 0.0 --> X > 0.0
sqrt(X) >= 0.0 --> X >= 0.0
sqrt(X) == 0.0 --> X == 0.0
sqrt(X) u!= 0.0 --> X u!= 0.0
sqrt(X) <= 0.0 --> X == 0.0
sqrt(X) u> 0.0 --> X u!= 0.0
sqrt(X) u== 0.0 --> X u<= 0.0
sqrt(X) != 0.0 --> X > 0.0
!isnan(sqrt(X)) --> X >= 0.0
isnan(sqrt(X)) --> X u< 0.0

In most cases, sqrt cannot be eliminated since it has multiple uses. But this patch will break data dependencies and allow optimizer to sink expensive sqrt calls into successor blocks.

@dtcxzyw dtcxzyw requested a review from nikic as a code owner August 2, 2024 07:26
dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Aug 2, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 2, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

Changes

Proof (Please run alive-tv with larger smt-to): https://alive2.llvm.org/ce/z/-aqixk

sqrt(X) &lt; 0.0 --&gt; false
sqrt(X) u&gt;= 0.0 --&gt; true
sqrt(X) u&lt; 0.0 --&gt; X u&lt; 0.0
sqrt(X) u&lt;= 0.0 --&gt; X u&lt;= 0.0
sqrt(X) &gt; 0.0 --&gt; X &gt; 0.0
sqrt(X) &gt;= 0.0 --&gt; X &gt;= 0.0
sqrt(X) == 0.0 --&gt; X == 0.0
sqrt(X) u!= 0.0 --&gt; X u!= 0.0
sqrt(X) &lt;= 0.0 --&gt; X == 0.0
sqrt(X) u&gt; 0.0 --&gt; X u!= 0.0
sqrt(X) u== 0.0 --&gt; X u&lt;= 0.0
sqrt(X) != 0.0 --&gt; X &gt; 0.0
!isnan(sqrt(X)) --&gt; X &gt;= 0.0
isnan(sqrt(X)) --&gt; X u&lt; 0.0

In most cases, sqrt cannot be eliminated since it has multiple uses. But this patch will break data dependencies and allow optimizer to sink expensive sqrt calls into successor blocks.


Full diff: https://github.com/llvm/llvm-project/pull/101626.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+60)
  • (modified) llvm/test/Transforms/InstCombine/fcmp.ll (+172)
  • (modified) llvm/test/Transforms/InstCombine/known-never-nan.ll (+1-3)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 3b6df2760ecc2..622e7a420dd95 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -7980,6 +7980,63 @@ static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
   }
 }
 
+/// Optimize sqrt(X) compared with zero.
+static Instruction *foldSqrtWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
+  Value *X;
+  if (!match(I.getOperand(0), m_Sqrt(m_Value(X))))
+    return nullptr;
+
+  if (!match(I.getOperand(1), m_PosZeroFP()))
+    return nullptr;
+
+  auto ReplacePredAndOp0 = [&](FCmpInst::Predicate P) {
+    I.setPredicate(P);
+    return IC.replaceOperand(I, 0, X);
+  };
+
+  switch (I.getPredicate()) {
+  case FCmpInst::FCMP_OLT:
+  case FCmpInst::FCMP_UGE:
+    // sqrt(X) < 0.0 --> false
+    // sqrt(X) u>= 0.0 --> true
+    llvm_unreachable("fcmp should have simplified");
+  case FCmpInst::FCMP_ULT:
+  case FCmpInst::FCMP_ULE:
+  case FCmpInst::FCMP_OGT:
+  case FCmpInst::FCMP_OGE:
+  case FCmpInst::FCMP_OEQ:
+  case FCmpInst::FCMP_UNE:
+    // sqrt(X) u< 0.0 --> X u< 0.0
+    // sqrt(X) u<= 0.0 --> X u<= 0.0
+    // sqrt(X) > 0.0 --> X > 0.0
+    // sqrt(X) >= 0.0 --> X >= 0.0
+    // sqrt(X) == 0.0 --> X == 0.0
+    // sqrt(X) u!= 0.0 --> X u!= 0.0
+    return IC.replaceOperand(I, 0, X);
+
+  case FCmpInst::FCMP_OLE:
+    // sqrt(X) <= 0.0 --> X == 0.0
+    return ReplacePredAndOp0(FCmpInst::FCMP_OEQ);
+  case FCmpInst::FCMP_UGT:
+    // sqrt(X) u> 0.0 --> X u!= 0.0
+    return ReplacePredAndOp0(FCmpInst::FCMP_UNE);
+  case FCmpInst::FCMP_UEQ:
+    // sqrt(X) u== 0.0 --> X u<= 0.0
+    return ReplacePredAndOp0(FCmpInst::FCMP_ULE);
+  case FCmpInst::FCMP_ONE:
+    // sqrt(X) != 0.0 --> X > 0.0
+    return ReplacePredAndOp0(FCmpInst::FCMP_OGT);
+  case FCmpInst::FCMP_ORD:
+    // !isnan(sqrt(X)) --> X >= 0.0
+    return ReplacePredAndOp0(FCmpInst::FCMP_OGE);
+  case FCmpInst::FCMP_UNO:
+    // isnan(sqrt(X)) --> X u< 0.0
+    return ReplacePredAndOp0(FCmpInst::FCMP_ULT);
+  default:
+    llvm_unreachable("Unexpected predicate!");
+  }
+}
+
 static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) {
   CmpInst::Predicate Pred = I.getPredicate();
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@@ -8247,6 +8304,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
   if (Instruction *R = foldFabsWithFcmpZero(I, *this))
     return R;
 
+  if (Instruction *R = foldSqrtWithFcmpZero(I, *this))
+    return R;
+
   if (match(Op0, m_FNeg(m_Value(X)))) {
     // fcmp pred (fneg X), C --> fcmp swap(pred) X, -C
     Constant *C;
diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll
index 656b3d2c49206..3ea93149094c8 100644
--- a/llvm/test/Transforms/InstCombine/fcmp.ll
+++ b/llvm/test/Transforms/InstCombine/fcmp.ll
@@ -2117,3 +2117,175 @@ define <8 x i1> @fcmp_ogt_fsub_const_vec_denormal_preserve-sign(<8 x float> %x,
   %cmp = fcmp ogt <8 x float> %fs, zeroinitializer
   ret <8 x i1> %cmp
 }
+
+define i1 @fcmp_sqrt_zero_olt(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_olt(
+; CHECK-NEXT:    ret i1 false
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp olt half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ult(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ult half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define <2 x i1> @fcmp_sqrt_zero_ult_vec(<2 x half> %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult_vec(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult <2 x half> [[X:%.*]], zeroinitializer
+; CHECK-NEXT:    ret <2 x i1> [[CMP]]
+;
+  %sqrt = call <2 x half> @llvm.sqrt(<2 x half> %x)
+  %cmp = fcmp ult <2 x half> %sqrt, zeroinitializer
+  ret <2 x i1> %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ole(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ole(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ole half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ule(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ule(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ule half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ule half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ogt(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ogt(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ogt half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ugt(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ugt(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp une half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ugt half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_oge(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_oge(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oge half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp oge half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_uge(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_uge(
+; CHECK-NEXT:    ret i1 true
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp uge half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_oeq(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_oeq(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp oeq half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ueq(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ueq(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ule half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ueq half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_one(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_one(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ogt half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp one half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_une(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_une(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp une half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp une half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ord(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ord(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp oge half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ord half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_uno(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_uno(
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp uno half %sqrt, 0.0
+  ret i1 %cmp
+}
+
+; negative tests
+
+define i1 @fcmp_sqrt_zero_ult_var(half %x, half %y) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult_var(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[SQRT]], [[Y:%.*]]
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ult half %sqrt, %y
+  ret i1 %cmp
+}
+
+define i1 @fcmp_sqrt_zero_ult_nonzero(half %x) {
+; CHECK-LABEL: @fcmp_sqrt_zero_ult_nonzero(
+; CHECK-NEXT:    [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]])
+; CHECK-NEXT:    [[CMP:%.*]] = fcmp ult half [[SQRT]], 0xH3C00
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %sqrt = call half @llvm.sqrt(half %x)
+  %cmp = fcmp ult half %sqrt, 1.000000e+00
+  ret i1 %cmp
+}
diff --git a/llvm/test/Transforms/InstCombine/known-never-nan.ll b/llvm/test/Transforms/InstCombine/known-never-nan.ll
index a1cabc29682b4..82075b37b4361 100644
--- a/llvm/test/Transforms/InstCombine/known-never-nan.ll
+++ b/llvm/test/Transforms/InstCombine/known-never-nan.ll
@@ -9,9 +9,7 @@
 
 define i1 @fabs_sqrt_src_maybe_nan(double %arg0, double %arg1) {
 ; CHECK-LABEL: @fabs_sqrt_src_maybe_nan(
-; CHECK-NEXT:    [[FABS:%.*]] = call double @llvm.fabs.f64(double [[ARG0:%.*]])
-; CHECK-NEXT:    [[OP:%.*]] = call double @llvm.sqrt.f64(double [[FABS]])
-; CHECK-NEXT:    [[TMP:%.*]] = fcmp ord double [[OP]], 0.000000e+00
+; CHECK-NEXT:    [[TMP:%.*]] = fcmp ord double [[ARG0:%.*]], 0.000000e+00
 ; CHECK-NEXT:    ret i1 [[TMP]]
 ;
   %fabs = call double @llvm.fabs.f64(double %arg0)

@dtcxzyw dtcxzyw merged commit 8bd9ade into llvm:main Aug 3, 2024
@dtcxzyw dtcxzyw deleted the perf/fcmp-sqrt-zero branch August 3, 2024 05:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants