Use SmallVector instead of std::vector.
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 473eadcd87864bac583aefbffb628813789757f4..5bae18cc4c33ab8c2932ba90896357349c209f11 100644 (file)
@@ -105,6 +105,7 @@ namespace {
   RegisterPass<ScalarEvolution>
   R("scalar-evolution", "Scalar Evolution Analysis");
 }
+char ScalarEvolution::ID = 0;
 
 //===----------------------------------------------------------------------===//
 //                           SCEV class definitions
@@ -182,6 +183,10 @@ SCEVHandle SCEVConstant::get(ConstantInt *V) {
   return R;
 }
 
+SCEVHandle SCEVConstant::get(const APInt& Val) {
+  return get(ConstantInt::get(Val));
+}
+
 ConstantRange SCEVConstant::getValueRange() const {
   return ConstantRange(V->getValue());
 }
@@ -244,6 +249,32 @@ void SCEVZeroExtendExpr::print(std::ostream &OS) const {
   OS << "(zeroextend " << *Op << " to " << *Ty << ")";
 }
 
+// SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any
+// particular input.  Don't use a SCEVHandle here, or else the object will never
+// be deleted!
+static ManagedStatic<std::map<std::pair<SCEV*, const Type*>,
+                     SCEVSignExtendExpr*> > SCEVSignExtends;
+
+SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty)
+  : SCEV(scSignExtend), Op(op), Ty(ty) {
+  assert(Op->getType()->isInteger() && Ty->isInteger() &&
+         "Cannot sign extend non-integer value!");
+  assert(Op->getType()->getPrimitiveSizeInBits() < Ty->getPrimitiveSizeInBits()
+         && "This is not an extending conversion!");
+}
+
+SCEVSignExtendExpr::~SCEVSignExtendExpr() {
+  SCEVSignExtends->erase(std::make_pair(Op, Ty));
+}
+
+ConstantRange SCEVSignExtendExpr::getValueRange() const {
+  return getOperand()->getValueRange().signExtend(getBitWidth());
+}
+
+void SCEVSignExtendExpr::print(std::ostream &OS) const {
+  OS << "(signextend " << *Op << " to " << *Ty << ")";
+}
+
 // SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any
 // particular input.  Don't use a SCEVHandle here, or else the object will never
 // be deleted!
@@ -460,10 +491,6 @@ SCEVHandle SCEVUnknown::getIntegerSCEV(int Val, const Type *Ty) {
   return SCEVUnknown::get(C);
 }
 
-SCEVHandle SCEVUnknown::getIntegerSCEV(const APInt& Val) {
-  return SCEVUnknown::get(ConstantInt::get(Val));
-}
-
 /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
 /// input value to the specified type.  If the type must be extended, it is zero
 /// extended.
@@ -504,7 +531,7 @@ static SCEVHandle PartialFact(SCEVHandle V, unsigned NumSteps) {
     APInt Result(Val.getBitWidth(), 1);
     for (; NumSteps; --NumSteps)
       Result *= Val-(NumSteps-1);
-    return SCEVUnknown::get(ConstantInt::get(Result));
+    return SCEVConstant::get(Result);
   }
 
   const Type *Ty = V->getType();
@@ -587,6 +614,21 @@ SCEVHandle SCEVZeroExtendExpr::get(const SCEVHandle &Op, const Type *Ty) {
   return Result;
 }
 
