}
-/// BinomialCoefficient - Compute BC(It, K). The result is of the same type as
-/// It. Assume, K > 0.
+/// BinomialCoefficient - Compute BC(It, K). The result has width W.
+// Assume, K > 0.
static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K,
- ScalarEvolution &SE) {
+ ScalarEvolution &SE,
+ const IntegerType* ResultTy) {
+ // Handle the simplest case efficiently.
+ if (K == 1)
+ return SE.getTruncateOrZeroExtend(It, ResultTy);
+
// We are using the following formula for BC(It, K):
//
// BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
//
- // Suppose, W is the bitwidth of It (and of the return value as well). We
- // must be prepared for overflow. Hence, we must assure that the result of
- // our computation is equal to the accurate one modulo 2^W. Unfortunately,
- // division isn't safe in modular arithmetic. This means we must perform the
- // whole computation accurately and then truncate the result to W bits.
+ // Suppose, W is the bitwidth of the return value. We must be prepared for
+ // overflow. Hence, we must assure that the result of our computation is
+ // equal to the accurate one modulo 2^W. Unfortunately, division isn't
+ // safe in modular arithmetic.
+ //
+ // However, this code doesn't use exactly that formula; the formula it uses
+ // is something like the following, where T is the number of factors of 2 in
+ // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
+ // exponentiation:
+ //
+ // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
//
- // The dividend of the formula is a multiplication of K integers of bitwidth
- // W. K*W bits suffice to compute it accurately.
+ // This formula is trivially equivalent to the previous formula. However,
+ // this formula can be implemented much more efficiently. The trick is that
+ // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
+ // arithmetic. To do exact division in modular arithmetic, all we have
+ // to do is multiply by the inverse. Therefore, this step can be done at
+ // width W.
+ //
+ // The next issue is how to safely do the division by 2^T. The way this
+ // is done is by doing the multiplication step at a width of at least W + T
+ // bits. This way, the bottom W+T bits of the product are accurate. Then,
+ // when we perform the division by 2^T (which is equivalent to a right shift
+ // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
+ // truncated out after the division by 2^T.
//
- // FIXME: We assume the divisor can be accurately computed using 16-bit
- // unsigned integer type. It is true up to K = 8 (AddRecs of length 9). In
- // future we may use APInt to use the minimum number of bits necessary to
- // compute it accurately.
+ // In comparison to just directly using the first formula, this technique
+ // is much more efficient; using the first formula requires W * K bits,
+ // but this formula less than W + K bits. Also, the first formula requires
+ // a division step, whereas this formula only requires multiplies and shifts.
//
- // It is safe to use unsigned division here: the dividend is nonnegative and
- // the divisor is positive.
+ // It doesn't matter whether the subtraction step is done in the calculation
+ // width or the input iteration count's width; if the subtraction overflows,
+ // the result must be zero anyway. We prefer here to do it in the width of
+ // the induction variable because it helps a lot for certain cases; CodeGen
+ // isn't smart enough to ignore the overflow, which leads to much less
+ // efficient code if the width of the subtraction is wider than the native
+ // register width.
+ //
+ // (It's possible to not widen at all by pulling out factors of 2 before
+ // the multiplication; for example, K=2 can be calculated as
+ // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
+ // extra arithmetic, so it's not an obvious win, and it gets
+ // much more complicated for K > 3.)
+
+ // Protection from insane SCEVs; this bound is conservative,
+ // but it probably doesn't matter.
+ if (K > 1000)
+ return new SCEVCouldNotCompute();
- // Handle the simplest case efficiently.
- if (K == 1)
- return It;
+ unsigned W = ResultTy->getBitWidth();
+
+ // Calculate K! / 2^T and T; we divide out the factors of two before
+ // multiplying for calculating K! / 2^T to avoid overflow.
+ // Other overflow doesn't matter because we only care about the bottom
+ // W bits of the result.
+ APInt OddFactorial(W, 1);
+ unsigned T = 1;
+ for (unsigned i = 3; i <= K; ++i) {
+ APInt Mult(W, i);
+ unsigned TwoFactors = Mult.countTrailingZeros();
+ T += TwoFactors;
+ Mult = Mult.lshr(TwoFactors);
+ OddFactorial *= Mult;
+ }
- assert(K < 9 && "We cannot handle such long AddRecs yet.");
-
- // FIXME: A temporary hack to remove in future. Arbitrary precision integers
- // aren't supported by the code generator yet. For the dividend, the bitwidth
- // we use is the smallest power of 2 greater or equal to K*W and less or equal
- // to 64. Note that setting the upper bound for bitwidth may still lead to
- // miscompilation in some cases.
- unsigned DividendBits = 1U << Log2_32_Ceil(K * It->getBitWidth());
- if (DividendBits > 64)
- DividendBits = 64;
-#if 0 // Waiting for the APInt support in the code generator...
- unsigned DividendBits = K * It->getBitWidth();
-#endif
+ // We need at least W + T bits for the multiplication step
+ // FIXME: A temporary hack; we round up the bitwidths
+ // to the nearest power of 2 to be nice to the code generator.
+ unsigned CalculationBits = 1U << Log2_32_Ceil(W + T);
+ // FIXME: Temporary hack to avoid generating integers that are too wide.
+ // Although, it's not completely clear how to determine how much
+ // widening is safe; for example, on X86, we can't really widen
+ // beyond 64 because we need to be able to do multiplication
+ // that's CalculationBits wide, but on X86-64, we can safely widen up to
+ // 128 bits.
+ if (CalculationBits > 64)
+ return new SCEVCouldNotCompute();
- const IntegerType *DividendTy = IntegerType::get(DividendBits);
- const SCEVHandle ExIt = SE.getTruncateOrZeroExtend(It, DividendTy);
-
- // The final number of bits we need to perform the division is the maximum of
- // dividend and divisor bitwidths.
- const IntegerType *DivisionTy =
- IntegerType::get(std::max(DividendBits, 16U));
-
- // Compute K! We know K >= 2 here.
- unsigned F = 2;
- for (unsigned i = 3; i <= K; ++i)
- F *= i;
- APInt Divisor(DivisionTy->getBitWidth(), F);
-
- // Handle this case efficiently, it is common to have constant iteration
- // counts while computing loop exit values.
- if (SCEVConstant *SC = dyn_cast<SCEVConstant>(ExIt)) {
- const APInt& N = SC->getValue()->getValue();
- APInt Dividend(N.getBitWidth(), 1);
- for (; K; --K)
- Dividend *= N-(K-1);
- if (DividendTy != DivisionTy)
- Dividend = Dividend.zext(DivisionTy->getBitWidth());
-
- APInt Result = Dividend.udiv(Divisor);
- if (Result.getBitWidth() != It->getBitWidth())
- Result = Result.trunc(It->getBitWidth());
-
- return SE.getConstant(Result);
+ // Calcuate 2^T, at width T+W.
+ APInt DivFactor = APInt(CalculationBits, 1).shl(T);
+
+ // Calculate the multiplicative inverse of K! / 2^T;
+ // this multiplication factor will perform the exact division by
+ // K! / 2^T.
+ APInt Mod = APInt::getSignedMinValue(W+1);
+ APInt MultiplyFactor = OddFactorial.zext(W+1);
+ MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
+ MultiplyFactor = MultiplyFactor.trunc(W);
+
+ // Calculate the product, at width T+W
+ const IntegerType *CalculationTy = IntegerType::get(CalculationBits);
+ SCEVHandle Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
+ for (unsigned i = 1; i != K; ++i) {
+ SCEVHandle S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType()));
+ Dividend = SE.getMulExpr(Dividend,
+ SE.getTruncateOrZeroExtend(S, CalculationTy));
}
-
- SCEVHandle Dividend = ExIt;
- for (unsigned i = 1; i != K; ++i)
- Dividend =
- SE.getMulExpr(Dividend,
- SE.getMinusSCEV(ExIt, SE.getIntegerSCEV(i, DividendTy)));
- return SE.getTruncateOrZeroExtend(
- SE.getUDivExpr(
- SE.getTruncateOrZeroExtend(Dividend, DivisionTy),
- SE.getConstant(Divisor)
- ), It->getType());
+ // Divide by 2^T
+ SCEVHandle DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
+
+ // Truncate the result, and divide by K! / 2^T.
+
+ return SE.getMulExpr(SE.getConstant(MultiplyFactor),
+ SE.getTruncateOrZeroExtend(DivResult, ResultTy));
}
/// evaluateAtIteration - Return the value of this chain of recurrences at
// The computation is correct in the face of overflow provided that the
// multiplication is performed _after_ the evaluation of the binomial
// coefficient.
- SCEVHandle Val = SE.getMulExpr(getOperand(i),
- BinomialCoefficient(It, i, SE));
+ SCEVHandle Val =
+ SE.getMulExpr(getOperand(i),
+ BinomialCoefficient(It, i, SE,
+ cast<IntegerType>(getType())));
Result = SE.getAddExpr(Result, Val);
}
return Result;
// loop iterates. Compute this now.
SCEVHandle IterationCount = getIterationCount(AddRec->getLoop());
if (IterationCount == UnknownValue) return UnknownValue;
- IterationCount = SE.getTruncateOrZeroExtend(IterationCount,
- AddRec->getType());
-
- // If the value is affine, simplify the expression evaluation to just
- // Start + Step*IterationCount.
- if (AddRec->isAffine())
- return SE.getAddExpr(AddRec->getStart(),
- SE.getMulExpr(IterationCount,
- AddRec->getOperand(1)));
- // Otherwise, evaluate it the hard way.
+ // Then, evaluate the AddRec.
return AddRec->evaluateAtIteration(IterationCount, SE);
}
return UnknownValue;