[PM/AA] Run clang-format over LibCallAliasAnalysis prior to making
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 81e07e99dca14dc91f6a7eb7cf4567e1991371a9..3e3032b48fe6f614ed90eff544016a45191b93fb 100644 (file)
@@ -627,7 +627,7 @@ namespace {
       llvm_unreachable("Unknown SCEV kind!");
     }
   };
-} // namespace
+}
 
 /// GroupByComplexity - Given a list of SCEV objects, order them by their
 /// complexity, and group objects of the same complexity together by value.
@@ -689,7 +689,7 @@ struct FindSCEVSize {
     return false;
   }
 };
-} // namespace
+}
 
 // Returns the size of the SCEV S.
 static inline int sizeOfSCEV(const SCEV *S) {
@@ -937,7 +937,7 @@ private:
   const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One;
 };
 
-} // namespace
+}
 
 //===----------------------------------------------------------------------===//
 //                      Simple SCEV method implementations
@@ -1248,7 +1248,7 @@ struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
 
 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
     SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr;
-} // namespace
+}
 
 // The recurrence AR has been shown to have no signed/unsigned wrap or something
 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
@@ -2936,7 +2936,8 @@ ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr,
   // FIXME(PR23527): Don't blindly transfer the inbounds flag from the GEP
   // instruction to its SCEV, because the Instruction may be guarded by control
   // flow and the no-overflow bits may not be valid for the expression in any
-  // context.
+  // context. This can be fixed similarly to how these flags are handled for
+  // adds.
   SCEV::NoWrapFlags Wrap = InBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
 
   const SCEV *TotalOffset = getConstant(IntPtrTy, 0);
@@ -3300,7 +3301,7 @@ namespace {
     }
     bool isDone() const { return FindOne; }
   };
-} // namespace
+}
 
 bool ScalarEvolution::checkValidity(const SCEV *S) const {
   FindInvalidSCEVUnknown F;
@@ -3315,22 +3316,25 @@ bool ScalarEvolution::checkValidity(const SCEV *S) const {
 const SCEV *ScalarEvolution::getSCEV(Value *V) {
   assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
 
+  const SCEV *S = getExistingSCEV(V);
+  if (S == nullptr) {
+    S = createSCEV(V);
+    ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S));
+  }
+  return S;
+}
+
+const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
+  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
+
   ValueExprMapType::iterator I = ValueExprMap.find_as(V);
   if (I != ValueExprMap.end()) {
     const SCEV *S = I->second;
     if (checkValidity(S))
       return S;
-    else
-      ValueExprMap.erase(I);
+    ValueExprMap.erase(I);
   }
-  const SCEV *S = createSCEV(V);
-
-  // The process of creating a SCEV for V may have caused other SCEVs
-  // to have been created, so it's necessary to insert the new entry
-  // from scratch, rather than trying to remember the insert position
-  // above.
-  ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S));
-  return S;
+  return nullptr;
 }
 
 /// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
@@ -4089,8 +4093,63 @@ ScalarEvolution::getRange(const SCEV *S,
   return setRange(S, SignHint, ConservativeResult);
 }
 
