Prune CRLF.
[oota-llvm.git] / include / llvm / ADT / SmallBitVector.h
index 346fb1ca43dcd2232c21581ea5d8df9a05eee5e4..ababf0f3cbe3e58626ab061835cc54c63894d077 100644 (file)
@@ -15,7 +15,7 @@
 #define LLVM_ADT_SMALLBITVECTOR_H
 
 #include "llvm/ADT/BitVector.h"
-#include "llvm/ADT/PointerIntPair.h"
+#include "llvm/Support/Compiler.h"
 #include "llvm/Support/MathExtras.h"
 #include <cassert>
 
@@ -32,48 +32,86 @@ class SmallBitVector {
   // TODO: In "large" mode, a pointer to a BitVector is used, leading to an
   // unnecessary level of indirection. It would be more efficient to use a
   // pointer to memory containing size, allocation size, and the array of bits.
-  PointerIntPair<BitVector *, 1, uintptr_t> X;
+  uintptr_t X;
 
-  // The number of bits in this class.
-  static const size_t NumBaseBits = sizeof(uintptr_t) * CHAR_BIT;
+  enum {
+    // The number of bits in this class.
+    NumBaseBits = sizeof(uintptr_t) * CHAR_BIT,
 
-  // One bit is used to discriminate between small and large mode. The
-  // remaining bits are used for the small-mode representation.
-  static const size_t SmallNumRawBits = NumBaseBits - 1;
+    // One bit is used to discriminate between small and large mode. The
+    // remaining bits are used for the small-mode representation.
+    SmallNumRawBits = NumBaseBits - 1,
 
-  // A few more bits are used to store the size of the bit set in small mode.
-  // Theoretically this is a ceil-log2. These bits are encoded in the most
-  // significant bits of the raw bits.
-  static const size_t SmallNumSizeBits = (NumBaseBits == 32 ? 5 :
-                                          NumBaseBits == 64 ? 6 :
-                                          SmallNumRawBits);
+    // A few more bits are used to store the size of the bit set in small mode.
+    // Theoretically this is a ceil-log2. These bits are encoded in the most
+    // significant bits of the raw bits.
+    SmallNumSizeBits = (NumBaseBits == 32 ? 5 :
+                        NumBaseBits == 64 ? 6 :
+                        SmallNumRawBits),
 
-  // The remaining bits are used to store the actual set in small mode.
-  static const size_t SmallNumDataBits = SmallNumRawBits - SmallNumSizeBits;
+    // The remaining bits are used to store the actual set in small mode.
+    SmallNumDataBits = SmallNumRawBits - SmallNumSizeBits
+  };
 
+public:
+  typedef unsigned size_type;
+  // Encapsulation of a single bit.
+  class reference {
+    SmallBitVector &TheVector;
+    unsigned BitPos;
+
+  public:
+    reference(SmallBitVector &b, unsigned Idx) : TheVector(b), BitPos(Idx) {}
+
+    reference& operator=(reference t) {
+      *this = bool(t);
+      return *this;
+    }
+
+    reference& operator=(bool t) {
+      if (t)
+        TheVector.set(BitPos);
+      else
+        TheVector.reset(BitPos);
+      return *this;
+    }
+
+    operator bool() const {
+      return const_cast<const SmallBitVector &>(TheVector).operator[](BitPos);
+    }
+  };
+
+private:
   bool isSmall() const {
-    return X.getInt();
+    return X & uintptr_t(1);
+  }
+
+  BitVector *getPointer() const {
+    assert(!isSmall());
+    return reinterpret_cast<BitVector *>(X);
   }
 
   void switchToSmall(uintptr_t NewSmallBits, size_t NewSize) {
-    X.setInt(true);
+    X = 1;
     setSmallSize(NewSize);
     setSmallBits(NewSmallBits);
   }
 
   void switchToLarge(BitVector *BV) {
-    X.setInt(false);
-    X.setPointer(BV);
+    X = reinterpret_cast<uintptr_t>(BV);
+    assert(!isSmall() && "Tried to use an unaligned pointer");
   }
 
   // Return all the bits used for the "small" representation; this includes
   // bits for the size as well as the element bits.
   uintptr_t getSmallRawBits() const {
-    return reinterpret_cast<uintptr_t>(X.getPointer()) >> 1;
+    assert(isSmall());
+    return X >> 1;
   }
 
   void setSmallRawBits(uintptr_t NewRawBits) {
-    return X.setPointer(reinterpret_cast<BitVector *>(NewRawBits << 1));
+    assert(isSmall());
+    X = (NewRawBits << 1) | uintptr_t(1);
   }
 
   // Return the size.
@@ -87,22 +125,22 @@ class SmallBitVector {
 
   // Return the element bits.
   uintptr_t getSmallBits() const {
-    return getSmallRawBits() & ~(~uintptr_t(0) << SmallNumDataBits);
+    return getSmallRawBits() & ~(~uintptr_t(0) << getSmallSize());
   }
 
   void setSmallBits(uintptr_t NewBits) {
-    setSmallRawBits((getSmallRawBits() & (~uintptr_t(0) << SmallNumDataBits)) |
-                    (NewBits & ~(~uintptr_t(0) << getSmallSize())));
+    setSmallRawBits((NewBits & ~(~uintptr_t(0) << getSmallSize())) |
+                    (getSmallSize() << SmallNumDataBits));
   }
 
 public:
   /// SmallBitVector default ctor - Creates an empty bitvector.
-  SmallBitVector() : X(0, 1) {}
+  SmallBitVector() : X(1) {}
 
   /// SmallBitVector ctor - Creates a bitvector of specified number of bits. All
   /// bits are initialized to the specified value.
-  explicit SmallBitVector(unsigned s, bool t = false) : X(0, 1) {
-    if (s <= SmallNumRawBits)
+  explicit SmallBitVector(unsigned s, bool t = false) {
+    if (s <= SmallNumDataBits)
       switchToSmall(t ? ~uintptr_t(0) : 0, s);
     else
       switchToLarge(new BitVector(s, t));
@@ -113,49 +151,60 @@ public:
     if (RHS.isSmall())
       X = RHS.X;
     else
-      switchToLarge(new BitVector(*RHS.X.getPointer()));
+      switchToLarge(new BitVector(*RHS.getPointer()));
+  }
+
+  SmallBitVector(SmallBitVector &&RHS) : X(RHS.X) {
+    RHS.X = 1;
   }
 
   ~SmallBitVector() {
     if (!isSmall())
-      delete X.getPointer();
+      delete getPointer();
   }
 
   /// empty - Tests whether there are no bits in this bitvector.
   bool empty() const {
-    return isSmall() ? getSmallSize() == 0 : X.getPointer()->empty();
+    return isSmall() ? getSmallSize() == 0 : getPointer()->empty();
   }
 
   /// size - Returns the number of bits in this bitvector.
   size_t size() const {
-    return isSmall() ? getSmallSize() : X.getPointer()->size();
+    return isSmall() ? getSmallSize() : getPointer()->size();
   }
 
   /// count - Returns the number of bits which are set.
-  unsigned count() const {
+  size_type count() const {
     if (isSmall()) {
       uintptr_t Bits = getSmallBits();
-      if (sizeof(uintptr_t) * CHAR_BIT == 32)
+      if (NumBaseBits == 32)
         return CountPopulation_32(Bits);
-      if (sizeof(uintptr_t) * CHAR_BIT == 64)
+      if (NumBaseBits == 64)
         return CountPopulation_64(Bits);
-      assert(0 && "Unsupported!");
+      llvm_unreachable("Unsupported!");
     }
-    return X.getPointer()->count();
+    return getPointer()->count();
   }
 
   /// any - Returns true if any bit is set.
   bool any() const {
     if (isSmall())
       return getSmallBits() != 0;
-    return X.getPointer()->any();
+    return getPointer()->any();
+  }
+
+  /// all - Returns true if all bits are set.
+  bool all() const {
+    if (isSmall())
+      return getSmallBits() == (uintptr_t(1) << getSmallSize()) - 1;
+    return getPointer()->all();
   }
 
   /// none - Returns true if none of the bits are set.
   bool none() const {
     if (isSmall())
       return getSmallBits() == 0;
-    return X.getPointer()->none();
+    return getPointer()->none();
   }
 
   /// find_first - Returns the index of the first set bit, -1 if none
@@ -163,13 +212,15 @@ public:
   int find_first() const {
     if (isSmall()) {
       uintptr_t Bits = getSmallBits();
-      if (sizeof(uintptr_t) * CHAR_BIT == 32)
-        return CountTrailingZeros_32(Bits);
-      if (sizeof(uintptr_t) * CHAR_BIT == 64)
-        return CountTrailingZeros_64(Bits);
-      assert(0 && "Unsupported!");
+      if (Bits == 0)
+        return -1;
+      if (NumBaseBits == 32)
+        return countTrailingZeros(Bits);
+      if (NumBaseBits == 64)
+        return countTrailingZeros(Bits);
+      llvm_unreachable("Unsupported!");
     }
-    return X.getPointer()->find_first();
+    return getPointer()->find_first();
   }
 
   /// find_next - Returns the index of the next set bit following the
@@ -178,30 +229,33 @@ public:
     if (isSmall()) {
       uintptr_t Bits = getSmallBits();
       // Mask off previous bits.
-      Bits &= ~uintptr_t(0) << Prev;
-      if (sizeof(uintptr_t) * CHAR_BIT == 32)
-        return CountTrailingZeros_32(Bits);
-      if (sizeof(uintptr_t) * CHAR_BIT == 64)
-        return CountTrailingZeros_64(Bits);
-      assert(0 && "Unsupported!");
+      Bits &= ~uintptr_t(0) << (Prev + 1);
+      if (Bits == 0 || Prev + 1 >= getSmallSize())
+        return -1;
+      if (NumBaseBits == 32)
+        return countTrailingZeros(Bits);
+      if (NumBaseBits == 64)
+        return countTrailingZeros(Bits);
+      llvm_unreachable("Unsupported!");
     }
-    return X.getPointer()->find_next(Prev);
+    return getPointer()->find_next(Prev);
   }
 
   /// clear - Clear all bits.
   void clear() {
     if (!isSmall())
-      delete X.getPointer();
+      delete getPointer();
     switchToSmall(0, 0);
   }
 
   /// resize - Grow or shrink the bitvector.
   void resize(unsigned N, bool t = false) {
     if (!isSmall()) {
-      X.getPointer()->resize(N, t);
-    } else if (getSmallSize() >= N) {
+      getPointer()->resize(N, t);
+    } else if (SmallNumDataBits >= N) {
+      uintptr_t NewBits = t ? ~uintptr_t(0) << getSmallSize() : 0;
       setSmallSize(N);
-      setSmallBits(getSmallBits());
+      setSmallBits(NewBits | getSmallBits());
     } else {
       BitVector *BV = new BitVector(N, t);
       uintptr_t OldBits = getSmallBits();
@@ -224,7 +278,7 @@ public:
         switchToLarge(BV);
       }
     } else {
-      X.getPointer()->reserve(N);
+      getPointer()->reserve(N);
     }
   }
 
@@ -233,7 +287,7 @@ public:
     if (isSmall())
       setSmallBits(~uintptr_t(0));
     else
-      X.getPointer()->set();
+      getPointer()->set();
     return *this;
   }
 
@@ -241,7 +295,22 @@ public:
     if (isSmall())
       setSmallBits(getSmallBits() | (uintptr_t(1) << Idx));
     else
-      X.getPointer()->set(Idx);
+      getPointer()->set(Idx);
+    return *this;
+  }
+
+  /// set - Efficiently set a range of bits in [I, E)
+  SmallBitVector &set(unsigned I, unsigned E) {
+    assert(I <= E && "Attempted to set backwards range!");
+    assert(E <= size() && "Attempted to set out-of-bounds range!");
+    if (I == E) return *this;
+    if (isSmall()) {
+      uintptr_t EMask = ((uintptr_t)1) << E;
+      uintptr_t IMask = ((uintptr_t)1) << I;
+      uintptr_t Mask = EMask - IMask;
+      setSmallBits(getSmallBits() | Mask);
+    } else
+      getPointer()->set(I, E);
     return *this;
   }
 
