Support: Return ScaledNumbers::MaxScale from getQuotient()
[oota-llvm.git] / include / llvm / Support / ScaledNumber.h
index 240d5b64aa85a901d26e52117be0aaf9e56c9acb..e7c329f7bffc9f2e5bf988fb0ec6a1919409de30 100644 (file)
@@ -24,6 +24,7 @@
 
 #include "llvm/Support/MathExtras.h"
 
+#include <algorithm>
 #include <cstdint>
 #include <limits>
 #include <utility>
 namespace llvm {
 namespace ScaledNumbers {
 
+/// \brief Maximum scale; same as APFloat for easy debug printing.
+const int32_t MaxScale = 16383;
+
+/// \brief Maximum scale; same as APFloat for easy debug printing.
+const int32_t MinScale = -16382;
+
 /// \brief Get the width of a number.
 template <class DigitsT> inline int getWidth() { return sizeof(DigitsT) * 8; }
 
@@ -141,7 +148,7 @@ std::pair<uint32_t, int16_t> divide32(uint32_t Dividend, uint32_t Divisor);
 ///
 /// Implemented with one 64-bit integer divide/remainder pair.
 ///
-/// Returns \c (DigitsT_MAX, INT16_MAX) for divide-by-zero (0 for 0/0).
+/// Returns \c (DigitsT_MAX, MaxScale) for divide-by-zero (0 for 0/0).
 template <class DigitsT>
 std::pair<DigitsT, int16_t> getQuotient(DigitsT Dividend, DigitsT Divisor) {
   static_assert(!std::numeric_limits<DigitsT>::is_signed, "expected unsigned");
@@ -152,7 +159,7 @@ std::pair<DigitsT, int16_t> getQuotient(DigitsT Dividend, DigitsT Divisor) {
   if (!Dividend)
     return std::make_pair(0, 0);
   if (!Divisor)
-    return std::make_pair(std::numeric_limits<DigitsT>::max(), INT16_MAX);
+    return std::make_pair(std::numeric_limits<DigitsT>::max(), MaxScale);
 
   if (getWidth<DigitsT>() == 64)
     return divide64(Dividend, Divisor);
@@ -263,6 +270,146 @@ int compare(DigitsT LDigits, int16_t LScale, DigitsT RDigits, int16_t RScale) {
   return -compareImpl(RDigits, LDigits, LScale - RScale);
 }
 
+/// \brief Match scales of two numbers.
+///
+/// Given two scaled numbers, match up their scales.  Change the digits and
+/// scales in place.  Shift the digits as necessary to form equivalent numbers,
+/// losing precision only when necessary.
+///
+/// If the output value of \c LDigits (\c RDigits) is \c 0, the output value of
+/// \c LScale (\c RScale) is unspecified.
+///
+/// As a convenience, returns the matching scale.  If the output value of one
+/// number is zero, returns the scale of the other.  If both are zero, which
+/// scale is returned is unspecifed.
+template <class DigitsT>
+int16_t matchScales(DigitsT &LDigits, int16_t &LScale, DigitsT &RDigits,
+                    int16_t &RScale) {
+  static_assert(!std::numeric_limits<DigitsT>::is_signed, "expected unsigned");
+
+  if (LScale < RScale)
+    // Swap arguments.
+    return matchScales(RDigits, RScale, LDigits, LScale);
+  if (!LDigits)
+    return RScale;
+  if (!RDigits || LScale == RScale)
+    return LScale;
+
+  // Now LScale > RScale.  Get the difference.
+  int32_t ScaleDiff = int32_t(LScale) - RScale;
+  if (ScaleDiff >= 2 * getWidth<DigitsT>()) {
+    // Don't bother shifting.  RDigits will get zero-ed out anyway.
+    RDigits = 0;
+    return LScale;
+  }
+
+  // Shift LDigits left as much as possible, then shift RDigits right.
+  int32_t ShiftL = std::min<int32_t>(countLeadingZeros(LDigits), ScaleDiff);
+  assert(ShiftL < getWidth<DigitsT>() && "can't shift more than width");
+
+  int32_t ShiftR = ScaleDiff - ShiftL;
+  if (ShiftR >= getWidth<DigitsT>()) {
+    // Don't bother shifting.  RDigits will get zero-ed out anyway.
+    RDigits = 0;
+    return LScale;
+  }
+
+  LDigits <<= ShiftL;
+  RDigits >>= ShiftR;
+
+  LScale -= ShiftL;
+  RScale += ShiftR;
+  assert(LScale == RScale && "scales should match");
+  return LScale;
+}
+
+/// \brief Get the sum of two scaled numbers.
+///
+/// Get the sum of two scaled numbers with as much precision as possible.
+///
+/// \pre Adding 1 to \c LScale (or \c RScale) will not overflow INT16_MAX.
+template <class DigitsT>
+std::pair<DigitsT, int16_t> getSum(DigitsT LDigits, int16_t LScale,
+                                   DigitsT RDigits, int16_t RScale) {
+  static_assert(!std::numeric_limits<DigitsT>::is_signed, "expected unsigned");
+
+  // Check inputs up front.  This is only relevent if addition overflows, but
+  // testing here should catch more bugs.
+  assert(LScale < INT16_MAX && "scale too large");
+  assert(RScale < INT16_MAX && "scale too large");
+
+  // Normalize digits to match scales.
+  int16_t Scale = matchScales(LDigits, LScale, RDigits, RScale);
+
+  // Compute sum.
+  DigitsT Sum = LDigits + RDigits;
+  if (Sum >= RDigits)
+    return std::make_pair(Sum, Scale);
+
+  // Adjust sum after arithmetic overflow.
+  DigitsT HighBit = DigitsT(1) << (getWidth<DigitsT>() - 1);
+  return std::make_pair(HighBit | Sum >> 1, Scale + 1);
+}
+
+/// \brief Convenience helper for 32-bit sum.
+inline std::pair<uint32_t, int16_t> getSum32(uint32_t LDigits, int16_t LScale,
+                                             uint32_t RDigits, int16_t RScale) {
+  return getSum(LDigits, LScale, RDigits, RScale);
+}
+
+/// \brief Convenience helper for 64-bit sum.
+inline std::pair<uint64_t, int16_t> getSum64(uint64_t LDigits, int16_t LScale,
+                                             uint64_t RDigits, int16_t RScale) {
+  return getSum(LDigits, LScale, RDigits, RScale);
+}
+
+/// \brief Get the difference of two scaled numbers.
+///
+/// Get LHS minus RHS with as much precision as possible.
+///
+/// Returns \c (0, 0) if the RHS is larger than the LHS.
+template <class DigitsT>
+std::pair<DigitsT, int16_t> getDifference(DigitsT LDigits, int16_t LScale,
+                                          DigitsT RDigits, int16_t RScale) {
+  static_assert(!std::numeric_limits<DigitsT>::is_signed, "expected unsigned");
+
+  // Normalize digits to match scales.
+  const DigitsT SavedRDigits = RDigits;
+  const int16_t SavedRScale = RScale;
+  matchScales(LDigits, LScale, RDigits, RScale);
+
+  // Compute difference.
+  if (LDigits <= RDigits)
+    return std::make_pair(0, 0);
+  if (RDigits || !SavedRDigits)
+    return std::make_pair(LDigits - RDigits, LScale);
+
+  // Check if RDigits just barely lost its last bit.  E.g., for 32-bit:
+  //
+  //   1*2^32 - 1*2^0 == 0xffffffff != 1*2^32
+  const auto RLgFloor = getLgFloor(SavedRDigits, SavedRScale);
+  if (!compare(LDigits, LScale, DigitsT(1), RLgFloor + getWidth<DigitsT>()))
+    return std::make_pair(std::numeric_limits<DigitsT>::max(), RLgFloor);
+
+  return std::make_pair(LDigits, LScale);
+}
+
+/// \brief Convenience helper for 32-bit difference.
+inline std::pair<uint32_t, int16_t> getDifference32(uint32_t LDigits,
+                                                    int16_t LScale,
+                                                    uint32_t RDigits,
+                                                    int16_t RScale) {
+  return getDifference(LDigits, LScale, RDigits, RScale);
+}
+
+/// \brief Convenience helper for 64-bit difference.
+inline std::pair<uint64_t, int16_t> getDifference64(uint64_t LDigits,
+                                                    int16_t LScale,
+                                                    uint64_t RDigits,
+                                                    int16_t RScale) {
+  return getDifference(LDigits, LScale, RDigits, RScale);
+}
+
 } // end namespace ScaledNumbers
 } // end namespace llvm