-/// createSCEV - We know that there is no SCEV for the specified value.
-/// Analyze the expression.
+SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
+  const BinaryOperator *BinOp = cast<BinaryOperator>(V);
+
+  // Return early if there are no flags to propagate to the SCEV.
+  SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
+  if (BinOp->hasNoUnsignedWrap())
+    Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
+  if (BinOp->hasNoSignedWrap())
+    Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
+  if (Flags == SCEV::FlagAnyWrap) {
+    return SCEV::FlagAnyWrap;
+  }
+
+  // Here we check that BinOp is in the header of the innermost loop
+  // containing BinOp, since we only deal with instructions in the loop
+  // header. The actual loop we need to check later will come from an add
+  // recurrence, but getting that requires computing the SCEV of the operands,
+  // which can be expensive. This check we can do cheaply to rule out some
+  // cases early.
+  Loop *innermostContainingLoop = LI->getLoopFor(BinOp->getParent());
+  if (innermostContainingLoop == nullptr ||
+      innermostContainingLoop->getHeader() != BinOp->getParent())
+    return SCEV::FlagAnyWrap;
+
+  // Only proceed if we can prove that BinOp does not yield poison.
+  if (!isKnownNotFullPoison(BinOp)) return SCEV::FlagAnyWrap;
+
+  // At this point we know that if V is executed, then it does not wrap
+  // according to at least one of NSW or NUW. If V is not executed, then we do
+  // not know if the calculation that V represents would wrap. Multiple
+  // instructions can map to the same SCEV. If we apply NSW or NUW from V to
+  // the SCEV, we must guarantee no wrapping for that SCEV also when it is
+  // derived from other instructions that map to the same SCEV. We cannot make
+  // that guarantee for cases where V is not executed. So we need to find the
+  // loop that V is considered in relation to and prove that V is executed for
+  // every iteration of that loop. That implies that the value that V
+  // calculates does not wrap anywhere in the loop, so then we can apply the
+  // flags to the SCEV.
+  //
+  // We check isLoopInvariant to disambiguate in case we are adding two
+  // recurrences from different loops, so that we know which loop to prove
+  // that V is executed in.
+  for (int OpIndex = 0; OpIndex < 2; ++OpIndex) {
+    const SCEV *Op = getSCEV(BinOp->getOperand(OpIndex));
+    if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
+      const int OtherOpIndex = 1 - OpIndex;
+      const SCEV *OtherOp = getSCEV(BinOp->getOperand(OtherOpIndex));
+      if (isLoopInvariant(OtherOp, AddRec->getLoop()) &&
+          isGuaranteedToExecuteForEveryIteration(BinOp, AddRec->getLoop()))
+        return Flags;
+    }
+  }
+  return SCEV::FlagAnyWrap;
+}
+
+/// createSCEV - We know that there is no SCEV for the specified value.  Analyze
+/// the expression.
 ///
 const SCEV *ScalarEvolution::createSCEV(Value *V) {
   if (!isSCEVable(V->getType()))
@@ -4127,29 +4186,52 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
     // Instead, gather up all the operands and make a single getAddExpr call.
     // LLVM IR canonical form means we need only traverse the left operands.
     //
-    // Don't apply this instruction's NSW or NUW flags to the new
-    // expression. The instruction may be guarded by control flow that the
-    // no-wrap behavior depends on. Non-control-equivalent instructions can be
-    // mapped to the same SCEV expression, and it would be incorrect to transfer
-    // NSW/NUW semantics to those operations.
+    // FIXME: Expand this handling of NSW and NUW to other instructions, like
+    // sub and mul.
     SmallVector<const SCEV *, 4> AddOps;
-    AddOps.push_back(getSCEV(U->getOperand(1)));
-    for (Value *Op = U->getOperand(0); ; Op = U->getOperand(0)) {
-      unsigned Opcode = Op->getValueID() - Value::InstructionVal;
-      if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
+    for (Value *Op = U;; Op = U->getOperand(0)) {
+      U = dyn_cast<Operator>(Op);
+      unsigned Opcode = U ? U->getOpcode() : 0;
+      if (!U || (Opcode != Instruction::Add && Opcode != Instruction::Sub)) {
+        assert(Op != V && "V should be an add");
+        AddOps.push_back(getSCEV(Op));
         break;
-      U = cast<Operator>(Op);
+      }
+
+      if (auto *OpSCEV = getExistingSCEV(Op)) {
+        AddOps.push_back(OpSCEV);
+        break;
+      }
+
+      // If a NUW or NSW flag can be applied to the SCEV for this
+      // addition, then compute the SCEV for this addition by itself
+      // with a separate call to getAddExpr. We need to do that
+      // instead of pushing the operands of the addition onto AddOps,
+      // since the flags are only known to apply to this particular
+      // addition - they may not apply to other additions that can be
+      // formed with operands from AddOps.
+      //
+      // FIXME: Expand this to sub instructions.
+      if (Opcode == Instruction::Add && isa<BinaryOperator>(U)) {
+        SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U);
+        if (Flags != SCEV::FlagAnyWrap) {
+          AddOps.push_back(getAddExpr(getSCEV(U->getOperand(0)),
+                                      getSCEV(U->getOperand(1)), Flags));
+          break;
+        }
+      }
+
       const SCEV *Op1 = getSCEV(U->getOperand(1));
       if (Opcode == Instruction::Sub)
         AddOps.push_back(getNegativeSCEV(Op1));
       else
         AddOps.push_back(Op1);
     }
-    AddOps.push_back(getSCEV(U->getOperand(0)));
     return getAddExpr(AddOps);
   }