+SCEVHandle SCEVSignExtendExpr::get(const SCEVHandle &Op, const Type *Ty) {
+  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
+    return SCEVUnknown::get(
+        ConstantExpr::getSExt(SC->getValue(), Ty));
+
+  // FIXME: If the input value is a chrec scev, and we can prove that the value
+  // did not overflow the old, smaller, value, we can sign extend all of the
+  // operands (often constants).  This would allow analysis of something like
+  // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
+
+  SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)];
+  if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty);
+  return Result;
+}
+
 // get - Get a canonical add expression, or something simpler if possible.
 SCEVHandle SCEVAddExpr::get(std::vector<SCEVHandle> &Ops) {
   assert(!Ops.empty() && "Cannot get empty add!");
@@ -643,8 +685,11 @@ SCEVHandle SCEVAddExpr::get(std::vector<SCEVHandle> &Ops) {
       return SCEVAddExpr::get(Ops);
     }
 
-  // Okay, now we know the first non-constant operand.  If there are add
-  // operands they would be next.
+  // Now we know the first non-constant operand.  Skip past any cast SCEVs.
+  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
+    ++Idx;
+
+  // If there are add operands they would be next.
   if (Idx < Ops.size()) {
     bool DeletedAdd = false;
     while (SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
@@ -1122,10 +1167,10 @@ namespace {
     /// loop without a loop-invariant iteration count.
     SCEVHandle getIterationCount(const Loop *L);
 
-    /// deleteInstructionFromRecords - This method should be called by the
-    /// client before it removes an instruction from the program, to make sure
+    /// deleteValueFromRecords - This method should be called by the
+    /// client before it removes a value from the program, to make sure
     /// that no dangling references are left around.
-    void deleteInstructionFromRecords(Instruction *I);
+    void deleteValueFromRecords(Value *V);
 
   private:
     /// createSCEV - We know that there is no SCEV for the specified value.
@@ -1175,8 +1220,9 @@ namespace {
 
     /// HowManyLessThans - Return the number of times a backedge containing the
     /// specified less-than comparison will execute.  If not computable, return
-    /// UnknownValue.
-    SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L);
+    /// UnknownValue. isSigned specifies whether the less-than is signed.
+    SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L,
+                                bool isSigned);
 
     /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
     /// in the header of its containing loop, we know the loop executes a
@@ -1191,13 +1237,32 @@ namespace {
 //            Basic SCEV Analysis and PHI Idiom Recognition Code
 //
 
-/// deleteInstructionFromRecords - This method should be called by the
+/// deleteValueFromRecords - This method should be called by the
 /// client before it removes an instruction from the program, to make sure
 /// that no dangling references are left around.
-void ScalarEvolutionsImpl::deleteInstructionFromRecords(Instruction *I) {
-  Scalars.erase(I);
-  if (PHINode *PN = dyn_cast<PHINode>(I))
-    ConstantEvolutionLoopExitValue.erase(PN);
+void ScalarEvolutionsImpl::deleteValueFromRecords(Value *V) {
+  SmallVector<Value *, 16> Worklist;
+
+  if (Scalars.erase(V)) {
+    if (PHINode *PN = dyn_cast<PHINode>(V))
+      ConstantEvolutionLoopExitValue.erase(PN);
+    Worklist.push_back(V);
+  }
+
+  while (!Worklist.empty()) {
+    Value *VV = Worklist.back();
+    Worklist.pop_back();
+
+    for (Instruction::use_iterator UI = VV->use_begin(), UE = VV->use_end();
+         UI != UE; ++UI) {
+      Instruction *Inst = cast<Instruction>(*UI);
+      if (Scalars.erase(Inst)) {
+        if (PHINode *PN = dyn_cast<PHINode>(VV))
+          ConstantEvolutionLoopExitValue.erase(PN);
+        Worklist.push_back(Inst);
+      }
+    }
+  }
 }
 
 
@@ -1350,6 +1415,9 @@ static APInt GetConstantFactor(SCEVHandle S) {
   if (SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S))
     return GetConstantFactor(E->getOperand()).zext(
                                cast<IntegerType>(E->getType())->getBitWidth());
+  if (SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S))
+    return GetConstantFactor(E->getOperand()).sext(
+                               cast<IntegerType>(E->getType())->getBitWidth());
   
   if (SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
     // The result is the min of all operands.
@@ -1450,6 +1518,9 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
     case Instruction::ZExt:
       return SCEVZeroExtendExpr::get(getSCEV(I->getOperand(0)), I->getType());
 
+    case Instruction::SExt:
+      return SCEVSignExtendExpr::get(getSCEV(I->getOperand(0)), I->getType());
+
     case Instruction::BitCast:
       // BitCasts are no-op casts so we just eliminate the cast.
       if (I->getType()->isInteger() &&
@@ -1498,7 +1569,7 @@ SCEVHandle ScalarEvolutionsImpl::getIterationCount(const Loop *L) {
 /// will iterate.
 SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) {
   // If the loop has a non-one exit block count, we can't analyze it.
-  std::vector<BasicBlock*> ExitBlocks;
+  SmallVector<BasicBlock*, 8> ExitBlocks;
   L->getExitBlocks(ExitBlocks);
   if (ExitBlocks.size() != 1) return UnknownValue;
 
@@ -1601,8 +1672,7 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) {
           ConstantRange CompRange(
               ICmpInst::makeConstantRange(Cond, CompVal->getValue()));
 
-          SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange, 
-              false /*Always treat as unsigned range*/);
+          SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange);
           if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
         }
       }
@@ -1621,12 +1691,24 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) {
     break;
   }
   case ICmpInst::ICMP_SLT: {
-    SCEVHandle TC = HowManyLessThans(LHS, RHS, L);
+    SCEVHandle TC = HowManyLessThans(LHS, RHS, L, true);
     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
     break;
   }
   case ICmpInst::ICMP_SGT: {
-    SCEVHandle TC = HowManyLessThans(RHS, LHS, L);
+    SCEVHandle TC = HowManyLessThans(SCEV::getNegativeSCEV(LHS),
+                                     SCEV::getNegativeSCEV(RHS), L, true);
+    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
+    break;
+  }
+  case ICmpInst::ICMP_ULT: {
+    SCEVHandle TC = HowManyLessThans(LHS, RHS, L, false);
+    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
+    break;
+  }
+  case ICmpInst::ICMP_UGT: {
+    SCEVHandle TC = HowManyLessThans(SCEV::getNegativeSCEV(LHS),
+                                     SCEV::getNegativeSCEV(RHS), L, false);
     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
     break;
   }
@@ -1646,8 +1728,8 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) {
 }
 
 static ConstantInt *
-EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, Constant *C) {
-  SCEVHandle InVal = SCEVConstant::get(cast<ConstantInt>(C));
+EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C) {
+  SCEVHandle InVal = SCEVConstant::get(C);
   SCEVHandle Val = AddRec->evaluateAtIteration(InVal);
   assert(isa<SCEVConstant>(Val) &&
          "Evaluation of SCEV at constant didn't fold correctly?");
@@ -2129,8 +2211,8 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec) {
     ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA));
     ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA));
 
