add methods to do equality checks and get hashes of Matchers
authorChris Lattner <sabre@nondot.org>
Thu, 25 Feb 2010 06:49:58 +0000 (06:49 +0000)
committerChris Lattner <sabre@nondot.org>
Thu, 25 Feb 2010 06:49:58 +0000 (06:49 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@97123 91177308-0d34-0410-b5e6-96231b3b80d8

utils/TableGen/DAGISelMatcher.cpp
utils/TableGen/DAGISelMatcher.h

index c38b2307b678114d45faebc0f6ef02828e957bb4..4b1ae82f97dbc9a582cea02aa6a170980f21d709 100644 (file)
@@ -12,6 +12,7 @@
 #include "CodeGenTarget.h"
 #include "Record.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/ADT/StringExtras.h"
 using namespace llvm;
 
 void Matcher::dump() const {
@@ -23,7 +24,6 @@ void Matcher::printNext(raw_ostream &OS, unsigned indent) const {
     return Next->print(OS, indent);
 }
 
-
 void ScopeMatcher::print(raw_ostream &OS, unsigned indent) const {
   OS.indent(indent) << "Scope\n";
   Check->print(OS, indent+2);
@@ -209,3 +209,69 @@ void CompleteMatchMatcher::print(raw_ostream &OS, unsigned indent) const {
   printNext(OS, indent);
 }
 
+// getHashImpl Implementation.
+
+unsigned CheckPatternPredicateMatcher::getHashImpl() const {
+  return HashString(Predicate);
+}
+
+unsigned CheckPredicateMatcher::getHashImpl() const {
+  return HashString(PredName);
+}
+
+unsigned CheckOpcodeMatcher::getHashImpl() const {
+  return HashString(OpcodeName);
+}
+
+unsigned CheckMultiOpcodeMatcher::getHashImpl() const {
+  unsigned Result = 0;
+  for (unsigned i = 0, e = OpcodeNames.size(); i != e; ++i)
+    Result |= HashString(OpcodeNames[i]);
+  return Result;
+}
+
+unsigned CheckCondCodeMatcher::getHashImpl() const {
+  return HashString(CondCodeName);
+}
+
+unsigned CheckValueTypeMatcher::getHashImpl() const {
+  return HashString(TypeName);
+}
+
+unsigned EmitStringIntegerMatcher::getHashImpl() const {
+  return HashString(Val) ^ VT;
+}
+
+template<typename It>
+static unsigned HashUnsigneds(It I, It E) {
+  unsigned Result = 0;
+  for (; I != E; ++I)
+    Result = (Result<<3) ^ *I;
+  return Result;
+}
+
+unsigned EmitMergeInputChainsMatcher::getHashImpl() const {
+  return HashUnsigneds(ChainNodes.begin(), ChainNodes.end());
+}
+
+bool EmitNodeMatcher::isEqualImpl(const Matcher *m) const {
+  const EmitNodeMatcher *M = cast<EmitNodeMatcher>(m);
+  return M->OpcodeName == OpcodeName && M->VTs == VTs &&
+         M->Operands == Operands && M->HasChain == HasChain &&
+         M->HasFlag == HasFlag && M->HasMemRefs == HasMemRefs &&
+         M->NumFixedArityOperands == NumFixedArityOperands;
+}
+
+unsigned EmitNodeMatcher::getHashImpl() const {
+  return (HashString(OpcodeName) << 4) | Operands.size();
+}
+
+
+unsigned MarkFlagResultsMatcher::getHashImpl() const {
+  return HashUnsigneds(FlagResultNodes.begin(), FlagResultNodes.end());
+}
+
+unsigned CompleteMatchMatcher::getHashImpl() const {
+  return HashUnsigneds(Results.begin(), Results.end()) ^ 
+          ((unsigned)(intptr_t)&Pattern << 8);
+}
index 9286b33dc5bef1aab9f7c2de45d3e247f12296d1..68132219bce6bee8562993885ad522a841ee289d 100644 (file)
@@ -94,10 +94,21 @@ public:
   
   static inline bool classof(const Matcher *) { return true; }
   
+  bool isEqual(const Matcher *M) const {
+    if (getKind() != M->getKind()) return false;
+    return isEqualImpl(M);
+  }
+  
+  unsigned getHash() const {
+    return (getHashImpl() << 4) ^ getKind();
+  }
+  
   virtual void print(raw_ostream &OS, unsigned indent = 0) const = 0;
   void dump() const;
 protected:
   void printNext(raw_ostream &OS, unsigned indent) const;
+  virtual bool isEqualImpl(const Matcher *M) const = 0;
+  virtual unsigned getHashImpl() const = 0;
 };
   
 /// ScopeMatcher - This pushes a failure scope on the stack and evaluates
@@ -120,7 +131,10 @@ public:
     return N->getKind() == Scope;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const { return false; }
+  virtual unsigned getHashImpl() const { return 0; }
 };
 
 /// RecordMatcher - Save the current node in the operand list.
@@ -138,7 +152,10 @@ public:
     return N->getKind() == RecordNode;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const { return true; }
+  virtual unsigned getHashImpl() const { return 0; }
 };
   
 /// RecordChildMatcher - Save a numbered child of the current node, or fail
@@ -161,7 +178,12 @@ public:
     return N->getKind() == RecordChild;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<RecordChildMatcher>(M)->getChildNo() == getChildNo();
+  }
+  virtual unsigned getHashImpl() const { return getChildNo(); }
 };
   
 /// RecordMemRefMatcher - Save the current node's memref.
@@ -173,7 +195,10 @@ public:
     return N->getKind() == RecordMemRef;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const { return true; }
+  virtual unsigned getHashImpl() const { return 0; }
 };
 
   
@@ -187,7 +212,10 @@ public:
     return N->getKind() == CaptureFlagInput;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const { return true; }
+  virtual unsigned getHashImpl() const { return 0; }
 };
   
 /// MoveChildMatcher - This tells the interpreter to move into the
@@ -203,7 +231,12 @@ public:
     return N->getKind() == MoveChild;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<MoveChildMatcher>(M)->getChildNo() == getChildNo();
+  }
+  virtual unsigned getHashImpl() const { return getChildNo(); }
 };
   
 /// MoveParentMatcher - This tells the interpreter to move to the parent
@@ -216,7 +249,10 @@ public:
     return N->getKind() == MoveParent;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const { return true; }
+  virtual unsigned getHashImpl() const { return 0; }
 };
 
 /// CheckSameMatcher - This checks to see if this node is exactly the same
@@ -226,7 +262,7 @@ class CheckSameMatcher : public Matcher {
   unsigned MatchNumber;
 public:
   CheckSameMatcher(unsigned matchnumber)
-  : Matcher(CheckSame), MatchNumber(matchnumber) {}
+    : Matcher(CheckSame), MatchNumber(matchnumber) {}
   
   unsigned getMatchNumber() const { return MatchNumber; }
   
@@ -234,7 +270,12 @@ public:
     return N->getKind() == CheckSame;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<CheckSameMatcher>(M)->getMatchNumber() == getMatchNumber();
+  }
+  virtual unsigned getHashImpl() const { return getMatchNumber(); }
 };
   
 /// CheckPatternPredicateMatcher - This checks the target-specific predicate
@@ -244,7 +285,7 @@ class CheckPatternPredicateMatcher : public Matcher {
   std::string Predicate;
 public:
   CheckPatternPredicateMatcher(StringRef predicate)
-  : Matcher(CheckPatternPredicate), Predicate(predicate) {}
+    : Matcher(CheckPatternPredicate), Predicate(predicate) {}
   
   StringRef getPredicate() const { return Predicate; }
   
@@ -252,7 +293,12 @@ public:
     return N->getKind() == CheckPatternPredicate;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<CheckPatternPredicateMatcher>(M)->getPredicate() == Predicate;
+  }
+  virtual unsigned getHashImpl() const;
 };
   
 /// CheckPredicateMatcher - This checks the target-specific predicate to
@@ -269,7 +315,12 @@ public:
     return N->getKind() == CheckPredicate;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<CheckPredicateMatcher>(M)->PredName == PredName;
+  }
+  virtual unsigned getHashImpl() const;
 };
   
   
@@ -287,7 +338,12 @@ public:
     return N->getKind() == CheckOpcode;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<CheckOpcodeMatcher>(M)->OpcodeName == OpcodeName;
+  }
+  virtual unsigned getHashImpl() const;
 };
   
 /// CheckMultiOpcodeMatcher - This checks to see if the current node has one
@@ -305,7 +361,12 @@ public:
     return N->getKind() == CheckMultiOpcode;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<CheckMultiOpcodeMatcher>(M)->OpcodeNames == OpcodeNames;
+  }
+  virtual unsigned getHashImpl() const;
 };
   
   
@@ -324,7 +385,12 @@ public:
     return N->getKind() == CheckType;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<CheckTypeMatcher>(this)->Type == Type;
+  }
+  virtual unsigned getHashImpl() const { return Type; }
 };
   
 /// CheckChildTypeMatcher - This checks to see if a child node has the
@@ -343,7 +409,13 @@ public:
     return N->getKind() == CheckChildType;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<CheckChildTypeMatcher>(M)->ChildNo == ChildNo &&
+           cast<CheckChildTypeMatcher>(M)->Type == Type;
+  }
+  virtual unsigned getHashImpl() const { return (Type << 3) | ChildNo; }
 };
   
 
@@ -361,7 +433,12 @@ public:
     return N->getKind() == CheckInteger;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<CheckIntegerMatcher>(M)->Value == Value;
+  }
+  virtual unsigned getHashImpl() const { return Value; }
 };
   
 /// CheckCondCodeMatcher - This checks to see if the current node is a
@@ -378,7 +455,12 @@ public:
     return N->getKind() == CheckCondCode;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<CheckCondCodeMatcher>(M)->CondCodeName == CondCodeName;
+  }
+  virtual unsigned getHashImpl() const;
 };
   
 /// CheckValueTypeMatcher - This checks to see if the current node is a
@@ -395,7 +477,12 @@ public:
     return N->getKind() == CheckValueType;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<CheckValueTypeMatcher>(M)->TypeName == TypeName;
+  }
+  virtual unsigned getHashImpl() const;
 };
   
   
@@ -414,7 +501,14 @@ public:
     return N->getKind() == CheckComplexPat;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return &cast<CheckComplexPatMatcher>(M)->Pattern == &Pattern;
+  }
+  virtual unsigned getHashImpl() const {
+    return (unsigned)(intptr_t)&Pattern;
+  }
 };
   
 /// CheckAndImmMatcher - This checks to see if the current node is an 'and'
@@ -431,7 +525,12 @@ public:
     return N->getKind() == CheckAndImm;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<CheckAndImmMatcher>(M)->Value == Value;
+  }
+  virtual unsigned getHashImpl() const { return Value; }
 };
 
 /// CheckOrImmMatcher - This checks to see if the current node is an 'and'
@@ -448,7 +547,12 @@ public:
     return N->getKind() == CheckOrImm;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<CheckOrImmMatcher>(M)->Value == Value;
+  }
+  virtual unsigned getHashImpl() const { return Value; }
 };
 
 /// CheckFoldableChainNodeMatcher - This checks to see if the current node
@@ -462,7 +566,10 @@ public:
     return N->getKind() == CheckFoldableChainNode;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const { return true; }
+  virtual unsigned getHashImpl() const { return 0; }
 };
 
 /// CheckChainCompatibleMatcher - Verify that the current node's chain
@@ -479,7 +586,12 @@ public:
     return N->getKind() == CheckChainCompatible;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<CheckChainCompatibleMatcher>(this)->PreviousOp == PreviousOp;
+  }
+  virtual unsigned getHashImpl() const { return PreviousOp; }
 };
   
 /// EmitIntegerMatcher - This creates a new TargetConstant.
@@ -497,7 +609,13 @@ public:
     return N->getKind() == EmitInteger;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<EmitIntegerMatcher>(M)->Val == Val &&
+           cast<EmitIntegerMatcher>(M)->VT == VT;
+  }
+  virtual unsigned getHashImpl() const { return (Val << 4) | VT; }
 };
 
 /// EmitStringIntegerMatcher - A target constant whose value is represented
@@ -516,7 +634,13 @@ public:
     return N->getKind() == EmitStringInteger;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<EmitStringIntegerMatcher>(M)->Val == Val &&
