SmallVector: Resolve a long-standing fixme by using the existing unitialized_copy...
[oota-llvm.git] / include / llvm / ADT / SmallBitVector.h
index 7563e81e1cef5e46c01234f08bca20f1f16b4d42..22e8ccd8ea0fc6df22f187a5a8634e48b9029daa 100644 (file)
@@ -15,6 +15,7 @@
 #define LLVM_ADT_SMALLBITVECTOR_H
 
 #include "llvm/ADT/BitVector.h"
+#include "llvm/Support/Compiler.h"
 #include "llvm/Support/MathExtras.h"
 #include <cassert>
 
@@ -52,6 +53,38 @@ class SmallBitVector {
     SmallNumDataBits = SmallNumRawBits - SmallNumSizeBits
   };
 
+  static_assert(NumBaseBits == 64 || NumBaseBits == 32,
+                "Unsupported word size");
+
+public:
+  typedef unsigned size_type;
+  // Encapsulation of a single bit.
+  class reference {
+    SmallBitVector &TheVector;
+    unsigned BitPos;
+
+  public:
+    reference(SmallBitVector &b, unsigned Idx) : TheVector(b), BitPos(Idx) {}
+
+    reference& operator=(reference t) {
+      *this = bool(t);
+      return *this;
+    }
+
+    reference& operator=(bool t) {
+      if (t)
+        TheVector.set(BitPos);
+      else
+        TheVector.reset(BitPos);
+      return *this;
+    }
+
+    operator bool() const {
+      return const_cast<const SmallBitVector &>(TheVector).operator[](BitPos);
+    }
+  };
+
+private:
   bool isSmall() const {
     return X & uintptr_t(1);
   }
@@ -81,7 +114,7 @@ class SmallBitVector {
 
   void setSmallRawBits(uintptr_t NewRawBits) {
     assert(isSmall());
-    X = NewRawBits << 1 | uintptr_t(1);
+    X = (NewRawBits << 1) | uintptr_t(1);
   }
 
   // Return the size.
@@ -99,7 +132,7 @@ class SmallBitVector {
   }
 
   void setSmallBits(uintptr_t NewBits) {
-    setSmallRawBits(NewBits & ~(~uintptr_t(0) << getSmallSize()) |
+    setSmallRawBits((NewBits & ~(~uintptr_t(0) << getSmallSize())) |
                     (getSmallSize() << SmallNumDataBits));
   }
 
@@ -124,6 +157,10 @@ public:
       switchToLarge(new BitVector(*RHS.getPointer()));
   }
 
+  SmallBitVector(SmallBitVector &&RHS) : X(RHS.X) {
+    RHS.X = 1;
+  }
+
   ~SmallBitVector() {
     if (!isSmall())
       delete getPointer();
@@ -140,14 +177,10 @@ public:
   }
 
   /// count - Returns the number of bits which are set.
-  unsigned count() const {
+  size_type count() const {
     if (isSmall()) {
       uintptr_t Bits = getSmallBits();
-      if (sizeof(uintptr_t) * CHAR_BIT == 32)
-        return CountPopulation_32(Bits);
-      if (sizeof(uintptr_t) * CHAR_BIT == 64)
-        return CountPopulation_64(Bits);
-      assert(0 && "Unsupported!");
+      return countPopulation(Bits);
     }
     return getPointer()->count();
   }
@@ -159,6 +192,13 @@ public:
     return getPointer()->any();
   }
 
+  /// all - Returns true if all bits are set.
+  bool all() const {
+    if (isSmall())
+      return getSmallBits() == (uintptr_t(1) << getSmallSize()) - 1;
+    return getPointer()->all();
+  }
+
   /// none - Returns true if none of the bits are set.
   bool none() const {
     if (isSmall())
@@ -171,14 +211,9 @@ public:
   int find_first() const {
     if (isSmall()) {
       uintptr_t Bits = getSmallBits();
-      if (sizeof(uintptr_t) * CHAR_BIT == 32) {
-        size_t FirstBit = CountTrailingZeros_32(Bits);
-        return FirstBit == 32 ? -1 : FirstBit;
-      } else if (sizeof(uintptr_t) * CHAR_BIT == 64) {
-        size_t FirstBit = CountTrailingZeros_64(Bits);
-        return FirstBit == 64 ? -1 : FirstBit;
-      }
-      assert(0 && "Unsupported!");
+      if (Bits == 0)
+        return -1;
+      return countTrailingZeros(Bits);
     }
     return getPointer()->find_first();
   }
@@ -190,14 +225,9 @@ public:
       uintptr_t Bits = getSmallBits();
       // Mask off previous bits.
       Bits &= ~uintptr_t(0) << (Prev + 1);
-      if (sizeof(uintptr_t) * CHAR_BIT == 32) {
-        size_t FirstBit = CountTrailingZeros_32(Bits);
-        return FirstBit == 32 ? -1 : FirstBit;
-      } else if (sizeof(uintptr_t) * CHAR_BIT == 64) {
-        size_t FirstBit = CountTrailingZeros_64(Bits);
-        return FirstBit == 64 ? -1 : FirstBit;
-      }
-      assert(0 && "Unsupported!");
+      if (Bits == 0 || Prev + 1 >= getSmallSize())
+        return -1;
+      return countTrailingZeros(Bits);
     }
     return getPointer()->find_next(Prev);
   }
