PR2621: Improvements to the SCEV AddRec binomial expansion. This
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index bf6bc0b5dc288c2edfc43b3413cd81a66e1e9ce3..84e02e47a0c4fbb7879e51a126924aabe8a3782d 100644 (file)
@@ -507,91 +507,125 @@ SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS,
 }
 
 
-/// BinomialCoefficient - Compute BC(It, K).  The result is of the same type as
-/// It.  Assume, K > 0.
+/// BinomialCoefficient - Compute BC(It, K).  The result has width W.
+// Assume, K > 0.
 static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K,
-                                      ScalarEvolution &SE) {
+                                      ScalarEvolution &SE,
+                                      const IntegerType* ResultTy) {
+  // Handle the simplest case efficiently.
+  if (K == 1)
+    return SE.getTruncateOrZeroExtend(It, ResultTy);
+
   // We are using the following formula for BC(It, K):
   //
   //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
   //
-  // Suppose, W is the bitwidth of It (and of the return value as well).  We
-  // must be prepared for overflow.  Hence, we must assure that the result of
-  // our computation is equal to the accurate one modulo 2^W.  Unfortunately,
-  // division isn't safe in modular arithmetic.  This means we must perform the
-  // whole computation accurately and then truncate the result to W bits.
+  // Suppose, W is the bitwidth of the return value.  We must be prepared for
+  // overflow.  Hence, we must assure that the result of our computation is
+  // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
+  // safe in modular arithmetic.
+  //
+  // However, this code doesn't use exactly that formula; the formula it uses
+  // is something like the following, where T is the number of factors of 2 in 
+  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
+  // exponentiation:
+  //
+  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
   //
-  // The dividend of the formula is a multiplication of K integers of bitwidth
-  // W.  K*W bits suffice to compute it accurately.
+  // This formula is trivially equivalent to the previous formula.  However,
+  // this formula can be implemented much more efficiently.  The trick is that
+  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
+  // arithmetic.  To do exact division in modular arithmetic, all we have
+  // to do is multiply by the inverse.  Therefore, this step can be done at
+  // width W.
+  // 
+  // The next issue is how to safely do the division by 2^T.  The way this
+  // is done is by doing the multiplication step at a width of at least W + T
+  // bits.  This way, the bottom W+T bits of the product are accurate. Then,
+  // when we perform the division by 2^T (which is equivalent to a right shift
+  // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
+  // truncated out after the division by 2^T.
   //
-  // FIXME: We assume the divisor can be accurately computed using 16-bit
-  // unsigned integer type. It is true up to K = 8 (AddRecs of length 9). In
-  // future we may use APInt to use the minimum number of bits necessary to
-  // compute it accurately.
+  // In comparison to just directly using the first formula, this technique
+  // is much more efficient; using the first formula requires W * K bits,
+  // but this formula less than W + K bits. Also, the first formula requires
+  // a division step, whereas this formula only requires multiplies and shifts.
   //
-  // It is safe to use unsigned division here: the dividend is nonnegative and
-  // the divisor is positive.
+  // It doesn't matter whether the subtraction step is done in the calculation
+  // width or the input iteration count's width; if the subtraction overflows,
+  // the result must be zero anyway.  We prefer here to do it in the width of
+  // the induction variable because it helps a lot for certain cases; CodeGen
+  // isn't smart enough to ignore the overflow, which leads to much less
+  // efficient code if the width of the subtraction is wider than the native
+  // register width.
+  //
+  // (It's possible to not widen at all by pulling out factors of 2 before
+  // the multiplication; for example, K=2 can be calculated as
+  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
+  // extra arithmetic, so it's not an obvious win, and it gets
+  // much more complicated for K > 3.)
+
+  // Protection from insane SCEVs; this bound is conservative,
+  // but it probably doesn't matter.
+  if (K > 1000)
+    return new SCEVCouldNotCompute();
 
-  // Handle the simplest case efficiently.
-  if (K == 1)
-    return It;
+  unsigned W = ResultTy->getBitWidth();
+
+  // Calculate K! / 2^T and T; we divide out the factors of two before
+  // multiplying for calculating K! / 2^T to avoid overflow.
+  // Other overflow doesn't matter because we only care about the bottom
+  // W bits of the result.
+  APInt OddFactorial(W, 1);
+  unsigned T = 1;
+  for (unsigned i = 3; i <= K; ++i) {
+    APInt Mult(W, i);
+    unsigned TwoFactors = Mult.countTrailingZeros();
+    T += TwoFactors;
+    Mult = Mult.lshr(TwoFactors);
+    OddFactorial *= Mult;
+  }
 
-  assert(K < 9 && "We cannot handle such long AddRecs yet.");
-  
-  // FIXME: A temporary hack to remove in future.  Arbitrary precision integers
-  // aren't supported by the code generator yet.  For the dividend, the bitwidth
-  // we use is the smallest power of 2 greater or equal to K*W and less or equal
-  // to 64.  Note that setting the upper bound for bitwidth may still lead to
-  // miscompilation in some cases.
-  unsigned DividendBits = 1U << Log2_32_Ceil(K * It->getBitWidth());
-  if (DividendBits > 64)
-    DividendBits = 64;
-#if 0 // Waiting for the APInt support in the code generator...
-  unsigned DividendBits = K * It->getBitWidth();
-#endif
+  // We need at least W + T bits for the multiplication step
+  // FIXME: A temporary hack; we round up the bitwidths
+  // to the nearest power of 2 to be nice to the code generator.
+  unsigned CalculationBits = 1U << Log2_32_Ceil(W + T);
+  // FIXME: Temporary hack to avoid generating integers that are too wide.
+  // Although, it's not completely clear how to determine how much
+  // widening is safe; for example, on X86, we can't really widen
+  // beyond 64 because we need to be able to do multiplication
+  // that's CalculationBits wide, but on X86-64, we can safely widen up to
+  // 128 bits.
+  if (CalculationBits > 64)
+    return new SCEVCouldNotCompute();
 
-  const IntegerType *DividendTy = IntegerType::get(DividendBits);
-  const SCEVHandle ExIt = SE.getTruncateOrZeroExtend(It, DividendTy);
-
-  // The final number of bits we need to perform the division is the maximum of
-  // dividend and divisor bitwidths.
-  const IntegerType *DivisionTy =
-    IntegerType::get(std::max(DividendBits, 16U));
-
-  // Compute K!  We know K >= 2 here.
-  unsigned F = 2;
-  for (unsigned i = 3; i <= K; ++i)
-    F *= i;
-  APInt Divisor(DivisionTy->getBitWidth(), F);
-
-  // Handle this case efficiently, it is common to have constant iteration
-  // counts while computing loop exit values.
-  if (SCEVConstant *SC = dyn_cast<SCEVConstant>(ExIt)) {
-    const APInt& N = SC->getValue()->getValue();
-    APInt Dividend(N.getBitWidth(), 1);
-    for (; K; --K)
-      Dividend *= N-(K-1);
-    if (DividendTy != DivisionTy)
-      Dividend = Dividend.zext(DivisionTy->getBitWidth());
-
-    APInt Result = Dividend.udiv(Divisor);
-    if (Result.getBitWidth() != It->getBitWidth())
-      Result = Result.trunc(It->getBitWidth());
-
-    return SE.getConstant(Result);
+  // Calcuate 2^T, at width T+W.
+  APInt DivFactor = APInt(CalculationBits, 1).shl(T);
+
+  // Calculate the multiplicative inverse of K! / 2^T;
+  // this multiplication factor will perform the exact division by
+  // K! / 2^T.
+  APInt Mod = APInt::getSignedMinValue(W+1);
+  APInt MultiplyFactor = OddFactorial.zext(W+1);
+  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
+  MultiplyFactor = MultiplyFactor.trunc(W);
+
+  // Calculate the product, at width T+W
+  const IntegerType *CalculationTy = IntegerType::get(CalculationBits);
+  SCEVHandle Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
+  for (unsigned i = 1; i != K; ++i) {
+    SCEVHandle S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType()));
+    Dividend = SE.getMulExpr(Dividend,
+                             SE.getTruncateOrZeroExtend(S, CalculationTy));
   }
