Renamed internal method "Create" of ImutAVLTree to "CreateNode".
[oota-llvm.git] / include / llvm / ADT / ImmutableSet.h
1 //===--- ImmutableSet.h - Immutable (functional) set interface --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Ted Kremenek and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file defines the ImutAVLTree and ImmutableSet classes.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef LLVM_ADT_IMSET_H
15 #define LLVM_ADT_IMSET_H
16
17 #include "llvm/Support/Allocator.h"
18 #include "llvm/ADT/FoldingSet.h"
19 #include <cassert>
20
21 namespace llvm {
22   
23 //===----------------------------------------------------------------------===//    
24 // Immutable AVL-Tree Definition.
25 //===----------------------------------------------------------------------===//
26
27 template <typename ImutInfo> class ImutAVLFactory;
28
29
30 template <typename ImutInfo >
31 class ImutAVLTree : public FoldingSetNode {
32   struct ComputeIsEqual;
33 public:
34   typedef typename ImutInfo::key_type_ref   key_type_ref;
35   typedef typename ImutInfo::value_type     value_type;
36   typedef typename ImutInfo::value_type_ref value_type_ref;
37   typedef ImutAVLFactory<ImutInfo>          Factory;
38   
39   friend class ImutAVLFactory<ImutInfo>;
40   
41   //===----------------------------------------------------===//  
42   // Public Interface.
43   //===----------------------------------------------------===//  
44   
45   ImutAVLTree* getLeft() const { return reinterpret_cast<ImutAVLTree*>(Left); }  
46   
47   ImutAVLTree* getRight() const { return Right; }  
48   
49   unsigned getHeight() const { return Height; }  
50   
51   const value_type& getValue() const { return Value; }
52   
53   ImutAVLTree* find(key_type_ref K) {
54     ImutAVLTree *T = this;
55     
56     while (T) {
57       key_type_ref CurrentKey = ImutInfo::KeyOfValue(T->getValue());
58       
59       if (ImutInfo::isEqual(K,CurrentKey))
60         return T;
61       else if (ImutInfo::isLess(K,CurrentKey))
62         T = T->getLeft();
63       else
64         T = T->getRight();
65     }
66     
67     return NULL;
68   }
69   
70   unsigned size() const {
71     unsigned n = 1;
72     
73     if (const ImutAVLTree* L = getLeft())  n += L->size();
74     if (const ImutAVLTree* R = getRight()) n += R->size();
75     
76     return n;
77   }
78   
79   
80   bool isEqual(const ImutAVLTree& RHS) const {
81     // FIXME: Todo.
82     return true;    
83   }
84   
85   bool isNotEqual(const ImutAVLTree& RHS) const { return !isEqual(RHS); }
86   
87   bool contains(const key_type_ref K) { return (bool) find(K); }
88   
89   template <typename Callback>
90   void foreach(Callback& C) {
91     if (ImutAVLTree* L = getLeft()) L->foreach(C);
92     
93     C(Value);    
94     
95     if (ImutAVLTree* R = getRight()) R->foreach(C);
96   }
97   
98   unsigned verify() const {
99     unsigned HL = getLeft() ? getLeft()->verify() : 0;
100     unsigned HR = getRight() ? getRight()->verify() : 0;
101     
102     assert (getHeight() == ( HL > HR ? HL : HR ) + 1 
103             && "Height calculation wrong.");
104     
105     assert ((HL > HR ? HL-HR : HR-HL) <= 2
106             && "Balancing invariant violated.");
107     
108     
109     assert (!getLeft()
110             || ImutInfo::isLess(ImutInfo::KeyOfValue(getLeft()->getValue()),
111                                 ImutInfo::KeyOfValue(getValue()))
112             && "Value in left child is not less that current value.");
113     
114     
115     assert (!getRight()
116             || ImutInfo::isLess(ImutInfo::KeyOfValue(getValue()),
117                                 ImutInfo::KeyOfValue(getRight()->getValue()))
118             && "Current value is not less that value of right child.");
119     
120     return getHeight();
121   }  
122   
123   //===----------------------------------------------------===//  
124   // Internal Values.
125   //===----------------------------------------------------===//
126   
127 private:
128   uintptr_t        Left;
129   ImutAVLTree*     Right;
130   unsigned         Height;
131   value_type       Value;
132   
133   //===----------------------------------------------------===//  
134   // Profiling or FoldingSet.
135   //===----------------------------------------------------===//
136   
137   static inline
138   void Profile(FoldingSetNodeID& ID, ImutAVLTree* L, ImutAVLTree* R,
139                unsigned H, value_type_ref V) {    
140     ID.AddPointer(L);
141     ID.AddPointer(R);
142     ID.AddInteger(H);
143     ImutInfo::Profile(ID,V);
144   }
145   
146 public:
147   
148   void Profile(FoldingSetNodeID& ID) {
149     Profile(ID,getSafeLeft(),getRight(),getHeight(),getValue());    
150   }
151   
152   //===----------------------------------------------------===//    
153   // Internal methods (node manipulation; used by Factory).
154   //===----------------------------------------------------===//
155   
156 private:
157   
158   ImutAVLTree(ImutAVLTree* l, ImutAVLTree* r, value_type_ref v, unsigned height)
159   : Left(reinterpret_cast<uintptr_t>(l) | 0x1),
160   Right(r), Height(height), Value(v) {}
161   
162   bool isMutable() const { return Left & 0x1; }
163   
164   ImutAVLTree* getSafeLeft() const { 
165     return reinterpret_cast<ImutAVLTree*>(Left & ~0x1);
166   }
167   
168   // Mutating operations.  A tree root can be manipulated as long as
169   // its reference has not "escaped" from internal methods of a
170   // factory object (see below).  When a tree pointer is externally
171   // viewable by client code, the internal "mutable bit" is cleared
172   // to mark the tree immutable.  Note that a tree that still has
173   // its mutable bit set may have children (subtrees) that are themselves
174   // immutable.
175   
176   void RemoveMutableFlag() {
177     assert (Left & 0x1 && "Mutable flag already removed.");
178     Left &= ~0x1;
179   }
180   
181   void setLeft(ImutAVLTree* NewLeft) {
182     assert (isMutable());
183     Left = reinterpret_cast<uintptr_t>(NewLeft) | 0x1;
184   }
185   
186   void setRight(ImutAVLTree* NewRight) {
187     assert (isMutable());
188     Right = NewRight;
189   }
190   
191   void setHeight(unsigned h) {
192     assert (isMutable());
193     Height = h;
194   }
195 };
196
197 //===----------------------------------------------------------------------===//    
198 // Immutable AVL-Tree Factory class.
199 //===----------------------------------------------------------------------===//
200
201 template <typename ImutInfo >  
202 class ImutAVLFactory {
203   typedef ImutAVLTree<ImutInfo> TreeTy;
204   typedef typename TreeTy::value_type_ref value_type_ref;
205   typedef typename TreeTy::key_type_ref   key_type_ref;
206   
207   typedef FoldingSet<TreeTy> CacheTy;
208   
209   CacheTy Cache;  
210   BumpPtrAllocator Allocator;    
211   
212   //===--------------------------------------------------===//    
213   // Public interface.
214   //===--------------------------------------------------===//
215   
216 public:
217   ImutAVLFactory() {}
218   
219   TreeTy* Add(TreeTy* T, value_type_ref V) {
220     T = Add_internal(V,T);
221     MarkImmutable(T);
222     return T;
223   }
224   
225   TreeTy* Remove(TreeTy* T, key_type_ref V) {
226     T = Remove_internal(V,T);
227     MarkImmutable(T);
228     return T;
229   }
230   
231   TreeTy* GetEmptyTree() const { return NULL; }
232   
233   //===--------------------------------------------------===//    
234   // A bunch of quick helper functions used for reasoning
235   // about the properties of trees and their children.
236   // These have succinct names so that the balancing code
237   // is as terse (and readable) as possible.
238   //===--------------------------------------------------===//
239 private:
240   
241   bool isEmpty(TreeTy* T) const {
242     return !T;
243   }
244   
245   unsigned Height(TreeTy* T) const {
246     return T ? T->getHeight() : 0;
247   }
248   
249   TreeTy* Left(TreeTy* T) const {
250     assert (T);
251     return T->getSafeLeft();
252   }
253   
254   TreeTy* Right(TreeTy* T) const {
255     assert (T);
256     return T->getRight();
257   }
258   
259   value_type_ref Value(TreeTy* T) const {
260     assert (T);
261     return T->Value;
262   }
263   
264   unsigned IncrementHeight(TreeTy* L, TreeTy* R) const {
265     unsigned hl = Height(L);
266     unsigned hr = Height(R);
267     return ( hl > hr ? hl : hr ) + 1;
268   }
269   
270   //===--------------------------------------------------===//    
271   // "CreateNode" is used to generate new tree roots that link
272   // to other trees.  The functon may also simply move links
273   // in an existing root if that root is still marked mutable.
274   // This is necessary because otherwise our balancing code
275   // would leak memory as it would create nodes that are
276   // then discarded later before the finished tree is
277   // returned to the caller.
278   //===--------------------------------------------------===//
279   
280   TreeTy* CreateNode(TreeTy* L, value_type_ref V, TreeTy* R) {
281     FoldingSetNodeID ID;      
282     unsigned height = IncrementHeight(L,R);
283     
284     TreeTy::Profile(ID,L,R,height,V);      
285     void* InsertPos;
286     
287     if (TreeTy* T = Cache.FindNodeOrInsertPos(ID,InsertPos))
288       return T;
289     
290     assert (InsertPos != NULL);
291     
292     // FIXME: more intelligent calculation of alignment.
293     TreeTy* T = (TreeTy*) Allocator.Allocate(sizeof(*T),16);
294     new (T) TreeTy(L,R,V,height);
295     
296     Cache.InsertNode(T,InsertPos);
297     return T;      
298   }
299   
300   TreeTy* CreateNode(TreeTy* L, TreeTy* OldTree, TreeTy* R) {      
301     assert (!isEmpty(OldTree));
302     
303     if (OldTree->isMutable()) {
304       OldTree->setLeft(L);
305       OldTree->setRight(R);
306       OldTree->setHeight(IncrementHeight(L,R));
307       return OldTree;
308     }
309     else return CreateNode(L, Value(OldTree), R);
310   }
311   
312   /// Balance - Used by Add_internal and Remove_internal to
313   ///  balance a newly created tree.
314   TreeTy* Balance(TreeTy* L, value_type_ref V, TreeTy* R) {
315     
316     unsigned hl = Height(L);
317     unsigned hr = Height(R);
318     
319     if (hl > hr + 2) {
320       assert (!isEmpty(L) &&
321               "Left tree cannot be empty to have a height >= 2.");
322       
323       TreeTy* LL = Left(L);
324       TreeTy* LR = Right(L);
325       
326       if (Height(LL) >= Height(LR))
327         return CreateNode(LL, L, CreateNode(LR,V,R));
328       
329       assert (!isEmpty(LR) &&
330               "LR cannot be empty because it has a height >= 1.");
331       
332       TreeTy* LRL = Left(LR);
333       TreeTy* LRR = Right(LR);
334       
335       return CreateNode(CreateNode(LL,L,LRL), LR, CreateNode(LRR,V,R));                              
336     }
337     else if (hr > hl + 2) {
338       assert (!isEmpty(R) &&
339               "Right tree cannot be empty to have a height >= 2.");
340       
341       TreeTy* RL = Left(R);
342       TreeTy* RR = Right(R);
343       
344       if (Height(RR) >= Height(RL))
345         return CreateNode(CreateNode(L,V,RL), R, RR);
346       
347       assert (!isEmpty(RL) &&
348               "RL cannot be empty because it has a height >= 1.");
349       
350       TreeTy* RLL = Left(RL);
351       TreeTy* RLR = Right(RL);
352       
353       return CreateNode(CreateNode(L,V,RLL), RL, CreateNode(RLR,R,RR));
354     }
355     else
356       return CreateNode(L,V,R);
357   }
358   
359   /// Add_internal - Creates a new tree that includes the specified
360   ///  data and the data from the original tree.  If the original tree
361   ///  already contained the data item, the original tree is returned.
362   TreeTy* Add_internal(value_type_ref V, TreeTy* T) {
363     if (isEmpty(T))
364       return CreateNode(T, V, T);
365     
366     assert (!T->isMutable());
367     
368     key_type_ref K = ImutInfo::KeyOfValue(V);
369     key_type_ref KCurrent = ImutInfo::KeyOfValue(Value(T));
370     
371     if (ImutInfo::isEqual(K,KCurrent))
372       return CreateNode(Left(T), V, Right(T));
373     else if (ImutInfo::isLess(K,KCurrent))
374       return Balance(Add_internal(V,Left(T)), Value(T), Right(T));
375     else
376       return Balance(Left(T), Value(T), Add_internal(V,Right(T)));
377   }
378   
379   /// Remove_interal - Creates a new tree that includes all the data
380   ///  from the original tree except the specified data.  If the
381   ///  specified data did not exist in the original tree, the original
382   ///  tree is returned.
383   TreeTy* Remove_internal(key_type_ref K, TreeTy* T) {
384     if (isEmpty(T))
385       return T;
386     
387     assert (!T->isMutable());
388     
389     key_type_ref KCurrent = ImutInfo::KeyOfValue(Value(T));
390     
391     if (ImutInfo::isEqual(K,KCurrent))
392       return CombineLeftRightTrees(Left(T),Right(T));
393     else if (ImutInfo::isLess(K,KCurrent))
394       return Balance(Remove_internal(K,Left(T)), Value(T), Right(T));
395     else
396       return Balance(Left(T), Value(T), Remove_internal(K,Right(T)));
397   }
398   
399   TreeTy* CombineLeftRightTrees(TreeTy* L, TreeTy* R) {
400     if (isEmpty(L)) return R;      
401     if (isEmpty(R)) return L;
402     
403     TreeTy* OldNode;          
404     TreeTy* NewRight = RemoveMinBinding(R,OldNode);
405     return Balance(L,Value(OldNode),NewRight);
406   }
407   
408   TreeTy* RemoveMinBinding(TreeTy* T, TreeTy*& NodeRemoved) {
409     assert (!isEmpty(T));
410     
411     if (isEmpty(Left(T))) {
412       NodeRemoved = T;
413       return Right(T);
414     }
415     
416     return Balance(RemoveMinBinding(Left(T),NodeRemoved),Value(T),Right(T));
417   }    
418   
419   /// MarkImmutable - Clears the mutable bits of a root and all of its
420   ///  descendants.
421   void MarkImmutable(TreeTy* T) {
422     if (!T || !T->isMutable())
423       return;
424     
425     T->RemoveMutableFlag();
426     MarkImmutable(Left(T));
427     MarkImmutable(Right(T));
428   }
429 };
430
431
432 //===----------------------------------------------------------------------===//    
433 // Trait classes for Profile information.
434 //===----------------------------------------------------------------------===//
435
436 /// Generic profile template.  The default behavior is to invoke the
437 /// profile method of an object.  Specializations for primitive integers
438 /// and generic handling of pointers is done below.
439 template <typename T>
440 struct ImutProfileInfo {
441   typedef const T  value_type;
442   typedef const T& value_type_ref;
443   
444   static inline void Profile(FoldingSetNodeID& ID, value_type_ref X) {
445     X.Profile(ID);
446   }  
447 };
448
449 /// Profile traits for integers.
450 template <typename T>
451 struct ImutProfileInteger {    
452   typedef const T  value_type;
453   typedef const T& value_type_ref;
454   
455   static inline void Profile(FoldingSetNodeID& ID, value_type_ref X) {
456     ID.AddInteger(X);
457   }  
458 };
459
460 #define PROFILE_INTEGER_INFO(X)\
461 template<> struct ImutProfileInfo<X> : ImutProfileInteger<X> {};
462
463 PROFILE_INTEGER_INFO(char)
464 PROFILE_INTEGER_INFO(unsigned char)
465 PROFILE_INTEGER_INFO(short)
466 PROFILE_INTEGER_INFO(unsigned short)
467 PROFILE_INTEGER_INFO(unsigned)
468 PROFILE_INTEGER_INFO(signed)
469 PROFILE_INTEGER_INFO(long)
470 PROFILE_INTEGER_INFO(unsigned long)
471 PROFILE_INTEGER_INFO(long long)
472 PROFILE_INTEGER_INFO(unsigned long long)
473
474 #undef PROFILE_INTEGER_INFO
475
476 /// Generic profile trait for pointer types.  We treat pointers as
477 /// references to unique objects.
478 template <typename T>
479 struct ImutProfileInfo<T*> {
480   typedef const T*   value_type;
481   typedef value_type value_type_ref;
482   
483   static inline void Profile(FoldingSetNodeID &ID, value_type_ref X) {
484     ID.AddPointer(X);
485   }
486 };
487
488 //===----------------------------------------------------------------------===//    
489 // Trait classes that contain element comparison operators and type
490 //  definitions used by ImutAVLTree, ImmutableSet, and ImmutableMap.  These
491 //  inherit from the profile traits (ImutProfileInfo) to include operations
492 //  for element profiling.
493 //===----------------------------------------------------------------------===//
494
495
496 /// ImutContainerInfo - Generic definition of comparison operations for
497 ///   elements of immutable containers that defaults to using
498 ///   std::equal_to<> and std::less<> to perform comparison of elements.
499 template <typename T>
500 struct ImutContainerInfo : public ImutProfileInfo<T> {
501   typedef typename ImutProfileInfo<T>::value_type      value_type;
502   typedef typename ImutProfileInfo<T>::value_type_ref  value_type_ref;
503   typedef value_type      key_type;
504   typedef value_type_ref  key_type_ref;
505   
506   static inline key_type_ref KeyOfValue(value_type_ref D) { return D; }
507   
508   static inline bool isEqual(key_type_ref LHS, key_type_ref RHS) { 
509     return std::equal_to<key_type>()(LHS,RHS);
510   }
511   
512   static inline bool isLess(key_type_ref LHS, key_type_ref RHS) {
513     return std::less<key_type>()(LHS,RHS);
514   }
515 };
516
517 /// ImutContainerInfo - Specialization for pointer values to treat pointers
518 ///  as references to unique objects.  Pointers are thus compared by
519 ///  their addresses.
520 template <typename T>
521 struct ImutContainerInfo<T*> : public ImutProfileInfo<T*> {
522   typedef typename ImutProfileInfo<T*>::value_type      value_type;
523   typedef typename ImutProfileInfo<T*>::value_type_ref  value_type_ref;
524   typedef value_type      key_type;
525   typedef value_type_ref  key_type_ref;
526   
527   static inline key_type_ref KeyOfValue(value_type_ref D) { return D; }
528   
529   static inline bool isEqual(key_type_ref LHS, key_type_ref RHS) {
530     return LHS == RHS;
531   }
532   
533   static inline bool isLess(key_type_ref LHS, key_type_ref RHS) {
534     return LHS < RHS;
535   }
536 };
537
538 //===----------------------------------------------------------------------===//    
539 // Immutable Set
540 //===----------------------------------------------------------------------===//
541
542 template <typename ValT, typename ValInfo = ImutContainerInfo<ValT> >
543 class ImmutableSet {
544 public:
545   typedef typename ValInfo::value_type      value_type;
546   typedef typename ValInfo::value_type_ref  value_type_ref;
547   
548 private:  
549   typedef ImutAVLTree<ValInfo> TreeTy;
550   TreeTy* Root;
551   
552   ImmutableSet(TreeTy* R) : Root(R) {}
553   
554 public:
555   
556   class Factory {
557     typename TreeTy::Factory F;
558     
559   public:
560     Factory() {}
561     
562     ImmutableSet GetEmptySet() { return ImmutableSet(F.GetEmptyTree()); }
563     
564     ImmutableSet Add(ImmutableSet Old, value_type_ref V) {
565       return ImmutableSet(F.Add(Old.Root,V));
566     }
567     
568     ImmutableSet Remove(ImmutableSet Old, value_type_ref V) {
569       return ImmutableSet(F.Remove(Old.Root,V));
570     }
571     
572   private:
573     Factory(const Factory& RHS) {};
574     void operator=(const Factory& RHS) {};    
575   };
576   
577   friend class Factory;
578   
579   bool contains(const value_type_ref V) const {
580     return Root ? Root->contains(V) : false;
581   }
582   
583   bool operator==(ImmutableSet RHS) const {
584     return Root && RHS.Root ? Root->isEqual(*RHS.Root) : Root == RHS.Root;
585   }
586   
587   bool operator!=(ImmutableSet RHS) const {
588     return Root && RHS.Root ? Root->isNotEqual(*RHS.Root) : Root != RHS.Root;
589   }
590   
591   bool isEmpty() const { return !Root; }
592   
593   template <typename Callback>
594   void foreach(Callback& C) { if (Root) Root->foreach(C); }
595   
596   template <typename Callback>
597   void foreach() { if (Root) { Callback C; Root->foreach(C); } }
598   
599   //===--------------------------------------------------===//    
600   // For testing.
601   //===--------------------------------------------------===//  
602   
603   void verify() const { if (Root) Root->verify(); }
604   unsigned getHeight() const { return Root ? Root->getHeight() : 0; }
605 };
606
607 } // end namespace llvm
608
609 #endif