[PM] Port TargetLibraryInfo to the new pass manager, provided by the
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 06ae9c16d145c04bf355379872448a60de63ad26..f0c8f3b7178c013494ef6760af17e7e63bac96ef 100644 (file)
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/Statistic.h"
-#include "llvm/Analysis/AssumptionTracker.h"
+#include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Constants.h"
@@ -87,7 +88,6 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
-#include "llvm/Target/TargetLibraryInfo.h"
 #include <algorithm>
 using namespace llvm;
 
@@ -116,10 +116,10 @@ VerifySCEV("verify-scev",
 
 INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution",
                 "Scalar Evolution Analysis", false, true)
-INITIALIZE_PASS_DEPENDENCY(AssumptionTracker)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
 INITIALIZE_PASS_DEPENDENCY(LoopInfo)
 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo)
+INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
 INITIALIZE_PASS_END(ScalarEvolution, "scalar-evolution",
                 "Scalar Evolution Analysis", false, true)
 char ScalarEvolution::ID = 0;
@@ -1737,6 +1737,36 @@ namespace {
   };
 }
 
+// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
+// `OldFlags' as can't-wrap behavior.  Infer a more aggressive set of
+// can't-overflow flags for the operation if possible.
+static SCEV::NoWrapFlags
+StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
+                      const SmallVectorImpl<const SCEV *> &Ops,
+                      SCEV::NoWrapFlags OldFlags) {
+  using namespace std::placeholders;
+
+  bool CanAnalyze =
+      Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
+  (void)CanAnalyze;
+  assert(CanAnalyze && "don't call from other places!");
+
+  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
+  SCEV::NoWrapFlags SignOrUnsignWrap =
+      ScalarEvolution::maskFlags(OldFlags, SignOrUnsignMask);
+
+  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
+  auto IsKnownNonNegative =
+    std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1);
+
+  if (SignOrUnsignWrap == SCEV::FlagNSW &&
+      std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative))
+    return ScalarEvolution::setFlags(OldFlags,
+                                     (SCEV::NoWrapFlags)SignOrUnsignMask);
+
+  return OldFlags;
+}
+
 /// getAddExpr - Get a canonical add expression, or something simpler if
 /// possible.
 const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
@@ -1752,20 +1782,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
            "SCEVAddExpr operand types don't match!");
 #endif
 
-  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
-  // And vice-versa.
-  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
-  SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
-  if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
-    bool All = true;
-    for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(),
-         E = Ops.end(); I != E; ++I)
-      if (!isKnownNonNegative(*I)) {
-        All = false;
-        break;
-      }
-    if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
-  }
+  Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags);
 
   // Sort by complexity, this groups all similar expression types together.
   GroupByComplexity(Ops, LI);
@@ -2174,20 +2191,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
            "SCEVMulExpr operand types don't match!");
 #endif
 
-  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
-  // And vice-versa.
-  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
-  SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
-  if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
-    bool All = true;
-    for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(),
-         E = Ops.end(); I != E; ++I)
-      if (!isKnownNonNegative(*I)) {
-        All = false;
-        break;
-      }
-    if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
-  }
+  Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags);
 
   // Sort by complexity, this groups all similar expression types together.
   GroupByComplexity(Ops, LI);
@@ -2653,20 +2657,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
   // meaningful BE count at this point (and if we don't, we'd be stuck
   // with a SCEVCouldNotCompute as the cached BE count).
 
-  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
-  // And vice-versa.
-  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
-  SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
-  if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
-    bool All = true;
-    for (SmallVectorImpl<const SCEV *>::const_iterator I = Operands.begin(),
-         E = Operands.end(); I != E; ++I)
-      if (!isKnownNonNegative(*I)) {
-        All = false;
-        break;
-      }
-    if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
-  }
+  Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
 
   // Canonicalize nested AddRecs in by nesting them in order of loop depth.
   if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
@@ -3531,7 +3522,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
   // PHI's incoming blocks are in a different loop, in which case doing so
   // risks breaking LCSSA form. Instcombine would normally zap these, but
   // it doesn't have DominatorTree information, so it may miss cases.
