Teach ScalarEvolution how to analyze loops with multiple exit
authorDan Gohman <gohman@apple.com>
Mon, 22 Jun 2009 00:31:57 +0000 (00:31 +0000)
committerDan Gohman <gohman@apple.com>
Mon, 22 Jun 2009 00:31:57 +0000 (00:31 +0000)
blocks, and also exit blocks with multiple conditions (combined
with (bitwise) ands and ors). It's often infeasible to compute an
exact trip count in such cases, but a useful upper bound can often
be found.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@73866 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
lib/Analysis/ScalarEvolution.cpp
test/Analysis/ScalarEvolution/trip-count5.ll [new file with mode: 0644]

index 857895830e70549e25dea6b223c2bd79cb0a994e..37f25fcf0cd6c3e9975d33b0663482f5fd1a9953 100644 (file)
@@ -348,6 +348,31 @@ namespace llvm {
     /// loop will iterate.
     BackedgeTakenInfo ComputeBackedgeTakenCount(const Loop *L);
 
+    /// ComputeBackedgeTakenCountFromExit - Compute the number of times the
+    /// backedge of the specified loop will execute if it exits via the
+    /// specified block.
+    BackedgeTakenInfo ComputeBackedgeTakenCountFromExit(const Loop *L,
+                                                      BasicBlock *ExitingBlock);
+
+    /// ComputeBackedgeTakenCountFromExitCond - Compute the number of times the
+    /// backedge of the specified loop will execute if its exit condition
+    /// were a conditional branch of ExitCond, TBB, and FBB.
+    BackedgeTakenInfo
+      ComputeBackedgeTakenCountFromExitCond(const Loop *L,
+                                            Value *ExitCond,
+                                            BasicBlock *TBB,
+                                            BasicBlock *FBB);
+
+    /// ComputeBackedgeTakenCountFromExitCondICmp - Compute the number of
+    /// times the backedge of the specified loop will execute if its exit
+    /// condition were a conditional branch of the ICmpInst ExitCond, TBB,
+    /// and FBB.
+    BackedgeTakenInfo
+      ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L,
+                                                ICmpInst *ExitCond,
+                                                BasicBlock *TBB,
+                                                BasicBlock *FBB);
+
     /// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition
     /// of 'icmp op load X, cst', try to see if we can compute the trip count.
     SCEVHandle
@@ -520,6 +545,12 @@ namespace llvm {
     /// specified signed integer value and return a SCEV for the constant.
     SCEVHandle getIntegerSCEV(int Val, const Type *Ty);
 
+    /// getUMaxFromMismatchedTypes - Promote the operands to the wider of
+    /// the types using zero-extension, and then perform a umax operation
+    /// with them.
+    SCEVHandle getUMaxFromMismatchedTypes(const SCEVHandle &LHS,
+                                          const SCEVHandle &RHS);
+
     /// hasSCEV - Return true if the SCEV for this value has already been
     /// computed.
     bool hasSCEV(Value *V) const;
index 15d1c247639e89a20eb60e82df0560e0b6ce2bb5..d85191377f1d1780e2937b249aeb9f7264fa84d8 100644 (file)
@@ -2128,6 +2128,22 @@ ScalarEvolution::getTruncateOrNoop(const SCEVHandle &V, const Type *Ty) {
   return getTruncateExpr(V, Ty);
 }
 
+/// getUMaxFromMismatchedTypes - Promote the operands to the wider of
+/// the types using zero-extension, and then perform a umax operation
+/// with them.
+SCEVHandle ScalarEvolution::getUMaxFromMismatchedTypes(const SCEVHandle &LHS,
+                                                       const SCEVHandle &RHS) {
+  SCEVHandle PromotedLHS = LHS;
+  SCEVHandle PromotedRHS = RHS;
+
+  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
+    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
+  else
+    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
+
+  return getUMaxExpr(PromotedLHS, PromotedRHS);
+}
+
 /// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value for
 /// the specified instruction and replaces any references to the symbolic value
 /// SymName with the specified value.  This is used during PHI resolution.
@@ -2723,9 +2739,13 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
 
       // Update the value in the map.
       Pair.first->second = ItCount;
-    } else if (isa<PHINode>(L->getHeader()->begin())) {
-      // Only count loops that have phi nodes as not being computable.
-      ++NumTripCountsNotComputed;
+    } else {
+      if (ItCount.Max != CouldNotCompute)
+        // Update the value in the map.
+        Pair.first->second = ItCount;
+      if (isa<PHINode>(L->getHeader()->begin()))
+        // Only count loops that have phi nodes as not being computable.
+        ++NumTripCountsNotComputed;
     }
 
     // Now that we know more about the trip count for this loop, forget any
