Added implementation of immutable (functional) maps and sets, as
[oota-llvm.git] / include / llvm / ADT / ImmutableSet.h
1 //===--- ImmutableSet.h - Immutable (functional) set interface --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Ted Kremenek and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file defines the ImutAVLTree and ImmutableSet classes.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef LLVM_ADT_IMSET_H
15 #define LLVM_ADT_IMSET_H
16
17 #include "llvm/Support/Allocator.h"
18 #include "llvm/ADT/FoldingSet.h"
19 #include <cassert>
20
21 namespace llvm {
22   
23 //===----------------------------------------------------------------------===//    
24 // Immutable AVL-Tree Definition.
25 //===----------------------------------------------------------------------===//
26
27 template <typename ImutInfo> class ImutAVLFactory;
28
29 template <typename ImutInfo >
30 class ImutAVLTree : public FoldingSetNode {
31   struct ComputeIsEqual;
32 public:
33   typedef typename ImutInfo::key_type_ref   key_type_ref;
34   typedef typename ImutInfo::value_type     value_type;
35   typedef typename ImutInfo::value_type_ref value_type_ref;
36   typedef ImutAVLFactory<ImutInfo>          Factory;
37   friend class ImutAVLFactory<ImutInfo>;
38   
39   //===----------------------------------------------------===//  
40   // Public Interface.
41   //===----------------------------------------------------===//  
42   
43   ImutAVLTree* getLeft() const { return reinterpret_cast<ImutAVLTree*>(Left); }  
44   
45   ImutAVLTree* getRight() const { return Right; }  
46   
47   unsigned getHeight() const { return Height; }  
48   
49   const value_type& getValue() const { return Value; }
50   
51   ImutAVLTree* find(key_type_ref K) {
52     ImutAVLTree *T = this;
53     
54     while (T) {
55       key_type_ref CurrentKey = ImutInfo::KeyOfValue(Value(T));
56       
57       if (ImutInfo::isEqual(K,CurrentKey))
58         return T;
59       else if (ImutInfo::isLess(K,CurrentKey))
60         T = T->getLeft();
61       else
62         T = T->getRight();
63     }
64     
65     return NULL;
66   }
67   
68   unsigned size() const {
69     unsigned n = 1;
70     
71     if (const ImutAVLTree* L = getLeft())  n += L->size();
72     if (const ImutAVLTree* R = getRight()) n += R->size();
73     
74     return n;
75   }
76   
77   
78   
79   bool isEqual(const ImutAVLTree& RHS) const {
80     // FIXME: Todo.
81     return true;    
82   }
83   
84   bool isNotEqual(const ImutAVLTree& RHS) const { return !isEqual(RHS); }
85   
86   bool contains(const key_type_ref K) { return (bool) find(K); }
87   
88   template <typename Callback>
89   void foreach(Callback& C) {
90     if (ImutAVLTree* L = getLeft()) L->foreach(C);
91     
92     C(Value);    
93     
94     if (ImutAVLTree* R = getRight()) R->foreach(C);
95   }
96   
97   unsigned verify() const {
98     unsigned HL = getLeft() ? getLeft()->verify() : 0;
99     unsigned HR = getRight() ? getRight()->verify() : 0;
100     
101     assert (getHeight() == ( HL > HR ? HL : HR ) + 1 
102             && "Height calculation wrong.");
103     
104     assert ((HL > HR ? HL-HR : HR-HL) <= 2
105             && "Balancing invariant violated.");
106     
107     
108     assert (!getLeft()
109             || ImutInfo::isLess(ImutInfo::KeyOfValue(getLeft()->getValue()),
110                                 ImutInfo::KeyOfValue(getValue()))
111             && "Value in left child is not less that current value.");
112     
113     
114     assert (!getRight()
115             || ImutInfo::isLess(ImutInfo::KeyOfValue(getValue()),
116                                 ImutInfo::KeyOfValue(getRight()->getValue()))
117             && "Current value is not less that value of right child.");
118     
119     return getHeight();
120   }  
121   
122   //===----------------------------------------------------===//  
123   // Internal Values.
124   //===----------------------------------------------------===//
125   
126 private:
127   uintptr_t        Left;
128   ImutAVLTree*     Right;
129   unsigned         Height;
130   value_type       Value;
131   
132   //===----------------------------------------------------===//  
133   // Profiling or FoldingSet.
134   //===----------------------------------------------------===//
135   
136   static inline
137   void Profile(FoldingSetNodeID& ID, ImutAVLTree* L, ImutAVLTree* R,
138                unsigned H, value_type_ref V) {    
139     ID.AddPointer(L);
140     ID.AddPointer(R);
141     ID.AddInteger(H);
142     ImutInfo::Profile(ID,V);
143   }
144   
145 public:
146   
147   void Profile(FoldingSetNodeID& ID) {
148     Profile(ID,getSafeLeft(),getRight(),getHeight(),getValue());    
149   }
150   
151   //===----------------------------------------------------===//    
152   // Internal methods (node manipulation; used by Factory).
153   //===----------------------------------------------------===//
154   
155 private:
156   
157   ImutAVLTree(ImutAVLTree* l, ImutAVLTree* r, value_type_ref v, unsigned height)
158   : Left(reinterpret_cast<uintptr_t>(l) | 0x1),
159   Right(r), Height(height), Value(v) {}
160   
161   bool isMutable() const { return Left & 0x1; }
162   
163   ImutAVLTree* getSafeLeft() const { 
164     return reinterpret_cast<ImutAVLTree*>(Left & ~0x1);
165   }
166   
167   // Mutating operations.  A tree root can be manipulated as long as
168   // its reference has not "escaped" from internal methods of a
169   // factory object (see below).  When a tree pointer is externally
170   // viewable by client code, the internal "mutable bit" is cleared
171   // to mark the tree immutable.  Note that a tree that still has
172   // its mutable bit set may have children (subtrees) that are themselves
173   // immutable.
174   
175   void RemoveMutableFlag() {
176     assert (Left & 0x1 && "Mutable flag already removed.");
177     Left &= ~0x1;
178   }
179   
180   void setLeft(ImutAVLTree* NewLeft) {
181     assert (isMutable());
182     Left = reinterpret_cast<uintptr_t>(NewLeft) | 0x1;
183   }
184   
185   void setRight(ImutAVLTree* NewRight) {
186     assert (isMutable());
187     Right = NewRight;
188   }
189   
190   void setHeight(unsigned h) {
191     assert (isMutable());
192     Height = h;
193   }
194 };
195
196 //===----------------------------------------------------------------------===//    
197 // Immutable AVL-Tree Factory class.
198 //===----------------------------------------------------------------------===//
199
200 template <typename ImutInfo >  
201 class ImutAVLFactory {
202   typedef ImutAVLTree<ImutInfo> TreeTy;
203   typedef typename TreeTy::value_type_ref value_type_ref;
204   typedef typename TreeTy::key_type_ref   key_type_ref;
205   
206   typedef FoldingSet<TreeTy> CacheTy;
207   
208   CacheTy Cache;  
209   BumpPtrAllocator Allocator;    
210   
211   //===--------------------------------------------------===//    
212   // Public interface.
213   //===--------------------------------------------------===//
214   
215 public:
216   ImutAVLFactory() {}
217   
218   TreeTy* Add(TreeTy* T, value_type_ref V) {
219     T = Add_internal(V,T);
220     MarkImmutable(T);
221     return T;
222   }
223   
224   TreeTy* Remove(TreeTy* T, key_type_ref V) {
225     T = Remove_internal(V,T);
226     MarkImmutable(T);
227     return T;
228   }
229   
230   TreeTy* GetEmptyTree() const { return NULL; }
231   
232   //===--------------------------------------------------===//    
233   // A bunch of quick helper functions used for reasoning
234   // about the properties of trees and their children.
235   // These have succinct names so that the balancing code
236   // is as terse (and readable) as possible.
237   //===--------------------------------------------------===//
238 private:
239   
240   bool isEmpty(TreeTy* T) const {
241     return !T;
242   }
243   
244   unsigned Height(TreeTy* T) const {
245     return T ? T->getHeight() : 0;
246   }
247   
248   TreeTy* Left(TreeTy* T) const {
249     assert (T);
250     return T->getSafeLeft();
251   }
252   
253   TreeTy* Right(TreeTy* T) const {
254     assert (T);
255     return T->getRight();
256   }
257   
258   value_type_ref Value(TreeTy* T) const {
259     assert (T);
260     return T->Value;
261   }
262   
263   unsigned IncrementHeight(TreeTy* L, TreeTy* R) const {
264     unsigned hl = Height(L);
265     unsigned hr = Height(R);
266     return ( hl > hr ? hl : hr ) + 1;
267   }
268   
269   //===--------------------------------------------------===//    
270   // "Create" is used to generate new tree roots that link
271   // to other trees.  The functon may also simply move links
272   // in an existing root if that root is still marked mutable.
273   // This is necessary because otherwise our balancing code
274   // would leak memory as it would create nodes that are
275   // then discarded later before the finished tree is
276   // returned to the caller.
277   //===--------------------------------------------------===//
278   
279   TreeTy* Create(TreeTy* L, value_type_ref V, TreeTy* R) {
280     FoldingSetNodeID ID;      
281     unsigned height = IncrementHeight(L,R);
282     
283     TreeTy::Profile(ID,L,R,height,V);      
284     void* InsertPos;
285     
286     if (TreeTy* T = Cache.FindNodeOrInsertPos(ID,InsertPos))
287       return T;
288     
289     assert (InsertPos != NULL);
290     
291     // FIXME: more intelligent calculation of alignment.
292     TreeTy* T = (TreeTy*) Allocator.Allocate(sizeof(*T),16);
293     new (T) TreeTy(L,R,V,height);
294     
295     Cache.InsertNode(T,InsertPos);
296     return T;      
297   }
298   
299   TreeTy* Create(TreeTy* L, TreeTy* OldTree, TreeTy* R) {      
300     assert (!isEmpty(OldTree));
301     
302     if (OldTree->isMutable()) {
303       OldTree->setLeft(L);
304       OldTree->setRight(R);
305       OldTree->setHeight(IncrementHeight(L,R));
306       return OldTree;
307     }
308     else return Create(L, Value(OldTree), R);
309   }
310   
311   /// Balance - Used by Add_internal and Remove_internal to
312   ///  balance a newly created tree.
313   TreeTy* Balance(TreeTy* L, value_type_ref V, TreeTy* R) {
314     
315     unsigned hl = Height(L);
316     unsigned hr = Height(R);
317     
318     if (hl > hr + 2) {
319       assert (!isEmpty(L) &&
320               "Left tree cannot be empty to have a height >= 2.");
321       
322       TreeTy* LL = Left(L);
323       TreeTy* LR = Right(L);
324       
325       if (Height(LL) >= Height(LR))
326         return Create(LL, L, Create(LR,V,R));
327       
328       assert (!isEmpty(LR) &&
329               "LR cannot be empty because it has a height >= 1.");
330       
331       TreeTy* LRL = Left(LR);
332       TreeTy* LRR = Right(LR);
333       
334       return Create(Create(LL,L,LRL), LR, Create(LRR,V,R));                              
335     }
336     else if (hr > hl + 2) {
337       assert (!isEmpty(R) &&
338               "Right tree cannot be empty to have a height >= 2.");
339       
340       TreeTy* RL = Left(R);
341       TreeTy* RR = Right(R);
342       
343       if (Height(RR) >= Height(RL))
344         return Create(Create(L,V,RL), R, RR);
345       
346       assert (!isEmpty(RL) &&
347               "RL cannot be empty because it has a height >= 1.");
348       
349       TreeTy* RLL = Left(RL);
350       TreeTy* RLR = Right(RL);
351       
352       return Create(Create(L,V,RLL), RL, Create(RLR,R,RR));
353     }
354     else
355       return Create(L,V,R);
356   }
357   
358   /// Add_internal - Creates a new tree that includes the specified
359   ///  data and the data from the original tree.  If the original tree
360   ///  already contained the data item, the original tree is returned.
361   TreeTy* Add_internal(value_type_ref V, TreeTy* T) {
362     if (isEmpty(T))
363       return Create(T, V, T);
364     
365     assert (!T->isMutable());
366     
367     key_type_ref K = ImutInfo::KeyOfValue(V);
368     key_type_ref KCurrent = ImutInfo::KeyOfValue(Value(T));
369     
370     if (ImutInfo::isEqual(K,KCurrent))
371       return Create(Left(T), V, Right(T));
372     else if (ImutInfo::isLess(K,KCurrent))
373       return Balance(Add_internal(V,Left(T)), Value(T), Right(T));
374     else
375       return Balance(Left(T), Value(T), Add_internal(V,Right(T)));
376   }
377   
378   /// Remove_interal - Creates a new tree that includes all the data
379   ///  from the original tree except the specified data.  If the
380   ///  specified data did not exist in the original tree, the original
381   ///  tree is returned.
382   TreeTy* Remove_internal(key_type_ref K, TreeTy* T) {
383     if (isEmpty(T))
384       return T;
385     
386     assert (!T->isMutable());
387     
388     key_type_ref KCurrent = ImutInfo::KeyOfValue(Value(T));
389     
390     if (ImutInfo::isEqual(K,KCurrent))
391       return CombineLeftRightTrees(Left(T),Right(T));
392     else if (ImutInfo::isLess(K,KCurrent))
393       return Balance(Remove_internal(K,Left(T)), Value(T), Right(T));
394     else
395       return Balance(Left(T), Value(T), Remove_internal(K,Right(T)));
396   }
397   
398   TreeTy* CombineLeftRightTrees(TreeTy* L, TreeTy* R) {
399     if (isEmpty(L)) return R;      
400     if (isEmpty(R)) return L;
401     
402     TreeTy* OldNode;          
403     TreeTy* NewRight = RemoveMinBinding(R,OldNode);
404     return Balance(L,Value(OldNode),NewRight);
405   }
406   
407   TreeTy* RemoveMinBinding(TreeTy* T, TreeTy*& NodeRemoved) {
408     assert (!isEmpty(T));
409     
410     if (isEmpty(Left(T))) {
411       NodeRemoved = T;
412       return Right(T);
413     }
414     
415     return Balance(RemoveMinBinding(Left(T),NodeRemoved),Value(T),Right(T));
416   }    
417   
418   /// MarkImmutable - Clears the mutable bits of a root and all of its
419   ///  descendants.
420   void MarkImmutable(TreeTy* T) {
421     if (!T || !T->isMutable())
422       return;
423     
424     T->RemoveMutableFlag();
425     MarkImmutable(Left(T));
426     MarkImmutable(Right(T));
427   }
428 };
429
430
431 //===----------------------------------------------------------------------===//    
432 // Trait classes for Profile information.
433 //===----------------------------------------------------------------------===//
434
435 /// Generic profile template.  The default behavior is to invoke the
436 /// profile method of an object.  Specializations for primitive integers
437 /// and generic handling of pointers is done below.
438 template <typename T>
439 struct ImutProfileInfo {
440   typedef const T  value_type;
441   typedef const T& value_type_ref;
442   
443   static inline void Profile(FoldingSetNodeID& ID, value_type_ref X) {
444     X.Profile(ID);
445   }  
446 };
447
448 /// Profile traits for integers.
449 template <typename T>
450 struct ImutProfileInteger {    
451   typedef const T  value_type;
452   typedef const T& value_type_ref;
453   
454   static inline void Profile(FoldingSetNodeID& ID, value_type_ref X) {
455     ID.AddInteger(X);
456   }  
457 };
458
459 #define PROFILE_INTEGER_INFO(X)\
460 template<> struct ImutProfileInfo<X> : ImutProfileInteger<X> {};
461
462 PROFILE_INTEGER_INFO(char)
463 PROFILE_INTEGER_INFO(unsigned char)
464 PROFILE_INTEGER_INFO(short)
465 PROFILE_INTEGER_INFO(unsigned short)
466 PROFILE_INTEGER_INFO(unsigned)
467 PROFILE_INTEGER_INFO(signed)
468 PROFILE_INTEGER_INFO(long)
469 PROFILE_INTEGER_INFO(unsigned long)
470 PROFILE_INTEGER_INFO(long long)
471 PROFILE_INTEGER_INFO(unsigned long long)
472
473 #undef PROFILE_INTEGER_INFO
474
475 /// Generic profile trait for pointer types.  We treat pointers as
476 /// references to unique objects.
477 template <typename T>
478 struct ImutProfileInfo<T*> {
479   typedef const T*   value_type;
480   typedef value_type value_type_ref;
481   
482   static inline void Profile(FoldingSetNodeID &ID, value_type_ref X) {
483     ID.AddPointer(X);
484   }
485 };
486
487 //===----------------------------------------------------------------------===//    
488 // Trait classes that contain element comparison operators and type
489 //  definitions used by ImutAVLTree, ImmutableSet, and ImmutableMap.  These
490 //  inherit from the profile traits (ImutProfileInfo) to include operations
491 //  for element profiling.
492 //===----------------------------------------------------------------------===//
493
494
495 /// ImutContainerInfo - Generic definition of comparison operations for
496 ///   elements of immutable containers that defaults to using
497 ///   std::equal_to<> and std::less<> to perform comparison of elements.
498 template <typename T>
499 struct ImutContainerInfo : public ImutProfileInfo<T> {
500   typedef typename ImutProfileInfo<T>::value_type      value_type;
501   typedef typename ImutProfileInfo<T>::value_type_ref  value_type_ref;
502   typedef value_type      key_type;
503   typedef value_type_ref  key_type_ref;
504   
505   static inline key_type_ref KeyOfValue(value_type_ref D) { return D; }
506   
507   static inline bool isEqual(key_type_ref LHS, key_type_ref RHS) { 
508     return std::equal_to<key_type>()(LHS,RHS);
509   }
510   
511   static inline bool isLess(key_type_ref LHS, key_type_ref RHS) {
512     return std::less<key_type>()(LHS,RHS);
513   }
514 };
515
516 /// ImutContainerInfo - Specialization for pointer values to treat pointers
517 ///  as references to unique objects.  Pointers are thus compared by
518 ///  their addresses.
519 template <typename T>
520 struct ImutContainerInfo<T*> : public ImutProfileInfo<T*> {
521   typedef typename ImutProfileInfo<T*>::value_type      value_type;
522   typedef typename ImutProfileInfo<T*>::value_type_ref  value_type_ref;
523   typedef value_type      key_type;
524   typedef value_type_ref  key_type_ref;
525   
526   static inline key_type_ref KeyOfValue(value_type_ref D) { return D; }
527   
528   static inline bool isEqual(key_type_ref LHS, key_type_ref RHS) {
529     return LHS == RHS;
530   }
531   
532   static inline bool isLess(key_type_ref LHS, key_type_ref RHS) {
533     return LHS < RHS;
534   }
535 };
536
537 //===----------------------------------------------------------------------===//    
538 // Immutable Set
539 //===----------------------------------------------------------------------===//
540
541 template <typename ValT, typename ValInfo = ImutContainerInfo<ValT> >
542 class ImmutableSet {
543 public:
544   typedef typename ValInfo::value_type      value_type;
545   typedef typename ValInfo::value_type_ref  value_type_ref;
546   
547 private:  
548   typedef ImutAVLTree<ValInfo> TreeTy;
549   TreeTy* Root;
550   
551   ImmutableSet(TreeTy* R) : Root(R) {}
552   
553 public:
554   
555   class Factory {
556     typename TreeTy::Factory F;
557     
558   public:
559     Factory() {}
560     
561     ImmutableSet GetEmptySet() { return ImmutableSet(F.GetEmptyTree()); }
562     
563     ImmutableSet Add(ImmutableSet Old, value_type_ref V) {
564       return ImmutableSet(F.Add(Old.Root,V));
565     }
566     
567     ImmutableSet Remove(ImmutableSet Old, value_type_ref V) {
568       return ImmutableSet(F.Remove(Old.Root,V));
569     }
570     
571   private:
572     Factory(const Factory& RHS) {};
573     void operator=(const Factory& RHS) {};    
574   };
575   
576   friend class Factory;
577   
578   bool contains(const value_type_ref V) const {
579     return Root ? Root->contains(V) : false;
580   }
581   
582   bool operator==(ImmutableSet RHS) const {
583     return Root && RHS.Root ? Root->isEqual(*RHS.Root) : Root == RHS.Root;
584   }
585   
586   bool operator!=(ImmutableSet RHS) const {
587     return Root && RHS.Root ? Root->isNotEqual(*RHS.Root) : Root != RHS.Root;
588   }
589   
590   bool isEmpty() const { return !Root; }
591   
592   template <typename Callback>
593   void foreach(Callback& C) { if (Root) Root->foreach(C); }
594   
595   template <typename Callback>
596   void foreach() { if (Root) { Callback C; Root->foreach(C); } }
597   
598   //===--------------------------------------------------===//    
599   // For testing.
600   //===--------------------------------------------------===//  
601   
602   void verify() const { if (Root) Root->verify(); }
603   unsigned getHeight() const { return Root ? Root->getHeight() : 0; }
604 };
605
606 } // end namespace llvm
607
608 #endif