@@ -253,13 +283,32 @@ public:
   }
 
   SmallBitVector &set(unsigned Idx) {
-    if (isSmall())
+    if (isSmall()) {
+      assert(Idx <= static_cast<unsigned>(
+                        std::numeric_limits<uintptr_t>::digits) &&
+             "undefined behavior");
       setSmallBits(getSmallBits() | (uintptr_t(1) << Idx));
+    }
     else
       getPointer()->set(Idx);
     return *this;
   }
 
+  /// set - Efficiently set a range of bits in [I, E)
+  SmallBitVector &set(unsigned I, unsigned E) {
+    assert(I <= E && "Attempted to set backwards range!");
+    assert(E <= size() && "Attempted to set out-of-bounds range!");
+    if (I == E) return *this;
+    if (isSmall()) {
+      uintptr_t EMask = ((uintptr_t)1) << E;
+      uintptr_t IMask = ((uintptr_t)1) << I;
+      uintptr_t Mask = EMask - IMask;
+      setSmallBits(getSmallBits() | Mask);
+    } else
+      getPointer()->set(I, E);
+    return *this;
+  }
+
   SmallBitVector &reset() {
     if (isSmall())
       setSmallBits(0);
@@ -276,6 +325,21 @@ public:
     return *this;
   }
 
+  /// reset - Efficiently reset a range of bits in [I, E)
+  SmallBitVector &reset(unsigned I, unsigned E) {
+    assert(I <= E && "Attempted to reset backwards range!");
+    assert(E <= size() && "Attempted to reset out-of-bounds range!");
+    if (I == E) return *this;
+    if (isSmall()) {
+      uintptr_t EMask = ((uintptr_t)1) << E;
+      uintptr_t IMask = ((uintptr_t)1) << I;
+      uintptr_t Mask = EMask - IMask;
+      setSmallBits(getSmallBits() & ~Mask);
+    } else
+      getPointer()->reset(I, E);
+    return *this;
+  }
+
   SmallBitVector &flip() {
     if (isSmall())
       setSmallBits(~getSmallBits());
@@ -298,7 +362,11 @@ public:
   }
 
   // Indexing.
-  // TODO: Add an index operator which returns a "reference" (proxy class).
+  reference operator[](unsigned Idx) {
+    assert(Idx < size() && "Out-of-bounds Bit access.");
+    return reference(*this, Idx);
+  }
+
   bool operator[](unsigned Idx) const {
     assert(Idx < size() && "Out-of-bounds Bit access.");
     if (isSmall())
@@ -310,6 +378,19 @@ public:
     return (*this)[Idx];
   }
 
+  /// Test if any common bits are set.
+  bool anyCommon(const SmallBitVector &RHS) const {
+    if (isSmall() && RHS.isSmall())
+      return (getSmallBits() & RHS.getSmallBits()) != 0;
+    if (!isSmall() && !RHS.isSmall())
+      return getPointer()->anyCommon(*RHS.getPointer());
+
+    for (unsigned i = 0, e = std::min(size(), RHS.size()); i != e; ++i)
+      if (test(i) && RHS.test(i))
+        return true;
+    return false;
+  }
+
   // Comparison operators.
   bool operator==(const SmallBitVector &RHS) const {
     if (size() != RHS.size())
@@ -339,6 +420,40 @@ public:
     return *this;
   }
 
