Add new SCEV, SCEVSMax. This allows LLVM to analyze do-while loops.
authorNick Lewycky <nicholas@mxc.ca>
Sun, 25 Nov 2007 22:41:31 +0000 (22:41 +0000)
committerNick Lewycky <nicholas@mxc.ca>
Sun, 25 Nov 2007 22:41:31 +0000 (22:41 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@44319 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
include/llvm/Analysis/ScalarEvolutionExpander.h
include/llvm/Analysis/ScalarEvolutionExpressions.h
lib/Analysis/ScalarEvolution.cpp
lib/Analysis/ScalarEvolutionExpander.cpp
test/Analysis/ScalarEvolution/do-loop.ll [new file with mode: 0644]
test/Analysis/ScalarEvolution/smax.ll [new file with mode: 0644]
test/Transforms/IndVarsSimplify/loop_evaluate_2.ll

index a52f273b2c716f90b66777ec2dd89818c41f7c77..d0886e8c64cba59bacc49d6331fd5033786460f2 100644 (file)
@@ -235,6 +235,8 @@ namespace llvm {
       std::vector<SCEVHandle> NewOp(Operands);
       return getAddRecExpr(NewOp, L);
     }
+    SCEVHandle getSMaxExpr(const SCEVHandle &LHS, const SCEVHandle &RHS);
+    SCEVHandle getSMaxExpr(std::vector<SCEVHandle> Operands);
     SCEVHandle getUnknown(Value *V);
 
     /// getNegativeSCEV - Return the SCEV object corresponding to -V.
index 8582067e108f4ccb491f9f3297f14a7438a203c0..6529902ec811450eb4789710df0de41fef1aa8be 100644 (file)
@@ -134,6 +134,8 @@ namespace llvm {
 
     Value *visitAddRecExpr(SCEVAddRecExpr *S);
 
+    Value *visitSMaxExpr(SCEVSMaxExpr *S);
+
     Value *visitUnknown(SCEVUnknown *S) {
       return S->getValue();
     }
index fb69a90fe74774dc529b82fcb3f04659f99fa41f..f6392433a7a00a55601be3c4628a75c622ab70d3 100644 (file)
@@ -25,7 +25,7 @@ namespace llvm {
     // These should be ordered in terms of increasing complexity to make the
     // folders simpler.
     scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr,
-    scSDivExpr, scAddRecExpr, scUnknown, scCouldNotCompute
+    scSDivExpr, scAddRecExpr, scSMaxExpr, scUnknown, scCouldNotCompute
   };
 
   //===--------------------------------------------------------------------===//
@@ -274,7 +274,8 @@ namespace llvm {
     static inline bool classof(const SCEVCommutativeExpr *S) { return true; }
     static inline bool classof(const SCEV *S) {
       return S->getSCEVType() == scAddExpr ||
-             S->getSCEVType() == scMulExpr;
+             S->getSCEVType() == scMulExpr ||
+             S->getSCEVType() == scSMaxExpr;
     }
   };
 
@@ -459,6 +460,28 @@ namespace llvm {
     }
   };
 
+
+  //===--------------------------------------------------------------------===//
+  /// SCEVSMaxExpr - This class represents a signed maximum selection.
+  ///
+  class SCEVSMaxExpr : public SCEVCommutativeExpr {
+    friend class ScalarEvolution;
+
+    explicit SCEVSMaxExpr(const std::vector<SCEVHandle> &ops)
+      : SCEVCommutativeExpr(scSMaxExpr, ops) {
+    }
+
+  public:
+    virtual const char *getOperationStr() const { return " smax "; }
+
+    /// Methods for support type inquiry through isa, cast, and dyn_cast:
+    static inline bool classof(const SCEVSMaxExpr *S) { return true; }
+    static inline bool classof(const SCEV *S) {
+      return S->getSCEVType() == scSMaxExpr;
+    }
+  };
+
+
   //===--------------------------------------------------------------------===//
   /// SCEVUnknown - This means that we are dealing with an entirely unknown SCEV
   /// value, and only represent it as it's LLVM Value.  This is the "bottom"
@@ -521,6 +544,8 @@ namespace llvm {
         return ((SC*)this)->visitSDivExpr((SCEVSDivExpr*)S);
       case scAddRecExpr:
         return ((SC*)this)->visitAddRecExpr((SCEVAddRecExpr*)S);
+      case scSMaxExpr:
+        return ((SC*)this)->visitSMaxExpr((SCEVSMaxExpr*)S);
       case scUnknown:
         return ((SC*)this)->visitUnknown((SCEVUnknown*)S);
       case scCouldNotCompute:
index 27158e5dddeb01718cd1d8d5681aac2062fd0ecc..558da230d1931c6705dc0f7af3fbff90621f4f9d 100644 (file)
@@ -318,6 +318,8 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
         return SE.getAddExpr(NewOps);
       else if (isa<SCEVMulExpr>(this))
         return SE.getMulExpr(NewOps);
+      else if (isa<SCEVSMaxExpr>(this))
+        return SE.getSMaxExpr(NewOps);
       else
         assert(0 && "Unknown commutative expr!");
     }
@@ -1095,6 +1097,93 @@ SCEVHandle ScalarEvolution::getAddRecExpr(std::vector<SCEVHandle> &Operands,
   return Result;
 }
 
+SCEVHandle ScalarEvolution::getSMaxExpr(const SCEVHandle &LHS,
+                                        const SCEVHandle &RHS) {
+  std::vector<SCEVHandle> Ops;
+  Ops.push_back(LHS);
+  Ops.push_back(RHS);
+  return getSMaxExpr(Ops);
+}
+
+SCEVHandle ScalarEvolution::getSMaxExpr(std::vector<SCEVHandle> Ops) {
+  assert(!Ops.empty() && "Cannot get empty smax!");
+  if (Ops.size() == 1) return Ops[0];
+
+  // Sort by complexity, this groups all similar expression types together.
+  GroupByComplexity(Ops);
+
+  // If there are any constants, fold them together.
+  unsigned Idx = 0;
+  if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
+    ++Idx;
+    assert(Idx < Ops.size());
+    while (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
+      // We found two constants, fold them together!
+      Constant *Fold = ConstantInt::get(
+                              APIntOps::smax(LHSC->getValue()->getValue(),
+                                             RHSC->getValue()->getValue()));
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(Fold)) {
+        Ops[0] = getConstant(CI);
+        Ops.erase(Ops.begin()+1);  // Erase the folded element
+        if (Ops.size() == 1) return Ops[0];
+        LHSC = cast<SCEVConstant>(Ops[0]);
+      } else {
+        // If we couldn't fold the expression, move to the next constant.  Note
+        // that this is impossible to happen in practice because we always
+        // constant fold constant ints to constant ints.
+        ++Idx;
+      }
+    }
+
+    // If we are left with a constant -inf, strip it off.
+    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
+      Ops.erase(Ops.begin());
+      --Idx;
+    }
+  }
+
+  if (Ops.size() == 1) return Ops[0];
+
+  // Find the first SMax
+  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
+    ++Idx;
+
+  // Check to see if one of the operands is an SMax. If so, expand its operands
+  // onto our operand list, and recurse to simplify.
+  if (Idx < Ops.size()) {
+    bool DeletedSMax = false;
+    while (SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
+      Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end());
+      Ops.erase(Ops.begin()+Idx);
+      DeletedSMax = true;
+    }
+
+    if (DeletedSMax)
+      return getSMaxExpr(Ops);
+  }
+
+  // Okay, check to see if the same value occurs in the operand list twice.  If
+  // so, delete one.  Since we sorted the list, these values are required to
+  // be adjacent.
+  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
+    if (Ops[i] == Ops[i+1]) {      //  X smax Y smax Y  -->  X smax Y
+      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
+      --i; --e;
+    }
+
+  if (Ops.size() == 1) return Ops[0];
+
+  assert(!Ops.empty() && "Reduced smax down to nothing!");
+
+  // Okay, it looks like we really DO need an add expr.  Check to see if we
+  // already have one, otherwise create a new one.
+  std::vector<SCEV*> SCEVOps(Ops.begin(), Ops.end());
+  SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr,
+                                                                 SCEVOps)];
+  if (Result == 0) Result = new SCEVSMaxExpr(Ops);
+  return Result;
+}
+
 SCEVHandle ScalarEvolution::getUnknown(Value *V) {
   if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
     return getConstant(CI);
@@ -1458,6 +1547,14 @@ static uint32_t GetMinTrailingZeros(SCEVHandle S) {
     return MinOpRes;
   }
 
+  if (SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
+    // The result is the min of all operands results.
+    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
+    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
+      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
+    return MinOpRes;
+  }
+
   // SCEVSDivExpr, SCEVUnknown
   return 0;
 }