@@ -2781,19 +2801,58 @@ void ScalarEvolution::forgetLoopPHIs(const Loop *L) {
 /// of the specified loop will execute.
 ScalarEvolution::BackedgeTakenInfo
 ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
-  // If the loop has a non-one exit block count, we can't analyze it.
-  BasicBlock *ExitBlock = L->getExitBlock();
-  if (!ExitBlock)
-    return CouldNotCompute;
+  SmallVector<BasicBlock*, 8> ExitingBlocks;
+  L->getExitingBlocks(ExitingBlocks);
+
+  // Examine all exits and pick the most conservative values.
+  SCEVHandle BECount = CouldNotCompute;
+  SCEVHandle MaxBECount = CouldNotCompute;
+  bool CouldNotComputeBECount = false;
+  bool CouldNotComputeMaxBECount = false;
+  for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
+    BackedgeTakenInfo NewBTI =
+      ComputeBackedgeTakenCountFromExit(L, ExitingBlocks[i]);
+
+    if (NewBTI.Exact == CouldNotCompute) {
+      // We couldn't compute an exact value for this exit, so
+      // we don't be able to compute an exact value for the loop.
+      CouldNotComputeBECount = true;
+      BECount = CouldNotCompute;
+    } else if (!CouldNotComputeBECount) {
+      if (BECount == CouldNotCompute)
+        BECount = NewBTI.Exact;
+      else {
+        // TODO: More analysis could be done here. For example, a
+        // loop with a short-circuiting && operator has an exact count
+        // of the min of both sides.
+        CouldNotComputeBECount = true;
+        BECount = CouldNotCompute;
+      }
+    }
+    if (NewBTI.Max == CouldNotCompute) {
+      // We couldn't compute an maximum value for this exit, so
+      // we don't be able to compute an maximum value for the loop.
+      CouldNotComputeMaxBECount = true;
+      MaxBECount = CouldNotCompute;
+    } else if (!CouldNotComputeMaxBECount) {
+      if (MaxBECount == CouldNotCompute)
+        MaxBECount = NewBTI.Max;
+      else
+        MaxBECount = getUMaxFromMismatchedTypes(MaxBECount, NewBTI.Max);
+    }
+  }
 