+
   case Instruction::Mul: {
-    // Don't transfer NSW/NUW for the same reason as AddExpr.
+    // FIXME: Transfer NSW/NUW as in AddExpr.
     SmallVector<const SCEV *, 4> MulOps;
     MulOps.push_back(getSCEV(U->getOperand(1)));
     for (Value *Op = U->getOperand(0);
@@ -4744,12 +4826,12 @@ void ScalarEvolution::forgetValue(Value *V) {
 }
 
 /// getExact - Get the exact loop backedge taken count considering all loop
-/// exits. A computable result can only be return for loops with a single exit.
-/// Returning the minimum taken count among all exits is incorrect because one
-/// of the loop's exit limit's may have been skipped. HowFarToZero assumes that
-/// the limit of each loop test is never skipped. This is a valid assumption as
-/// long as the loop exits via that test. For precise results, it is the
-/// caller's responsibility to specify the relevant loop exit using
+/// exits. A computable result can only be returned for loops with a single
+/// exit.  Returning the minimum taken count among all exits is incorrect
+/// because one of the loop's exit limit's may have been skipped. HowFarToZero
+/// assumes that the limit of each loop test is never skipped. This is a valid
+/// assumption as long as the loop exits via that test. For precise results, it
+/// is the caller's responsibility to specify the relevant loop exit using
 /// getExact(ExitingBlock, SE).
 const SCEV *
 ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const {
@@ -4952,8 +5034,7 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) {
       if (!Pred)
         return getCouldNotCompute();
       TerminatorInst *PredTerm = Pred->getTerminator();
-      for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) {
-        BasicBlock *PredSucc = PredTerm->getSuccessor(i);
+      for (const BasicBlock *PredSucc : PredTerm->successors()) {
         if (PredSucc == BB)
           continue;
         // If the predecessor has a successor that isn't BB and isn't
@@ -6616,6 +6697,133 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
   return isKnownPredicateWithRanges(Pred, LHS, RHS);
 }
 
+bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS,
+                                           ICmpInst::Predicate Pred,
+                                           bool &Increasing) {
+  bool Result = isMonotonicPredicateImpl(LHS, Pred, Increasing);
+
+#ifndef NDEBUG
+  // Verify an invariant: inverting the predicate should turn a monotonically
+  // increasing change to a monotonically decreasing one, and vice versa.
+  bool IncreasingSwapped;
+  bool ResultSwapped = isMonotonicPredicateImpl(
+      LHS, ICmpInst::getSwappedPredicate(Pred), IncreasingSwapped);
+
+  assert(Result == ResultSwapped && "should be able to analyze both!");
+  if (ResultSwapped)
+    assert(Increasing == !IncreasingSwapped &&
+           "monotonicity should flip as we flip the predicate");
+#endif
+
+  return Result;
+}
+
+bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS,
+                                               ICmpInst::Predicate Pred,
+                                               bool &Increasing) {
+
+  // A zero step value for LHS means the induction variable is essentially a
+  // loop invariant value. We don't really depend on the predicate actually
+  // flipping from false to true (for increasing predicates, and the other way
+  // around for decreasing predicates), all we care about is that *if* the
+  // predicate changes then it only changes from false to true.
+  //
+  // A zero step value in itself is not very useful, but there may be places
+  // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
+  // as general as possible.
+
+  switch (Pred) {
+  default:
+    return false; // Conservative answer
+
+  case ICmpInst::ICMP_UGT:
+  case ICmpInst::ICMP_UGE:
+  case ICmpInst::ICMP_ULT:
+  case ICmpInst::ICMP_ULE:
+    if (!LHS->getNoWrapFlags(SCEV::FlagNUW))
+      return false;
+
+    Increasing = Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE;
+    return true;
+
+  case ICmpInst::ICMP_SGT:
+  case ICmpInst::ICMP_SGE:
+  case ICmpInst::ICMP_SLT:
+  case ICmpInst::ICMP_SLE: {
+    if (!LHS->getNoWrapFlags(SCEV::FlagNSW))
+      return false;
+
+    const SCEV *Step = LHS->getStepRecurrence(*this);
+
+    if (isKnownNonNegative(Step)) {
+      Increasing = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE;
+      return true;
+    }
+
+    if (isKnownNonPositive(Step)) {
+      Increasing = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE;
+      return true;
+    }
+
+    return false;
+  }
+
+  }
+
+  llvm_unreachable("switch has default clause!");
+}
+
+bool ScalarEvolution::isLoopInvariantPredicate(
+    ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
+    ICmpInst::Predicate &InvariantPred, const SCEV *&InvariantLHS,
+    const SCEV *&InvariantRHS) {
+
+  // If there is a loop-invariant, force it into the RHS, otherwise bail out.
+  if (!isLoopInvariant(RHS, L)) {
+    if (!isLoopInvariant(LHS, L))
+      return false;
+
+    std::swap(LHS, RHS);
+    Pred = ICmpInst::getSwappedPredicate(Pred);
+  }
+
+  const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
+  if (!ArLHS || ArLHS->getLoop() != L)
+    return false;
+
+  bool Increasing;
+  if (!isMonotonicPredicate(ArLHS, Pred, Increasing))
+    return false;
+
+  // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
+  // true as the loop iterates, and the backedge is control dependent on
+  // "ArLHS `Pred` RHS" == true then we can reason as follows:
+  //
+  //   * if the predicate was false in the first iteration then the predicate
+  //     is never evaluated again, since the loop exits without taking the
+  //     backedge.
+  //   * if the predicate was true in the first iteration then it will
+  //     continue to be true for all future iterations since it is
+  //     monotonically increasing.
+  //
+  // For both the above possibilities, we can replace the loop varying
+  // predicate with its value on the first iteration of the loop (which is
+  // loop invariant).
+  //
+  // A similar reasoning applies for a monotonically decreasing predicate, by
+  // replacing true with false and false with true in the above two bullets.
+
+  auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
+
+  if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
+    return false;
+
+  InvariantPred = Pred;
+  InvariantLHS = ArLHS->getStart();
+  InvariantRHS = RHS;
+  return true;
+}
+
 bool
 ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred,
                                             const SCEV *LHS, const SCEV *RHS) {
@@ -6731,7 +6939,7 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
     ScalarEvolution &SE;
 
     explicit ClearWalkingBEDominatingCondsOnExit(ScalarEvolution &SE)
-        : SE(SE){};
+        : SE(SE){}
 
     ~ClearWalkingBEDominatingCondsOnExit() {
       SE.WalkingBEDominatingConds = false;
@@ -7594,7 +7802,7 @@ struct FindUndefs {
     return Found;
   }
 };
-} // namespace
+}
 
 // Return true when S contains at least an undef value.
 static inline bool
