Speculatively revert r97010, "Add an argument to PHITranslateValue to specify
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 82be9cd5c4e3f45e8105990c705abf8f4921f08a..c17f6f38c66f69a8b3ec24990c4cf96df99680a2 100644 (file)
@@ -214,8 +214,8 @@ bool SCEVCastExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeID &ID,
                                    const SCEV *op, const Type *ty)
   : SCEVCastExpr(ID, scTruncate, op, ty) {
-  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
-         (Ty->isInteger() || isa<PointerType>(Ty)) &&
+  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
+         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
          "Cannot truncate non-integer value!");
 }
 
@@ -226,8 +226,8 @@ void SCEVTruncateExpr::print(raw_ostream &OS) const {
 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeID &ID,
                                        const SCEV *op, const Type *ty)
   : SCEVCastExpr(ID, scZeroExtend, op, ty) {
-  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
-         (Ty->isInteger() || isa<PointerType>(Ty)) &&
+  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
+         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
          "Cannot zero extend non-integer value!");
 }
 
@@ -238,8 +238,8 @@ void SCEVZeroExtendExpr::print(raw_ostream &OS) const {
 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeID &ID,
                                        const SCEV *op, const Type *ty)
   : SCEVCastExpr(ID, scSignExtend, op, ty) {
-  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
-         (Ty->isInteger() || isa<PointerType>(Ty)) &&
+  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
+         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
          "Cannot sign extend non-integer value!");
 }
 
@@ -312,6 +312,21 @@ bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const {
   return true;
 }
 