-  // Okay, there is one exit block.  Try to find the condition that causes the
-  // loop to be exited.
-  BasicBlock *ExitingBlock = L->getExitingBlock();
-  if (!ExitingBlock)
-    return CouldNotCompute;   // More than one block exiting!
+  return BackedgeTakenInfo(BECount, MaxBECount);
+}
+
+/// ComputeBackedgeTakenCountFromExit - Compute the number of times the backedge
+/// of the specified loop will execute if it exits via the specified block.
+ScalarEvolution::BackedgeTakenInfo
+ScalarEvolution::ComputeBackedgeTakenCountFromExit(const Loop *L,
+                                                   BasicBlock *ExitingBlock) {
 
-  // Okay, we've computed the exiting block.  See what condition causes us to
-  // exit.
+  // Okay, we've chosen an exiting block.  See what condition causes us to
+  // exit at this block.
   //
   // FIXME: we should be able to handle switch instructions (with a single exit)
   BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
@@ -2808,23 +2867,154 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
   // Currently we check for this by checking to see if the Exit branch goes to
   // the loop header.  If so, we know it will always execute the same number of
   // times as the loop.  We also handle the case where the exit block *is* the
-  // loop header.  This is common for un-rotated loops.  More extensive analysis
-  // could be done to handle more cases here.
+  // loop header.  This is common for un-rotated loops.
+  //
+  // If both of those tests fail, walk up the unique predecessor chain to the
+  // header, stopping if there is an edge that doesn't exit the loop. If the
+  // header is reached, the execution count of the branch will be equal to the
+  // trip count of the loop.
+  //
+  //  More extensive analysis could be done to handle more cases here.
+  //
   if (ExitBr->getSuccessor(0) != L->getHeader() &&
       ExitBr->getSuccessor(1) != L->getHeader() &&
-      ExitBr->getParent() != L->getHeader())
-    return CouldNotCompute;
-  
-  ICmpInst *ExitCond = dyn_cast<ICmpInst>(ExitBr->getCondition());
+      ExitBr->getParent() != L->getHeader()) {
+    // The simple checks failed, try climbing the unique predecessor chain
+    // up to the header.
+    bool Ok = false;
+    for (BasicBlock *BB = ExitBr->getParent(); BB; ) {
+      BasicBlock *Pred = BB->getUniquePredecessor();
+      if (!Pred)
+        return CouldNotCompute;
+      TerminatorInst *PredTerm = Pred->getTerminator();
+      for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) {
+        BasicBlock *PredSucc = PredTerm->getSuccessor(i);
+        if (PredSucc == BB)
+          continue;
+        // If the predecessor has a successor that isn't BB and isn't
+        // outside the loop, assume the worst.
+        if (L->contains(PredSucc))
+          return CouldNotCompute;
+      }
+      if (Pred == L->getHeader()) {
+        Ok = true;
+        break;
+      }
+      BB = Pred;
+    }
+    if (!Ok)
+      return CouldNotCompute;
+  }
+
+  // Procede to the next level to examine the exit condition expression.
+  return ComputeBackedgeTakenCountFromExitCond(L, ExitBr->getCondition(),
+                                               ExitBr->getSuccessor(0),
+                                               ExitBr->getSuccessor(1));
+}
+
+/// ComputeBackedgeTakenCountFromExitCond - Compute the number of times the
+/// backedge of the specified loop will execute if its exit condition
+/// were a conditional branch of ExitCond, TBB, and FBB.
+ScalarEvolution::BackedgeTakenInfo
+ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L,
+                                                       Value *ExitCond,
+                                                       BasicBlock *TBB,
+                                                       BasicBlock *FBB) {
+  // Check if the controlling expression for this loop is an and or or. In
+  // such cases, an exact backedge-taken count may be infeasible, but a
+  // maximum count may still be feasible.
+  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
+    if (BO->getOpcode() == Instruction::And) {
+      // Recurse on the operands of the and.
+      BackedgeTakenInfo BTI0 =
+        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB);
+      BackedgeTakenInfo BTI1 =
+        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB);
+      SCEVHandle BECount = CouldNotCompute;
+      SCEVHandle MaxBECount = CouldNotCompute;
+      if (L->contains(TBB)) {
+        // Both conditions must be true for the loop to continue executing.
+        // Choose the less conservative count.
+        // TODO: Take the minimum of the exact counts.
+        if (BTI0.Exact == BTI1.Exact)
+          BECount = BTI0.Exact;
+        // TODO: Take the minimum of the maximum counts.
+        if (BTI0.Max == CouldNotCompute)
+          MaxBECount = BTI1.Max;
+        else if (BTI1.Max == CouldNotCompute)
+          MaxBECount = BTI0.Max;
+        else if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(BTI0.Max))
+          if (const SCEVConstant *C1 = dyn_cast<SCEVConstant>(BTI1.Max))
+              MaxBECount = getConstant(APIntOps::umin(C0->getValue()->getValue(),
+                                                      C1->getValue()->getValue()));
+      } else {
+        // Both conditions must be true for the loop to exit.
+        assert(L->contains(FBB) && "Loop block has no successor in loop!");
+        if (BTI0.Exact != CouldNotCompute && BTI1.Exact != CouldNotCompute)
+          BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
+        if (BTI0.Max != CouldNotCompute && BTI1.Max != CouldNotCompute)
+          MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max);
+      }
+
+      return BackedgeTakenInfo(BECount, MaxBECount);
+    }
+    if (BO->getOpcode() == Instruction::Or) {
+      // Recurse on the operands of the or.
+      BackedgeTakenInfo BTI0 =
+        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB);
+      BackedgeTakenInfo BTI1 =
+        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB);
+      SCEVHandle BECount = CouldNotCompute;
+      SCEVHandle MaxBECount = CouldNotCompute;
+      if (L->contains(FBB)) {
+        // Both conditions must be false for the loop to continue executing.
+        // Choose the less conservative count.
+        // TODO: Take the minimum of the exact counts.
+        if (BTI0.Exact == BTI1.Exact)
+          BECount = BTI0.Exact;
+        // TODO: Take the minimum of the maximum counts.
+        if (BTI0.Max == CouldNotCompute)
+          MaxBECount = BTI1.Max;
+        else if (BTI1.Max == CouldNotCompute)
+          MaxBECount = BTI0.Max;
+        else if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(BTI0.Max))
+          if (const SCEVConstant *C1 = dyn_cast<SCEVConstant>(BTI1.Max))
+              MaxBECount = getConstant(APIntOps::umin(C0->getValue()->getValue(),
+                                                      C1->getValue()->getValue()));
+      } else {
+        // Both conditions must be false for the loop to exit.
+        assert(L->contains(TBB) && "Loop block has no successor in loop!");
+        if (BTI0.Exact != CouldNotCompute && BTI1.Exact != CouldNotCompute)
+          BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
+        if (BTI0.Max != CouldNotCompute && BTI1.Max != CouldNotCompute)
+          MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max);
+      }
+
+      return BackedgeTakenInfo(BECount, MaxBECount);
+    }
+  }
+
+  // With an icmp, it may be feasible to compute an exact backedge-taken count.
+  // Procede to the next level to examine the icmp.
+  if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond))
+    return ComputeBackedgeTakenCountFromExitCondICmp(L, ExitCondICmp, TBB, FBB);
 
   // If it's not an integer or pointer comparison then compute it the hard way.