@@ -1537,6 +1634,25 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
     case Instruction::PHI:
       return createNodeForPHI(cast<PHINode>(I));
 
+    case Instruction::Select:
+      // This could be an SCEVSMax that was lowered earlier. Try to recover it.
+      if (ICmpInst *ICI = dyn_cast<ICmpInst>(I->getOperand(0))) {
+        Value *LHS = ICI->getOperand(0);
+        Value *RHS = ICI->getOperand(1);
+        switch (ICI->getPredicate()) {
+        case ICmpInst::ICMP_SLT:
+        case ICmpInst::ICMP_SLE:
+          std::swap(LHS, RHS);
+          // fall through
+        case ICmpInst::ICMP_SGT:
+        case ICmpInst::ICMP_SGE:
+          if (LHS == I->getOperand(1) && RHS == I->getOperand(2))
+            return SE.getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
+        default:
+          break;
+        }
+      }
+
     default: // We cannot analyze this expression.
       break;
     }
@@ -2125,8 +2241,11 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) {
         }
         if (isa<SCEVAddExpr>(Comm))
           return SE.getAddExpr(NewOps);
-        assert(isa<SCEVMulExpr>(Comm) && "Only know about add and mul!");
-        return SE.getMulExpr(NewOps);
+        if (isa<SCEVMulExpr>(Comm))
+          return SE.getMulExpr(NewOps);
+        if (isa<SCEVSMaxExpr>(Comm))
+          return SE.getSMaxExpr(NewOps);
+        assert(0 && "Unknown commutative SCEV type!");
       }
     }
     // If we got here, all operands are loop invariant.