-  
-  SCEVHandle Dividend = ExIt;
-  for (unsigned i = 1; i != K; ++i)
-    Dividend =
-      SE.getMulExpr(Dividend,
-                    SE.getMinusSCEV(ExIt, SE.getIntegerSCEV(i, DividendTy)));
 
-  return SE.getTruncateOrZeroExtend(
-             SE.getUDivExpr(
-                 SE.getTruncateOrZeroExtend(Dividend, DivisionTy),
-                 SE.getConstant(Divisor)
-             ), It->getType());
+  // Divide by 2^T
+  SCEVHandle DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
+
+  // Truncate the result, and divide by K! / 2^T.
+
+  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
+                       SE.getTruncateOrZeroExtend(DivResult, ResultTy));
 }
 
 /// evaluateAtIteration - Return the value of this chain of recurrences at
@@ -610,8 +644,10 @@ SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It,
     // The computation is correct in the face of overflow provided that the
     // multiplication is performed _after_ the evaluation of the binomial
     // coefficient.
-    SCEVHandle Val = SE.getMulExpr(getOperand(i),
-                                   BinomialCoefficient(It, i, SE));
+    SCEVHandle Val =
+      SE.getMulExpr(getOperand(i),
+                    BinomialCoefficient(It, i, SE,
+                                        cast<IntegerType>(getType())));
     Result = SE.getAddExpr(Result, Val);
   }
   return Result;
@@ -2441,17 +2477,8 @@ SCEVHandle ScalarEvolutionsImpl::getSCEVAtScope(SCEV *V, const Loop *L) {
       // loop iterates.  Compute this now.
       SCEVHandle IterationCount = getIterationCount(AddRec->getLoop());
       if (IterationCount == UnknownValue) return UnknownValue;
-      IterationCount = SE.getTruncateOrZeroExtend(IterationCount,
-                                                  AddRec->getType());
-
-      // If the value is affine, simplify the expression evaluation to just
-      // Start + Step*IterationCount.
-      if (AddRec->isAffine())
-        return SE.getAddExpr(AddRec->getStart(),
-                             SE.getMulExpr(IterationCount,
-                                           AddRec->getOperand(1)));
 
-      // Otherwise, evaluate it the hard way.
+      // Then, evaluate the AddRec.
       return AddRec->evaluateAtIteration(IterationCount, SE);
     }
     return UnknownValue;