-    return std::make_pair(SCEVUnknown::get(Solution1), 
-                          SCEVUnknown::get(Solution2));
+    return std::make_pair(SCEVConstant::get(Solution1), 
+                          SCEVConstant::get(Solution2));
     } // end APIntOps namespace
 }
 
@@ -2240,7 +2322,7 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) {
 /// specified less-than comparison will execute.  If not computable, return
 /// UnknownValue.
 SCEVHandle ScalarEvolutionsImpl::
-HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) {
+HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) {
   // Only handle:  "ADDREC < LoopInvariant".
   if (!RHS->isLoopInvariant(L)) return UnknownValue;
 
@@ -2297,28 +2379,34 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) {
     
       switch (Cond) {
       case ICmpInst::ICMP_UGT:
+        if (isSigned) return UnknownValue;
         std::swap(PreCondLHS, PreCondRHS);
         Cond = ICmpInst::ICMP_ULT;
         break;
       case ICmpInst::ICMP_SGT:
+        if (!isSigned) return UnknownValue;
         std::swap(PreCondLHS, PreCondRHS);
         Cond = ICmpInst::ICMP_SLT;
         break;
-      default: break;
+      case ICmpInst::ICMP_ULT:
+        if (isSigned) return UnknownValue;
+        break;
+      case ICmpInst::ICMP_SLT:
+        if (!isSigned) return UnknownValue;
+        break;
+      default:
+        return UnknownValue;
       }
 
-      if (Cond == ICmpInst::ICMP_SLT) {
-        if (PreCondLHS->getType()->isInteger()) {
-          if (RHS != getSCEV(PreCondRHS))
-            return UnknownValue;  // Not a comparison against 'm'.
+      if (PreCondLHS->getType()->isInteger()) {
+        if (RHS != getSCEV(PreCondRHS))
+          return UnknownValue;  // Not a comparison against 'm'.
 
-          if (SCEV::getMinusSCEV(AddRec->getOperand(0), One)
-                      != getSCEV(PreCondLHS))
-            return UnknownValue;  // Not a comparison against 'n-1'.
-        }
-        else return UnknownValue;
-      } else if (Cond == ICmpInst::ICMP_ULT)
-        return UnknownValue;
+        if (SCEV::getMinusSCEV(AddRec->getOperand(0), One)
+                    != getSCEV(PreCondLHS))
+          return UnknownValue;  // Not a comparison against 'n-1'.
+      }
+      else return UnknownValue;
 
       // cerr << "Computed Loop Trip Count as: " 
       //      << //  *SCEV::getMinusSCEV(RHS, AddRec->getOperand(0)) << "\n";
@@ -2336,8 +2424,7 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) {
 /// this is that it returns the first iteration number where the value is not in
 /// the condition, thus computing the exit count. If the iteration count can't
 /// be computed, an instance of SCEVCouldNotCompute is returned.
-SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, 
-                                                   bool isSigned) const {
+SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range) const {
   if (Range.isFullSet())  // Infinite loop.
     return new SCEVCouldNotCompute();
 
@@ -2349,7 +2436,7 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
       SCEVHandle Shifted = SCEVAddRecExpr::get(Operands, getLoop());
       if (SCEVAddRecExpr *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
         return ShiftedAddRec->getNumIterationsInRange(
-                           Range.subtract(SC->getValue()->getValue()),isSigned);
+                           Range.subtract(SC->getValue()->getValue()));
       // This is strange and shouldn't happen.
       return new SCEVCouldNotCompute();
     }
