Add the ability to track HasNSW and HasNUW on more kinds of SCEV expressions.
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index d639aee70993d0b7b6fb42e3cbb3c1e8f4f79118..9300de11685a4b9f107f3b04ed8e7e2878084d4e 100644 (file)
@@ -207,6 +207,10 @@ bool SCEVCastExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
   return Op->dominates(BB, DT);
 }
 
+bool SCEVCastExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
+  return Op->properlyDominates(BB, DT);
+}
+
 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeID &ID,
                                    const SCEV *op, const Type *ty)
   : SCEVCastExpr(ID, scTruncate, op, ty) {
@@ -260,10 +264,22 @@ bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
   return true;
 }
 
+bool SCEVNAryExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
+  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
+    if (!getOperand(i)->properlyDominates(BB, DT))
+      return false;
+  }
+  return true;
+}
+
 bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
   return LHS->dominates(BB, DT) && RHS->dominates(BB, DT);
 }
 
+bool SCEVUDivExpr::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
+  return LHS->properlyDominates(BB, DT) && RHS->properlyDominates(BB, DT);
+}
+
 void SCEVUDivExpr::print(raw_ostream &OS) const {
   OS << "(" << *LHS << " /u " << *RHS << ")";
 }
@@ -328,6 +344,12 @@ bool SCEVUnknown::dominates(BasicBlock *BB, DominatorTree *DT) const {
   return true;
 }
 
+bool SCEVUnknown::properlyDominates(BasicBlock *BB, DominatorTree *DT) const {
+  if (Instruction *I = dyn_cast<Instruction>(getValue()))
+    return DT->properlyDominates(I->getParent(), BB);
+  return true;
+}
+
 const Type *SCEVUnknown::getType() const {
   return V->getType();
 }
@@ -385,6 +407,10 @@ namespace {
     explicit SCEVComplexityCompare(LoopInfo *li) : LI(li) {}
 
     bool operator()(const SCEV *LHS, const SCEV *RHS) const {
+      // Fast-path: SCEVs are uniqued so we can do a quick equality check.
+      if (LHS == RHS)
+        return false;
+
       // Primarily, sort the SCEVs by their getSCEVType().
       if (LHS->getSCEVType() != RHS->getSCEVType())
         return LHS->getSCEVType() < RHS->getSCEVType();
@@ -1165,7 +1191,8 @@ namespace {
 
 /// getAddExpr - Get a canonical add expression, or something simpler if
 /// possible.
-const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops) {
+const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
+                                        bool HasNUW, bool HasNSW) {
   assert(!Ops.empty() && "Cannot get empty add!");
   if (Ops.size() == 1) return Ops[0];
 #ifndef NDEBUG
@@ -1215,7 +1242,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops) {
         return Mul;
       Ops.erase(Ops.begin()+i, Ops.begin()+i+2);
       Ops.push_back(Mul);
-      return getAddExpr(Ops);
+      return getAddExpr(Ops, HasNUW, HasNSW);
     }
 
   // Check for truncates. If all the operands are truncated from the same
@@ -1270,7 +1297,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops) {
     }
     if (Ok) {
       // Evaluate the expression in the larger type.
-      const SCEV *Fold = getAddExpr(LargeOps);
+      const SCEV *Fold = getAddExpr(LargeOps, HasNUW, HasNSW);
       // If it folds to something simple, use it. Otherwise, don't.
       if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
         return getTruncateExpr(Fold, DstType);
@@ -1490,16 +1517,19 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops) {
     ID.AddPointer(Ops[i]);
   void *IP = 0;
   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
-  SCEV *S = SCEVAllocator.Allocate<SCEVAddExpr>();
+  SCEVAddExpr *S = SCEVAllocator.Allocate<SCEVAddExpr>();
   new (S) SCEVAddExpr(ID, Ops);
   UniqueSCEVs.InsertNode(S, IP);
+  if (HasNUW) S->setHasNoUnsignedWrap(true);
+  if (HasNSW) S->setHasNoSignedWrap(true);
   return S;
 }
 
 
 /// getMulExpr - Get a canonical multiply expression, or something simpler if
 /// possible.
-const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops) {
+const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
+                                        bool HasNUW, bool HasNSW) {
   assert(!Ops.empty() && "Cannot get empty mul!");
 #ifndef NDEBUG
   for (unsigned i = 1, e = Ops.size(); i != e; ++i)
@@ -1662,9 +1692,11 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops) {
     ID.AddPointer(Ops[i]);
   void *IP = 0;
   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
-  SCEV *S = SCEVAllocator.Allocate<SCEVMulExpr>();
+  SCEVMulExpr *S = SCEVAllocator.Allocate<SCEVMulExpr>();
   new (S) SCEVMulExpr(ID, Ops);
   UniqueSCEVs.InsertNode(S, IP);
+  if (HasNUW) S->setHasNoUnsignedWrap(true);
+  if (HasNSW) S->setHasNoSignedWrap(true);
   return S;
 }
 