+bool
+SCEVAddRecExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
+  return DT->dominates(L->getHeader(), BB) &&
+         SCEVNAryExpr::dominates(BB, DT);
+}
+
+bool
+SCEVAddRecExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
+  // This uses a "dominates" query instead of "properly dominates" query because
+  // the instruction which produces the addrec's value is a PHI, and a PHI
+  // effectively properly dominates its entire containing block.
+  return DT->dominates(L->getHeader(), BB) &&
+         SCEVNAryExpr::properlyDominates(BB, DT);
+}
+
 void SCEVAddRecExpr::print(raw_ostream &OS) const {
   OS << "{" << *Operands[0];
   for (unsigned i = 1, e = Operands.size(); i != e; ++i)
@@ -379,7 +394,7 @@ bool SCEVUnknown::isAlignOf(const Type *&AllocTy) const {
               if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
                 if (CI->isOne() &&
                     STy->getNumElements() == 2 &&
-                    STy->getElementType(0)->isInteger(1)) {
+                    STy->getElementType(0)->isIntegerTy(1)) {
                   AllocTy = STy->getElementType(1);
                   return true;
                 }
@@ -401,7 +416,7 @@ bool SCEVUnknown::isOffsetOf(const Type *&CTy, Constant *&FieldNo) const {
             cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
           // Ignore vector types here so that ScalarEvolutionExpander doesn't
           // emit getelementptrs that index into vectors.
-          if (isa<StructType>(Ty) || isa<ArrayType>(Ty)) {
+          if (Ty->isStructTy() || Ty->isArrayTy()) {
             CTy = Ty;
             FieldNo = CE->getOperand(2);
             return true;
@@ -503,9 +518,9 @@ namespace {
 
         // Order pointer values after integer values. This helps SCEVExpander
         // form GEPs.
-        if (isa<PointerType>(LU->getType()) && !isa<PointerType>(RU->getType()))
+        if (LU->getType()->isPointerTy() && !RU->getType()->isPointerTy())
           return false;
-        if (isa<PointerType>(RU->getType()) && !isa<PointerType>(LU->getType()))
+        if (RU->getType()->isPointerTy() && !LU->getType()->isPointerTy())
           return true;
 
         // Compare getValueID values.
@@ -1958,6 +1973,12 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
     return getAddRecExpr(Operands, L, HasNUW, HasNSW); // {X,+,0}  -->  X
   }
 
+  // It's tempting to want to call getMaxBackedgeTakenCount count here and
+  // use that information to infer NUW and NSW flags. However, computing a
+  // BE count requires calling getAddRecExpr, so we may not yet have a
+  // meaningful BE count at this point (and if we don't, we'd be stuck
+  // with a SCEVCouldNotCompute as the cached BE count).
+
   // If HasNSW is true and all the operands are non-negative, infer HasNUW.
   if (!HasNUW && HasNSW) {
     bool All = true;
@@ -2293,7 +2314,7 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) {
 /// has access to target-specific information.
 bool ScalarEvolution::isSCEVable(const Type *Ty) const {
   // Integers and pointers are always SCEVable.
-  return Ty->isInteger() || isa<PointerType>(Ty);
+  return Ty->isIntegerTy() || Ty->isPointerTy();
 }
 
 /// getTypeSizeInBits - Return the size in bits of the specified type,
@@ -2306,12 +2327,12 @@ uint64_t ScalarEvolution::getTypeSizeInBits(const Type *Ty) const {
     return TD->getTypeSizeInBits(Ty);
 
   // Integer types have fixed sizes.
-  if (Ty->isInteger())
+  if (Ty->isIntegerTy())
     return Ty->getPrimitiveSizeInBits();
 
   // The only other support type is pointer. Without TargetData, conservatively
   // assume pointers are 64-bit.
-  assert(isa<PointerType>(Ty) && "isSCEVable permitted a non-SCEVable type!");
+  assert(Ty->isPointerTy() && "isSCEVable permitted a non-SCEVable type!");
   return 64;
 }
 
@@ -2322,11 +2343,11 @@ uint64_t ScalarEvolution::getTypeSizeInBits(const Type *Ty) const {
 const Type *ScalarEvolution::getEffectiveSCEVType(const Type *Ty) const {
   assert(isSCEVable(Ty) && "Type is not SCEVable!");
 
-  if (Ty->isInteger())
+  if (Ty->isIntegerTy())
     return Ty;
 
   // The only other support type is pointer.
-  assert(isa<PointerType>(Ty) && "Unexpected non-pointer non-integer type!");
+  assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
   if (TD) return TD->getIntPtrType(getContext());
 
   // Without TargetData, conservatively assume pointers are 64-bit.
@@ -2397,8 +2418,8 @@ const SCEV *
 ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V,
                                          const Type *Ty) {
   const Type *SrcTy = V->getType();
-  assert((SrcTy->isInteger() || isa<PointerType>(SrcTy)) &&
-         (Ty->isInteger() || isa<PointerType>(Ty)) &&
+  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
+         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
          "Cannot truncate or zero extend with non-integer arguments!");
   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
     return V;  // No conversion
@@ -2414,8 +2435,8 @@ const SCEV *
 ScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
                                          const Type *Ty) {
   const Type *SrcTy = V->getType();
-  assert((SrcTy->isInteger() || isa<PointerType>(SrcTy)) &&
-         (Ty->isInteger() || isa<PointerType>(Ty)) &&
+  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
+         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
          "Cannot truncate or zero extend with non-integer arguments!");
   if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
     return V;  // No conversion
@@ -2430,8 +2451,8 @@ ScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
 const SCEV *
 ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, const Type *Ty) {
   const Type *SrcTy = V->getType();
-  assert((SrcTy->isInteger() || isa<PointerType>(SrcTy)) &&
-         (Ty->isInteger() || isa<PointerType>(Ty)) &&
+  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
+         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
          "Cannot noop or zero extend with non-integer arguments!");
   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
          "getNoopOrZeroExtend cannot truncate!");
@@ -2446,8 +2467,8 @@ ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, const Type *Ty) {
 const SCEV *
 ScalarEvolution::getNoopOrSignExtend(const SCEV *V, const Type *Ty) {
   const Type *SrcTy = V->getType();
-  assert((SrcTy->isInteger() || isa<PointerType>(SrcTy)) &&
-         (Ty->isInteger() || isa<PointerType>(Ty)) &&
+  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
+         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
          "Cannot noop or sign extend with non-integer arguments!");
   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
          "getNoopOrSignExtend cannot truncate!");
@@ -2463,8 +2484,8 @@ ScalarEvolution::getNoopOrSignExtend(const SCEV *V, const Type *Ty) {
 const SCEV *
 ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, const Type *Ty) {
   const Type *SrcTy = V->getType();
-  assert((SrcTy->isInteger() || isa<PointerType>(SrcTy)) &&
-         (Ty->isInteger() || isa<PointerType>(Ty)) &&
+  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
+         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
          "Cannot noop or any extend with non-integer arguments!");
   assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
          "getNoopOrAnyExtend cannot truncate!");
@@ -2478,8 +2499,8 @@ ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, const Type *Ty) {
 const SCEV *
 ScalarEvolution::getTruncateOrNoop(const SCEV *V, const Type *Ty) {
   const Type *SrcTy = V->getType();
-  assert((SrcTy->isInteger() || isa<PointerType>(SrcTy)) &&
-         (Ty->isInteger() || isa<PointerType>(Ty)) &&
+  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
+         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
          "Cannot truncate or noop with non-integer arguments!");
   assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
          "getTruncateOrNoop cannot extend!");
@@ -2543,7 +2564,7 @@ ScalarEvolution::ForgetSymbolicName(Instruction *I, const SCEV *SymName) {
   SmallPtrSet<Instruction *, 8> Visited;
   Visited.insert(I);
   while (!Worklist.empty()) {
-    Instruction *I = Worklist.pop_back_val();
+    I = Worklist.pop_back_val();
     if (!Visited.insert(I)) continue;
 
     std::map<SCEVCallbackVH, const SCEV *>::iterator It =
@@ -2551,7 +2572,7 @@ ScalarEvolution::ForgetSymbolicName(Instruction *I, const SCEV *SymName) {
     if (It != Scalars.end()) {
       // Short-circuit the def-use traversal if the symbolic name
       // ceases to appear in expressions.
-      if (!It->second->hasOperand(SymName))
+      if (It->second != SymName && !It->second->hasOperand(SymName))
         continue;
 
       // SCEVUnknown for a PHI either means that it has an unrecognized
@@ -2921,7 +2942,6 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) {
 
   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
     // For a SCEVUnknown, ask ValueTracking.
-    unsigned BitWidth = getTypeSizeInBits(U->getType());
     APInt Mask = APInt::getAllOnesValue(BitWidth);
     APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
     ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD);
@@ -3053,7 +3073,7 @@ ScalarEvolution::getSignedRange(const SCEV *S) {
 
   if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
     // For a SCEVUnknown, ask ValueTracking.
-    if (!U->getValue()->getType()->isInteger() && !TD)
+    if (!U->getValue()->getType()->isIntegerTy() && !TD)
       return ConservativeResult;
     unsigned NS = ComputeNumSignBits(U->getValue(), TD);
     if (NS == 1)
@@ -3470,6 +3490,35 @@ void ScalarEvolution::forgetLoop(const Loop *L) {
   }
 }
 
+/// forgetValue - This method should be called by the client when it has
+/// changed a value in a way that may effect its value, or which may
+/// disconnect it from a def-use chain linking it to a loop.
+void ScalarEvolution::forgetValue(Value *V) {
+  Instruction *I = dyn_cast<Instruction>(V);
+  if (!I) return;
+
+  // Drop information about expressions based on loop-header PHIs.
+  SmallVector<Instruction *, 16> Worklist;
+  Worklist.push_back(I);
+
+  SmallPtrSet<Instruction *, 8> Visited;
+  while (!Worklist.empty()) {
+    I = Worklist.pop_back_val();
+    if (!Visited.insert(I)) continue;
+
+    std::map<SCEVCallbackVH, const SCEV *>::iterator It =
+      Scalars.find(static_cast<Value *>(I));
+    if (It != Scalars.end()) {
+      ValuesAtScopes.erase(It->second);
+      Scalars.erase(It);
+      if (PHINode *PN = dyn_cast<PHINode>(I))
+        ConstantEvolutionLoopExitValue.erase(PN);
+    }
+
+    PushDefUseChildren(I, Worklist);
+  }
+}
+
 /// ComputeBackedgeTakenCount - Compute the number of times the backedge
 /// of the specified loop will execute.
 ScalarEvolution::BackedgeTakenInfo
@@ -3659,6 +3708,19 @@ ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L,
   if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond))
     return ComputeBackedgeTakenCountFromExitCondICmp(L, ExitCondICmp, TBB, FBB);
 
+  // Check for a constant condition. These are normally stripped out by
+  // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
+  // preserve the CFG and is temporarily leaving constant conditions
+  // in place.
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
+    if (L->contains(FBB) == !CI->getZExtValue())
+      // The backedge is always taken.
+      return getCouldNotCompute();
+    else
+      // The backedge is never taken.
+      return getIntegerSCEV(0, CI->getType());
+  }
+
   // If it's not an integer or pointer comparison then compute it the hard way.
   return ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
 }
@@ -4435,7 +4497,7 @@ const SCEV *ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) {
                                             -StartC->getValue()->getValue(),
                                             *this);
     }
-  } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) {
+  } else if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
     // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
     // the quadratic equation to solve it.
     std::pair<const SCEV *,const SCEV *> Roots = SolveQuadraticEquation(AddRec,
@@ -5304,8 +5366,8 @@ ScalarEvolution::ScalarEvolution()
 bool ScalarEvolution::runOnFunction(Function &F) {
   this->F = &F;
   LI = &getAnalysis<LoopInfo>();
-  DT = &getAnalysis<DominatorTree>();
   TD = getAnalysisIfAvailable<TargetData>();
+  DT = &getAnalysis<DominatorTree>();
   return false;
 }