Several fixes:
[oota-llvm.git] / include / llvm / ADT / BitSetVector.h
1 //===-- BitVectorSet.h - A bit-vector representation of sets -----*- C++ -*--=//
2 //
3 // class BitVectorSet --
4 // 
5 // An implementation of the bit-vector representation of sets.
6 // Unlike vector<bool>, this allows much more efficient parallel set
7 // operations on bits, by using the bitset template .  The bitset template
8 // unfortunately can only represent sets with a size chosen at compile-time.
9 // We therefore use a vector of bitsets.  The maxmimum size of our sets
10 // (i.e., the size of the universal set) can be chosen at creation time.
11 //
12 // The size of each Bitset is defined by the macro WORDSIZE.
13 // 
14 // NOTE: The WORDSIZE macro should be made machine-dependent, in order to use
15 // 64-bit words or whatever gives most efficient Bitsets on each platform.
16 // 
17 // 
18 // External functions:
19 // 
20 // bool Disjoint(const BitSetVector& set1, const BitSetVector& set2):
21 //    Tests if two sets have an empty intersection.
22 //    This is more efficient than !(set1 & set2).any().
23 // 
24 //===----------------------------------------------------------------------===//
25
26 #ifndef LLVM_SUPPORT_BITVECTORSET_H
27 #define LLVM_SUPPORT_BITVECTORSET_H
28
29 #include <bitset>
30 #include <vector>
31 #include <functional>
32 #include <iostream>
33
34
35 #define WORDSIZE (32U)
36
37
38 class BitSetVector {
39   // Types used internal to the representation
40   typedef std::bitset<WORDSIZE> bitword;
41   typedef bitword::reference reference;
42   class iterator;
43
44   // Data used in the representation
45   std::vector<bitword> bitsetVec;
46   unsigned maxSize;
47
48 private:
49   // Utility functions for the representation
50   static unsigned NumWords(unsigned Size) { return (Size+WORDSIZE-1)/WORDSIZE;} 
51   static unsigned LastWordSize(unsigned Size) { return Size % WORDSIZE; }
52
53   // Clear the unused bits in the last word.
54   // The unused bits are the high (WORDSIZE - LastWordSize()) bits
55   void ClearUnusedBits() {
56     unsigned long usedBits = (1U << LastWordSize(size())) - 1;
57     bitsetVec.back() &= bitword(usedBits);
58   }
59
60   const bitword& getWord(unsigned i) const { return bitsetVec[i]; }
61         bitword& getWord(unsigned i)       { return bitsetVec[i]; }
62
63   friend bool Disjoint(const BitSetVector& set1,
64                        const BitSetVector& set2);
65
66   BitSetVector();                       // do not implement!
67
68 public:
69   /// 
70   /// Constructor: create a set of the maximum size maxSetSize.
71   /// The set is initialized to empty.
72   ///
73   BitSetVector(unsigned maxSetSize)
74     : bitsetVec(NumWords(maxSetSize)), maxSize(maxSetSize) { }
75
76   /// size - Return the number of bits tracked by this bit vector...
77   unsigned size() const { return maxSize; }
78
79   /// 
80   ///  Modifier methods: reset, set for entire set, operator[] for one element.
81   ///  
82   void reset() {
83     for (unsigned i=0, N = bitsetVec.size(); i < N; ++i)
84       bitsetVec[i].reset();
85   }
86   void set() {
87     for (unsigned i=0, N = bitsetVec.size(); i < N; ++i) // skip last word
88       bitsetVec[i].set();
89     ClearUnusedBits();
90   }
91   reference operator[](unsigned n) {
92     assert(n  < size() && "BitSetVector: Bit number out of range");
93     unsigned ndiv = n / WORDSIZE, nmod = n % WORDSIZE;
94     return bitsetVec[ndiv][nmod];
95   }
96   iterator begin() { return iterator::begin(*this); }
97   iterator end()   { return iterator::end(*this);   } 
98
99   /// 
100   ///  Comparison operations: equal, not equal
101   /// 
102   bool operator == (const BitSetVector& set2) const {
103     assert(maxSize == set2.maxSize && "Illegal == comparison");
104     for (unsigned i = 0; i < bitsetVec.size(); ++i)
105       if (getWord(i) != set2.getWord(i))
106         return false;
107     return true;
108   }
109   bool operator != (const BitSetVector& set2) const {
110     return ! (*this == set2);
111   }
112
113   /// 
114   ///  Set membership operations: single element, any, none, count
115   ///  
116   bool test(unsigned n) const {
117     assert(n  < size() && "BitSetVector: Bit number out of range");
118     unsigned ndiv = n / WORDSIZE, nmod = n % WORDSIZE;
119     return bitsetVec[ndiv].test(nmod);
120   }
121   bool any() const {
122     for (unsigned i = 0; i < bitsetVec.size(); ++i)
123       if (bitsetVec[i].any())
124         return true;
125     return false;
126   }
127   bool none() const {
128     return ! any();
129   }
130   unsigned count() const {
131     unsigned n = 0;
132     for (unsigned i = 0; i < bitsetVec.size(); ++i)
133       n += bitsetVec[i].count();
134     return n;
135   }
136   bool all() const {
137     return (count() == size());
138   }
139
140   /// 
141   ///  Set operations: intersection, union, disjoint union, complement.
142   ///  
143   BitSetVector operator& (const BitSetVector& set2) const {
144     assert(maxSize == set2.maxSize && "Illegal intersection");
145     BitSetVector result(maxSize);
146     for (unsigned i = 0; i < bitsetVec.size(); ++i)
147       result.getWord(i) = getWord(i) & set2.getWord(i);
148     return result;
149   }
150   BitSetVector operator| (const BitSetVector& set2) const {
151     assert(maxSize == set2.maxSize && "Illegal intersection");
152     BitSetVector result(maxSize);
153     for (unsigned i = 0; i < bitsetVec.size(); ++i)
154       result.getWord(i) = getWord(i) | set2.getWord(i);
155     return result;
156   }
157   BitSetVector operator^ (const BitSetVector& set2) const {
158     assert(maxSize == set2.maxSize && "Illegal intersection");
159     BitSetVector result(maxSize);
160     for (unsigned i = 0; i < bitsetVec.size(); ++i)
161       result.getWord(i) = getWord(i) ^ set2.getWord(i);
162     return result;
163   }
164   BitSetVector operator~ () const {
165     BitSetVector result(maxSize);
166     for (unsigned i = 0; i < bitsetVec.size(); ++i)
167       (result.getWord(i) = getWord(i)).flip();
168     result.ClearUnusedBits();
169     return result;
170   }
171
172   /// 
173   ///  Printing and debugging support
174   ///  
175   void print(std::ostream &O) const;
176   void dump() const { print(std::cerr); }
177
178 public:
179   // 
180   // An iterator to enumerate the bits in a BitSetVector.
181   // Eventually, this needs to inherit from bidirectional_iterator.
182   // But this iterator may not be as useful as I once thought and
183   // may just go away.
184   // 
185   class iterator {
186     unsigned   currentBit;
187     unsigned   currentWord;
188     BitSetVector* bitvec;
189     iterator(unsigned B, unsigned W, BitSetVector& _bitvec)
190       : currentBit(B), currentWord(W), bitvec(&_bitvec) { }
191   public:
192     iterator(BitSetVector& _bitvec)
193       : currentBit(0), currentWord(0), bitvec(&_bitvec) { }
194     iterator(const iterator& I)
195       : currentBit(I.currentBit),currentWord(I.currentWord),bitvec(I.bitvec) { }
196     iterator& operator=(const iterator& I) {
197       currentWord == I.currentWord;
198       currentBit == I.currentBit;
199       bitvec = I.bitvec;
200       return *this;
201     }
202
203     // Increment and decrement operators (pre and post)
204     iterator& operator++() {
205       if (++currentBit == WORDSIZE)
206         { currentBit = 0; if (currentWord < bitvec->maxSize) ++currentWord; }
207       return *this;
208     }
209     iterator& operator--() {
210       if (currentBit == 0) {
211         currentBit = WORDSIZE-1;
212         currentWord = (currentWord == 0)? bitvec->maxSize : --currentWord;
213       }
214       else
215         --currentBit;
216       return *this;
217     }
218     iterator operator++(int) { iterator copy(*this); ++*this; return copy; }
219     iterator operator--(int) { iterator copy(*this); --*this; return copy; }
220
221     // Dereferencing operators
222     reference operator*() {
223       assert(currentWord < bitvec->maxSize &&
224              "Dereferencing iterator past the end of a BitSetVector");
225       return bitvec->getWord(currentWord)[currentBit];
226     }
227
228     // Comparison operator
229     bool operator==(const iterator& I) {
230       return (I.bitvec == bitvec &&
231               I.currentWord == currentWord && I.currentBit == currentBit);
232     }
233
234   protected:
235     static iterator begin(BitSetVector& _bitvec) { return iterator(_bitvec); }
236     static iterator end(BitSetVector& _bitvec)   { return iterator(0,
237                                                     _bitvec.maxSize, _bitvec); }
238     friend class BitSetVector;
239   };
240 };
241
242
243 inline void BitSetVector::print(std::ostream& O) const
244 {
245   for (std::vector<bitword>::const_iterator
246          I=bitsetVec.begin(), E=bitsetVec.end(); I != E; ++I)
247     O << "<" << (*I) << ">" << (I+1 == E? "\n" : ", ");
248 }
249
250 inline std::ostream& operator<< (std::ostream& O, const BitSetVector& bset)
251 {
252   bset.print(O);
253   return O;
254 };
255
256
257 ///
258 /// Optimized versions of fundamental comparison operations
259 /// 
260 inline bool Disjoint(const BitSetVector& set1,
261                      const BitSetVector& set2)
262 {
263   assert(set1.size() == set2.size() && "Illegal intersection");
264   for (unsigned i = 0; i < set1.bitsetVec.size(); ++i)
265     if ((set1.getWord(i) & set2.getWord(i)).any())
266       return false;
267   return true;
268 }
269
270 #endif