-  if (Value *V = SimplifyInstruction(PN, DL, TLI, DT, AT))
+  if (Value *V = SimplifyInstruction(PN, DL, TLI, DT, AC))
     if (LI->replacementPreservesLCSSAForm(PN, V))
       return getSCEV(V);
 
@@ -3663,7 +3654,7 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
     // For a SCEVUnknown, ask ValueTracking.
     unsigned BitWidth = getTypeSizeInBits(U->getType());
     APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
-    computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AT, nullptr, DT);
+    computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT);
     return Zeros.countTrailingOnes();
   }
 
@@ -3834,7 +3825,7 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) {
 
     // For a SCEVUnknown, ask ValueTracking.
     APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
-    computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AT, nullptr, DT);
+    computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT);
     if (Ones == ~Zeros + 1)
       return setUnsignedRange(U, ConservativeResult);
     return setUnsignedRange(U,
@@ -3991,7 +3982,7 @@ ScalarEvolution::getSignedRange(const SCEV *S) {
     // For a SCEVUnknown, ask ValueTracking.
     if (!U->getValue()->getType()->isIntegerTy() && !DL)
       return setSignedRange(U, ConservativeResult);
-    unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AT, nullptr, DT);
+    unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AC, nullptr, DT);
     if (NS <= 1)
       return setSignedRange(U, ConservativeResult);
     return setSignedRange(U, ConservativeResult.intersectWith(
@@ -4098,8 +4089,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       unsigned TZ = A.countTrailingZeros();
       unsigned BitWidth = A.getBitWidth();
       APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
-      computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL,
-                       0, AT, nullptr, DT);
+      computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL, 0, AC,
+                       nullptr, DT);
 
       APInt EffectiveMask =
           APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
@@ -6630,7 +6621,10 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
     return true;
 
   // Check conditions due to any @llvm.assume intrinsics.
-  for (auto &CI : AT->assumptions(F)) {
+  for (auto &AssumeVH : AC->assumptions()) {
+    if (!AssumeVH)
+      continue;
+    auto *CI = cast<CallInst>(AssumeVH);
     if (!DT->dominates(CI, Latch->getTerminator()))
       continue;
 
@@ -6675,7 +6669,10 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
   }
 
   // Check conditions due to any @llvm.assume intrinsics.
-  for (auto &CI : AT->assumptions(F)) {
+  for (auto &AssumeVH : AC->assumptions()) {
+    if (!AssumeVH)
+      continue;
+    auto *CI = cast<CallInst>(AssumeVH);
     if (!DT->dominates(CI, L->getHeader()))
       continue;
 
@@ -6886,6 +6883,85 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
                                      getNotSCEV(FoundLHS));
 }
 
+
+/// If Expr computes ~A, return A else return nullptr
+static const SCEV *MatchNotExpr(const SCEV *Expr) {
+  const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
+  if (!Add || Add->getNumOperands() != 2) return nullptr;
+
+  const SCEVConstant *AddLHS = dyn_cast<SCEVConstant>(Add->getOperand(0));
+  if (!(AddLHS && AddLHS->getValue()->getValue().isAllOnesValue()))
+    return nullptr;
+
+  const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
+  if (!AddRHS || AddRHS->getNumOperands() != 2) return nullptr;
+
+  const SCEVConstant *MulLHS = dyn_cast<SCEVConstant>(AddRHS->getOperand(0));
+  if (!(MulLHS && MulLHS->getValue()->getValue().isAllOnesValue()))
+    return nullptr;
+
+  return AddRHS->getOperand(1);
+}
+
+
+/// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values?
+template<typename MaxExprType>
+static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr,
+                              const SCEV *Candidate) {
+  const MaxExprType *MaxExpr = dyn_cast<MaxExprType>(MaybeMaxExpr);
+  if (!MaxExpr) return false;
+
+  auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate);
+  return It != MaxExpr->op_end();
+}
+
+
+/// Is MaybeMinExpr an SMin or UMin of Candidate and some other values?
+template<typename MaxExprType>
+static bool IsMinConsistingOf(ScalarEvolution &SE,
+                              const SCEV *MaybeMinExpr,
+                              const SCEV *Candidate) {
+  const SCEV *MaybeMaxExpr = MatchNotExpr(MaybeMinExpr);
+  if (!MaybeMaxExpr)
+    return false;
+
+  return IsMaxConsistingOf<MaxExprType>(MaybeMaxExpr, SE.getNotSCEV(Candidate));
+}
+
+
+/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
+/// expression?
+static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,
+                                        ICmpInst::Predicate Pred,
+                                        const SCEV *LHS, const SCEV *RHS) {
+  switch (Pred) {
+  default:
+    return false;
+
+  case ICmpInst::ICMP_SGE:
+    std::swap(LHS, RHS);
+    // fall through
+  case ICmpInst::ICMP_SLE:
+    return
+      // min(A, ...) <= A
+      IsMinConsistingOf<SCEVSMaxExpr>(SE, LHS, RHS) ||
+      // A <= max(A, ...)
+      IsMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
+
+  case ICmpInst::ICMP_UGE:
+    std::swap(LHS, RHS);
+    // fall through
+  case ICmpInst::ICMP_ULE:
+    return
+      // min(A, ...) <= A
+      IsMinConsistingOf<SCEVUMaxExpr>(SE, LHS, RHS) ||
+      // A <= max(A, ...)
+      IsMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
+  }
+
+  llvm_unreachable("covered switch fell through?!");
+}
+
 /// isImpliedCondOperandsHelper - Test whether the condition described by
 /// Pred, LHS, and RHS is true whenever the condition described by Pred,
 /// FoundLHS, and FoundRHS is true.