@@ -249,7 +318,7 @@ public:
     if (isSmall())
       setSmallBits(0);
     else
-      X.getPointer()->reset();
+      getPointer()->reset();
     return *this;
   }
 
@@ -257,7 +326,22 @@ public:
     if (isSmall())
       setSmallBits(getSmallBits() & ~(uintptr_t(1) << Idx));
     else
-      X.getPointer()->reset(Idx);
+      getPointer()->reset(Idx);
+    return *this;
+  }
+
+  /// reset - Efficiently reset a range of bits in [I, E)
+  SmallBitVector &reset(unsigned I, unsigned E) {
+    assert(I <= E && "Attempted to reset backwards range!");
+    assert(E <= size() && "Attempted to reset out-of-bounds range!");
+    if (I == E) return *this;
+    if (isSmall()) {
+      uintptr_t EMask = ((uintptr_t)1) << E;
+      uintptr_t IMask = ((uintptr_t)1) << I;
+      uintptr_t Mask = EMask - IMask;
+      setSmallBits(getSmallBits() & ~Mask);
+    } else
+      getPointer()->reset(I, E);
     return *this;
   }
 
@@ -265,7 +349,7 @@ public:
     if (isSmall())
       setSmallBits(~getSmallBits());
     else
-      X.getPointer()->flip();
+      getPointer()->flip();
     return *this;
   }
 
