[ValueTracking] Recognize that and(x, add (x, -1)) clears the low bit
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index a24b154ab96522b52a2ff6e6690b2e92eba51f4f..c2db02fe85a49aaa5cfef7ed849a84e30f9dbdfb 100644 (file)
@@ -3943,6 +3943,11 @@ const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
   if (PN->getNumIncomingValues() == 2) {
     const Loop *L = LI.getLoopFor(PN->getParent());
 
+    // We don't want to break LCSSA, even in a SCEV expression tree.
+    for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
+      if (LI.getLoopFor(PN->getIncomingBlock(i)) != L)
+        return nullptr;
+
     // Try to match
     //
     //  br %cond, label %left, label %right
@@ -5923,6 +5928,30 @@ static Constant *EvaluateExpression(Value *V, const Loop *L,
                                   TLI);
 }
 
+
+// If every incoming value to PN except the one for BB is a specific Constant,
+// return that, else return nullptr.
+static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) {
+  Constant *IncomingVal = nullptr;
+
+  for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
+    if (PN->getIncomingBlock(i) == BB)
+      continue;
+
+    auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
+    if (!CurrentVal)
+      return nullptr;
+
+    if (IncomingVal != CurrentVal) {
+      if (IncomingVal)
+        return nullptr;
+      IncomingVal = CurrentVal;
+    }
+  }
+
+  return IncomingVal;
+}
+
 /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
 /// in the header of its containing loop, we know the loop executes a
 /// constant number of times, and the PHI node is just a recurrence
@@ -5948,25 +5977,10 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
   if (!Latch)
     return nullptr;
 
-  // Since the loop has one latch, the PHI node must have two entries.  One
-  // entry must be a constant (coming in from outside of the loop), and the
-  // second must be derived from the same PHI.
-
-  BasicBlock *NonLatch = Latch == PN->getIncomingBlock(0)
-                             ? PN->getIncomingBlock(1)
-                             : PN->getIncomingBlock(0);
-
-  assert(PN->getNumIncomingValues() == 2 && "Follows from having one latch!");
-
-  // Note: not all PHI nodes in the same block have to have their incoming
-  // values in the same order, so we use the basic block to look up the incoming
-  // value, not an index.
-
   for (auto &I : *Header) {
     PHINode *PHI = dyn_cast<PHINode>(&I);
     if (!PHI) break;
-    auto *StartCST =
-        dyn_cast<Constant>(PHI->getIncomingValueForBlock(NonLatch));
+    auto *StartCST = getOtherIncomingValue(PHI, Latch);
     if (!StartCST) continue;
     CurrentIterVals[PHI] = StartCST;
   }
@@ -6045,21 +6059,11 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
   BasicBlock *Latch = L->getLoopLatch();
   assert(Latch && "Should follow from NumIncomingValues == 2!");
 
-  // NonLatch is the preheader, or something equivalent.
-  BasicBlock *NonLatch = Latch == PN->getIncomingBlock(0)
-                             ? PN->getIncomingBlock(1)
-                             : PN->getIncomingBlock(0);
-
-  // Note: not all PHI nodes in the same block have to have their incoming
-  // values in the same order, so we use the basic block to look up the incoming
-  // value, not an index.
-
   for (auto &I : *Header) {
     PHINode *PHI = dyn_cast<PHINode>(&I);
     if (!PHI)
       break;
-    auto *StartCST =
-      dyn_cast<Constant>(PHI->getIncomingValueForBlock(NonLatch));
+    auto *StartCST = getOtherIncomingValue(PHI, Latch);
     if (!StartCST) continue;
     CurrentIterVals[PHI] = StartCST;
   }
@@ -7381,6 +7385,7 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
     if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) &&
         !C.isStrictlyPositive())
       return true;
+    break;
 
   case ICmpInst::ICMP_SGT:
     std::swap(LHS, RHS);
@@ -7393,6 +7398,7 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
     // (X + C)<nsw> s< X if C < 0
     if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
       return true;
+    break;
   }
 
   return false;
@@ -7415,12 +7421,9 @@ bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
   // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
   // interesting cases seen in practice.  We can consider "upgrading" L >= 0 to
   // use isKnownPredicate later if needed.
-  if (isKnownNonNegative(RHS) &&
-      isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
-      isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS))
-    return true;
-
-  return false;
+  return isKnownNonNegative(RHS) &&
+         isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) &&
+         isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
 }
 
 /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
@@ -9089,6 +9092,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
       UnsignedRanges(std::move(Arg.UnsignedRanges)),
       SignedRanges(std::move(Arg.SignedRanges)),
       UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
+      UniquePreds(std::move(Arg.UniquePreds)),
       SCEVAllocator(std::move(Arg.SCEVAllocator)),
       FirstUnknown(Arg.FirstUnknown) {
   Arg.FirstUnknown = nullptr;
@@ -9592,3 +9596,134 @@ void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.addRequiredTransitive<DominatorTreeWrapperPass>();
   AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
 }
+
+const SCEVPredicate *
+ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS,
+                                   const SCEVConstant *RHS) {
+  FoldingSetNodeID ID;
+  // Unique this node based on the arguments
+  ID.AddInteger(SCEVPredicate::P_Equal);
+  ID.AddPointer(LHS);
+  ID.AddPointer(RHS);
+  void *IP = nullptr;
+  if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
+    return S;
+  SCEVEqualPredicate *Eq = new (SCEVAllocator)
+      SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS);
+  UniquePreds.InsertNode(Eq, IP);
+  return Eq;
+}
+
+class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
+public:
+  static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
+                             SCEVUnionPredicate &A) {
+    SCEVPredicateRewriter Rewriter(SE, A);
+    return Rewriter.visit(Scev);
+  }
+
+  SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P)
+      : SCEVRewriteVisitor(SE), P(P) {}
+
+  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+    auto ExprPreds = P.getPredicatesForExpr(Expr);
+    for (auto *Pred : ExprPreds)
+      if (const auto *IPred = dyn_cast<const SCEVEqualPredicate>(Pred))
+        if (IPred->getLHS() == Expr)
+          return IPred->getRHS();
+
+    return Expr;
+  }
+
+private:
+  SCEVUnionPredicate &P;
+};
+
+const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev,
+                                                   SCEVUnionPredicate &Preds) {
+  return SCEVPredicateRewriter::rewrite(Scev, *this, Preds);
+}
+
+/// SCEV predicates
+SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID,
+                             SCEVPredicateKind Kind)
+    : FastID(ID), Kind(Kind) {}
+
+SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID,
+                                       const SCEVUnknown *LHS,
+                                       const SCEVConstant *RHS)
+    : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {}
+
+bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const {
+  const auto *Op = dyn_cast<const SCEVEqualPredicate>(N);
+
+  if (!Op)
+    return false;
+
+  return Op->LHS == LHS && Op->RHS == RHS;
+}
+
+bool SCEVEqualPredicate::isAlwaysTrue() const { return false; }
+
+const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; }
+
+void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const {
+  OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
+}
+
+/// Union predicates don't get cached so create a dummy set ID for it.
+SCEVUnionPredicate::SCEVUnionPredicate()
+    : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
+
+bool SCEVUnionPredicate::isAlwaysTrue() const {
+  return std::all_of(Preds.begin(), Preds.end(),
+                     [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
+}
+
+ArrayRef<const SCEVPredicate *>
+SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) {
+  auto I = SCEVToPreds.find(Expr);
+  if (I == SCEVToPreds.end())
+    return ArrayRef<const SCEVPredicate *>();
+  return I->second;
+}
+
+bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const {
+  if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N))
+    return std::all_of(
+        Set->Preds.begin(), Set->Preds.end(),
+        [this](const SCEVPredicate *I) { return this->implies(I); });
+
+  auto ScevPredsIt = SCEVToPreds.find(N->getExpr());
+  if (ScevPredsIt == SCEVToPreds.end())
+    return false;
+  auto &SCEVPreds = ScevPredsIt->second;
+
+  return std::any_of(SCEVPreds.begin(), SCEVPreds.end(),
+                     [N](const SCEVPredicate *I) { return I->implies(N); });
+}
+
+const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; }
+
+void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const {
+  for (auto Pred : Preds)
+    Pred->print(OS, Depth);
+}
+
+void SCEVUnionPredicate::add(const SCEVPredicate *N) {
+  if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) {
+    for (auto Pred : Set->Preds)
+      add(Pred);
+    return;
+  }
+
+  if (implies(N))
+    return;
+
+  const SCEV *Key = N->getExpr();
+  assert(Key && "Only SCEVUnionPredicate doesn't have an "
+                " associated expression!");
+
+  SCEVToPreds[Key].push_back(N);
+  Preds.push_back(N);
+}