Add a square root function.
authorReid Spencer <rspencer@reidspencer.com>
Thu, 1 Mar 2007 05:39:56 +0000 (05:39 +0000)
committerReid Spencer <rspencer@reidspencer.com>
Thu, 1 Mar 2007 05:39:56 +0000 (05:39 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@34775 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/ADT/APInt.h
lib/Support/APInt.cpp

index 702bd7156fd675a709ece42dd74fafe7cb0da620..0e127b8c44d73a640c36a7bf01a91555b3e15eb4 100644 (file)
@@ -686,6 +686,9 @@ public:
   double signedRoundToDouble() const {
     return roundToDouble(true);
   }
+
+  /// @brief Compute the square root
+  APInt sqrt() const;
 };
 
 inline bool operator==(uint64_t V1, const APInt& V2) {
index b56c561ef5d6bd870bd4c27b257c7c1b91e950e9..810fd9ee1703cc65a1a828eebf9bd938ae4b11c6 100644 (file)
@@ -1015,7 +1015,7 @@ APInt APInt::ashr(uint32_t shiftAmt) const {
   if (shiftAmt < APINT_BITS_PER_WORD) {
     uint64_t carry = 0;
     for (int i = getNumWords()-1; i >= 0; --i) {
-      val[i] = pVal[i] >> shiftAmt | carry;
+      val[i] = (pVal[i] >> shiftAmt) | carry;
       carry = pVal[i] << (APINT_BITS_PER_WORD - shiftAmt);
     }
     return APInt(val, BitWidth).clearUnusedBits();
@@ -1037,8 +1037,8 @@ APInt APInt::ashr(uint32_t shiftAmt) const {
   // Shift the low order words 
   uint32_t breakWord = getNumWords() - offset -1;
   for (uint32_t i = 0; i < breakWord; ++i)
-    val[i] = pVal[i+offset] >> wordShift |
-             pVal[i+offset+1] << (APINT_BITS_PER_WORD - wordShift);
+    val[i] = (pVal[i+offset] >> wordShift) |
+             (pVal[i+offset+1] << (APINT_BITS_PER_WORD - wordShift));
   // Shift the break word.
   uint32_t SignBit = APINT_BITS_PER_WORD - (BitWidth % APINT_BITS_PER_WORD);
   val[breakWord] = uint64_t(
@@ -1072,7 +1072,7 @@ APInt APInt::lshr(uint32_t shiftAmt) const {
   if (shiftAmt < APINT_BITS_PER_WORD) {
     uint64_t carry = 0;
     for (int i = getNumWords()-1; i >= 0; --i) {
-      val[i] = pVal[i] >> shiftAmt | carry;
+      val[i] = (pVal[i] >> shiftAmt) | carry;
       carry = pVal[i] << (APINT_BITS_PER_WORD - shiftAmt);
     }
     return APInt(val, BitWidth).clearUnusedBits();
@@ -1094,8 +1094,8 @@ APInt APInt::lshr(uint32_t shiftAmt) const {
   // Shift the low order words 
   uint32_t breakWord = getNumWords() - offset -1;
   for (uint32_t i = 0; i < breakWord; ++i)
-    val[i] = pVal[i+offset] >> wordShift |
-             pVal[i+offset+1] << (APINT_BITS_PER_WORD - wordShift);
+    val[i] = (pVal[i+offset] >> wordShift) |
+             (pVal[i+offset+1] << (APINT_BITS_PER_WORD - wordShift));
   // Shift the break word.
   val[breakWord] = pVal[breakWord+offset] >> wordShift;
 
@@ -1158,6 +1158,87 @@ APInt APInt::shl(uint32_t shiftAmt) const {
   return APInt(val, BitWidth).clearUnusedBits();
 }
 
+
+// Square Root - this method computes and returns the square root of "this".
+// Three mechanisms are used for computation. For small values (<= 5 bits),
+// a table lookup is done. This gets some performance for common cases. For
+// values using less than 52 bits, the value is converted to double and then
+// the libc sqrt function is called. The result is rounded and then converted
+// back to a uint64_t which is then used to construct the result. Finally,
+// the Babylonian method for computing square roots is used. 
+APInt APInt::sqrt() const {
+
+  // Determine the magnitude of the value.
+  uint32_t magnitude = getActiveBits();
+
+  // Use a fast table for some small values. This also gets rid of some
+  // rounding errors in libc sqrt for small values.
+  if (magnitude <= 5) {
+    uint64_t result = 0;
+    switch (isSingleWord() ? VAL : pVal[0]) {
+      case 0 : break;
+      case 1 : case 2 : result = 1; break;
+      case 3 : case 4 : case 5: case 6: result = 2; break;
+      case 7 : case 8 : case 9: case 10: case 11: case 12: 
+        result = 3; break;
+      case 13: case 14: case 15: case 16: case 17: case 18: case 19: case 20:
+        result = 4; break;
+      case 21: case 22: case 23: case 24: case 25: case 26: case 27: case 28:
+      case 29: case 30: result = 5; break;
+      case 31: result = 6; break;
+    }
+    return APInt(BitWidth, result);
+  }
+
+  // If the magnitude of the value fits in less than 52 bits (the precision of
+  // an IEEE double precision floating point value), then we can use the
+  // libc sqrt function which will probably use a hardware sqrt computation.
+  // This should be faster than the algorithm below.
+  if (magnitude < 52)
+    return APInt(BitWidth, 
+                 uint64_t(::round(::sqrt(double(isSingleWord()?VAL:pVal[0])))));
+
+  // Okay, all the short cuts are exhausted. We must compute it. The following
+  // is a classical Babylonian method for computing the square root. This code
+  // was adapted to APINt from a wikipedia article on such computations.
+  // See http://www.wikipedia.org/ and go to the page named
+  // Calculate_an_integer_square_root. 
+  uint32_t nbits = BitWidth, i = 4;
+  APInt testy(BitWidth, 16);
+  APInt x_old(BitWidth, 1);
+  APInt x_new(BitWidth, 0);
+  APInt two(BitWidth, 2);
+
+  // Select a good starting value using binary logarithms.
+  for (;; i += 2, testy = testy.shl(2)) 
+    if (i >= nbits || this->ule(testy)) {
+      x_old = x_old.shl(i / 2);
+      break;
+    }
+
+  // Use the Babylonian method to arrive at the integer square root: 
+  for (;;) {
+    x_new = (this->udiv(x_old) + x_old).udiv(two);
+    if (x_old.ule(x_new))
+      break;
+    x_old = x_new;
+  }
+
+  // Make sure we return the closest approximation
+  APInt square(x_old * x_old);
+  APInt nextSquare((x_old + 1) * (x_old +1));
+  if (this->ult(square))
+    return x_old;
+  else if (this->ule(nextSquare))
+    if ((nextSquare - *this).ult(*this - square))
+      return x_old + 1;
+    else
+      return x_old;
+  else
+    assert(0 && "Error in APInt::sqrt computation");
+  return x_old + 1;
+}
+
 /// Implementation of Knuth's Algorithm D (Division of nonnegative integers)
 /// from "Art of Computer Programming, Volume 2", section 4.3.1, p. 272. The
 /// variables here have the same names as in the algorithm. Comments explain