Replace the original ad-hoc code for determining whether (v pred w) implies
authorDan Gohman <gohman@apple.com>
Tue, 21 Jul 2009 23:03:19 +0000 (23:03 +0000)
committerDan Gohman <gohman@apple.com>
Tue, 21 Jul 2009 23:03:19 +0000 (23:03 +0000)
(x pred y) with more thorough code that does more complete canonicalization
before resorting to range checks. This helps it find more cases where
the canonicalized expressions match.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@76671 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
lib/Analysis/ScalarEvolution.cpp
test/Analysis/ScalarEvolution/trip-count8.ll [new file with mode: 0644]

index 274a3b5f58b58a29e20aab753e837dfef84d26d5..d22f5e42af5f15b071df4b301b1757b03552e088 100644 (file)
@@ -337,19 +337,25 @@ namespace llvm {
     /// found.
     BasicBlock* getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB);
 
-    /// isNecessaryCond - Test whether the condition described by Pred, LHS,
-    /// and RHS is a necessary condition for the given Cond value to evaluate
-    /// to true.
-    bool isNecessaryCond(Value *Cond, ICmpInst::Predicate Pred,
-                         const SCEV *LHS, const SCEV *RHS,
-                         bool Inverse);
-
-    /// isNecessaryCondOperands - Test whether the condition described by Pred,
-    /// LHS, and RHS is a necessary condition for the condition described by
-    /// Pred, FoundLHS, and FoundRHS to evaluate to true.
-    bool isNecessaryCondOperands(ICmpInst::Predicate Pred,
-                                 const SCEV *LHS, const SCEV *RHS,
-                                 const SCEV *FoundLHS, const SCEV *FoundRHS);
+    /// isImpliedCond - Test whether the condition described by Pred, LHS,
+    /// and RHS is true whenever the given Cond value evaluates to true.
+    bool isImpliedCond(Value *Cond, ICmpInst::Predicate Pred,
+                       const SCEV *LHS, const SCEV *RHS,
+                       bool Inverse);
+
+    /// isImpliedCondOperands - Test whether the condition described by Pred,
+    /// LHS, and RHS is true whenever the condition desribed by Pred, FoundLHS,
+    /// and FoundRHS is true.
+    bool isImpliedCondOperands(ICmpInst::Predicate Pred,
+                               const SCEV *LHS, const SCEV *RHS,
+                               const SCEV *FoundLHS, const SCEV *FoundRHS);
+
+    /// isImpliedCondOperandsHelper - Test whether the condition described by
+    /// Pred, LHS, and RHS is true whenever the condition desribed by Pred,
+    /// FoundLHS, and FoundRHS is true.
+    bool isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
+                                     const SCEV *LHS, const SCEV *RHS,
+                                     const SCEV *FoundLHS, const SCEV *FoundRHS);
 
     /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
     /// in the header of its containing loop, we know the loop executes a
index 20b2aa48c12fc92075ca85489a786cbe1602ec9b..62b20321074f3b30ccf2a65dc6751da741d8491f 100644 (file)
@@ -4358,9 +4358,8 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
       LoopContinuePredicate->isUnconditional())
     return false;
 
-  return
-    isNecessaryCond(LoopContinuePredicate->getCondition(), Pred, LHS, RHS,
-                    LoopContinuePredicate->getSuccessor(0) != L->getHeader());
+  return isImpliedCond(LoopContinuePredicate->getCondition(), Pred, LHS, RHS,
+                       LoopContinuePredicate->getSuccessor(0) != L->getHeader());
 }
 
 /// isLoopGuardedByCond - Test whether entry to the loop is protected
@@ -4390,122 +4389,55 @@ ScalarEvolution::isLoopGuardedByCond(const Loop *L,
         LoopEntryPredicate->isUnconditional())
       continue;
 
-    if (isNecessaryCond(LoopEntryPredicate->getCondition(), Pred, LHS, RHS,
-                        LoopEntryPredicate->getSuccessor(0) != PredecessorDest))
+    if (isImpliedCond(LoopEntryPredicate->getCondition(), Pred, LHS, RHS,
+                      LoopEntryPredicate->getSuccessor(0) != PredecessorDest))
       return true;
   }
 
   return false;
 }
 
