Add missing include (for ptrdiff_t).
[oota-llvm.git] / include / llvm / ADT / ImmutableSet.h
index 841b4ab6371bb23c98176efa85a6bc53ed06fc38..c351771c6da22413c9f4b88f0199ae6d2d6d8d6a 100644 (file)
@@ -2,8 +2,8 @@
 //
 //                     The LLVM Compiler Infrastructure
 //
-// This file was developed by Ted Kremenek and is distributed under
-// the University of Illinois Open Source License. See LICENSE.TXT for details.
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
 //
 //===----------------------------------------------------------------------===//
 //
@@ -16,7 +16,9 @@
 
 #include "llvm/Support/Allocator.h"
 #include "llvm/ADT/FoldingSet.h"
+#include "llvm/Support/DataTypes.h"
 #include <cassert>
+#include <functional>
 
 namespace llvm {
   
@@ -25,31 +27,50 @@ namespace llvm {
 //===----------------------------------------------------------------------===//
 
 template <typename ImutInfo> class ImutAVLFactory;
-
-
+template <typename ImutInfo> class ImutAVLTreeInOrderIterator;
+template <typename ImutInfo> class ImutAVLTreeGenericIterator;
+  
 template <typename ImutInfo >
 class ImutAVLTree : public FoldingSetNode {
-  struct ComputeIsEqual;
 public:
   typedef typename ImutInfo::key_type_ref   key_type_ref;
   typedef typename ImutInfo::value_type     value_type;
   typedef typename ImutInfo::value_type_ref value_type_ref;
+
   typedef ImutAVLFactory<ImutInfo>          Factory;
-  
   friend class ImutAVLFactory<ImutInfo>;
   
+  friend class ImutAVLTreeGenericIterator<ImutInfo>;
+  friend class FoldingSet<ImutAVLTree>;
+  
+  typedef ImutAVLTreeInOrderIterator<ImutInfo>  iterator;
+  
   //===----------------------------------------------------===//  
   // Public Interface.
   //===----------------------------------------------------===//  
   
-  ImutAVLTree* getLeft() const { return reinterpret_cast<ImutAVLTree*>(Left); }  
+  /// getLeft - Returns a pointer to the left subtree.  This value
+  ///  is NULL if there is no left subtree.
+  ImutAVLTree* getLeft() const { 
+    assert (!isMutable() && "Node is incorrectly marked mutable.");
+    
+    return reinterpret_cast<ImutAVLTree*>(Left);
+  }
   
+  /// getRight - Returns a pointer to the right subtree.  This value is
+  ///  NULL if there is no right subtree.
   ImutAVLTree* getRight() const { return Right; }  
   
+  
+  /// getHeight - Returns the height of the tree.  A tree with no subtrees
+  ///  has a height of 1.
   unsigned getHeight() const { return Height; }  
   
+  /// getValue - Returns the data value associated with the tree node.
   const value_type& getValue() const { return Value; }
   
+  /// find - Finds the subtree associated with the specified key value.
+  ///  This method returns NULL if no matching subtree is found.
   ImutAVLTree* find(key_type_ref K) {
     ImutAVLTree *T = this;
     
@@ -67,6 +88,8 @@ public:
     return NULL;
   }
   
+  /// size - Returns the number of nodes in the tree, which includes
+  ///  both leaves and non-leaf nodes.
   unsigned size() const {
     unsigned n = 1;
     
@@ -76,16 +99,72 @@ public:
     return n;
   }
   
+  /// begin - Returns an iterator that iterates over the nodes of the tree
+  ///  in an inorder traversal.  The returned iterator thus refers to the
+  ///  the tree node with the minimum data element.
+  iterator begin() const { return iterator(this); }
   
-  bool isEqual(const ImutAVLTree& RHS) const {
-    // FIXME: Todo.
-    return true;    
+  /// end - Returns an iterator for the tree that denotes the end of an
+  ///  inorder traversal.
+  iterator end() const { return iterator(); }
+    
+  bool ElementEqual(value_type_ref V) const {
+    // Compare the keys.
+    if (!ImutInfo::isEqual(ImutInfo::KeyOfValue(getValue()),
+                           ImutInfo::KeyOfValue(V)))
+      return false;
+    
+    // Also compare the data values.
+    if (!ImutInfo::isDataEqual(ImutInfo::DataOfValue(getValue()),
+                               ImutInfo::DataOfValue(V)))
+      return false;
+    
+    return true;
+  }
+  
+  bool ElementEqual(const ImutAVLTree* RHS) const {
+    return ElementEqual(RHS->getValue());
   }
   
+  /// isEqual - Compares two trees for structural equality and returns true
+  ///   if they are equal.  This worst case performance of this operation is
+  //    linear in the sizes of the trees.
+  bool isEqual(const ImutAVLTree& RHS) const {
+    if (&RHS == this)
+      return true;
+    
+    iterator LItr = begin(), LEnd = end();
+    iterator RItr = RHS.begin(), REnd = RHS.end();
+    
+    while (LItr != LEnd && RItr != REnd) {
+      if (*LItr == *RItr) {
+        LItr.SkipSubTree();
+        RItr.SkipSubTree();
+        continue;
+      }
+      
+      if (!LItr->ElementEqual(*RItr))
+        return false;
+      
+      ++LItr;
+      ++RItr;
+    }
+    
+    return LItr == LEnd && RItr == REnd;
+  }
+
+  /// isNotEqual - Compares two trees for structural inequality.  Performance
+  ///  is the same is isEqual.
   bool isNotEqual(const ImutAVLTree& RHS) const { return !isEqual(RHS); }
   
+  /// contains - Returns true if this tree contains a subtree (node) that
+  ///  has an data element that matches the specified key.  Complexity
+  ///  is logarithmic in the size of the tree.
   bool contains(const key_type_ref K) { return (bool) find(K); }
   
+  /// foreach - A member template the accepts invokes operator() on a functor
+  ///  object (specifed by Callback) for every node/subtree in the tree.
+  ///  Nodes are visited using an inorder traversal.
   template <typename Callback>
   void foreach(Callback& C) {
     if (ImutAVLTree* L = getLeft()) L->foreach(C);
@@ -95,6 +174,12 @@ public:
     if (ImutAVLTree* R = getRight()) R->foreach(C);
   }
   
+  /// verify - A utility method that checks that the balancing and
+  ///  ordering invariants of the tree are satisifed.  It is a recursive
+  ///  method that returns the height of the tree, which is then consumed
+  ///  by the enclosing verify call.  External callers should ignore the
+  ///  return value.  An invalid tree will cause an assertion to fire in
+  ///  a debug build.
   unsigned verify() const {
     unsigned HL = getLeft() ? getLeft()->verify() : 0;
     unsigned HR = getRight() ? getRight()->verify() : 0;
@@ -118,7 +203,12 @@ public:
             && "Current value is not less that value of right child.");
     
     return getHeight();
-  }  
+  }
+  
+  /// Profile - Profiling for ImutAVLTree.
+  void Profile(llvm::FoldingSetNodeID& ID) {
+    ID.AddInteger(ComputeDigest());
+  }
   
   //===----------------------------------------------------===//  
   // Internal Values.
@@ -129,69 +219,109 @@ private:
   ImutAVLTree*     Right;
   unsigned         Height;
   value_type       Value;
-  
-  //===----------------------------------------------------===//  
-  // Profiling or FoldingSet.
-  //===----------------------------------------------------===//
-  
-  static inline
-  void Profile(FoldingSetNodeID& ID, ImutAVLTree* L, ImutAVLTree* R,
-               unsigned H, value_type_ref V) {    
-    ID.AddPointer(L);
-    ID.AddPointer(R);
-    ID.AddInteger(H);
-    ImutInfo::Profile(ID,V);
-  }
-  
-public:
-  
-  void Profile(FoldingSetNodeID& ID) {
-    Profile(ID,getSafeLeft(),getRight(),getHeight(),getValue());    
-  }
+  unsigned         Digest;
   
   //===----------------------------------------------------===//    
   // Internal methods (node manipulation; used by Factory).
   //===----------------------------------------------------===//
-  
+
 private:
   
+  enum { Mutable = 0x1 };
+
+  /// ImutAVLTree - Internal constructor that is only called by
+  ///   ImutAVLFactory.
   ImutAVLTree(ImutAVLTree* l, ImutAVLTree* r, value_type_ref v, unsigned height)
-  : Left(reinterpret_cast<uintptr_t>(l) | 0x1),
-  Right(r), Height(height), Value(v) {}
+  : Left(reinterpret_cast<uintptr_t>(l) | Mutable),
+    Right(r), Height(height), Value(v), Digest(0) {}
   
-  bool isMutable() const { return Left & 0x1; }
   
+  /// isMutable - Returns true if the left and right subtree references
+  ///  (as well as height) can be changed.  If this method returns false,
+  ///  the tree is truly immutable.  Trees returned from an ImutAVLFactory
+  ///  object should always have this method return true.  Further, if this
+  ///  method returns false for an instance of ImutAVLTree, all subtrees
+  ///  will also have this method return false.  The converse is not true.
+  bool isMutable() const { return Left & Mutable; }
+  
+  /// getSafeLeft - Returns the pointer to the left tree by always masking
+  ///  out the mutable bit.  This is used internally by ImutAVLFactory,
+  ///  as no trees returned to the client should have the mutable flag set.
   ImutAVLTree* getSafeLeft() const { 
-    return reinterpret_cast<ImutAVLTree*>(Left & ~0x1);
+    return reinterpret_cast<ImutAVLTree*>(Left & ~Mutable);
   }
   
-  // Mutating operations.  A tree root can be manipulated as long as
-  // its reference has not "escaped" from internal methods of a
-  // factory object (see below).  When a tree pointer is externally
-  // viewable by client code, the internal "mutable bit" is cleared
-  // to mark the tree immutable.  Note that a tree that still has
-  // its mutable bit set may have children (subtrees) that are themselves
+  //===----------------------------------------------------===//    
+  // Mutating operations.  A tree root can be manipulated as
+  // long as its reference has not "escaped" from internal 
+  // methods of a factory object (see below).  When a tree
+  // pointer is externally viewable by client code, the 
+  // internal "mutable bit" is cleared to mark the tree 
+  // immutable.  Note that a tree that still has its mutable
+  // bit set may have children (subtrees) that are themselves
   // immutable.
+  //===----------------------------------------------------===//
   
-  void RemoveMutableFlag() {
-    assert (Left & 0x1 && "Mutable flag already removed.");
-    Left &= ~0x1;
+  
+  /// MarkImmutable - Clears the mutable flag for a tree.  After this happens,
+  ///   it is an error to call setLeft(), setRight(), and setHeight().  It
+  ///   is also then safe to call getLeft() instead of getSafeLeft().  
+  void MarkImmutable() {
+    assert (isMutable() && "Mutable flag already removed.");
+    Left &= ~Mutable;
   }
   
+  /// setLeft - Changes the reference of the left subtree.  Used internally
+  ///   by ImutAVLFactory.
   void setLeft(ImutAVLTree* NewLeft) {
-    assert (isMutable());
-    Left = reinterpret_cast<uintptr_t>(NewLeft) | 0x1;
+    assert (isMutable() && 
+            "Only a mutable tree can have its left subtree changed.");
+    
+    Left = reinterpret_cast<uintptr_t>(NewLeft) | Mutable;
   }
   
+  /// setRight - Changes the reference of the right subtree.  Used internally
+  ///  by ImutAVLFactory.
   void setRight(ImutAVLTree* NewRight) {
-    assert (isMutable());
+    assert (isMutable() &&
+            "Only a mutable tree can have its right subtree changed.");
+    
     Right = NewRight;
   }
   
+  /// setHeight - Changes the height of the tree.  Used internally by
+  ///  ImutAVLFactory.
   void setHeight(unsigned h) {
-    assert (isMutable());
+    assert (isMutable() && "Only a mutable tree can have its height changed.");
     Height = h;
   }
+  
+  
+  static inline
+  unsigned ComputeDigest(ImutAVLTree* L, ImutAVLTree* R, value_type_ref V) {
+    unsigned digest = 0;
+    
+    if (L) digest += L->ComputeDigest();
+    
+    { // Compute digest of stored data.
+      FoldingSetNodeID ID;
+      ImutInfo::Profile(ID,V);
+      digest += ID.ComputeHash();
+    }
+    
+    if (R) digest += R->ComputeDigest();
+    
+    return digest;
+  }
+  
+  inline unsigned ComputeDigest() {
+    if (Digest) return Digest;
+    
+    unsigned X = ComputeDigest(getSafeLeft(), getRight(), getValue());
+    if (!isMutable()) Digest = X;
+    
+    return X;
+  }
 };
 
 //===----------------------------------------------------------------------===//    
@@ -206,15 +336,31 @@ class ImutAVLFactory {
   
   typedef FoldingSet<TreeTy> CacheTy;
   
-  CacheTy Cache;  
-  BumpPtrAllocator Allocator;    
+  CacheTy Cache;
+  uintptr_t Allocator;
+  
+  bool ownsAllocator() const {
+    return Allocator & 0x1 ? false : true;
+  }
+
+  BumpPtrAllocator& getAllocator() const { 
+    return *reinterpret_cast<BumpPtrAllocator*>(Allocator & ~0x1);
+  }
   
   //===--------------------------------------------------===//    
   // Public interface.
   //===--------------------------------------------------===//
   
 public:
-  ImutAVLFactory() {}
+  ImutAVLFactory()
+    : Allocator(reinterpret_cast<uintptr_t>(new BumpPtrAllocator())) {}
+  
+  ImutAVLFactory(BumpPtrAllocator& Alloc)
+    : Allocator(reinterpret_cast<uintptr_t>(&Alloc) | 0x1) {}
+  
+  ~ImutAVLFactory() {
+    if (ownsAllocator()) delete &getAllocator();
+  }
   
   TreeTy* Add(TreeTy* T, value_type_ref V) {
     T = Add_internal(V,T);
@@ -238,28 +384,11 @@ public:
   //===--------------------------------------------------===//
 private:
   
-  bool isEmpty(TreeTy* T) const {
-    return !T;
-  }
-  
-  unsigned Height(TreeTy* T) const {
-    return T ? T->getHeight() : 0;
-  }
-  
-  TreeTy* Left(TreeTy* T) const {
-    assert (T);
-    return T->getSafeLeft();
-  }
-  
-  TreeTy* Right(TreeTy* T) const {
-    assert (T);
-    return T->getRight();
-  }
-  
-  value_type_ref Value(TreeTy* T) const {
-    assert (T);
-    return T->Value;
-  }
+  bool           isEmpty(TreeTy* T) const { return !T; }
+  unsigned        Height(TreeTy* T) const { return T ? T->getHeight() : 0; }  
+  TreeTy*           Left(TreeTy* T) const { return T->getSafeLeft(); }
+  TreeTy*          Right(TreeTy* T) const { return T->getRight(); }  
+  value_type_ref   Value(TreeTy* T) const { return T->Value; }
   
   unsigned IncrementHeight(TreeTy* L, TreeTy* R) const {
     unsigned hl = Height(L);
@@ -267,6 +396,20 @@ private:
     return ( hl > hr ? hl : hr ) + 1;
   }
   
+  
+  static bool CompareTreeWithSection(TreeTy* T,
+                                     typename TreeTy::iterator& TI,
+                                     typename TreeTy::iterator& TE) {
+    
+    typename TreeTy::iterator I = T->begin(), E = T->end();
+    
+    for ( ; I!=E ; ++I, ++TI)
+      if (TI == TE || !I->ElementEqual(*TI))
+        return false;
+
+    return true;
+  }                     
+  
   //===--------------------------------------------------===//    
   // "CreateNode" is used to generate new tree roots that link
   // to other trees.  The functon may also simply move links
@@ -278,23 +421,62 @@ private:
   //===--------------------------------------------------===//
   
   TreeTy* CreateNode(TreeTy* L, value_type_ref V, TreeTy* R) {
-    FoldingSetNodeID ID;      
-    unsigned height = IncrementHeight(L,R);
+    // Search the FoldingSet bucket for a Tree with the same digest.
+    FoldingSetNodeID ID;
+    unsigned digest = TreeTy::ComputeDigest(L, R, V);
+    ID.AddInteger(digest);
+    unsigned hash = ID.ComputeHash();
     
-    TreeTy::Profile(ID,L,R,height,V);      
-    void* InsertPos;
+    typename CacheTy::bucket_iterator I = Cache.bucket_begin(hash);
+    typename CacheTy::bucket_iterator E = Cache.bucket_end(hash);
     
-    if (TreeTy* T = Cache.FindNodeOrInsertPos(ID,InsertPos))
+    for (; I != E; ++I) {
+      TreeTy* T = &*I;
+
+      if (T->ComputeDigest() != digest)
+        continue;
+      
+      // We found a collision.  Perform a comparison of Contents('T')
+      // with Contents('L')+'V'+Contents('R').
+      
+      typename TreeTy::iterator TI = T->begin(), TE = T->end();
+      
+      // First compare Contents('L') with the (initial) contents of T.
+      if (!CompareTreeWithSection(L, TI, TE))
+        continue;
+      
+      // Now compare the new data element.
+      if (TI == TE || !TI->ElementEqual(V))
+        continue;
+      
+      ++TI;
+
+      // Now compare the remainder of 'T' with 'R'.
+      if (!CompareTreeWithSection(R, TI, TE))
+        continue;
+      
+      if (TI != TE) // Contents('R') did not match suffix of 'T'.
+        continue;
+      
+      // Trees did match!  Return 'T'.
       return T;
+    }
     
-    assert (InsertPos != NULL);
-    
-    // FIXME: more intelligent calculation of alignment.
-    TreeTy* T = (TreeTy*) Allocator.Allocate(sizeof(*T),16);
-    new (T) TreeTy(L,R,V,height);
-    
-    Cache.InsertNode(T,InsertPos);
-    return T;      
+    // No tree with the contents: Contents('L')+'V'+Contents('R').
+    // Create it.
+
+    // Allocate the new tree node and insert it into the cache.
+    BumpPtrAllocator& A = getAllocator();
+    TreeTy* T = (TreeTy*) A.Allocate<TreeTy>();
+    new (T) TreeTy(L,R,V,IncrementHeight(L,R));
+
+    // We do not insert 'T' into the FoldingSet here.  This is because
+    // this tree is still mutable and things may get rebalanced.
+    // Because our digest is associative and based on the contents of
+    // the set, this should hopefully not cause any strange bugs.
+    // 'T' is inserted by 'MarkImmutable'.
+
+    return T;
   }
   
   TreeTy* CreateNode(TreeTy* L, TreeTy* OldTree, TreeTy* R) {      
@@ -422,13 +604,206 @@ private:
     if (!T || !T->isMutable())
       return;
     
-    T->RemoveMutableFlag();
+    T->MarkImmutable();
     MarkImmutable(Left(T));
     MarkImmutable(Right(T));
+        
+    // Now that the node is immutable it can safely be inserted
+    // into the node cache.
+    llvm::FoldingSetNodeID ID;
+    ID.AddInteger(T->ComputeDigest());
+    Cache.InsertNode(T, (void*) &*Cache.bucket_end(ID.ComputeHash()));
+  }
+};
+  
+  
+//===----------------------------------------------------------------------===//    
+// Immutable AVL-Tree Iterators.
+//===----------------------------------------------------------------------===//  
+
+template <typename ImutInfo>
+class ImutAVLTreeGenericIterator {
+  SmallVector<uintptr_t,20> stack;
+public:
+  enum VisitFlag { VisitedNone=0x0, VisitedLeft=0x1, VisitedRight=0x3, 
+                   Flags=0x3 };
+  
+  typedef ImutAVLTree<ImutInfo> TreeTy;      
+  typedef ImutAVLTreeGenericIterator<ImutInfo> _Self;
+
+  inline ImutAVLTreeGenericIterator() {}
+  inline ImutAVLTreeGenericIterator(const TreeTy* Root) {
+    if (Root) stack.push_back(reinterpret_cast<uintptr_t>(Root));
+  }  
+  
+  TreeTy* operator*() const {
+    assert (!stack.empty());    
+    return reinterpret_cast<TreeTy*>(stack.back() & ~Flags);
+  }
+  
+  uintptr_t getVisitState() {
+    assert (!stack.empty());
+    return stack.back() & Flags;
+  }
+  
+  
+  bool AtEnd() const { return stack.empty(); }
+
+  bool AtBeginning() const { 
+    return stack.size() == 1 && getVisitState() == VisitedNone;
+  }
+  
+  void SkipToParent() {
+    assert (!stack.empty());
+    stack.pop_back();
+    
+    if (stack.empty())
+      return;
+    
+    switch (getVisitState()) {
+      case VisitedNone:
+        stack.back() |= VisitedLeft;
+        break;
+      case VisitedLeft:
+        stack.back() |= VisitedRight;
+        break;
+      default:
+        assert (false && "Unreachable.");            
+    }
+  }
+  
+  inline bool operator==(const _Self& x) const {
+    if (stack.size() != x.stack.size())
+      return false;
+    
+    for (unsigned i = 0 ; i < stack.size(); i++)
+      if (stack[i] != x.stack[i])
+        return false;
+    
+    return true;
+  }
+  
+  inline bool operator!=(const _Self& x) const { return !operator==(x); }  
+  
+  _Self& operator++() {
+    assert (!stack.empty());
+    
+    TreeTy* Current = reinterpret_cast<TreeTy*>(stack.back() & ~Flags);
+    assert (Current);
+    
+    switch (getVisitState()) {
+      case VisitedNone:
+        if (TreeTy* L = Current->getSafeLeft())
+          stack.push_back(reinterpret_cast<uintptr_t>(L));
+        else
+          stack.back() |= VisitedLeft;
+        
+        break;
+        
+      case VisitedLeft:
+        if (TreeTy* R = Current->getRight())
+          stack.push_back(reinterpret_cast<uintptr_t>(R));
+        else
+          stack.back() |= VisitedRight;
+        
+        break;
+        
+      case VisitedRight:
+        SkipToParent();        
+        break;
+        
+      default:
+        assert (false && "Unreachable.");
+    }
+    
+    return *this;
+  }
+  
+  _Self& operator--() {
+    assert (!stack.empty());
+    
+    TreeTy* Current = reinterpret_cast<TreeTy*>(stack.back() & ~Flags);
+    assert (Current);
+    
+    switch (getVisitState()) {
+      case VisitedNone:
+        stack.pop_back();
+        break;
+        
+      case VisitedLeft:                
+        stack.back() &= ~Flags; // Set state to "VisitedNone."
+        
+        if (TreeTy* L = Current->getLeft())
+          stack.push_back(reinterpret_cast<uintptr_t>(L) | VisitedRight);
+          
+        break;
+        
+      case VisitedRight:        
+        stack.back() &= ~Flags;
+        stack.back() |= VisitedLeft;
+        
+        if (TreeTy* R = Current->getRight())
+          stack.push_back(reinterpret_cast<uintptr_t>(R) | VisitedRight);
+          
+        break;
+        
+      default:
+        assert (false && "Unreachable.");
+    }
+    
+    return *this;
   }
 };
+  
+template <typename ImutInfo>
+class ImutAVLTreeInOrderIterator {
+  typedef ImutAVLTreeGenericIterator<ImutInfo> InternalIteratorTy;
+  InternalIteratorTy InternalItr;
 
+public:
+  typedef ImutAVLTree<ImutInfo> TreeTy;
+  typedef ImutAVLTreeInOrderIterator<ImutInfo> _Self;
 
+  ImutAVLTreeInOrderIterator(const TreeTy* Root) : InternalItr(Root) { 
+    if (Root) operator++(); // Advance to first element.
+  }
+  
+  ImutAVLTreeInOrderIterator() : InternalItr() {}
+
+  inline bool operator==(const _Self& x) const {
+    return InternalItr == x.InternalItr;
+  }
+  
+  inline bool operator!=(const _Self& x) const { return !operator==(x); }  
+  
+  inline TreeTy* operator*() const { return *InternalItr; }
+  inline TreeTy* operator->() const { return *InternalItr; }
+  
+  inline _Self& operator++() { 
+    do ++InternalItr;
+    while (!InternalItr.AtEnd() && 
+           InternalItr.getVisitState() != InternalIteratorTy::VisitedLeft);
+
+    return *this;
+  }
+  
+  inline _Self& operator--() { 
+    do --InternalItr;
+    while (!InternalItr.AtBeginning() && 
+           InternalItr.getVisitState() != InternalIteratorTy::VisitedLeft);
+    
+    return *this;
+  }
+  
+  inline void SkipSubTree() {
+    InternalItr.SkipToParent();
+    
+    while (!InternalItr.AtEnd() &&
+           InternalItr.getVisitState() != InternalIteratorTy::VisitedLeft)
+      ++InternalItr;        
+  }
+};
+    
 //===----------------------------------------------------------------------===//    
 // Trait classes for Profile information.
 //===----------------------------------------------------------------------===//
@@ -442,8 +817,8 @@ struct ImutProfileInfo {
   typedef const T& value_type_ref;
   
   static inline void Profile(FoldingSetNodeID& ID, value_type_ref X) {
-    X.Profile(ID);
-  }  
+    FoldingSetTrait<T>::Profile(X,ID);
+  }
 };
 
 /// Profile traits for integers.
@@ -502,8 +877,11 @@ struct ImutContainerInfo : public ImutProfileInfo<T> {
   typedef typename ImutProfileInfo<T>::value_type_ref  value_type_ref;
   typedef value_type      key_type;
   typedef value_type_ref  key_type_ref;
+  typedef bool            data_type;
+  typedef bool            data_type_ref;
   
   static inline key_type_ref KeyOfValue(value_type_ref D) { return D; }
+  static inline data_type_ref DataOfValue(value_type_ref) { return true; }
   
   static inline bool isEqual(key_type_ref LHS, key_type_ref RHS) { 
     return std::equal_to<key_type>()(LHS,RHS);
@@ -512,6 +890,8 @@ struct ImutContainerInfo : public ImutProfileInfo<T> {
   static inline bool isLess(key_type_ref LHS, key_type_ref RHS) {
     return std::less<key_type>()(LHS,RHS);
   }
+  
+  static inline bool isDataEqual(data_type_ref,data_type_ref) { return true; }
 };
 
 /// ImutContainerInfo - Specialization for pointer values to treat pointers
@@ -523,8 +903,11 @@ struct ImutContainerInfo<T*> : public ImutProfileInfo<T*> {
   typedef typename ImutProfileInfo<T*>::value_type_ref  value_type_ref;
   typedef value_type      key_type;
   typedef value_type_ref  key_type_ref;
+  typedef bool            data_type;
+  typedef bool            data_type_ref;
   
   static inline key_type_ref KeyOfValue(value_type_ref D) { return D; }
+  static inline data_type_ref DataOfValue(value_type_ref) { return true; }
   
   static inline bool isEqual(key_type_ref LHS, key_type_ref RHS) {
     return LHS == RHS;
@@ -533,6 +916,8 @@ struct ImutContainerInfo<T*> : public ImutProfileInfo<T*> {
   static inline bool isLess(key_type_ref LHS, key_type_ref RHS) {
     return LHS < RHS;
   }
+  
+  static inline bool isDataEqual(data_type_ref,data_type_ref) { return true; }
 };
 
 //===----------------------------------------------------------------------===//    
@@ -544,14 +929,17 @@ class ImmutableSet {
 public:
   typedef typename ValInfo::value_type      value_type;
   typedef typename ValInfo::value_type_ref  value_type_ref;
-  
-private:  
   typedef ImutAVLTree<ValInfo> TreeTy;
+
+private:  
   TreeTy* Root;
-  
-  ImmutableSet(TreeTy* R) : Root(R) {}
-  
+
 public:
+  /// Constructs a set from a pointer to a tree root.  In general one
+  /// should use a Factory object to create sets instead of directly
+  /// invoking the constructor, but there are cases where make this
+  /// constructor public is useful.
+  explicit ImmutableSet(TreeTy* R) : Root(R) {}
   
   class Factory {
     typename TreeTy::Factory F;
@@ -559,23 +947,44 @@ public:
   public:
     Factory() {}
     
+    Factory(BumpPtrAllocator& Alloc)
+      : F(Alloc) {}
+    
+    /// GetEmptySet - Returns an immutable set that contains no elements.
     ImmutableSet GetEmptySet() { return ImmutableSet(F.GetEmptyTree()); }
     
+    /// Add - Creates a new immutable set that contains all of the values
+    ///  of the original set with the addition of the specified value.  If
+    ///  the original set already included the value, then the original set is
+    ///  returned and no memory is allocated.  The time and space complexity
+    ///  of this operation is logarithmic in the size of the original set.
+    ///  The memory allocated to represent the set is released when the
+    ///  factory object that created the set is destroyed.
     ImmutableSet Add(ImmutableSet Old, value_type_ref V) {
       return ImmutableSet(F.Add(Old.Root,V));
     }
     
+    /// Remove - Creates a new immutable set that contains all of the values
+    ///  of the original set with the exception of the specified value.  If
+    ///  the original set did not contain the value, the original set is
+    ///  returned and no memory is allocated.  The time and space complexity
+    ///  of this operation is logarithmic in the size of the original set.
+    ///  The memory allocated to represent the set is released when the
+    ///  factory object that created the set is destroyed.
     ImmutableSet Remove(ImmutableSet Old, value_type_ref V) {
       return ImmutableSet(F.Remove(Old.Root,V));
     }
     
+    BumpPtrAllocator& getAllocator() { return F.getAllocator(); }
+
   private:
     Factory(const Factory& RHS) {};
     void operator=(const Factory& RHS) {};    
   };
   
-  friend class Factory;
-  
+  friend class Factory;  
+
+  /// contains - Returns true if the set contains the specified value.
   bool contains(const value_type_ref V) const {
     return Root ? Root->contains(V) : false;
   }
@@ -588,6 +997,9 @@ public:
     return Root && RHS.Root ? Root->isNotEqual(*RHS.Root) : Root != RHS.Root;
   }
   
+  TreeTy* getRoot() const { return Root; }
+  
+  /// isEmpty - Return true if the set contains no elements.
   bool isEmpty() const { return !Root; }
   
   template <typename Callback>
@@ -595,13 +1007,49 @@ public:
   
   template <typename Callback>
   void foreach() { if (Root) { Callback C; Root->foreach(C); } }
+    
+  //===--------------------------------------------------===//    
+  // Iterators.
+  //===--------------------------------------------------===//  
+
+  class iterator {
+    typename TreeTy::iterator itr;
+    
+    iterator() {}
+    iterator(TreeTy* t) : itr(t) {}
+    friend class ImmutableSet<ValT,ValInfo>;
+  public:
+    inline value_type_ref operator*() const { return itr->getValue(); }
+    inline iterator& operator++() { ++itr; return *this; }
+    inline iterator  operator++(int) { iterator tmp(*this); ++itr; return tmp; }
+    inline iterator& operator--() { --itr; return *this; }
+    inline iterator  operator--(int) { iterator tmp(*this); --itr; return tmp; }
+    inline bool operator==(const iterator& RHS) const { return RHS.itr == itr; }
+    inline bool operator!=(const iterator& RHS) const { return RHS.itr != itr; }        
+  };
+  
+  iterator begin() const { return iterator(Root); }
+  iterator end() const { return iterator(); }  
+  
+  //===--------------------------------------------------===//    
+  // Utility methods.
+  //===--------------------------------------------------===//  
+  
+  inline unsigned getHeight() const { return Root ? Root->getHeight() : 0; }
+  
+  static inline void Profile(FoldingSetNodeID& ID, const ImmutableSet& S) {
+    ID.AddPointer(S.Root);
+  }
+  
+  inline void Profile(FoldingSetNodeID& ID) const {
+    return Profile(ID,*this);
+  }
   
   //===--------------------------------------------------===//    
   // For testing.
   //===--------------------------------------------------===//  
   
   void verify() const { if (Root) Root->verify(); }
-  unsigned getHeight() const { return Root ? Root->getHeight() : 0; }
 };
 
 } // end namespace llvm