@@ -7644,14 +7852,14 @@ struct SCEVCollectTerms {
   }
   bool isDone() const { return false; }
 };
-} // namespace
+}
 
 /// Find parametric terms in this SCEVAddRecExpr.
-void SCEVAddRecExpr::collectParametricTerms(
-    ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &Terms) const {
+void ScalarEvolution::collectParametricTerms(const SCEV *Expr,
+    SmallVectorImpl<const SCEV *> &Terms) {
   SmallVector<const SCEV *, 4> Strides;
-  SCEVCollectStrides StrideCollector(SE, Strides);
-  visitAll(this, StrideCollector);
+  SCEVCollectStrides StrideCollector(*this, Strides);
+  visitAll(Expr, StrideCollector);
 
   DEBUG({
       dbgs() << "Strides:\n";
@@ -7737,7 +7945,7 @@ struct FindParameter {
     return FoundParameter;
   }
 };
-} // namespace
+}
 
 // Returns true when S contains at least a SCEVUnknown parameter.
 static inline bool
@@ -7867,19 +8075,23 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,
 
 /// Third step of delinearization: compute the access functions for the
 /// Subscripts based on the dimensions in Sizes.
-void SCEVAddRecExpr::computeAccessFunctions(
-    ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &Subscripts,
-    SmallVectorImpl<const SCEV *> &Sizes) const {
+void ScalarEvolution::computeAccessFunctions(
+    const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts,
+    SmallVectorImpl<const SCEV *> &Sizes) {
 
   // Early exit in case this SCEV is not an affine multivariate function.
-  if (Sizes.empty() || !this->isAffine())
+  if (Sizes.empty())
     return;
 
-  const SCEV *Res = this;
+  if (auto AR = dyn_cast<SCEVAddRecExpr>(Expr))
+    if (!AR->isAffine())
+      return;
+
+  const SCEV *Res = Expr;
   int Last = Sizes.size() - 1;
   for (int i = Last; i >= 0; i--) {
     const SCEV *Q, *R;
-    SCEVDivision::divide(SE, Res, Sizes[i], &Q, &R);
+    SCEVDivision::divide(*this, Res, Sizes[i], &Q, &R);
 
     DEBUG({
         dbgs() << "Res: " << *Res << "\n";
@@ -7971,31 +8183,31 @@ void SCEVAddRecExpr::computeAccessFunctions(
 /// asking for the SCEV of the memory access with respect to all enclosing
 /// loops, calling SCEV->delinearize on that and printing the results.
 
-void SCEVAddRecExpr::delinearize(ScalarEvolution &SE,
+void ScalarEvolution::delinearize(const SCEV *Expr,
                                  SmallVectorImpl<const SCEV *> &Subscripts,
                                  SmallVectorImpl<const SCEV *> &Sizes,
-                                 const SCEV *ElementSize) const {
+                                 const SCEV *ElementSize) {
   // First step: collect parametric terms.
   SmallVector<const SCEV *, 4> Terms;
-  collectParametricTerms(SE, Terms);
+  collectParametricTerms(Expr, Terms);
 
   if (Terms.empty())
     return;
 
   // Second step: find subscript sizes.
-  SE.findArrayDimensions(Terms, Sizes, ElementSize);
+  findArrayDimensions(Terms, Sizes, ElementSize);
 
   if (Sizes.empty())
     return;
 
   // Third step: compute the access functions for each subscript.
-  computeAccessFunctions(SE, Subscripts, Sizes);
+  computeAccessFunctions(Expr, Subscripts, Sizes);
 
   if (Subscripts.empty())
     return;
 
   DEBUG({
-      dbgs() << "succeeded to delinearize " << *this << "\n";
+      dbgs() << "succeeded to delinearize " << *Expr << "\n";
       dbgs() << "ArrayDecl[UnknownSize]";
       for (const SCEV *S : Sizes)
         dbgs() << "[" << *S << "]";
@@ -8103,10 +8315,10 @@ void ScalarEvolution::releaseMemory() {
 
 void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesAll();
-  AU.addRequired<AssumptionCacheTracker>();
+  AU.addRequiredTransitive<AssumptionCacheTracker>();
   AU.addRequiredTransitive<LoopInfoWrapperPass>();
   AU.addRequiredTransitive<DominatorTreeWrapperPass>();
-  AU.addRequired<TargetLibraryInfoWrapperPass>();
+  AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>();
 }
 
 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
@@ -8418,7 +8630,7 @@ struct SCEVSearch {
   }
   bool isDone() const { return IsFound; }
 };
-} // namespace
+}
 
 bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
   SCEVSearch Search(Op);