@@ -273,7 +357,7 @@ public:
     if (isSmall())
       setSmallBits(getSmallBits() ^ (uintptr_t(1) << Idx));
     else
-      X.getPointer()->flip(Idx);
+      getPointer()->flip(Idx);
     return *this;
   }
 
@@ -283,18 +367,35 @@ public:
   }
 
   // Indexing.
-  // TODO: Add an index operator which returns a "reference" (proxy class).
+  reference operator[](unsigned Idx) {
+    assert(Idx < size() && "Out-of-bounds Bit access.");
+    return reference(*this, Idx);
+  }
+
   bool operator[](unsigned Idx) const {
     assert(Idx < size() && "Out-of-bounds Bit access.");
     if (isSmall())
       return ((getSmallBits() >> Idx) & 1) != 0;
-    return X.getPointer()->operator[](Idx);
+    return getPointer()->operator[](Idx);
   }
 
   bool test(unsigned Idx) const {
     return (*this)[Idx];
   }
 
+  /// Test if any common bits are set.
+  bool anyCommon(const SmallBitVector &RHS) const {
+    if (isSmall() && RHS.isSmall())
+      return (getSmallBits() & RHS.getSmallBits()) != 0;
+    if (!isSmall() && !RHS.isSmall())
+      return getPointer()->anyCommon(*RHS.getPointer());
+
+    for (unsigned i = 0, e = std::min(size(), RHS.size()); i != e; ++i)
+      if (test(i) && RHS.test(i))
+        return true;
+    return false;
+  }
+
   // Comparison operators.
   bool operator==(const SmallBitVector &RHS) const {
     if (size() != RHS.size())
@@ -302,7 +403,7 @@ public:
     if (isSmall())
       return getSmallBits() == RHS.getSmallBits();
     else
-      return *X.getPointer() == *RHS.X.getPointer();
+      return *getPointer() == *RHS.getPointer();
   }
 
   bool operator!=(const SmallBitVector &RHS) const {
@@ -310,11 +411,81 @@ public:
   }
 
   // Intersection, union, disjoint union.
-  BitVector &operator&=(const SmallBitVector &RHS); // TODO: implement
+  SmallBitVector &operator&=(const SmallBitVector &RHS) {
+    resize(std::max(size(), RHS.size()));
+    if (isSmall())
+      setSmallBits(getSmallBits() & RHS.getSmallBits());
+    else if (!RHS.isSmall())
+      getPointer()->operator&=(*RHS.getPointer());
+    else {
+      SmallBitVector Copy = RHS;
+      Copy.resize(size());
+      getPointer()->operator&=(*Copy.getPointer());
+    }
+    return *this;
+  }
 
-  BitVector &operator|=(const SmallBitVector &RHS); // TODO: implement
+  /// reset - Reset bits that are set in RHS. Same as *this &= ~RHS.
+  SmallBitVector &reset(const SmallBitVector &RHS) {
+    if (isSmall() && RHS.isSmall())
+      setSmallBits(getSmallBits() & ~RHS.getSmallBits());
+    else if (!isSmall() && !RHS.isSmall())
+      getPointer()->reset(*RHS.getPointer());
+    else
+      for (unsigned i = 0, e = std::min(size(), RHS.size()); i != e; ++i)
+        if (RHS.test(i))
+          reset(i);
+
+    return *this;
+  }
 
-  BitVector &operator^=(const SmallBitVector &RHS); // TODO: implement
+  /// test - Check if (This - RHS) is zero.
+  /// This is the same as reset(RHS) and any().
+  bool test(const SmallBitVector &RHS) const {
+    if (isSmall() && RHS.isSmall())
+      return (getSmallBits() & ~RHS.getSmallBits()) != 0;
+    if (!isSmall() && !RHS.isSmall())
+      return getPointer()->test(*RHS.getPointer());
+
+    unsigned i, e;
+    for (i = 0, e = std::min(size(), RHS.size()); i != e; ++i)
+      if (test(i) && !RHS.test(i))
+        return true;
+
+    for (e = size(); i != e; ++i)
+      if (test(i))
+        return true;
+
+    return false;
+  }
+
+  SmallBitVector &operator|=(const SmallBitVector &RHS) {
+    resize(std::max(size(), RHS.size()));
+    if (isSmall())
+      setSmallBits(getSmallBits() | RHS.getSmallBits());
+    else if (!RHS.isSmall())
+      getPointer()->operator|=(*RHS.getPointer());
+    else {
+      SmallBitVector Copy = RHS;
+      Copy.resize(size());
+      getPointer()->operator|=(*Copy.getPointer());
+    }
+    return *this;
+  }
+
+  SmallBitVector &operator^=(const SmallBitVector &RHS) {
+    resize(std::max(size(), RHS.size()));
+    if (isSmall())
+      setSmallBits(getSmallBits() ^ RHS.getSmallBits());
+    else if (!RHS.isSmall())
+      getPointer()->operator^=(*RHS.getPointer());
+    else {
+      SmallBitVector Copy = RHS;
+      Copy.resize(size());
+      getPointer()->operator^=(*Copy.getPointer());
+    }
+    return *this;
+  }
 
   // Assignment operator.
   const SmallBitVector &operator=(const SmallBitVector &RHS) {
@@ -322,21 +493,82 @@ public:
       if (RHS.isSmall())
         X = RHS.X;
       else
-        switchToLarge(new BitVector(*RHS.X.getPointer()));
+        switchToLarge(new BitVector(*RHS.getPointer()));
     } else {
       if (!RHS.isSmall())
-        *X.getPointer() = *RHS.X.getPointer();
+        *getPointer() = *RHS.getPointer();
       else {
-        delete X.getPointer();
+        delete getPointer();
         X = RHS.X;
       }
     }
     return *this;
   }
 
