Add new SCEV, SCEVSMax. This allows LLVM to analyze do-while loops.
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 27158e5dddeb01718cd1d8d5681aac2062fd0ecc..558da230d1931c6705dc0f7af3fbff90621f4f9d 100644 (file)
@@ -318,6 +318,8 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
         return SE.getAddExpr(NewOps);
       else if (isa<SCEVMulExpr>(this))
         return SE.getMulExpr(NewOps);
+      else if (isa<SCEVSMaxExpr>(this))
+        return SE.getSMaxExpr(NewOps);
       else
         assert(0 && "Unknown commutative expr!");
     }
@@ -1095,6 +1097,93 @@ SCEVHandle ScalarEvolution::getAddRecExpr(std::vector<SCEVHandle> &Operands,
   return Result;
 }
 
+SCEVHandle ScalarEvolution::getSMaxExpr(const SCEVHandle &LHS,
+                                        const SCEVHandle &RHS) {
+  std::vector<SCEVHandle> Ops;
+  Ops.push_back(LHS);
+  Ops.push_back(RHS);
+  return getSMaxExpr(Ops);
+}
+
+SCEVHandle ScalarEvolution::getSMaxExpr(std::vector<SCEVHandle> Ops) {
+  assert(!Ops.empty() && "Cannot get empty smax!");
+  if (Ops.size() == 1) return Ops[0];
+
+  // Sort by complexity, this groups all similar expression types together.
+  GroupByComplexity(Ops);
+
+  // If there are any constants, fold them together.
+  unsigned Idx = 0;
+  if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
+    ++Idx;
+    assert(Idx < Ops.size());
+    while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
+      // We found two constants, fold them together!
+      Constant *Fold = ConstantInt::get(
+                              APIntOps::smax(LHSC->getValue()->getValue(),
+                                             RHSC->getValue()->getValue()));
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(Fold)) {
+        Ops[0] = getConstant(CI);
+        Ops.erase(Ops.begin()+1);  // Erase the folded element
+        if (Ops.size() == 1) return Ops[0];
+        LHSC = cast<SCEVConstant>(Ops[0]);
+      } else {
+        // If we couldn't fold the expression, move to the next constant.  Note
+        // that this is impossible to happen in practice because we always
+        // constant fold constant ints to constant ints.
+        ++Idx;
+      }
+    }
+
+    // If we are left with a constant -inf, strip it off.
+    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
+      Ops.erase(Ops.begin());
+      --Idx;
+    }
+  }
+
+  if (Ops.size() == 1) return Ops[0];
+
+  // Find the first SMax
+  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
+    ++Idx;
+
+  // Check to see if one of the operands is an SMax. If so, expand its operands
+  // onto our operand list, and recurse to simplify.
+  if (Idx < Ops.size()) {
+    bool DeletedSMax = false;
+    while (SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
+      Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end());
+      Ops.erase(Ops.begin()+Idx);
+      DeletedSMax = true;
+    }
+
+    if (DeletedSMax)
+      return getSMaxExpr(Ops);
+  }
+
+  // Okay, check to see if the same value occurs in the operand list twice.  If
+  // so, delete one.  Since we sorted the list, these values are required to
+  // be adjacent.
+  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
+    if (Ops[i] == Ops[i+1]) {      //  X smax Y smax Y  -->  X smax Y
+      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
+      --i; --e;
+    }
+
+  if (Ops.size() == 1) return Ops[0];
+
+  assert(!Ops.empty() && "Reduced smax down to nothing!");
+
+  // Okay, it looks like we really DO need an add expr.  Check to see if we
+  // already have one, otherwise create a new one.
+  std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
+  SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr,
+                                                                 SCEVOps)];
+  if (Result == 0) Result = new SCEVSMaxExpr(Ops);
+  return Result;
+}
+
 SCEVHandle ScalarEvolution::getUnknown(Value *V) {
   if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
     return getConstant(CI);
@@ -1458,6 +1547,14 @@ static uint32_t GetMinTrailingZeros(SCEVHandle S) {
     return MinOpRes;
   }
 
+  if (SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
+    // The result is the min of all operands results.
+    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
+    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
+      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
+    return MinOpRes;
+  }
+
   // SCEVSDivExpr, SCEVUnknown
   return 0;
 }
@@ -1537,6 +1634,25 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
     case Instruction::PHI:
       return createNodeForPHI(cast<PHINode>(I));
 
+    case Instruction::Select:
+      // This could be an SCEVSMax that was lowered earlier. Try to recover it.
+      if (ICmpInst *ICI = dyn_cast<ICmpInst>(I->getOperand(0))) {
+        Value *LHS = ICI->getOperand(0);
+        Value *RHS = ICI->getOperand(1);
+        switch (ICI->getPredicate()) {
+        case ICmpInst::ICMP_SLT:
+        case ICmpInst::ICMP_SLE:
+          std::swap(LHS, RHS);
+          // fall through
+        case ICmpInst::ICMP_SGT:
+        case ICmpInst::ICMP_SGE:
+          if (LHS == I->getOperand(1) && RHS == I->getOperand(2))
+            return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
+        default:
+          break;
+        }
+      }
+
     default: // We cannot analyze this expression.
       break;
     }