-  if (ExitCond == 0)
-    return ComputeBackedgeTakenCountExhaustively(L, ExitBr->getCondition(),
-                                          ExitBr->getSuccessor(0) == ExitBlock);
+  return ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
+}
+
+/// ComputeBackedgeTakenCountFromExitCondICmp - Compute the number of times the
+/// backedge of the specified loop will execute if its exit condition
+/// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB.
+ScalarEvolution::BackedgeTakenInfo
+ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L,
+                                                           ICmpInst *ExitCond,
+                                                           BasicBlock *TBB,
+                                                           BasicBlock *FBB) {
 
   // If the condition was exit on true, convert the condition to exit on false
   ICmpInst::Predicate Cond;
-  if (ExitBr->getSuccessor(1) == ExitBlock)
+  if (!L->contains(FBB))
     Cond = ExitCond->getPredicate();
   else
     Cond = ExitCond->getInversePredicate();
@@ -2834,7 +3024,12 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
     if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
       SCEVHandle ItCnt =
         ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond);
-      if (!isa<SCEVCouldNotCompute>(ItCnt)) return ItCnt;
+      if (!isa<SCEVCouldNotCompute>(ItCnt)) {
+        unsigned BitWidth = getTypeSizeInBits(ItCnt->getType());
+        return BackedgeTakenInfo(ItCnt,
+                                 isa<SCEVConstant>(ItCnt) ? ItCnt :
+                                   getConstant(APInt::getMaxValue(BitWidth)-1));
+      }
     }
 
   SCEVHandle LHS = getSCEV(ExitCond->getOperand(0));