@@ -1771,7 +1803,8 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
 /// getAddRecExpr - Get an add recurrence expression for the specified loop.
 /// Simplify the expression as much as possible.
 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start,
-                                           const SCEV *Step, const Loop *L) {
+                                           const SCEV *Step, const Loop *L,
+                                           bool HasNUW, bool HasNSW) {
   SmallVector<const SCEV *, 4> Operands;
   Operands.push_back(Start);
   if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
@@ -1782,14 +1815,15 @@ const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start,
     }
 
   Operands.push_back(Step);
-  return getAddRecExpr(Operands, L);
+  return getAddRecExpr(Operands, L, HasNUW, HasNSW);
 }
 
 /// getAddRecExpr - Get an add recurrence expression for the specified loop.
 /// Simplify the expression as much as possible.
 const SCEV *
 ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
-                               const Loop *L) {
+                               const Loop *L,
+                               bool HasNUW, bool HasNSW) {
   if (Operands.size() == 1) return Operands[0];
 #ifndef NDEBUG
   for (unsigned i = 1, e = Operands.size(); i != e; ++i)
@@ -1800,7 +1834,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
 
   if (Operands.back()->isZero()) {
     Operands.pop_back();
-    return getAddRecExpr(Operands, L);             // {X,+,0}  -->  X
+    return getAddRecExpr(Operands, L, HasNUW, HasNSW); // {X,+,0}  -->  X
   }
 
   // Canonicalize nested AddRecs in by nesting them in order of loop depth.
@@ -1829,7 +1863,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
           }
         if (AllInvariant)
           // Ok, both add recurrences are valid after the transformation.
-          return getAddRecExpr(NestedOperands, NestedLoop);
+          return getAddRecExpr(NestedOperands, NestedLoop, HasNUW, HasNSW);
       }
       // Reset Operands to its original state.
       Operands[0] = NestedAR;
@@ -1844,9 +1878,11 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
   ID.AddPointer(L);
   void *IP = 0;
   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
-  SCEV *S = SCEVAllocator.Allocate<SCEVAddRecExpr>();
+  SCEVAddRecExpr *S = SCEVAllocator.Allocate<SCEVAddRecExpr>();
   new (S) SCEVAddRecExpr(ID, Operands, L);
   UniqueSCEVs.InsertNode(S, IP);
+  if (HasNUW) S->setHasNoUnsignedWrap(true);
+  if (HasNSW) S->setHasNoSignedWrap(true);
   return S;
 }
 
@@ -2420,9 +2456,10 @@ ScalarEvolution::ForgetSymbolicName(Instruction *I, const SCEV *SymName) {
       // count information isn't going to change anything. In the later
       // case, createNodeForPHI will perform the necessary updates on its
       // own when it gets to that point.
-      if (!isa<PHINode>(I) || !isa<SCEVUnknown>(It->second))
+      if (!isa<PHINode>(I) || !isa<SCEVUnknown>(It->second)) {
+        ValuesAtScopes.erase(It->second);
         Scalars.erase(It);
-      ValuesAtScopes.erase(I);
+      }
     }
 
     PushDefUseChildren(I, Worklist);
@@ -2967,8 +3004,20 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       const SCEV *LHS = getSCEV(U->getOperand(0));
       const APInt &CIVal = CI->getValue();
       if (GetMinTrailingZeros(LHS) >=
-          (CIVal.getBitWidth() - CIVal.countLeadingZeros()))
-        return getAddExpr(LHS, getSCEV(U->getOperand(1)));
+          (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
+        // Build a plain add SCEV.
+        const SCEV *S = getAddExpr(LHS, getSCEV(CI));
+        // If the LHS of the add was an addrec and it has no-wrap flags,
+        // transfer the no-wrap flags, since an or won't introduce a wrap.
+        if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) {
+          const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS);
+          if (OldAR->hasNoUnsignedWrap())
+            const_cast<SCEVAddRecExpr *>(NewAR)->setHasNoUnsignedWrap(true);
+          if (OldAR->hasNoSignedWrap())
+            const_cast<SCEVAddRecExpr *>(NewAR)->setHasNoSignedWrap(true);
+        }
+        return S;
+      }
     }
     break;
   case Instruction::Xor:
@@ -3232,9 +3281,10 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
           // count information isn't going to change anything. In the later
           // case, createNodeForPHI will perform the necessary updates on its
           // own when it gets to that point.
-          if (!isa<PHINode>(I) || !isa<SCEVUnknown>(It->second))
+          if (!isa<PHINode>(I) || !isa<SCEVUnknown>(It->second)) {
+            ValuesAtScopes.erase(It->second);
             Scalars.erase(It);
-          ValuesAtScopes.erase(I);
+          }
           if (PHINode *PN = dyn_cast<PHINode>(I))
             ConstantEvolutionLoopExitValue.erase(PN);
         }
@@ -3264,8 +3314,8 @@ void ScalarEvolution::forgetLoopBackedgeTakenCount(const Loop *L) {
     std::map<SCEVCallbackVH, const SCEV*>::iterator It =
       Scalars.find(static_cast<Value *>(I));
     if (It != Scalars.end()) {
+      ValuesAtScopes.erase(It->second);
       Scalars.erase(It);
-      ValuesAtScopes.erase(I);
       if (PHINode *PN = dyn_cast<PHINode>(I))
         ConstantEvolutionLoopExitValue.erase(PN);
     }
@@ -3886,7 +3936,7 @@ ScalarEvolution::ComputeBackedgeTakenCountExhaustively(const Loop *L,
   return getCouldNotCompute();
 }
 
-/// getSCEVAtScope - Return a SCEV expression handle for the specified value
+/// getSCEVAtScope - Return a SCEV expression for the specified value
 /// at the specified scope in the program.  The L value specifies a loop
 /// nest to evaluate the expression at, where null is the top-level or a
 /// specified loop is immediately inside of the loop.
@@ -3897,8 +3947,20 @@ ScalarEvolution::ComputeBackedgeTakenCountExhaustively(const Loop *L,
 /// In the case that a relevant loop exit value cannot be computed, the
 /// original value V is returned.
 const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
-  // FIXME: this should be turned into a virtual method on SCEV!
+  // Check to see if we've folded this expression at this loop before.
+  std::map<const Loop *, const SCEV *> &Values = ValuesAtScopes[V];
+  std::pair<std::map<const Loop *, const SCEV *>::iterator, bool> Pair =
+    Values.insert(std::make_pair(L, static_cast<const SCEV *>(0)));
+  if (!Pair.second)
+    return Pair.first->second ? Pair.first->second : V;
 
+  // Otherwise compute it.
+  const SCEV *C = computeSCEVAtScope(V, L);
+  ValuesAtScopes[V][L] = C;
+  return C;
+}
+
+const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
   if (isa<SCEVConstant>(V)) return V;
 
   // If this instruction is evolved from a constant-evolving PHI, compute the
@@ -3931,13 +3993,6 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
       // the arguments into constants, and if so, try to constant propagate the
       // result.  This is particularly useful for computing loop exit values.
       if (CanConstantFold(I)) {
-        // Check to see if we've folded this instruction at this loop before.
-        std::map<const Loop *, Constant *> &Values = ValuesAtScopes[I];
-        std::pair<std::map<const Loop *, Constant *>::iterator, bool> Pair =
-          Values.insert(std::make_pair(L, static_cast<Constant *>(0)));
-        if (!Pair.second)
-          return Pair.first->second ? &*getSCEV(Pair.first->second) : V;
-
         std::vector<Constant*> Operands;
         Operands.reserve(I->getNumOperands());
         for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
@@ -3986,7 +4041,6 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
           C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
                                        &Operands[0], Operands.size(),
                                        getContext());
-        Pair.first->second = C;
         return getSCEV(C);
       }
     }