@@ -2373,17 +2460,16 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
     // If this is an affine expression then we have this situation:
     //   Solve {0,+,A} in Range  ===  Ax in Range
 
-    // Since we know that zero is in the range, we know that the upper value of
-    // the range must be the first possible exit value.  Also note that we
-    // already checked for a full range.
-    const APInt &Upper = Range.getUpper();
-    APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
+    // We know that zero is in the range.  If A is positive then we know that
+    // the upper value of the range must be the first possible exit value.
+    // If A is negative then the lower of the range is the last possible loop
+    // value.  Also note that we already checked for a full range.
     APInt One(getBitWidth(),1);
+    APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
+    APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
 
-    // The exit value should be (Upper+A-1)/A.
-    APInt ExitVal(Upper);
-    if (A != One)
-      ExitVal = (Upper + A - One).sdiv(A);
+    // The exit value should be (End+A)/A.
+    APInt ExitVal = (End + A).sdiv(A);
     ConstantInt *ExitValue = ConstantInt::get(ExitVal);
 
     // Evaluate at the exit value.  If we really did fall out of the valid
@@ -2398,15 +2484,14 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
            EvaluateConstantChrecAtConstant(this, 
            ConstantInt::get(ExitVal - One))->getValue()) &&
            "Linear scev computation is off in a bad way!");
-    return SCEVConstant::get(cast<ConstantInt>(ExitValue));
+    return SCEVConstant::get(ExitValue);
   } else if (isQuadratic()) {
     // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
     // quadratic equation to solve it.  To do this, we must frame our problem in
     // terms of figuring out when zero is crossed, instead of when
     // Range.getUpper() is crossed.
     std::vector<SCEVHandle> NewOps(op_begin(), op_end());
-    NewOps[0] = SCEV::getNegativeSCEV(SCEVUnknown::get(
-                                           ConstantInt::get(Range.getUpper())));
+    NewOps[0] = SCEV::getNegativeSCEV(SCEVConstant::get(Range.getUpper()));
     SCEVHandle NewAddRec = SCEVAddRecExpr::get(NewOps, getLoop());
 
     // Next, solve the constructed addrec
@@ -2429,17 +2514,17 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
                                                              R1->getValue());
         if (Range.contains(R1Val->getValue())) {
           // The next iteration must be out of the range...
-          Constant *NextVal = ConstantInt::get(R1->getValue()->getValue()+1);
+          ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()+1);
 
           R1Val = EvaluateConstantChrecAtConstant(this, NextVal);
           if (!Range.contains(R1Val->getValue()))
-            return SCEVUnknown::get(NextVal);
+            return SCEVConstant::get(NextVal);
           return new SCEVCouldNotCompute();  // Something strange happened
         }
 
         // If R1 was not in the range, then it is a good return value.  Make
         // sure that R1-1 WAS in the range though, just in case.
-        Constant *NextVal = ConstantInt::get(R1->getValue()->getValue()-1);
+        ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()-1);
         R1Val = EvaluateConstantChrecAtConstant(this, NextVal);
         if (Range.contains(R1Val->getValue()))
           return R1;
@@ -2523,8 +2608,8 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) const {
   return ((ScalarEvolutionsImpl*)Impl)->getSCEVAtScope(getSCEV(V), L);
 }
 
-void ScalarEvolution::deleteInstructionFromRecords(Instruction *I) const {
-  return ((ScalarEvolutionsImpl*)Impl)->deleteInstructionFromRecords(I);
+void ScalarEvolution::deleteValueFromRecords(Value *V) const {
+  return ((ScalarEvolutionsImpl*)Impl)->deleteValueFromRecords(V);
 }
 
 static void PrintLoopInfo(std::ostream &OS, const ScalarEvolution *SE,
@@ -2535,7 +2620,7 @@ static void PrintLoopInfo(std::ostream &OS, const ScalarEvolution *SE,
 
   cerr << "Loop " << L->getHeader()->getName() << ": ";
 
-  std::vector<BasicBlock*> ExitBlocks;
+  SmallVector<BasicBlock*, 8> ExitBlocks;
   L->getExitBlocks(ExitBlocks);
   if (ExitBlocks.size() != 1)
     cerr << "<multiple exits> ";