@@ -6894,6 +6970,12 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
                                              const SCEV *LHS, const SCEV *RHS,
                                              const SCEV *FoundLHS,
                                              const SCEV *FoundRHS) {
+  auto IsKnownPredicateFull =
+      [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
+    return isKnownPredicateWithRanges(Pred, LHS, RHS) ||
+        IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS);
+  };
+
   switch (Pred) {
   default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
   case ICmpInst::ICMP_EQ:
@@ -6903,26 +6985,26 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
     break;
   case ICmpInst::ICMP_SLT:
   case ICmpInst::ICMP_SLE:
-    if (isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
-        isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, RHS, FoundRHS))
+    if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
+        IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS))
       return true;
     break;
   case ICmpInst::ICMP_SGT:
   case ICmpInst::ICMP_SGE:
-    if (isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
-        isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, RHS, FoundRHS))
+    if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
+        IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS))
       return true;
     break;
   case ICmpInst::ICMP_ULT:
   case ICmpInst::ICMP_ULE:
-    if (isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
-        isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, RHS, FoundRHS))
+    if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
+        IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS))
       return true;
     break;
   case ICmpInst::ICMP_UGT:
   case ICmpInst::ICMP_UGE:
-    if (isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
-        isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, RHS, FoundRHS))
+    if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
+        IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS))
       return true;
     break;
   }
@@ -7784,11 +7866,11 @@ ScalarEvolution::ScalarEvolution()
 
 bool ScalarEvolution::runOnFunction(Function &F) {
   this->F = &F;
-  AT = &getAnalysis<AssumptionTracker>();
+  AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
   LI = &getAnalysis<LoopInfo>();
   DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
   DL = DLP ? &DLP->getDataLayout() : nullptr;
-  TLI = &getAnalysis<TargetLibraryInfo>();
+  TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
   DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
   return false;
 }
@@ -7825,10 +7907,10 @@ void ScalarEvolution::releaseMemory() {
 
 void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesAll();
-  AU.addRequired<AssumptionTracker>();
+  AU.addRequired<AssumptionCacheTracker>();
   AU.addRequiredTransitive<LoopInfo>();
   AU.addRequiredTransitive<DominatorTreeWrapperPass>();
-  AU.addRequired<TargetLibraryInfo>();
+  AU.addRequired<TargetLibraryInfoWrapperPass>();
 }
 
 bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {