Fix PR1798 - an error in the evaluation of SCEVAddRecExpr at an
authorWojciech Matyjewicz <wmatyjewicz@fastmail.fm>
Mon, 11 Feb 2008 11:03:14 +0000 (11:03 +0000)
committerWojciech Matyjewicz <wmatyjewicz@fastmail.fm>
Mon, 11 Feb 2008 11:03:14 +0000 (11:03 +0000)
arbitrary iteration.

The patch:
1) changes SCEVSDivExpr into SCEVUDivExpr,
2) replaces PartialFact() function with BinomialCoefficient(); the
computations (essentially, the division) in BinomialCoefficient() are
performed with the apprioprate bitwidth necessary to avoid overflow;
unsigned division is used instead of the signed one.

Computations in BinomialCoefficient() require support from the code
generator for APInts. Currently, we use a hack rounding up the
neccessary bitwidth to the nearest power of 2. The hack is easy to turn
off in future.

One remaining issue: we assume the divisor of the binomial coefficient
formula can be computed accurately using 16 bits. It means we can handle
AddRecs of length up to 9. In future, we should use APInts to evaluate
the divisor.

Thanks to Nicholas for cooperation!

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@46955 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/ScalarEvolution.h
include/llvm/Analysis/ScalarEvolutionExpander.h
include/llvm/Analysis/ScalarEvolutionExpressions.h
lib/Analysis/ScalarEvolution.cpp
test/Analysis/ScalarEvolution/2007-11-14-SignedAddRec.ll

index b1cd287c7e487a9af61391291a76fc7051447c06..ecf28adfc88fe7634a34560e9d5ad5c27bd148fd 100644 (file)
@@ -225,7 +225,7 @@ namespace llvm {
       Ops.push_back(RHS);
       return getMulExpr(Ops);
     }
