From af8fb1984674db462bc6923ed54db0275c78b711 Mon Sep 17 00:00:00 2001 From: Reid Spencer Date: Thu, 1 Mar 2007 05:39:56 +0000 Subject: [PATCH] Add a square root function. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@34775 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/ADT/APInt.h | 3 ++ lib/Support/APInt.cpp | 93 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 90 insertions(+), 6 deletions(-) diff --git a/include/llvm/ADT/APInt.h b/include/llvm/ADT/APInt.h index 702bd7156fd..0e127b8c44d 100644 --- a/include/llvm/ADT/APInt.h +++ b/include/llvm/ADT/APInt.h @@ -686,6 +686,9 @@ public: double signedRoundToDouble() const { return roundToDouble(true); } + + /// @brief Compute the square root + APInt sqrt() const; }; inline bool operator==(uint64_t V1, const APInt& V2) { diff --git a/lib/Support/APInt.cpp b/lib/Support/APInt.cpp index b56c561ef5d..810fd9ee170 100644 --- a/lib/Support/APInt.cpp +++ b/lib/Support/APInt.cpp @@ -1015,7 +1015,7 @@ APInt APInt::ashr(uint32_t shiftAmt) const { if (shiftAmt < APINT_BITS_PER_WORD) { uint64_t carry = 0; for (int i = getNumWords()-1; i >= 0; --i) { - val[i] = pVal[i] >> shiftAmt | carry; + val[i] = (pVal[i] >> shiftAmt) | carry; carry = pVal[i] << (APINT_BITS_PER_WORD - shiftAmt); } return APInt(val, BitWidth).clearUnusedBits(); @@ -1037,8 +1037,8 @@ APInt APInt::ashr(uint32_t shiftAmt) const { // Shift the low order words uint32_t breakWord = getNumWords() - offset -1; for (uint32_t i = 0; i < breakWord; ++i) - val[i] = pVal[i+offset] >> wordShift | - pVal[i+offset+1] << (APINT_BITS_PER_WORD - wordShift); + val[i] = (pVal[i+offset] >> wordShift) | + (pVal[i+offset+1] << (APINT_BITS_PER_WORD - wordShift)); // Shift the break word. uint32_t SignBit = APINT_BITS_PER_WORD - (BitWidth % APINT_BITS_PER_WORD); val[breakWord] = uint64_t( @@ -1072,7 +1072,7 @@ APInt APInt::lshr(uint32_t shiftAmt) const { if (shiftAmt < APINT_BITS_PER_WORD) { uint64_t carry = 0; for (int i = getNumWords()-1; i >= 0; --i) { - val[i] = pVal[i] >> shiftAmt | carry; + val[i] = (pVal[i] >> shiftAmt) | carry; carry = pVal[i] << (APINT_BITS_PER_WORD - shiftAmt); } return APInt(val, BitWidth).clearUnusedBits(); @@ -1094,8 +1094,8 @@ APInt APInt::lshr(uint32_t shiftAmt) const { // Shift the low order words uint32_t breakWord = getNumWords() - offset -1; for (uint32_t i = 0; i < breakWord; ++i) - val[i] = pVal[i+offset] >> wordShift | - pVal[i+offset+1] << (APINT_BITS_PER_WORD - wordShift); + val[i] = (pVal[i+offset] >> wordShift) | + (pVal[i+offset+1] << (APINT_BITS_PER_WORD - wordShift)); // Shift the break word. val[breakWord] = pVal[breakWord+offset] >> wordShift; @@ -1158,6 +1158,87 @@ APInt APInt::shl(uint32_t shiftAmt) const { return APInt(val, BitWidth).clearUnusedBits(); } + +// Square Root - this method computes and returns the square root of "this". +// Three mechanisms are used for computation. For small values (<= 5 bits), +// a table lookup is done. This gets some performance for common cases. For +// values using less than 52 bits, the value is converted to double and then +// the libc sqrt function is called. The result is rounded and then converted +// back to a uint64_t which is then used to construct the result. Finally, +// the Babylonian method for computing square roots is used. +APInt APInt::sqrt() const { + + // Determine the magnitude of the value. + uint32_t magnitude = getActiveBits(); + + // Use a fast table for some small values. This also gets rid of some + // rounding errors in libc sqrt for small values. + if (magnitude <= 5) { + uint64_t result = 0; + switch (isSingleWord() ? VAL : pVal[0]) { + case 0 : break; + case 1 : case 2 : result = 1; break; + case 3 : case 4 : case 5: case 6: result = 2; break; + case 7 : case 8 : case 9: case 10: case 11: case 12: + result = 3; break; + case 13: case 14: case 15: case 16: case 17: case 18: case 19: case 20: + result = 4; break; + case 21: case 22: case 23: case 24: case 25: case 26: case 27: case 28: + case 29: case 30: result = 5; break; + case 31: result = 6; break; + } + return APInt(BitWidth, result); + } + + // If the magnitude of the value fits in less than 52 bits (the precision of + // an IEEE double precision floating point value), then we can use the + // libc sqrt function which will probably use a hardware sqrt computation. + // This should be faster than the algorithm below. + if (magnitude < 52) + return APInt(BitWidth, + uint64_t(::round(::sqrt(double(isSingleWord()?VAL:pVal[0]))))); + + // Okay, all the short cuts are exhausted. We must compute it. The following + // is a classical Babylonian method for computing the square root. This code + // was adapted to APINt from a wikipedia article on such computations. + // See http://www.wikipedia.org/ and go to the page named + // Calculate_an_integer_square_root. + uint32_t nbits = BitWidth, i = 4; + APInt testy(BitWidth, 16); + APInt x_old(BitWidth, 1); + APInt x_new(BitWidth, 0); + APInt two(BitWidth, 2); + + // Select a good starting value using binary logarithms. + for (;; i += 2, testy = testy.shl(2)) + if (i >= nbits || this->ule(testy)) { + x_old = x_old.shl(i / 2); + break; + } + + // Use the Babylonian method to arrive at the integer square root: + for (;;) { + x_new = (this->udiv(x_old) + x_old).udiv(two); + if (x_old.ule(x_new)) + break; + x_old = x_new; + } + + // Make sure we return the closest approximation + APInt square(x_old * x_old); + APInt nextSquare((x_old + 1) * (x_old +1)); + if (this->ult(square)) + return x_old; + else if (this->ule(nextSquare)) + if ((nextSquare - *this).ult(*this - square)) + return x_old + 1; + else + return x_old; + else + assert(0 && "Error in APInt::sqrt computation"); + return x_old + 1; +} + /// Implementation of Knuth's Algorithm D (Division of nonnegative integers) /// from "Art of Computer Programming, Volume 2", section 4.3.1, p. 272. The /// variables here have the same names as in the algorithm. Comments explain -- 2.34.1