-/// isNecessaryCond - Test whether the condition described by Pred, LHS,
-/// and RHS is a necessary condition for the given Cond value to evaluate
-/// to true.
-bool ScalarEvolution::isNecessaryCond(Value *CondValue,
-                                      ICmpInst::Predicate Pred,
-                                      const SCEV *LHS, const SCEV *RHS,
-                                      bool Inverse) {
+/// isImpliedCond - Test whether the condition described by Pred, LHS,
+/// and RHS is true whenever the given Cond value evaluates to true.
+bool ScalarEvolution::isImpliedCond(Value *CondValue,
+                                    ICmpInst::Predicate Pred,
+                                    const SCEV *LHS, const SCEV *RHS,
+                                    bool Inverse) {
   // Recursivly handle And and Or conditions.
   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(CondValue)) {
     if (BO->getOpcode() == Instruction::And) {
       if (!Inverse)
-        return isNecessaryCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) ||
-               isNecessaryCond(BO->getOperand(1), Pred, LHS, RHS, Inverse);
+        return isImpliedCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) ||
+               isImpliedCond(BO->getOperand(1), Pred, LHS, RHS, Inverse);
     } else if (BO->getOpcode() == Instruction::Or) {
       if (Inverse)
-        return isNecessaryCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) ||
-               isNecessaryCond(BO->getOperand(1), Pred, LHS, RHS, Inverse);
+        return isImpliedCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) ||
+               isImpliedCond(BO->getOperand(1), Pred, LHS, RHS, Inverse);
     }
   }
 
   ICmpInst *ICI = dyn_cast<ICmpInst>(CondValue);
   if (!ICI) return false;
 
-  // Now that we found a conditional branch that dominates the loop, check to
-  // see if it is the comparison we are looking for.
-  Value *PreCondLHS = ICI->getOperand(0);
-  Value *PreCondRHS = ICI->getOperand(1);
-  ICmpInst::Predicate FoundPred;
-  if (Inverse)
-    FoundPred = ICI->getInversePredicate();
-  else
-    FoundPred = ICI->getPredicate();
-
-  if (FoundPred == Pred)
-    ; // An exact match.
-  else if (!ICmpInst::isTrueWhenEqual(FoundPred) && Pred == ICmpInst::ICMP_NE) {
-    // The actual condition is beyond sufficient.
-    FoundPred = ICmpInst::ICMP_NE;
-    // NE is symmetric but the original comparison may not be. Swap
-    // the operands if necessary so that they match below.
-    if (isa<SCEVConstant>(LHS))
-      std::swap(PreCondLHS, PreCondRHS);
-  } else
-    // Check a few special cases.
-    switch (FoundPred) {
-    case ICmpInst::ICMP_UGT:
-      if (Pred == ICmpInst::ICMP_ULT) {
-        std::swap(PreCondLHS, PreCondRHS);
-        FoundPred = ICmpInst::ICMP_ULT;
-        break;
-      }
-      return false;
-    case ICmpInst::ICMP_SGT:
-      if (Pred == ICmpInst::ICMP_SLT) {
-        std::swap(PreCondLHS, PreCondRHS);
-        FoundPred = ICmpInst::ICMP_SLT;
-        break;
-      }
-      return false;
-    case ICmpInst::ICMP_NE:
-      // Expressions like (x >u 0) are often canonicalized to (x != 0),
-      // so check for this case by checking if the NE is comparing against
-      // a minimum or maximum constant.
-      if (!ICmpInst::isTrueWhenEqual(Pred))
-        if (const SCEVConstant *C = dyn_cast<SCEVConstant>(RHS)) {
-          const APInt &A = C->getValue()->getValue();
-          switch (Pred) {
-          case ICmpInst::ICMP_SLT:
-            if (A.isMaxSignedValue()) break;
-            return false;
-          case ICmpInst::ICMP_SGT:
-            if (A.isMinSignedValue()) break;
-            return false;
-          case ICmpInst::ICMP_ULT:
-            if (A.isMaxValue()) break;
-            return false;
-          case ICmpInst::ICMP_UGT:
-            if (A.isMinValue()) break;
-            return false;
-          default:
-            return false;
-          }
-          FoundPred = Pred;
-          // NE is symmetric but the original comparison may not be. Swap
-          // the operands if necessary so that they match below.
-          if (isa<SCEVConstant>(LHS))
-            std::swap(PreCondLHS, PreCondRHS);
-          break;
-        }
-      return false;
-    default:
-      // We weren't able to reconcile the condition.
-      return false;
-    }
-
-  assert(Pred == FoundPred && "Conditions were not reconciled!");
-
   // Bail if the ICmp's operands' types are wider than the needed type
   // before attempting to call getSCEV on them. This avoids infinite
   // recursion, since the analysis of widening casts can require loop
   // exit condition information for overflow checking, which would
   // lead back here.
   if (getTypeSizeInBits(LHS->getType()) <
-      getTypeSizeInBits(PreCondLHS->getType()))
+      getTypeSizeInBits(ICI->getOperand(0)->getType()))
     return false;
 