@@ -2912,8 +3107,7 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
     break;
   }
   return
-    ComputeBackedgeTakenCountExhaustively(L, ExitCond,
-                                          ExitBr->getSuccessor(0) == ExitBlock);
+    ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
 }
 
 static ConstantInt *
diff --git a/test/Analysis/ScalarEvolution/trip-count5.ll b/test/Analysis/ScalarEvolution/trip-count5.ll
new file mode 100644 (file)
index 0000000..822dc26
--- /dev/null
@@ -0,0 +1,48 @@
+; RUN: llvm-as < %s | opt -analyze -scalar-evolution -disable-output > %t
+; RUN: grep sext %t | count 2
+; RUN: not grep {(sext} %t
+
+; ScalarEvolution should be able to compute a maximum trip count
+; value sufficient to fold away both sext casts.
+
+target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128"
+
+define float @t(float* %pTmp1, float* %peakWeight, float* %nrgReducePeakrate, i32 %bim) nounwind {
+entry:
+       %tmp3 = load float* %peakWeight, align 4                ; <float> [#uses=2]
+       %tmp2538 = icmp sgt i32 %bim, 0         ; <i1> [#uses=1]
+       br i1 %tmp2538, label %bb.nph, label %bb4
+
+bb.nph:                ; preds = %entry
+       br label %bb
+
+bb:            ; preds = %bb1, %bb.nph
+       %distERBhi.036 = phi float [ %tmp10, %bb1 ], [ 0.000000e+00, %bb.nph ]          ; <float> [#uses=1]
+       %hiPart.035 = phi i32 [ %tmp12, %bb1 ], [ 0, %bb.nph ]          ; <i32> [#uses=2]
+       %peakCount.034 = phi float [ %tmp19, %bb1 ], [ %tmp3, %bb.nph ]         ; <float> [#uses=1]
+       %tmp6 = sext i32 %hiPart.035 to i64             ; <i64> [#uses=1]
+       %tmp7 = getelementptr float* %pTmp1, i64 %tmp6          ; <float*> [#uses=1]
+       %tmp8 = load float* %tmp7, align 4              ; <float> [#uses=1]
+       %tmp10 = fadd float %tmp8, %distERBhi.036               ; <float> [#uses=3]
+       %tmp12 = add i32 %hiPart.035, 1         ; <i32> [#uses=3]
+       %tmp15 = sext i32 %tmp12 to i64         ; <i64> [#uses=1]
+       %tmp16 = getelementptr float* %peakWeight, i64 %tmp15           ; <float*> [#uses=1]
+       %tmp17 = load float* %tmp16, align 4            ; <float> [#uses=1]
+       %tmp19 = fadd float %tmp17, %peakCount.034              ; <float> [#uses=2]
+       br label %bb1
+
+bb1:           ; preds = %bb
+       %tmp21 = fcmp olt float %tmp10, 2.500000e+00            ; <i1> [#uses=1]
+       %tmp25 = icmp slt i32 %tmp12, %bim              ; <i1> [#uses=1]
+       %tmp27 = and i1 %tmp21, %tmp25          ; <i1> [#uses=1]
+       br i1 %tmp27, label %bb, label %bb1.bb4_crit_edge
+
+bb1.bb4_crit_edge:             ; preds = %bb1
+       br label %bb4
+
+bb4:           ; preds = %bb1.bb4_crit_edge, %entry
+       %distERBhi.0.lcssa = phi float [ %tmp10, %bb1.bb4_crit_edge ], [ 0.000000e+00, %entry ]         ; <float> [#uses=1]
+       %peakCount.0.lcssa = phi float [ %tmp19, %bb1.bb4_crit_edge ], [ %tmp3, %entry ]                ; <float> [#uses=1]
+       %tmp31 = fdiv float %peakCount.0.lcssa, %distERBhi.0.lcssa              ; <float> [#uses=1]
+       ret float %tmp31
+}