Build the correct range for loops with unusual bounds. Fix from Jay Foad.
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index ed8ea32767872f70d8d477ad14ebb00c3daa31ec..069f6ec714cc54a01d21df51ad5885a0b0387927 100644 (file)
@@ -183,6 +183,10 @@ SCEVHandle SCEVConstant::get(ConstantInt *V) {
   return R;
 }
 
+SCEVHandle SCEVConstant::get(const APInt& Val) {
+  return get(ConstantInt::get(Val));
+}
+
 ConstantRange SCEVConstant::getValueRange() const {
   return ConstantRange(V->getValue());
 }
@@ -481,16 +485,13 @@ SCEVHandle SCEVUnknown::getIntegerSCEV(int Val, const Type *Ty) {
   if (Val == 0)
     C = Constant::getNullValue(Ty);
   else if (Ty->isFloatingPoint())
-    C = ConstantFP::get(Ty, Val);
+    C = ConstantFP::get(Ty, APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : 
+                            APFloat::IEEEdouble, Val));
   else 
     C = ConstantInt::get(Ty, Val);
   return SCEVUnknown::get(C);
 }
 
-SCEVHandle SCEVUnknown::getIntegerSCEV(const APInt& Val) {
-  return SCEVUnknown::get(ConstantInt::get(Val));
-}
-
 /// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
 /// input value to the specified type.  If the type must be extended, it is zero
 /// extended.
@@ -531,7 +532,7 @@ static SCEVHandle PartialFact(SCEVHandle V, unsigned NumSteps) {
     APInt Result(Val.getBitWidth(), 1);
     for (; NumSteps; --NumSteps)
       Result *= Val-(NumSteps-1);
-    return SCEVUnknown::get(ConstantInt::get(Result));
+    return SCEVConstant::get(Result);
   }
 
   const Type *Ty = V->getType();
@@ -1167,10 +1168,10 @@ namespace {
     /// loop without a loop-invariant iteration count.
     SCEVHandle getIterationCount(const Loop *L);
 
-    /// deleteInstructionFromRecords - This method should be called by the
-    /// client before it removes an instruction from the program, to make sure
+    /// deleteValueFromRecords - This method should be called by the
+    /// client before it removes a value from the program, to make sure
     /// that no dangling references are left around.
-    void deleteInstructionFromRecords(Instruction *I);
+    void deleteValueFromRecords(Value *V);
 
   private:
     /// createSCEV - We know that there is no SCEV for the specified value.
@@ -1220,8 +1221,9 @@ namespace {
 
     /// HowManyLessThans - Return the number of times a backedge containing the
     /// specified less-than comparison will execute.  If not computable, return
-    /// UnknownValue.
-    SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L);
+    /// UnknownValue. isSigned specifies whether the less-than is signed.
+    SCEVHandle HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L,
+                                bool isSigned);
 
     /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
     /// in the header of its containing loop, we know the loop executes a