@@ -2343,90 +2462,21 @@ HowManyLessThans(SCEV *LHS, SCEV *RHS, const Loop *L, bool isSigned) {
     return UnknownValue;
 
   if (AddRec->isAffine()) {
+    // The number of iterations for "{n,+,1} < m", is m-n.  However, we don't
+    // know that m is >= n on input to the loop.  If it is, the condition
+    // returns true zero times.  To handle both cases, we return SMAX(0, m-n).
+
     // FORNOW: We only support unit strides.
-    SCEVHandle Zero = SE.getIntegerSCEV(0, RHS->getType());
     SCEVHandle One = SE.getIntegerSCEV(1, RHS->getType());
     if (AddRec->getOperand(1) != One)
       return UnknownValue;
 
-    // The number of iterations for "{n,+,1} < m", is m-n.  However, we don't
-    // know that m is >= n on input to the loop.  If it is, the condition return
-    // true zero times.  What we really should return, for full generality, is
-    // SMAX(0, m-n).  Since we cannot check this, we will instead check for a
-    // canonical loop form: most do-loops will have a check that dominates the
-    // loop, that only enters the loop if (n-1)<m.  If we can find this check,
-    // we know that the SMAX will evaluate to m-n, because we know that m >= n.
-
-    // Search for the check.
-    BasicBlock *Preheader = L->getLoopPreheader();
-    BasicBlock *PreheaderDest = L->getHeader();
-    if (Preheader == 0) return UnknownValue;
-
-    BranchInst *LoopEntryPredicate =
-      dyn_cast<BranchInst>(Preheader->getTerminator());
-    if (!LoopEntryPredicate) return UnknownValue;
-
-    // This might be a critical edge broken out.  If the loop preheader ends in
-    // an unconditional branch to the loop, check to see if the preheader has a
-    // single predecessor, and if so, look for its terminator.
-    while (LoopEntryPredicate->isUnconditional()) {
-      PreheaderDest = Preheader;
-      Preheader = Preheader->getSinglePredecessor();
-      if (!Preheader) return UnknownValue;  // Multiple preds.
-      
-      LoopEntryPredicate =
-        dyn_cast<BranchInst>(Preheader->getTerminator());
-      if (!LoopEntryPredicate) return UnknownValue;
-    }
-
-    // Now that we found a conditional branch that dominates the loop, check to
-    // see if it is the comparison we are looking for.
-    if (ICmpInst *ICI = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition())){
-      Value *PreCondLHS = ICI->getOperand(0);
-      Value *PreCondRHS = ICI->getOperand(1);
-      ICmpInst::Predicate Cond;
-      if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest)
-        Cond = ICI->getPredicate();
-      else
-        Cond = ICI->getInversePredicate();
-    
-      switch (Cond) {
-      case ICmpInst::ICMP_UGT:
-        if (isSigned) return UnknownValue;
-        std::swap(PreCondLHS, PreCondRHS);
-        Cond = ICmpInst::ICMP_ULT;
-        break;
-      case ICmpInst::ICMP_SGT:
-        if (!isSigned) return UnknownValue;
-        std::swap(PreCondLHS, PreCondRHS);
-        Cond = ICmpInst::ICMP_SLT;
-        break;
-      case ICmpInst::ICMP_ULT:
-        if (isSigned) return UnknownValue;
-        break;
-      case ICmpInst::ICMP_SLT:
-        if (!isSigned) return UnknownValue;
-        break;
-      default:
-        return UnknownValue;
-      }
-
-      if (PreCondLHS->getType()->isInteger()) {
-        if (RHS != getSCEV(PreCondRHS))
-          return UnknownValue;  // Not a comparison against 'm'.
+    SCEVHandle Iters = SE.getMinusSCEV(RHS, AddRec->getOperand(0));
 
-        if (SE.getMinusSCEV(AddRec->getOperand(0), One)
-                    != getSCEV(PreCondLHS))
-          return UnknownValue;  // Not a comparison against 'n-1'.
-      }
-      else return UnknownValue;
-
-      // cerr << "Computed Loop Trip Count as: " 
-      //      << //  *SE.getMinusSCEV(RHS, AddRec->getOperand(0)) << "\n";
-      return SE.getMinusSCEV(RHS, AddRec->getOperand(0));
-    }
-    else 
-      return UnknownValue;
+    if (isSigned)
+      return SE.getSMaxExpr(SE.getIntegerSCEV(0, RHS->getType()), Iters);
+    else
+      return Iters;
   }
 
   return UnknownValue;