+  const SmallBitVector &operator=(SmallBitVector &&RHS) {
+    if (this != &RHS) {
+      clear();
+      swap(RHS);
+    }
+    return *this;
+  }
+
   void swap(SmallBitVector &RHS) {
     std::swap(X, RHS.X);
   }
+
+  /// setBitsInMask - Add '1' bits from Mask to this vector. Don't resize.
+  /// This computes "*this |= Mask".
+  void setBitsInMask(const uint32_t *Mask, unsigned MaskWords = ~0u) {
+    if (isSmall())
+      applyMask<true, false>(Mask, MaskWords);
+    else
+      getPointer()->setBitsInMask(Mask, MaskWords);
+  }
+
+  /// clearBitsInMask - Clear any bits in this vector that are set in Mask.
+  /// Don't resize. This computes "*this &= ~Mask".
+  void clearBitsInMask(const uint32_t *Mask, unsigned MaskWords = ~0u) {
+    if (isSmall())
+      applyMask<false, false>(Mask, MaskWords);
+    else
+      getPointer()->clearBitsInMask(Mask, MaskWords);
+  }
+
+  /// setBitsNotInMask - Add a bit to this vector for every '0' bit in Mask.
+  /// Don't resize.  This computes "*this |= ~Mask".
+  void setBitsNotInMask(const uint32_t *Mask, unsigned MaskWords = ~0u) {
+    if (isSmall())
+      applyMask<true, true>(Mask, MaskWords);
+    else
+      getPointer()->setBitsNotInMask(Mask, MaskWords);
+  }
+
+  /// clearBitsNotInMask - Clear a bit in this vector for every '0' bit in Mask.
+  /// Don't resize.  This computes "*this &= Mask".
+  void clearBitsNotInMask(const uint32_t *Mask, unsigned MaskWords = ~0u) {
+    if (isSmall())
+      applyMask<false, true>(Mask, MaskWords);
+    else
+      getPointer()->clearBitsNotInMask(Mask, MaskWords);
+  }
+
+private:
+  template<bool AddBits, bool InvertMask>
+  void applyMask(const uint32_t *Mask, unsigned MaskWords) {
+    assert((NumBaseBits == 64 || NumBaseBits == 32) && "Unsupported word size");
+    if (NumBaseBits == 64 && MaskWords >= 2) {
+      uint64_t M = Mask[0] | (uint64_t(Mask[1]) << 32);
+      if (InvertMask) M = ~M;
+      if (AddBits) setSmallBits(getSmallBits() | M);
+      else         setSmallBits(getSmallBits() & ~M);
+    } else {
+      uint32_t M = Mask[0];
+      if (InvertMask) M = ~M;
+      if (AddBits) setSmallBits(getSmallBits() | M);
+      else         setSmallBits(getSmallBits() & ~M);
+    }
+  }
 };
 
 inline SmallBitVector