@@ -1236,27 +1238,27 @@ namespace {
 //            Basic SCEV Analysis and PHI Idiom Recognition Code
 //
 
-/// deleteInstructionFromRecords - This method should be called by the
+/// deleteValueFromRecords - This method should be called by the
 /// client before it removes an instruction from the program, to make sure
 /// that no dangling references are left around.
-void ScalarEvolutionsImpl::deleteInstructionFromRecords(Instruction *I) {
-  SmallVector<Instruction *, 16> Worklist;
+void ScalarEvolutionsImpl::deleteValueFromRecords(Value *V) {
+  SmallVector<Value *, 16> Worklist;
 
-  if (Scalars.erase(I)) {
-    if (PHINode *PN = dyn_cast<PHINode>(I))
+  if (Scalars.erase(V)) {
+    if (PHINode *PN = dyn_cast<PHINode>(V))
       ConstantEvolutionLoopExitValue.erase(PN);
-    Worklist.push_back(I);
+    Worklist.push_back(V);
   }
 
   while (!Worklist.empty()) {
-    Instruction *II = Worklist.back();
+    Value *VV = Worklist.back();
     Worklist.pop_back();
 
-    for (Instruction::use_iterator UI = II->use_begin(), UE = II->use_end();
+    for (Instruction::use_iterator UI = VV->use_begin(), UE = VV->use_end();
          UI != UE; ++UI) {
       Instruction *Inst = cast<Instruction>(*UI);
       if (Scalars.erase(Inst)) {
-        if (PHINode *PN = dyn_cast<PHINode>(II))
+        if (PHINode *PN = dyn_cast<PHINode>(VV))
           ConstantEvolutionLoopExitValue.erase(PN);
         Worklist.push_back(Inst);
       }
@@ -1568,7 +1570,7 @@ SCEVHandle ScalarEvolutionsImpl::getIterationCount(const Loop *L) {
 /// will iterate.
 SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) {
   // If the loop has a non-one exit block count, we can't analyze it.
-  std::vector<BasicBlock*> ExitBlocks;
+  SmallVector<BasicBlock*, 8> ExitBlocks;
   L->getExitBlocks(ExitBlocks);
   if (ExitBlocks.size() != 1) return UnknownValue;
 
@@ -1671,8 +1673,7 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) {
           ConstantRange CompRange(
               ICmpInst::makeConstantRange(Cond, CompVal->getValue()));
 
-          SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange, 
-              false /*Always treat as unsigned range*/);
+          SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange);
           if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
         }
       }
@@ -1691,12 +1692,24 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) {
     break;
   }
   case ICmpInst::ICMP_SLT: {
-    SCEVHandle TC = HowManyLessThans(LHS, RHS, L);
+    SCEVHandle TC = HowManyLessThans(LHS, RHS, L, true);
     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
     break;
   }
   case ICmpInst::ICMP_SGT: {
-    SCEVHandle TC = HowManyLessThans(RHS, LHS, L);
+    SCEVHandle TC = HowManyLessThans(SCEV::getNegativeSCEV(LHS),
+                                     SCEV::getNegativeSCEV(RHS), L, true);
+    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
+    break;
+  }
+  case ICmpInst::ICMP_ULT: {
+    SCEVHandle TC = HowManyLessThans(LHS, RHS, L, false);
+    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
+    break;
+  }
+  case ICmpInst::ICMP_UGT: {
+    SCEVHandle TC = HowManyLessThans(SCEV::getNegativeSCEV(LHS),
+                                     SCEV::getNegativeSCEV(RHS), L, false);
     if (!isa<SCEVCouldNotCompute>(TC)) return TC;
     break;
   }
@@ -1716,8 +1729,8 @@ SCEVHandle ScalarEvolutionsImpl::ComputeIterationCount(const Loop *L) {
 }
 
 static ConstantInt *
-EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, Constant *C) {
-  SCEVHandle InVal = SCEVConstant::get(cast<ConstantInt>(C));
+EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C) {
+  SCEVHandle InVal = SCEVConstant::get(C);
   SCEVHandle Val = AddRec->evaluateAtIteration(InVal);
   assert(isa<SCEVConstant>(Val) &&
          "Evaluation of SCEV at constant didn't fold correctly?");
@@ -2199,8 +2212,8 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec) {
     ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA));
     ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA));
 
-    return std::make_pair(SCEVUnknown::get(Solution1), 
-                          SCEVUnknown::get(Solution2));
+    return std::make_pair(SCEVConstant::get(Solution1), 
+                          SCEVConstant::get(Solution2));
     } // end APIntOps namespace
 }
 
@@ -2310,7 +2323,7 @@ SCEVHandle ScalarEvolutionsImpl::HowFarToNonZero(SCEV *V, const Loop *L) {
 /// specified less-than comparison will execute.  If not computable, return
 /// UnknownValue.
 SCEVHandle ScalarEvolutionsImpl::
-HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) {
+HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) {
   // Only handle:  "ADDREC < LoopInvariant".
   if (!RHS->isLoopInvariant(L)) return UnknownValue;
 
@@ -2367,28 +2380,34 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) {
     
       switch (Cond) {
       case ICmpInst::ICMP_UGT:
+        if (isSigned) return UnknownValue;
         std::swap(PreCondLHS, PreCondRHS);
         Cond = ICmpInst::ICMP_ULT;
         break;
       case ICmpInst::ICMP_SGT:
+        if (!isSigned) return UnknownValue;
         std::swap(PreCondLHS, PreCondRHS);
         Cond = ICmpInst::ICMP_SLT;
         break;
-      default: break;
+      case ICmpInst::ICMP_ULT:
+        if (isSigned) return UnknownValue;
+        break;
+      case ICmpInst::ICMP_SLT:
+        if (!isSigned) return UnknownValue;
+        break;
+      default:
+        return UnknownValue;
       }
 
-      if (Cond == ICmpInst::ICMP_SLT) {
-        if (PreCondLHS->getType()->isInteger()) {
-          if (RHS != getSCEV(PreCondRHS))
-            return UnknownValue;  // Not a comparison against 'm'.
+      if (PreCondLHS->getType()->isInteger()) {
+        if (RHS != getSCEV(PreCondRHS))
+          return UnknownValue;  // Not a comparison against 'm'.
 
-          if (SCEV::getMinusSCEV(AddRec->getOperand(0), One)
-                      != getSCEV(PreCondLHS))
-            return UnknownValue;  // Not a comparison against 'n-1'.
-        }
-        else return UnknownValue;
-      } else if (Cond == ICmpInst::ICMP_ULT)
-        return UnknownValue;
+        if (SCEV::getMinusSCEV(AddRec->getOperand(0), One)
+                    != getSCEV(PreCondLHS))
+          return UnknownValue;  // Not a comparison against 'n-1'.
+      }
+      else return UnknownValue;
 
       // cerr << "Computed Loop Trip Count as: " 
       //      << //  *SCEV::getMinusSCEV(RHS, AddRec->getOperand(0)) << "\n";
@@ -2406,8 +2425,7 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L) {
 /// this is that it returns the first iteration number where the value is not in
 /// the condition, thus computing the exit count. If the iteration count can't
 /// be computed, an instance of SCEVCouldNotCompute is returned.
-SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, 
-                                                   bool isSigned) const {
+SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range) const {
   if (Range.isFullSet())  // Infinite loop.
     return new SCEVCouldNotCompute();
 
@@ -2419,7 +2437,7 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
       SCEVHandle Shifted = SCEVAddRecExpr::get(Operands, getLoop());
       if (SCEVAddRecExpr *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
         return ShiftedAddRec->getNumIterationsInRange(
-                           Range.subtract(SC->getValue()->getValue()),isSigned);
+                           Range.subtract(SC->getValue()->getValue()));
       // This is strange and shouldn't happen.
       return new SCEVCouldNotCompute();
     }