index 3bac3024c4530dd0997219e593bbc7250b3ec522..88fd0aaf8fbc86fb757d72b5babb50edd55d6026 100644 (file)
@@ -208,6 +208,16 @@ Value *SCEVExpander::visitAddRecExpr(SCEVAddRecExpr *S) {
   return expand(V);
 }
 
+Value *SCEVExpander::visitSMaxExpr(SCEVSMaxExpr *S) {
+  Value *LHS = expand(S->getOperand(0));
+  for (unsigned i = 1; i < S->getNumOperands(); ++i) {
+    Value *RHS = expand(S->getOperand(i));
+    Value *ICmp = new ICmpInst(ICmpInst::ICMP_SGT, LHS, RHS, "tmp", InsertPt);
+    LHS = new SelectInst(ICmp, LHS, RHS, "smax", InsertPt);
+  }
+  return LHS;
+}
+
 Value *SCEVExpander::expand(SCEV *S) {
   // Check to see if we already expanded this.
   std::map<SCEVHandle, Value*>::iterator I = InsertedExpressions.find(S);
diff --git a/test/Analysis/ScalarEvolution/do-loop.ll b/test/Analysis/ScalarEvolution/do-loop.ll
new file mode 100644 (file)
index 0000000..c6b3298
--- /dev/null
@@ -0,0 +1,18 @@
+; RUN: llvm-as < %s | opt -analyze -scalar-evolution | grep smax
+; PR1614
+
+define i32 @f(i32 %x, i32 %y) {
+entry:
+       br label %bb
+
+bb:            ; preds = %bb, %entry
+       %indvar = phi i32 [ 0, %entry ], [ %indvar.next, %bb ]          ; <i32> [#uses=2]
+       %x_addr.0 = add i32 %indvar, %x         ; <i32> [#uses=1]
+       %tmp2 = add i32 %x_addr.0, 1            ; <i32> [#uses=2]
+       %tmp5 = icmp slt i32 %tmp2, %y          ; <i1> [#uses=1]
+       %indvar.next = add i32 %indvar, 1               ; <i32> [#uses=1]
+       br i1 %tmp5, label %bb, label %bb7
+
+bb7:           ; preds = %bb
+       ret i32 %tmp2
+}
diff --git a/test/Analysis/ScalarEvolution/smax.ll b/test/Analysis/ScalarEvolution/smax.ll
new file mode 100644 (file)
index 0000000..157d54f
--- /dev/null
@@ -0,0 +1,12 @@
+; RUN: llvm-as < %s | opt -analyze -scalar-evolution | grep smax | count 2
+; RUN: llvm-as < %s | opt -analyze -scalar-evolution | grep \
+; RUN:     "%. smax  %. smax  %."
+; PR1614
+
+define i32 @x(i32 %a, i32 %b, i32 %c) {
+  %A = icmp sgt i32 %a, %b
+  %B = select i1 %A, i32 %a, i32 %b
+  %C = icmp sle i32 %c, %B
+  %D = select i1 %C, i32 %B, i32 %c
+  ret i32 %D
+}
index c7426918462f46d6656828695b2d9f69726ddb79..635950aed661897cc8f61b22d23125610d906e70 100644 (file)
@@ -1,5 +1,5 @@
-; RUN: llvm-as < %s | opt -indvars -adce -simplifycfg | llvm-dis | grep select
-; RUN: llvm-as < %s | opt -indvars -adce -simplifycfg | llvm-dis | not grep br
+; RUN: llvm-as < %s | opt -indvars -adce -simplifycfg | opt \
+; RUN:     -analyze -loops | not grep "^Loop Containing" 
 ; PR1179
 
 define i32 @ltst(i32 %x) {