ScalarEvolution: HowFarToZero was wrongly using signed division
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 7324344c3e0e71ec421ce8d4061fd2985bf628bb..2026f8aa79cf089b30977444c05c10b1f064b449 100644 (file)
@@ -675,32 +675,32 @@ static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
   }
 }
 
-static const APInt srem(const SCEVConstant *C1, const SCEVConstant *C2) {
+static const APInt urem(const SCEVConstant *C1, const SCEVConstant *C2) {
   APInt A = C1->getValue()->getValue();
   APInt B = C2->getValue()->getValue();
   uint32_t ABW = A.getBitWidth();
   uint32_t BBW = B.getBitWidth();
 
   if (ABW > BBW)
-    B = B.sext(ABW);
+    B = B.zext(ABW);
   else if (ABW < BBW)
-    A = A.sext(BBW);
+    A = A.zext(BBW);
 
-  return APIntOps::srem(A, B);
+  return APIntOps::urem(A, B);
 }
 
-static const APInt sdiv(const SCEVConstant *C1, const SCEVConstant *C2) {
+static const APInt udiv(const SCEVConstant *C1, const SCEVConstant *C2) {
   APInt A = C1->getValue()->getValue();
   APInt B = C2->getValue()->getValue();
   uint32_t ABW = A.getBitWidth();
   uint32_t BBW = B.getBitWidth();
 
   if (ABW > BBW)
-    B = B.sext(ABW);
+    B = B.zext(ABW);
   else if (ABW < BBW)
-    A = A.sext(BBW);
+    A = A.zext(BBW);
 
-  return APIntOps::sdiv(A, B);
+  return APIntOps::udiv(A, B);
 }
 
 namespace {
@@ -803,8 +803,8 @@ public:
 
   void visitConstant(const SCEVConstant *Numerator) {
     if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
-      Quotient = SE.getConstant(sdiv(Numerator, D));
-      Remainder = SE.getConstant(srem(Numerator, D));
+      Quotient = SE.getConstant(udiv(Numerator, D));
+      Remainder = SE.getConstant(urem(Numerator, D));
       return;
     }
   }
@@ -6782,6 +6782,66 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
                                    RHS, LHS, FoundLHS, FoundRHS);
   }
 
+  // Check if we can make progress by sharpening ranges.
+  if (FoundPred == ICmpInst::ICMP_NE &&
+      (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
+
+    const SCEVConstant *C = nullptr;
+    const SCEV *V = nullptr;
+
+    if (isa<SCEVConstant>(FoundLHS)) {
+      C = cast<SCEVConstant>(FoundLHS);
+      V = FoundRHS;
+    } else {
+      C = cast<SCEVConstant>(FoundRHS);
+      V = FoundLHS;
+    }
+
+    // The guarding predicate tells us that C != V. If the known range
+    // of V is [C, t), we can sharpen the range to [C + 1, t).  The
+    // range we consider has to correspond to same signedness as the
+    // predicate we're interested in folding.
+
+    APInt Min = ICmpInst::isSigned(Pred) ?
+        getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin();
+
+    if (Min == C->getValue()->getValue()) {
+      // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
+      // This is true even if (Min + 1) wraps around -- in case of
+      // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
+
+      APInt SharperMin = Min + 1;
+
+      switch (Pred) {
+        case ICmpInst::ICMP_SGE:
+        case ICmpInst::ICMP_UGE:
+          // We know V `Pred` SharperMin.  If this implies LHS `Pred`
+          // RHS, we're done.
+          if (isImpliedCondOperands(Pred, LHS, RHS, V,
+                                    getConstant(SharperMin)))
+            return true;
+
+        case ICmpInst::ICMP_SGT:
+        case ICmpInst::ICMP_UGT:
+          // We know from the range information that (V `Pred` Min ||
+          // V == Min).  We know from the guarding condition that !(V
+          // == Min).  This gives us
+          //
+          //       V `Pred` Min || V == Min && !(V == Min)
+          //   =>  V `Pred` Min
+          //
+          // If V `Pred` Min implies LHS `Pred` RHS, we're done.
+
+          if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min)))
+            return true;
+
+        default:
+          // No change
+          break;
+      }
+    }
+  }
+
   // Check whether the actual condition is beyond sufficient.
   if (FoundPred == ICmpInst::ICMP_EQ)
     if (ICmpInst::isTrueWhenEqual(Pred))