+           cast<EmitStringIntegerMatcher>(M)->VT == VT;
+  }
+  virtual unsigned getHashImpl() const;
 };
   
 /// EmitRegisterMatcher - This creates a new TargetConstant.
@@ -536,7 +660,15 @@ public:
     return N->getKind() == EmitRegister;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<EmitRegisterMatcher>(M)->Reg == Reg &&
+           cast<EmitRegisterMatcher>(M)->VT == VT;
+  }
+  virtual unsigned getHashImpl() const {
+    return ((unsigned)(intptr_t)Reg) << 4 | VT;
+  }
 };
 
 /// EmitConvertToTargetMatcher - Emit an operation that reads a specified
@@ -554,7 +686,12 @@ public:
     return N->getKind() == EmitConvertToTarget;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<EmitConvertToTargetMatcher>(M)->Slot == Slot;
+  }
+  virtual unsigned getHashImpl() const { return Slot; }
 };
   
 /// EmitMergeInputChainsMatcher - Emit a node that merges a list of input
@@ -578,7 +715,12 @@ public:
     return N->getKind() == EmitMergeInputChains;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<EmitMergeInputChainsMatcher>(M)->ChainNodes == ChainNodes;
+  }
+  virtual unsigned getHashImpl() const;
 };
   
 /// EmitCopyToRegMatcher - Emit a CopyToReg node from a value to a physreg,
@@ -598,7 +740,15 @@ public:
     return N->getKind() == EmitCopyToReg;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<EmitCopyToRegMatcher>(M)->SrcSlot == SrcSlot &&
+           cast<EmitCopyToRegMatcher>(M)->DestPhysReg == DestPhysReg; 
+  }
+  virtual unsigned getHashImpl() const {
+    return SrcSlot ^ ((unsigned)(intptr_t)DestPhysReg << 4);
+  }
 };
   
     
@@ -619,7 +769,15 @@ public:
     return N->getKind() == EmitNodeXForm;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<EmitNodeXFormMatcher>(M)->Slot == Slot &&
+           cast<EmitNodeXFormMatcher>(M)->NodeXForm == NodeXForm; 
+  }
+  virtual unsigned getHashImpl() const {
+    return Slot ^ ((unsigned)(intptr_t)NodeXForm << 4);
+  }
 };
   
 /// EmitNodeMatcher - This signals a successful match and generates a node.
@@ -635,10 +793,10 @@ class EmitNodeMatcher : public Matcher {
   int NumFixedArityOperands;
 public:
   EmitNodeMatcher(const std::string &opcodeName,
-                      const MVT::SimpleValueType *vts, unsigned numvts,
-                      const unsigned *operands, unsigned numops,
-                      bool hasChain, bool hasFlag, bool hasmemrefs,
-                      int numfixedarityoperands)
+                  const MVT::SimpleValueType *vts, unsigned numvts,
+                  const unsigned *operands, unsigned numops,
+                  bool hasChain, bool hasFlag, bool hasmemrefs,
+                  int numfixedarityoperands)
     : Matcher(EmitNode), OpcodeName(opcodeName),
       VTs(vts, vts+numvts), Operands(operands, operands+numops),
       HasChain(hasChain), HasFlag(hasFlag), HasMemRefs(hasmemrefs),
@@ -667,7 +825,10 @@ public:
     return N->getKind() == EmitNode;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const;
+  virtual unsigned getHashImpl() const;
 };
   
 /// MarkFlagResultsMatcher - This node indicates which non-root nodes in the
@@ -690,7 +851,12 @@ public:
     return N->getKind() == MarkFlagResults;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<MarkFlagResultsMatcher>(M)->FlagResultNodes == FlagResultNodes;
+  }
+  virtual unsigned getHashImpl() const;
 };
 
 /// CompleteMatchMatcher - Complete a match by replacing the results of the
@@ -713,9 +879,14 @@ public:
     return N->getKind() == CompleteMatch;
   }
   
+private:
   virtual void print(raw_ostream &OS, unsigned indent = 0) const;
+  virtual bool isEqualImpl(const Matcher *M) const {
+    return cast<CompleteMatchMatcher>(M)->Results == Results &&
+          &cast<CompleteMatchMatcher>(M)->Pattern == &Pattern;
+  }
+  virtual unsigned getHashImpl() const;
 };
   
 } // end namespace llvm