@@ -4785,7 +4839,8 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
 /// CouldNotCompute if an intermediate computation overflows.
 const SCEV *ScalarEvolution::getBECount(const SCEV *Start,
                                         const SCEV *End,
-                                        const SCEV *Step) {
+                                        const SCEV *Step,
+                                        bool NoWrap) {
   const Type *Ty = Start->getType();
   const SCEV *NegOne = getIntegerSCEV(-1, Ty);
   const SCEV *Diff = getMinusSCEV(End, Start);
@@ -4795,15 +4850,17 @@ const SCEV *ScalarEvolution::getBECount(const SCEV *Start,
   // the division will effectively round up.
   const SCEV *Add = getAddExpr(Diff, RoundUp);
 
-  // Check Add for unsigned overflow.
-  // TODO: More sophisticated things could be done here.
-  const Type *WideTy = IntegerType::get(getContext(),
-                                        getTypeSizeInBits(Ty) + 1);
-  const SCEV *EDiff = getZeroExtendExpr(Diff, WideTy);
-  const SCEV *ERoundUp = getZeroExtendExpr(RoundUp, WideTy);
-  const SCEV *OperandExtendedAdd = getAddExpr(EDiff, ERoundUp);
-  if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd)
-    return getCouldNotCompute();
+  if (!NoWrap) {
+    // Check Add for unsigned overflow.
+    // TODO: More sophisticated things could be done here.
+    const Type *WideTy = IntegerType::get(getContext(),
+                                          getTypeSizeInBits(Ty) + 1);
+    const SCEV *EDiff = getZeroExtendExpr(Diff, WideTy);
+    const SCEV *ERoundUp = getZeroExtendExpr(RoundUp, WideTy);
+    const SCEV *OperandExtendedAdd = getAddExpr(EDiff, ERoundUp);
+    if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd)
+      return getCouldNotCompute();
+  }
 
   return getUDivExpr(Add, Step);
 }
@@ -4821,6 +4878,10 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
   if (!AddRec || AddRec->getLoop() != L)
     return getCouldNotCompute();
 
+  // Check to see if we have a flag which makes analysis easy.
+  bool NoWrap = isSigned ? AddRec->hasNoSignedWrap() :
+                           AddRec->hasNoUnsignedWrap();
+
   if (AddRec->isAffine()) {
     // FORNOW: We only support unit strides.
     unsigned BitWidth = getTypeSizeInBits(AddRec->getType());
@@ -4833,7 +4894,10 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
     if (CStep->isOne()) {
       // With unit stride, the iteration never steps past the limit value.
     } else if (CStep->getValue()->getValue().isStrictlyPositive()) {
-      if (const SCEVConstant *CLimit = dyn_cast<SCEVConstant>(RHS)) {
+      if (NoWrap) {
+        // We know the iteration won't step past the maximum value for its type.
+        ;
+      } else if (const SCEVConstant *CLimit = dyn_cast<SCEVConstant>(RHS)) {
         // Test whether a positive iteration iteration can step past the limit
         // value and past the maximum value for its type in a single step.
         if (isSigned) {
@@ -4886,11 +4950,11 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
 
     // Finally, we subtract these two values and divide, rounding up, to get
     // the number of times the backedge is executed.
-    const SCEV *BECount = getBECount(Start, End, Step);
+    const SCEV *BECount = getBECount(Start, End, Step, NoWrap);
 
     // The maximum backedge count is similar, except using the minimum start
     // value and the maximum end value.
-    const SCEV *MaxBECount = getBECount(MinStart, MaxEnd, Step);
+    const SCEV *MaxBECount = getBECount(MinStart, MaxEnd, Step, NoWrap);
 
     return BackedgeTakenInfo(BECount, MaxBECount);
   }
@@ -5031,8 +5095,6 @@ void ScalarEvolution::SCEVCallbackVH::deleted() {
   assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
   if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
     SE->ConstantEvolutionLoopExitValue.erase(PN);
-  if (Instruction *I = dyn_cast<Instruction>(getValPtr()))
-    SE->ValuesAtScopes.erase(I);
   SE->Scalars.erase(getValPtr());
   // this now dangles!
 }
@@ -5062,8 +5124,6 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *) {
       continue;
     if (PHINode *PN = dyn_cast<PHINode>(U))
       SE->ConstantEvolutionLoopExitValue.erase(PN);
-    if (Instruction *I = dyn_cast<Instruction>(U))
-      SE->ValuesAtScopes.erase(I);
     SE->Scalars.erase(U);
     for (Value::use_iterator UI = U->use_begin(), UE = U->use_end();
          UI != UE; ++UI)
@@ -5073,8 +5133,6 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *) {
   if (DeleteOld) {
     if (PHINode *PN = dyn_cast<PHINode>(Old))
       SE->ConstantEvolutionLoopExitValue.erase(PN);
-    if (Instruction *I = dyn_cast<Instruction>(Old))
-      SE->ValuesAtScopes.erase(I);
     SE->Scalars.erase(Old);
     // this now dangles!
   }