@@ -2125,8 +2241,11 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) {
         }
         if (isa<SCEVAddExpr>(Comm))
           return SE.getAddExpr(NewOps);
-        assert(isa<SCEVMulExpr>(Comm) && "Only know about add and mul!");
-        return SE.getMulExpr(NewOps);
+        if (isa<SCEVMulExpr>(Comm))
+          return SE.getMulExpr(NewOps);
+        if (isa<SCEVSMaxExpr>(Comm))
+          return SE.getSMaxExpr(NewOps);
+        assert(0 && "Unknown commutative SCEV type!");
       }
     }
     // If we got here, all operands are loop invariant.
@@ -2343,90 +2462,21 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) {
     return UnknownValue;
 
   if (AddRec->isAffine()) {
+    // The number of iterations for "{n,+,1} < m", is m-n.  However, we don't
+    // know that m is >= n on input to the loop.  If it is, the condition
+    // returns true zero times.  To handle both cases, we return SMAX(0, m-n).
+
     // FORNOW: We only support unit strides.
-    SCEVHandle Zero = SE.getIntegerSCEV(0, RHS->getType());
     SCEVHandle One = SE.getIntegerSCEV(1, RHS->getType());
     if (AddRec->getOperand(1) != One)
       return UnknownValue;
 
-    // The number of iterations for "{n,+,1} < m", is m-n.  However, we don't
-    // know that m is >= n on input to the loop.  If it is, the condition return
-    // true zero times.  What we really should return, for full generality, is
-    // SMAX(0, m-n).  Since we cannot check this, we will instead check for a
-    // canonical loop form: most do-loops will have a check that dominates the
-    // loop, that only enters the loop if (n-1)<m.  If we can find this check,
-    // we know that the SMAX will evaluate to m-n, because we know that m >= n.
-
-    // Search for the check.
-    BasicBlock *Preheader = L->getLoopPreheader();
-    BasicBlock *PreheaderDest = L->getHeader();
-    if (Preheader == 0) return UnknownValue;
-
-    BranchInst *LoopEntryPredicate =
-      dyn_cast<BranchInst>(Preheader->getTerminator());
-    if (!LoopEntryPredicate) return UnknownValue;
-
-    // This might be a critical edge broken out.  If the loop preheader ends in
-    // an unconditional branch to the loop, check to see if the preheader has a
-    // single predecessor, and if so, look for its terminator.
-    while (LoopEntryPredicate->isUnconditional()) {
-      PreheaderDest = Preheader;
-      Preheader = Preheader->getSinglePredecessor();
-      if (!Preheader) return UnknownValue;  // Multiple preds.
-      
-      LoopEntryPredicate =
-        dyn_cast<BranchInst>(Preheader->getTerminator());
-      if (!LoopEntryPredicate) return UnknownValue;
-    }
-
-    // Now that we found a conditional branch that dominates the loop, check to
-    // see if it is the comparison we are looking for.
-    if (ICmpInst *ICI = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition())){
-      Value *PreCondLHS = ICI->getOperand(0);
-      Value *PreCondRHS = ICI->getOperand(1);
-      ICmpInst::Predicate Cond;
-      if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest)
-        Cond = ICI->getPredicate();
-      else
-        Cond = ICI->getInversePredicate();
-    
-      switch (Cond) {
-      case ICmpInst::ICMP_UGT:
-        if (isSigned) return UnknownValue;
-        std::swap(PreCondLHS, PreCondRHS);
-        Cond = ICmpInst::ICMP_ULT;
-        break;
-      case ICmpInst::ICMP_SGT:
-        if (!isSigned) return UnknownValue;
-        std::swap(PreCondLHS, PreCondRHS);
-        Cond = ICmpInst::ICMP_SLT;
-        break;
-      case ICmpInst::ICMP_ULT:
-        if (isSigned) return UnknownValue;
-        break;
-      case ICmpInst::ICMP_SLT:
-        if (!isSigned) return UnknownValue;
-        break;
-      default:
-        return UnknownValue;
-      }
-
-      if (PreCondLHS->getType()->isInteger()) {
-        if (RHS != getSCEV(PreCondRHS))
-          return UnknownValue;  // Not a comparison against 'm'.
+    SCEVHandle Iters = SE.getMinusSCEV(RHS, AddRec->getOperand(0));
 
-        if (SE.getMinusSCEV(AddRec->getOperand(0), One)
-                    != getSCEV(PreCondLHS))
-          return UnknownValue;  // Not a comparison against 'n-1'.
-      }
-      else return UnknownValue;
-
-      // cerr << "Computed Loop Trip Count as: " 
-      //      << //  *SE.getMinusSCEV(RHS, AddRec->getOperand(0)) << "\n";
-      return SE.getMinusSCEV(RHS, AddRec->getOperand(0));
-    }
-    else 
-      return UnknownValue;
+    if (isSigned)
+      return SE.getSMaxExpr(SE.getIntegerSCEV(0, RHS->getType()), Iters);
+    else
+      return Iters;
   }
 
   return UnknownValue;