-    SCEVHandle getSDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS);
+    SCEVHandle getUDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS);
     SCEVHandle getAddRecExpr(const SCEVHandle &Start, const SCEVHandle &Step,
                              const Loop *L);
     SCEVHandle getAddRecExpr(std::vector<SCEVHandle> &Operands,
index 4470c9c1732dc89f5441505e78afc235831700a0..530ce378803b6987a44286d1732f0f2a5b89f705 100644 (file)
@@ -126,10 +126,10 @@ namespace llvm {
 
     Value *visitMulExpr(SCEVMulExpr *S);
 
-    Value *visitSDivExpr(SCEVSDivExpr *S) {
+    Value *visitUDivExpr(SCEVUDivExpr *S) {
       Value *LHS = expand(S->getLHS());
       Value *RHS = expand(S->getRHS());
-      return InsertBinop(Instruction::SDiv, LHS, RHS, InsertPt);
+      return InsertBinop(Instruction::UDiv, LHS, RHS, InsertPt);
     }
 
     Value *visitAddRecExpr(SCEVAddRecExpr *S);
index 6564d636b310a83f80d098b7c9f0798adde4baac..409ad9ecc407a00145c2e106c321eff3ffffd503 100644 (file)
@@ -25,7 +25,7 @@ namespace llvm {
     // These should be ordered in terms of increasing complexity to make the
     // folders simpler.
     scConstant, scTruncate, scZeroExtend, scSignExtend, scAddExpr, scMulExpr,
-    scSDivExpr, scAddRecExpr, scSMaxExpr, scUnknown, scCouldNotCompute
+    scUDivExpr, scAddRecExpr, scSMaxExpr, scUnknown, scCouldNotCompute
   };
 
   //===--------------------------------------------------------------------===//
@@ -322,16 +322,16 @@ namespace llvm {
 
 
   //===--------------------------------------------------------------------===//
-  /// SCEVSDivExpr - This class represents a binary signed division operation.
+  /// SCEVUDivExpr - This class represents a binary unsigned division operation.
   ///
-  class SCEVSDivExpr : public SCEV {
+  class SCEVUDivExpr : public SCEV {
     friend class ScalarEvolution;
 
     SCEVHandle LHS, RHS;
-    SCEVSDivExpr(const SCEVHandle &lhs, const SCEVHandle &rhs)
-      : SCEV(scSDivExpr), LHS(lhs), RHS(rhs) {}
+    SCEVUDivExpr(const SCEVHandle &lhs, const SCEVHandle &rhs)
+      : SCEV(scUDivExpr), LHS(lhs), RHS(rhs) {}
 
-    virtual ~SCEVSDivExpr();
+    virtual ~SCEVUDivExpr();
   public:
     const SCEVHandle &getLHS() const { return LHS; }
     const SCEVHandle &getRHS() const { return RHS; }
@@ -353,7 +353,7 @@ namespace llvm {
       if (L == LHS && R == RHS)
         return this;
       else
-        return SE.getSDivExpr(L, R);
+        return SE.getUDivExpr(L, R);
     }
 
 
@@ -363,9 +363,9 @@ namespace llvm {
     void print(std::ostream *OS) const { if (OS) print(*OS); }
 
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
-    static inline bool classof(const SCEVSDivExpr *S) { return true; }
+    static inline bool classof(const SCEVUDivExpr *S) { return true; }
     static inline bool classof(const SCEV *S) {
-      return S->getSCEVType() == scSDivExpr;
+      return S->getSCEVType() == scUDivExpr;
     }
   };
 
@@ -540,8 +540,8 @@ namespace llvm {
         return ((SC*)this)->visitAddExpr((SCEVAddExpr*)S);
       case scMulExpr:
         return ((SC*)this)->visitMulExpr((SCEVMulExpr*)S);
-      case scSDivExpr:
-        return ((SC*)this)->visitSDivExpr((SCEVSDivExpr*)S);
+      case scUDivExpr:
+        return ((SC*)this)->visitUDivExpr((SCEVUDivExpr*)S);
       case scAddRecExpr:
         return ((SC*)this)->visitAddRecExpr((SCEVAddRecExpr*)S);
       case scSMaxExpr:
index 10f05bc8dd0a88c3430518a80241cfa2e6d003dd..cbfc56373aea72c834376e2d902daa15b5cb9956 100644 (file)
@@ -328,21 +328,21 @@ replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym,
 }
 
 
-// SCEVSDivs - Only allow the creation of one SCEVSDivExpr for any particular
+// SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular
 // input.  Don't use a SCEVHandle here, or else the object will never be
 // deleted!
 static ManagedStatic<std::map<std::pair<SCEV*, SCEV*>, 
-                     SCEVSDivExpr*> > SCEVSDivs;
+                     SCEVUDivExpr*> > SCEVUDivs;
 
-SCEVSDivExpr::~SCEVSDivExpr() {
-  SCEVSDivs->erase(std::make_pair(LHS, RHS));
+SCEVUDivExpr::~SCEVUDivExpr() {
+  SCEVUDivs->erase(std::make_pair(LHS, RHS));
 }
 
-void SCEVSDivExpr::print(std::ostream &OS) const {
-  OS << "(" << *LHS << " /s " << *RHS << ")";
+void SCEVUDivExpr::print(std::ostream &OS) const {
+  OS << "(" << *LHS << " /u " << *RHS << ")";
 }
 
-const Type *SCEVSDivExpr::getType() const {
+const Type *SCEVUDivExpr::getType() const {
   return LHS->getType();
 }
 
@@ -532,57 +532,110 @@ SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS,
 }
 
 
-/// PartialFact - Compute V!/(V-NumSteps)!
-static SCEVHandle PartialFact(SCEVHandle V, unsigned NumSteps,
-                              ScalarEvolution &SE) {
+/// BinomialCoefficient - Compute BC(It, K).  The result is of the same type as
+/// It.  Assume, K > 0.
+static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K,
+                                      ScalarEvolution &SE) {
+  // We are using the following formula for BC(It, K):
+  //
+  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
+  //
+  // Suppose, W is the bitwidth of It (and of the return value as well).  We
+  // must be prepared for overflow.  Hence, we must assure that the result of
+  // our computation is equal to the accurate one modulo 2^W.  Unfortunately,
+  // division isn't safe in modular arithmetic.  This means we must perform the
+  // whole computation accurately and then truncate the result to W bits.
+  //
+  // The dividend of the formula is a multiplication of K integers of bitwidth
+  // W.  K*W bits suffice to compute it accurately.
+  //
+  // FIXME: We assume the divisor can be accurately computed using 16-bit
+  // unsigned integer type. It is true up to K = 8 (AddRecs of length 9). In
+  // future we may use APInt to use the minimum number of bits necessary to
+  // compute it accurately.
+  //
+  // It is safe to use unsigned division here: the dividend is nonnegative and
+  // the divisor is positive.
+
+  // Handle the simplest case efficiently.
+  if (K == 1)
+    return It;
+
+  assert(K < 9 && "We cannot handle such long AddRecs yet.");
+  
+  // FIXME: A temporary hack to remove in future.  Arbitrary precision integers
+  // aren't supported by the code generator yet.  For the dividend, the bitwidth
+  // we use is the smallest power of 2 greater or equal to K*W and less or equal
+  // to 64.  Note that setting the upper bound for bitwidth may still lead to
+  // miscompilation in some cases.
+  unsigned DividendBits = 1U << Log2_32_Ceil(K * It->getBitWidth());
+  if (DividendBits > 64)
+    DividendBits = 64;
+#if 0 // Waiting for the APInt support in the code generator...
+  unsigned DividendBits = K * It->getBitWidth();
+#endif
+
+  const IntegerType *DividendTy = IntegerType::get(DividendBits);
+  const SCEVHandle ExIt = SE.getZeroExtendExpr(It, DividendTy);
+
+  // The final number of bits we need to perform the division is the maximum of
+  // dividend and divisor bitwidths.
+  const IntegerType *DivisionTy =
+    IntegerType::get(std::max(DividendBits, 16U));
+
+  // Compute K!  We know K >= 2 here.
+  unsigned F = 2;
+  for (unsigned i = 3; i <= K; ++i)
+    F *= i;
+  APInt Divisor(DivisionTy->getBitWidth(), F);
+
   // Handle this case efficiently, it is common to have constant iteration
   // counts while computing loop exit values.
-  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(V)) {
-    const APInt& Val = SC->getValue()->getValue();
-    APInt Result(Val.getBitWidth(), 1);
-    for (; NumSteps; --NumSteps)
-      Result *= Val-(NumSteps-1);
-    return SE.getConstant(Result);
+  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(ExIt)) {
+    const APInt& N = SC->getValue()->getValue();
+    APInt Dividend(N.getBitWidth(), 1);
+    for (; K; --K)
+      Dividend *= N-(K-1);
+    if (DividendTy != DivisionTy)
+      Dividend = Dividend.zext(DivisionTy->getBitWidth());
+    return SE.getConstant(Dividend.udiv(Divisor).trunc(It->getBitWidth()));
   }
-
-  const Type *Ty = V->getType();
-  if (NumSteps == 0)
-    return SE.getIntegerSCEV(1, Ty);
-
-  SCEVHandle Result = V;
-  for (unsigned i = 1; i != NumSteps; ++i)
-    Result = SE.getMulExpr(Result, SE.getMinusSCEV(V,
-                                                   SE.getIntegerSCEV(i, Ty)));
-  return Result;
+  
+  SCEVHandle Dividend = ExIt;
+  for (unsigned i = 1; i != K; ++i)
+    Dividend =
+      SE.getMulExpr(Dividend,
+                    SE.getMinusSCEV(ExIt, SE.getIntegerSCEV(i, DividendTy)));
+  if (DividendTy != DivisionTy)
+    Dividend = SE.getZeroExtendExpr(Dividend, DivisionTy);
+  return
+    SE.getTruncateExpr(SE.getUDivExpr(Dividend, SE.getConstant(Divisor)),
+                       It->getType());
 }
 
-
 /// evaluateAtIteration - Return the value of this chain of recurrences at
 /// the specified iteration number.  We can evaluate this recurrence by
 /// multiplying each element in the chain by the binomial coefficient
 /// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
 ///
-///   A*choose(It, 0) + B*choose(It, 1) + C*choose(It, 2) + D*choose(It, 3)
+///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
 ///
-/// FIXME/VERIFY: I don't trust that this is correct in the face of overflow.
-/// Is the binomial equation safe using modular arithmetic??
+/// where BC(It, k) stands for binomial coefficient.
 ///
 SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It,
                                                ScalarEvolution &SE) const {
   SCEVHandle Result = getStart();
-  int Divisor = 1;
-  const Type *Ty = It->getType();
   for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
-    SCEVHandle BC = PartialFact(It, i, SE);
-    Divisor *= i;
-    SCEVHandle Val = SE.getSDivExpr(SE.getMulExpr(BC, getOperand(i)),
-                                    SE.getIntegerSCEV(Divisor,Ty));
+    // The computation is correct in the face of overflow provided that the
+    // multiplication is performed _after_ the evaluation of the binomial
+    // coefficient.
+    SCEVHandle Val = SE.getMulExpr(getOperand(i),
+                                   BinomialCoefficient(It, i, SE));
     Result = SE.getAddExpr(Result, Val);
   }
   return Result;
 }
 
-
 //===----------------------------------------------------------------------===//
 //                    SCEV Expression folder implementations
 //===----------------------------------------------------------------------===//
@@ -1039,24 +1092,22 @@ SCEVHandle ScalarEvolution::getMulExpr(std::vector<SCEVHandle> &Ops) {
   return Result;
 }
 
-SCEVHandle ScalarEvolution::getSDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) {
+SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, const SCEVHandle &RHS) {
   if (SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
     if (RHSC->getValue()->equalsInt(1))
-      return LHS;                            // X sdiv 1 --> x
-    if (RHSC->getValue()->isAllOnesValue())
-      return getNegativeSCEV(LHS);           // X sdiv -1  -->  -x
+      return LHS;                            // X udiv 1 --> x
 
     if (SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
       Constant *LHSCV = LHSC->getValue();
       Constant *RHSCV = RHSC->getValue();
-      return getUnknown(ConstantExpr::getSDiv(LHSCV, RHSCV));
+      return getUnknown(ConstantExpr::getUDiv(LHSCV, RHSCV));
     }
   }
 
   // FIXME: implement folding of (X*4)/4 when we know X*4 doesn't overflow.
 
-  SCEVSDivExpr *&Result = (*SCEVSDivs)[std::make_pair(LHS, RHS)];
-  if (Result == 0) Result = new SCEVSDivExpr(LHS, RHS);
+  SCEVUDivExpr *&Result = (*SCEVUDivs)[std::make_pair(LHS, RHS)];
+  if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS);
   return Result;
 }
 
@@ -1555,7 +1606,7 @@ static uint32_t GetMinTrailingZeros(SCEVHandle S) {
     return MinOpRes;
   }
 
-  // SCEVSDivExpr, SCEVUnknown
+  // SCEVUDivExpr, SCEVUnknown
   return 0;
 }
 
@@ -1574,8 +1625,8 @@ SCEVHandle ScalarEvolutionsImpl::createSCEV(Value *V) {
     case Instruction::Mul:
       return SE.getMulExpr(getSCEV(I->getOperand(0)),
                            getSCEV(I->getOperand(1)));
-    case Instruction::SDiv:
-      return SE.getSDivExpr(getSCEV(I->getOperand(0)),
+    case Instruction::UDiv:
+      return SE.getUDivExpr(getSCEV(I->getOperand(0)),
                             getSCEV(I->getOperand(1)));
     case Instruction::Sub:
       return SE.getMinusSCEV(getSCEV(I->getOperand(0)),
@@ -2264,14 +2315,14 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) {
     return Comm;
   }
 
-  if (SCEVSDivExpr *Div = dyn_cast<SCEVSDivExpr>(V)) {
+  if (SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
     SCEVHandle LHS = getSCEVAtScope(Div->getLHS(), L);
     if (LHS == UnknownValue) return LHS;
     SCEVHandle RHS = getSCEVAtScope(Div->getRHS(), L);
     if (RHS == UnknownValue) return RHS;
     if (LHS == Div->getLHS() && RHS == Div->getRHS())
       return Div;   // must be loop invariant
-    return SE.getSDivExpr(LHS, RHS);
+    return SE.getUDivExpr(LHS, RHS);
   }
 
   // If this is a loop recurrence for a loop that does not contain L, then we
index 1bb6a20c546d2e9740417aa2bf418678a4ce2bc3..66ca7551c240e6c5d59d1c5c4814d82ebefeb3b6 100644 (file)
@@ -1,6 +1,5 @@
 ; RUN: llvm-as < %s | opt -indvars | llvm-dis | grep printd | grep 1206807378
 ; PR1798
-; XFAIL: *
 
 declare void @printd(i32)