fa2598a285635aafae6f571a042213edbcdec3d6
[oota-llvm.git] / include / Support / BitSetVector.h
1 //===-- BitVectorSet.h - A bit-vector representation of sets ----*- C++ -*-===//
2 //
3 // This is an implementation of the bit-vector representation of sets.  Unlike
4 // vector<bool>, this allows much more efficient parallel set operations on
5 // bits, by using the bitset template.  The bitset template unfortunately can
6 // only represent sets with a size chosen at compile-time.  We therefore use a
7 // vector of bitsets.  The maxmimum size of our sets (i.e., the size of the
8 // universal set) can be chosen at creation time.
9 //
10 // External functions:
11 // 
12 // bool Disjoint(const BitSetVector& set1, const BitSetVector& set2):
13 //    Tests if two sets have an empty intersection.
14 //    This is more efficient than !(set1 & set2).any().
15 // 
16 //===----------------------------------------------------------------------===//
17
18 #ifndef SUPPORT_BITSETVECTOR_H
19 #define SUPPORT_BITSETVECTOR_H
20
21 #include <bitset>
22 #include <vector>
23 #include <functional>
24 #include <iostream>
25
26 class BitSetVector {
27   enum { BITSET_WORDSIZE = sizeof(long)*8 };
28
29   // Types used internal to the representation
30   typedef std::bitset<BITSET_WORDSIZE> bitword;
31   typedef bitword::reference reference;
32   class iterator;
33
34   // Data used in the representation
35   std::vector<bitword> bitsetVec;
36   unsigned maxSize;
37
38 private:
39   // Utility functions for the representation
40   static unsigned NumWords(unsigned Size) {
41     return (Size+BITSET_WORDSIZE-1)/BITSET_WORDSIZE;
42   } 
43   static unsigned LastWordSize(unsigned Size) { return Size % BITSET_WORDSIZE; }
44
45   // Clear the unused bits in the last word.
46   // The unused bits are the high (BITSET_WORDSIZE - LastWordSize()) bits
47   void ClearUnusedBits() {
48     unsigned long usedBits = (1U << LastWordSize(size())) - 1;
49     bitsetVec.back() &= bitword(usedBits);
50   }
51
52   const bitword& getWord(unsigned i) const { return bitsetVec[i]; }
53         bitword& getWord(unsigned i)       { return bitsetVec[i]; }
54
55   friend bool Disjoint(const BitSetVector& set1,
56                        const BitSetVector& set2);
57
58   BitSetVector();                       // do not implement!
59
60 public:
61   /// 
62   /// Constructor: create a set of the maximum size maxSetSize.
63   /// The set is initialized to empty.
64   ///
65   BitSetVector(unsigned maxSetSize)
66     : bitsetVec(NumWords(maxSetSize)), maxSize(maxSetSize) { }
67
68   /// size - Return the number of bits tracked by this bit vector...
69   unsigned size() const { return maxSize; }
70
71   /// 
72   ///  Modifier methods: reset, set for entire set, operator[] for one element.
73   ///  
74   void reset() {
75     for (unsigned i=0, N = bitsetVec.size(); i < N; ++i)
76       bitsetVec[i].reset();
77   }
78   void set() {
79     for (unsigned i=0, N = bitsetVec.size(); i < N; ++i) // skip last word
80       bitsetVec[i].set();
81     ClearUnusedBits();
82   }
83   reference operator[](unsigned n) {
84     assert(n  < size() && "BitSetVector: Bit number out of range");
85     unsigned ndiv = n / BITSET_WORDSIZE, nmod = n % BITSET_WORDSIZE;
86     return bitsetVec[ndiv][nmod];
87   }
88   iterator begin() { return iterator::begin(*this); }
89   iterator end()   { return iterator::end(*this);   } 
90
91   /// 
92   ///  Comparison operations: equal, not equal
93   /// 
94   bool operator == (const BitSetVector& set2) const {
95     assert(maxSize == set2.maxSize && "Illegal == comparison");
96     for (unsigned i = 0; i < bitsetVec.size(); ++i)
97       if (getWord(i) != set2.getWord(i))
98         return false;
99     return true;
100   }
101   bool operator != (const BitSetVector& set2) const {
102     return ! (*this == set2);
103   }
104
105   /// 
106   ///  Set membership operations: single element, any, none, count
107   ///  
108   bool test(unsigned n) const {
109     assert(n  < size() && "BitSetVector: Bit number out of range");
110     unsigned ndiv = n / BITSET_WORDSIZE, nmod = n % BITSET_WORDSIZE;
111     return bitsetVec[ndiv].test(nmod);
112   }
113   bool any() const {
114     for (unsigned i = 0; i < bitsetVec.size(); ++i)
115       if (bitsetVec[i].any())
116         return true;
117     return false;
118   }
119   bool none() const {
120     return ! any();
121   }
122   unsigned count() const {
123     unsigned n = 0;
124     for (unsigned i = 0; i < bitsetVec.size(); ++i)
125       n += bitsetVec[i].count();
126     return n;
127   }
128   bool all() const {
129     return (count() == size());
130   }
131
132   /// 
133   ///  Set operations: intersection, union, disjoint union, complement.
134   ///  
135   BitSetVector operator& (const BitSetVector& set2) const {
136     assert(maxSize == set2.maxSize && "Illegal intersection");
137     BitSetVector result(maxSize);
138     for (unsigned i = 0; i < bitsetVec.size(); ++i)
139       result.getWord(i) = getWord(i) & set2.getWord(i);
140     return result;
141   }
142   BitSetVector operator| (const BitSetVector& set2) const {
143     assert(maxSize == set2.maxSize && "Illegal intersection");
144     BitSetVector result(maxSize);
145     for (unsigned i = 0; i < bitsetVec.size(); ++i)
146       result.getWord(i) = getWord(i) | set2.getWord(i);
147     return result;
148   }
149   BitSetVector operator^ (const BitSetVector& set2) const {
150     assert(maxSize == set2.maxSize && "Illegal intersection");
151     BitSetVector result(maxSize);
152     for (unsigned i = 0; i < bitsetVec.size(); ++i)
153       result.getWord(i) = getWord(i) ^ set2.getWord(i);
154     return result;
155   }
156   BitSetVector operator~ () const {
157     BitSetVector result(maxSize);
158     for (unsigned i = 0; i < bitsetVec.size(); ++i)
159       (result.getWord(i) = getWord(i)).flip();
160     result.ClearUnusedBits();
161     return result;
162   }
163
164   /// 
165   ///  Printing and debugging support
166   ///  
167   void print(std::ostream &O) const;
168   void dump() const { print(std::cerr); }
169
170 public:
171   // 
172   // An iterator to enumerate the bits in a BitSetVector.
173   // Eventually, this needs to inherit from bidirectional_iterator.
174   // But this iterator may not be as useful as I once thought and
175   // may just go away.
176   // 
177   class iterator {
178     unsigned   currentBit;
179     unsigned   currentWord;
180     BitSetVector* bitvec;
181     iterator(unsigned B, unsigned W, BitSetVector& _bitvec)
182       : currentBit(B), currentWord(W), bitvec(&_bitvec) { }
183   public:
184     iterator(BitSetVector& _bitvec)
185       : currentBit(0), currentWord(0), bitvec(&_bitvec) { }
186     iterator(const iterator& I)
187       : currentBit(I.currentBit),currentWord(I.currentWord),bitvec(I.bitvec) { }
188     iterator& operator=(const iterator& I) {
189       currentWord = I.currentWord;
190       currentBit = I.currentBit;
191       bitvec = I.bitvec;
192       return *this;
193     }
194
195     // Increment and decrement operators (pre and post)
196     iterator& operator++() {
197       if (++currentBit == BITSET_WORDSIZE)
198         { currentBit = 0; if (currentWord < bitvec->size()) ++currentWord; }
199       return *this;
200     }
201     iterator& operator--() {
202       if (currentBit == 0) {
203         currentBit = BITSET_WORDSIZE-1;
204         currentWord = (currentWord == 0)? bitvec->size() : --currentWord;
205       }
206       else
207         --currentBit;
208       return *this;
209     }
210     iterator operator++(int) { iterator copy(*this); ++*this; return copy; }
211     iterator operator--(int) { iterator copy(*this); --*this; return copy; }
212
213     // Dereferencing operators
214     reference operator*() {
215       assert(currentWord < bitvec->size() &&
216              "Dereferencing iterator past the end of a BitSetVector");
217       return bitvec->getWord(currentWord)[currentBit];
218     }
219
220     // Comparison operator
221     bool operator==(const iterator& I) {
222       return (I.bitvec == bitvec &&
223               I.currentWord == currentWord && I.currentBit == currentBit);
224     }
225
226   protected:
227     static iterator begin(BitSetVector& _bitvec) { return iterator(_bitvec); }
228     static iterator end(BitSetVector& _bitvec)   { return iterator(0,
229                                                     _bitvec.size(), _bitvec); }
230     friend class BitSetVector;
231   };
232 };
233
234
235 inline void BitSetVector::print(std::ostream& O) const
236 {
237   for (std::vector<bitword>::const_iterator
238          I=bitsetVec.begin(), E=bitsetVec.end(); I != E; ++I)
239     O << "<" << (*I) << ">" << (I+1 == E? "\n" : ", ");
240 }
241
242 inline std::ostream& operator<< (std::ostream& O, const BitSetVector& bset)
243 {
244   bset.print(O);
245   return O;
246 };
247
248
249 ///
250 /// Optimized versions of fundamental comparison operations
251 /// 
252 inline bool Disjoint(const BitSetVector& set1,
253                      const BitSetVector& set2)
254 {
255   assert(set1.size() == set2.size() && "Illegal intersection");
256   for (unsigned i = 0; i < set1.bitsetVec.size(); ++i)
257     if ((set1.getWord(i) & set2.getWord(i)).any())
258       return false;
259   return true;
260 }
261
262 #endif