+  /// reset - Reset bits that are set in RHS. Same as *this &= ~RHS.
+  SmallBitVector &reset(const SmallBitVector &RHS) {
+    if (isSmall() && RHS.isSmall())
+      setSmallBits(getSmallBits() & ~RHS.getSmallBits());
+    else if (!isSmall() && !RHS.isSmall())
+      getPointer()->reset(*RHS.getPointer());
+    else
+      for (unsigned i = 0, e = std::min(size(), RHS.size()); i != e; ++i)
+        if (RHS.test(i))
+          reset(i);
+
+    return *this;
+  }
+
+  /// test - Check if (This - RHS) is zero.
+  /// This is the same as reset(RHS) and any().
+  bool test(const SmallBitVector &RHS) const {
+    if (isSmall() && RHS.isSmall())
+      return (getSmallBits() & ~RHS.getSmallBits()) != 0;
+    if (!isSmall() && !RHS.isSmall())
+      return getPointer()->test(*RHS.getPointer());
+
+    unsigned i, e;
+    for (i = 0, e = std::min(size(), RHS.size()); i != e; ++i)
+      if (test(i) && !RHS.test(i))
+        return true;
+
+    for (e = size(); i != e; ++i)
+      if (test(i))
+        return true;
+
+    return false;
+  }
+
   SmallBitVector &operator|=(const SmallBitVector &RHS) {
     resize(std::max(size(), RHS.size()));
     if (isSmall())
@@ -385,9 +500,69 @@ public:
     return *this;
   }
 
+  const SmallBitVector &operator=(SmallBitVector &&RHS) {
+    if (this != &RHS) {
+      clear();
+      swap(RHS);
+    }
+    return *this;
+  }
+
   void swap(SmallBitVector &RHS) {
     std::swap(X, RHS.X);
   }
+
+  /// setBitsInMask - Add '1' bits from Mask to this vector. Don't resize.
+  /// This computes "*this |= Mask".
+  void setBitsInMask(const uint32_t *Mask, unsigned MaskWords = ~0u) {
+    if (isSmall())
+      applyMask<true, false>(Mask, MaskWords);
+    else
+      getPointer()->setBitsInMask(Mask, MaskWords);
+  }
+
+  /// clearBitsInMask - Clear any bits in this vector that are set in Mask.
+  /// Don't resize. This computes "*this &= ~Mask".
+  void clearBitsInMask(const uint32_t *Mask, unsigned MaskWords = ~0u) {
+    if (isSmall())
+      applyMask<false, false>(Mask, MaskWords);
+    else
+      getPointer()->clearBitsInMask(Mask, MaskWords);
+  }
+
+  /// setBitsNotInMask - Add a bit to this vector for every '0' bit in Mask.
+  /// Don't resize.  This computes "*this |= ~Mask".
+  void setBitsNotInMask(const uint32_t *Mask, unsigned MaskWords = ~0u) {
+    if (isSmall())
+      applyMask<true, true>(Mask, MaskWords);
+    else
+      getPointer()->setBitsNotInMask(Mask, MaskWords);
+  }
+
+  /// clearBitsNotInMask - Clear a bit in this vector for every '0' bit in Mask.
+  /// Don't resize.  This computes "*this &= Mask".
+  void clearBitsNotInMask(const uint32_t *Mask, unsigned MaskWords = ~0u) {
+    if (isSmall())
+      applyMask<false, true>(Mask, MaskWords);
+    else
+      getPointer()->clearBitsNotInMask(Mask, MaskWords);
+  }
+
+private:
+  template<bool AddBits, bool InvertMask>
+  void applyMask(const uint32_t *Mask, unsigned MaskWords) {
+    if (NumBaseBits == 64 && MaskWords >= 2) {
+      uint64_t M = Mask[0] | (uint64_t(Mask[1]) << 32);
+      if (InvertMask) M = ~M;
+      if (AddBits) setSmallBits(getSmallBits() | M);
+      else         setSmallBits(getSmallBits() & ~M);
+    } else {
+      uint32_t M = Mask[0];
+      if (InvertMask) M = ~M;
+      if (AddBits) setSmallBits(getSmallBits() | M);
+      else         setSmallBits(getSmallBits() & ~M);
+    }
+  }
 };
 
 inline SmallBitVector