@@ -2443,17 +2461,16 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
     // If this is an affine expression then we have this situation:
     //   Solve {0,+,A} in Range  ===  Ax in Range
 
-    // Since we know that zero is in the range, we know that the upper value of
-    // the range must be the first possible exit value.  Also note that we
-    // already checked for a full range.
-    const APInt &Upper = Range.getUpper();
-    APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
+    // We know that zero is in the range.  If A is positive then we know that
+    // the upper value of the range must be the first possible exit value.
+    // If A is negative then the lower of the range is the last possible loop
+    // value.  Also note that we already checked for a full range.
     APInt One(getBitWidth(),1);
+    APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
+    APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
 
-    // The exit value should be (Upper+A-1)/A.
-    APInt ExitVal(Upper);
-    if (A != One)
-      ExitVal = (Upper + A - One).sdiv(A);
+    // The exit value should be (End+A)/A.
+    APInt ExitVal = (End + A).udiv(A);
     ConstantInt *ExitValue = ConstantInt::get(ExitVal);
 
     // Evaluate at the exit value.  If we really did fall out of the valid
@@ -2468,15 +2485,14 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
            EvaluateConstantChrecAtConstant(this, 
            ConstantInt::get(ExitVal - One))->getValue()) &&
            "Linear scev computation is off in a bad way!");
-    return SCEVConstant::get(cast<ConstantInt>(ExitValue));
+    return SCEVConstant::get(ExitValue);
   } else if (isQuadratic()) {
     // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
     // quadratic equation to solve it.  To do this, we must frame our problem in
     // terms of figuring out when zero is crossed, instead of when
     // Range.getUpper() is crossed.
     std::vector<SCEVHandle> NewOps(op_begin(), op_end());
-    NewOps[0] = SCEV::getNegativeSCEV(SCEVUnknown::get(
-                                           ConstantInt::get(Range.getUpper())));
+    NewOps[0] = SCEV::getNegativeSCEV(SCEVConstant::get(Range.getUpper()));
     SCEVHandle NewAddRec = SCEVAddRecExpr::get(NewOps, getLoop());
 
     // Next, solve the constructed addrec
@@ -2499,17 +2515,17 @@ SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
                                                              R1->getValue());
         if (Range.contains(R1Val->getValue())) {
           // The next iteration must be out of the range...
-          Constant *NextVal = ConstantInt::get(R1->getValue()->getValue()+1);
+          ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()+1);
 
           R1Val = EvaluateConstantChrecAtConstant(this, NextVal);
           if (!Range.contains(R1Val->getValue()))
-            return SCEVUnknown::get(NextVal);
+            return SCEVConstant::get(NextVal);
           return new SCEVCouldNotCompute();  // Something strange happened
         }
 
         // If R1 was not in the range, then it is a good return value.  Make
         // sure that R1-1 WAS in the range though, just in case.
-        Constant *NextVal = ConstantInt::get(R1->getValue()->getValue()-1);
+        ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()-1);
         R1Val = EvaluateConstantChrecAtConstant(this, NextVal);
         if (Range.contains(R1Val->getValue()))
           return R1;
@@ -2593,8 +2609,8 @@ SCEVHandle ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) const {
   return ((ScalarEvolutionsImpl*)Impl)->getSCEVAtScope(getSCEV(V), L);
 }
 
-void ScalarEvolution::deleteInstructionFromRecords(Instruction *I) const {
-  return ((ScalarEvolutionsImpl*)Impl)->deleteInstructionFromRecords(I);
+void ScalarEvolution::deleteValueFromRecords(Value *V) const {
+  return ((ScalarEvolutionsImpl*)Impl)->deleteValueFromRecords(V);
 }
 
 static void PrintLoopInfo(std::ostream &OS, const ScalarEvolution *SE,
@@ -2605,7 +2621,7 @@ static void PrintLoopInfo(std::ostream &OS, const ScalarEvolution *SE,
 
   cerr << "Loop " << L->getHeader()->getName() << ": ";
 
-  std::vector<BasicBlock*> ExitBlocks;
+  SmallVector<BasicBlock*, 8> ExitBlocks;
   L->getExitBlocks(ExitBlocks);
   if (ExitBlocks.size() != 1)
     cerr << "<multiple exits> ";