-  const SCEV *FoundLHS = getSCEV(PreCondLHS);
-  const SCEV *FoundRHS = getSCEV(PreCondRHS);
+  // Now that we found a conditional branch that dominates the loop, check to
+  // see if it is the comparison we are looking for.
+  ICmpInst::Predicate FoundPred;
+  if (Inverse)
+    FoundPred = ICI->getInversePredicate();
+  else
+    FoundPred = ICI->getPredicate();
+
+  const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
+  const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
 
   // Balance the types. The case where FoundLHS' type is wider than
   // LHS' type is checked for above.
@@ -4520,21 +4452,182 @@ bool ScalarEvolution::isNecessaryCond(Value *CondValue,
     }
   }
 
-  return isNecessaryCondOperands(Pred, LHS, RHS,
-                                 FoundLHS, FoundRHS) ||
+  // Canonicalize the query to match the way instcombine will have
+  // canonicalized the comparison.
+  // First, put a constant operand on the right.
+  if (isa<SCEVConstant>(LHS)) {
+    std::swap(LHS, RHS);
+    Pred = ICmpInst::getSwappedPredicate(Pred);
+  }
+  // Then, canonicalize comparisons with boundary cases.
+  if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
+    const APInt &RA = RC->getValue()->getValue();
+    switch (Pred) {
+    default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
+    case ICmpInst::ICMP_EQ:
+    case ICmpInst::ICMP_NE:
+      break;
+    case ICmpInst::ICMP_UGE:
+      if ((RA - 1).isMinValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        RHS = getConstant(RA - 1);
+        break;
+      }
+      if (RA.isMaxValue()) {
+        Pred = ICmpInst::ICMP_EQ;
+        break;
+      }
+      if (RA.isMinValue()) return true;
+      break;
+    case ICmpInst::ICMP_ULE:
+      if ((RA + 1).isMaxValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        RHS = getConstant(RA + 1);
+        break;
+      }
+      if (RA.isMinValue()) {
+        Pred = ICmpInst::ICMP_EQ;
+        break;
+      }
+      if (RA.isMaxValue()) return true;
+      break;
+    case ICmpInst::ICMP_SGE:
+      if ((RA - 1).isMinSignedValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        RHS = getConstant(RA - 1);
+        break;
+      }
+      if (RA.isMaxSignedValue()) {
+        Pred = ICmpInst::ICMP_EQ;
+        break;
+      }
+      if (RA.isMinSignedValue()) return true;
+      break;
+    case ICmpInst::ICMP_SLE:
+      if ((RA + 1).isMaxSignedValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        RHS = getConstant(RA + 1);
+        break;
+      }
+      if (RA.isMinSignedValue()) {
+        Pred = ICmpInst::ICMP_EQ;
+        break;
+      }
+      if (RA.isMaxSignedValue()) return true;
+      break;
+    case ICmpInst::ICMP_UGT:
+      if (RA.isMinValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        break;
+      }
+      if ((RA + 1).isMaxValue()) {
+        Pred = ICmpInst::ICMP_EQ;
+        RHS = getConstant(RA + 1);
+        break;
+      }
+      if (RA.isMaxValue()) return false;
+      break;
+    case ICmpInst::ICMP_ULT:
+      if (RA.isMaxValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        break;
+      }
+      if ((RA - 1).isMinValue()) {
+        Pred = ICmpInst::ICMP_EQ;
+        RHS = getConstant(RA - 1);
+        break;
+      }
+      if (RA.isMinValue()) return false;
+      break;
+    case ICmpInst::ICMP_SGT:
+      if (RA.isMinSignedValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        break;
+      }
+      if ((RA + 1).isMaxSignedValue()) {
+        Pred = ICmpInst::ICMP_EQ;
+        RHS = getConstant(RA + 1);
+        break;
+      }
+      if (RA.isMaxSignedValue()) return false;
+      break;
+    case ICmpInst::ICMP_SLT:
+      if (RA.isMaxSignedValue()) {
+        Pred = ICmpInst::ICMP_NE;
+        break;
+      }
+      if ((RA - 1).isMinSignedValue()) {
+       Pred = ICmpInst::ICMP_EQ;
+       RHS = getConstant(RA - 1);
+       break;
+      }
+      if (RA.isMinSignedValue()) return false;
+      break;
+    }
+  }
+
+  // Check to see if we can make the LHS or RHS match.
+  if (LHS == FoundRHS || RHS == FoundLHS) {
+    if (isa<SCEVConstant>(RHS)) {
+      std::swap(FoundLHS, FoundRHS);
+      FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
+    } else {
+      std::swap(LHS, RHS);
+      Pred = ICmpInst::getSwappedPredicate(Pred);
+    }
+  }
+
+  // Check whether the found predicate is the same as the desired predicate.
+  if (FoundPred == Pred)
+    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
+
+  // Check whether swapping the found predicate makes it the same as the
+  // desired predicate.
+  if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
+    if (isa<SCEVConstant>(RHS))
+      return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS);
+    else
+      return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred),
+                                   RHS, LHS, FoundLHS, FoundRHS);
+  }
+
+  // Check whether the actual condition is beyond sufficient.
+  if (FoundPred == ICmpInst::ICMP_EQ)
+    if (ICmpInst::isTrueWhenEqual(Pred))
+      if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS))
+        return true;
+  if (Pred == ICmpInst::ICMP_NE)
+    if (!ICmpInst::isTrueWhenEqual(FoundPred))
+      if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS))
+        return true;
+
+  // Otherwise assume the worst.
+  return false;
+}
+
+/// isImpliedCondOperands - Test whether the condition described by Pred,
+/// LHS, and RHS is true whenever the condition desribed by Pred, FoundLHS,
+/// and FoundRHS is true.
+bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
+                                            const SCEV *LHS, const SCEV *RHS,
+                                            const SCEV *FoundLHS,
+                                            const SCEV *FoundRHS) {
+  return isImpliedCondOperandsHelper(Pred, LHS, RHS,
+                                     FoundLHS, FoundRHS) ||
          // ~x < ~y --> x > y
-         isNecessaryCondOperands(Pred, LHS, RHS,
-                                 getNotSCEV(FoundRHS), getNotSCEV(FoundLHS));
+         isImpliedCondOperandsHelper(Pred, LHS, RHS,
+                                     getNotSCEV(FoundRHS),
+                                     getNotSCEV(FoundLHS));
 }
 
-/// isNecessaryCondOperands - Test whether the condition described by Pred,
-/// LHS, and RHS is a necessary condition for the condition described by
-/// Pred, FoundLHS, and FoundRHS to evaluate to true.
+/// isImpliedCondOperandsHelper - Test whether the condition described by
+/// Pred, LHS, and RHS is true whenever the condition desribed by Pred,
+/// FoundLHS, and FoundRHS is true.
 bool
-ScalarEvolution::isNecessaryCondOperands(ICmpInst::Predicate Pred,
-                                         const SCEV *LHS, const SCEV *RHS,
-                                         const SCEV *FoundLHS,
-                                         const SCEV *FoundRHS) {
+ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
+                                             const SCEV *LHS, const SCEV *RHS,
+                                             const SCEV *FoundLHS,
+                                             const SCEV *FoundRHS) {
   switch (Pred) {
   default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
   case ICmpInst::ICMP_EQ:
diff --git a/test/Analysis/ScalarEvolution/trip-count8.ll b/test/Analysis/ScalarEvolution/trip-count8.ll
new file mode 100644 (file)
index 0000000..21ccc47
--- /dev/null
@@ -0,0 +1,37 @@
+; RUN: llvm-as < %s | opt -analyze -scalar-evolution -disable-output \
+; RUN:  | grep {Loop for\\.body: backedge-taken count is (-1 + \[%\]ecx)}
+; PR4599
+
+target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128"
+
+define i32 @foo(i32 %ecx) nounwind {
+entry:
+       %cmp2 = icmp eq i32 %ecx, 0             ; <i1> [#uses=1]
+       br i1 %cmp2, label %for.end, label %bb.nph
+
+for.cond:              ; preds = %for.inc
+       %cmp = icmp ult i32 %inc, %ecx          ; <i1> [#uses=1]
+       br i1 %cmp, label %for.body, label %for.cond.for.end_crit_edge
+
+for.cond.for.end_crit_edge:            ; preds = %for.cond
+       %phitmp = add i32 %i.01, 2              ; <i32> [#uses=1]
+       br label %for.end
+
+bb.nph:                ; preds = %entry
+       br label %for.body
+
+for.body:              ; preds = %bb.nph, %for.cond
+       %i.01 = phi i32 [ %inc, %for.cond ], [ 0, %bb.nph ]             ; <i32> [#uses=3]
+       %call = call i32 @bar(i32 %i.01) nounwind               ; <i32> [#uses=0]
+       br label %for.inc
+
+for.inc:               ; preds = %for.body
+       %inc = add i32 %i.01, 1         ; <i32> [#uses=2]
+       br label %for.cond
+
+for.end:               ; preds = %for.cond.for.end_crit_edge, %entry
+       %i.0.lcssa = phi i32 [ %phitmp, %for.cond.for.end_crit_edge ], [ 1, %entry ]            ; <i32> [#uses=1]
+       ret i32 %i.0.lcssa
+}
+
+declare i32 @bar(i32)