[TableGen] Use std::set_intersection to merge TypeSets. NFC
[oota-llvm.git] / utils / TableGen / DAGISelMatcherEmitter.cpp
index 04fe0d1824adf19db9f2f69d1e8cbaf2fccb3122..26f53dca63618c65783b208d11d5739403334150 100644 (file)
@@ -16,6 +16,7 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/TinyPtrVector.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FormattedStream.h"
 #include "llvm/TableGen/Record.h"
@@ -36,6 +37,10 @@ class MatcherTableEmitter {
   
   DenseMap<TreePattern *, unsigned> NodePredicateMap;
   std::vector<TreePredicateFn> NodePredicates;
+
+  // We de-duplicate the predicates by code string, and use this map to track
+  // all the patterns with "identical" predicates.
+  StringMap<TinyPtrVector<TreePattern *>> NodePredicatesByCodeToRun;
   
   StringMap<unsigned> PatternPredicateMap;
   std::vector<std::string> PatternPredicates;
@@ -62,10 +67,23 @@ private:
                        formatted_raw_ostream &OS);
 
   unsigned getNodePredicate(TreePredicateFn Pred) {
-    unsigned &Entry = NodePredicateMap[Pred.getOrigPatFragRecord()];
+    TreePattern *TP = Pred.getOrigPatFragRecord();
+    unsigned &Entry = NodePredicateMap[TP];
     if (Entry == 0) {
-      NodePredicates.push_back(Pred);
-      Entry = NodePredicates.size();
+      TinyPtrVector<TreePattern *> &SameCodePreds =
+          NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()];
+      if (SameCodePreds.empty()) {
+        // We've never seen a predicate with the same code: allocate an entry.
+        NodePredicates.push_back(Pred);
+        Entry = NodePredicates.size();
+      } else {
+        // We did see an identical predicate: re-use it.
+        Entry = NodePredicateMap[SameCodePreds.front()];
+        assert(Entry != 0);
+      }
+      // In both cases, we've never seen this particular predicate before, so
+      // mark it in the list of predicates sharing the same code.
+      SameCodePreds.push_back(TP);
     }
     return Entry-1;
   }
@@ -142,7 +160,7 @@ EmitMatcher(const Matcher *N, unsigned Indent, unsigned CurrentIdx,
   switch (N->getKind()) {
   case Matcher::Scope: {
     const ScopeMatcher *SM = cast<ScopeMatcher>(N);
-    assert(SM->getNext() == 0 && "Shouldn't have next after scope");
+    assert(SM->getNext() == nullptr && "Shouldn't have next after scope");
 
     unsigned StartIdx = CurrentIdx;
 
@@ -188,7 +206,7 @@ EmitMatcher(const Matcher *N, unsigned Indent, unsigned CurrentIdx,
             << " children in Scope";
       }
 
-      OS << '\n' << TmpBuf.str();
+      OS << '\n' << TmpBuf;
       CurrentIdx += ChildSize;
     }
 
@@ -332,7 +350,6 @@ EmitMatcher(const Matcher *N, unsigned Indent, unsigned CurrentIdx,
       // Emit the VBR.
       CurrentIdx += EmitVBRValue(ChildSize, OS);
 
-      OS << ' ';
       if (const SwitchOpcodeMatcher *SOM = dyn_cast<SwitchOpcodeMatcher>(N))
         OS << "TARGET_VAL(" << SOM->getCaseOpcode(i).getEnumName() << "),";
       else
@@ -343,7 +360,7 @@ EmitMatcher(const Matcher *N, unsigned Indent, unsigned CurrentIdx,
       if (!OmitComments)
         OS << "// ->" << CurrentIdx+ChildSize;
       OS << '\n';
-      OS << TmpBuf.str();
+      OS << TmpBuf;
       CurrentIdx += ChildSize;
     }
 
@@ -379,6 +396,14 @@ EmitMatcher(const Matcher *N, unsigned Indent, unsigned CurrentIdx,
     OS << '\n';
     return Bytes;
   }
+  case Matcher::CheckChildInteger: {
+    OS << "OPC_CheckChild" << cast<CheckChildIntegerMatcher>(N)->getChildNo()
+       << "Integer, ";
+    unsigned Bytes=1+EmitVBRValue(cast<CheckChildIntegerMatcher>(N)->getValue(),
+                                  OS);
+    OS << '\n';
+    return Bytes;
+  }
   case Matcher::CheckCondCode:
     OS << "OPC_CheckCondCode, ISD::"
        << cast<CheckCondCodeMatcher>(N)->getCondCodeName() << ",\n";
@@ -608,7 +633,7 @@ EmitMatcherList(const Matcher *N, unsigned Indent, unsigned CurrentIdx,
 void MatcherTableEmitter::EmitPredicateFunctions(formatted_raw_ostream &OS) {
   // Emit pattern predicates.
   if (!PatternPredicates.empty()) {
-    OS << "virtual bool CheckPatternPredicate(unsigned PredNo) const {\n";
+    OS << "bool CheckPatternPredicate(unsigned PredNo) const override {\n";
     OS << "  switch (PredNo) {\n";
     OS << "  default: llvm_unreachable(\"Invalid predicate in table?\");\n";
     for (unsigned i = 0, e = PatternPredicates.size(); i != e; ++i)
@@ -618,16 +643,9 @@ void MatcherTableEmitter::EmitPredicateFunctions(formatted_raw_ostream &OS) {
   }
 
   // Emit Node predicates.
-  // FIXME: Annoyingly, these are stored by name, which we never even emit. Yay?
-  StringMap<TreePattern*> PFsByName;
-
-  for (CodeGenDAGPatterns::pf_iterator I = CGP.pf_begin(), E = CGP.pf_end();
-       I != E; ++I)
-    PFsByName[I->first->getName()] = I->second;
-
   if (!NodePredicates.empty()) {
-    OS << "virtual bool CheckNodePredicate(SDNode *Node,\n";
-    OS << "                                unsigned PredNo) const {\n";
+    OS << "bool CheckNodePredicate(SDNode *Node,\n";
+    OS << "                        unsigned PredNo) const override {\n";
     OS << "  switch (PredNo) {\n";
     OS << "  default: llvm_unreachable(\"Invalid predicate in table?\");\n";
     for (unsigned i = 0, e = NodePredicates.size(); i != e; ++i) {
@@ -635,7 +653,10 @@ void MatcherTableEmitter::EmitPredicateFunctions(formatted_raw_ostream &OS) {
       TreePredicateFn PredFn = NodePredicates[i];
       
       assert(!PredFn.isAlwaysTrue() && "No code in this predicate");
-      OS << "  case " << i << ": { // " << NodePredicates[i].getFnName() <<'\n';
+      OS << "  case " << i << ": { \n";
+      for (auto *SimilarPred :
+           NodePredicatesByCodeToRun[PredFn.getCodeToRunOnSDNode()])
+        OS << "    // " << TreePredicateFn(SimilarPred).getFnName() <<'\n';
       
       OS << PredFn.getCodeToRunOnSDNode() << "\n  }\n";
     }
@@ -646,9 +667,9 @@ void MatcherTableEmitter::EmitPredicateFunctions(formatted_raw_ostream &OS) {
   // Emit CompletePattern matchers.
   // FIXME: This should be const.
   if (!ComplexPatterns.empty()) {
-    OS << "virtual bool CheckComplexPattern(SDNode *Root, SDNode *Parent,\n";
-    OS << "                                 SDValue N, unsigned PatternNo,\n";
-    OS << "         SmallVectorImpl<std::pair<SDValue, SDNode*> > &Result) {\n";
+    OS << "bool CheckComplexPattern(SDNode *Root, SDNode *Parent,\n";
+    OS << "                         SDValue N, unsigned PatternNo,\n";
+    OS << "         SmallVectorImpl<std::pair<SDValue, SDNode*> > &Result) override {\n";
     OS << "  unsigned NextRes = Result.size();\n";
     OS << "  switch (PatternNo) {\n";
     OS << "  default: llvm_unreachable(\"Invalid pattern # in table?\");\n";
@@ -687,7 +708,7 @@ void MatcherTableEmitter::EmitPredicateFunctions(formatted_raw_ostream &OS) {
   // Emit SDNodeXForm handlers.
   // FIXME: This should be const.
   if (!NodeXForms.empty()) {
-    OS << "virtual SDValue RunSDNodeXForm(SDValue V, unsigned XFormNo) {\n";
+    OS << "SDValue RunSDNodeXForm(SDValue V, unsigned XFormNo) override {\n";
     OS << "  switch (XFormNo) {\n";
     OS << "  default: llvm_unreachable(\"Invalid xform # in table?\");\n";
 
@@ -718,7 +739,7 @@ void MatcherTableEmitter::EmitPredicateFunctions(formatted_raw_ostream &OS) {
 }
 
 static void BuildHistogram(const Matcher *M, std::vector<unsigned> &OpcodeFreq){
-  for (; M != 0; M = M->getNext()) {
+  for (; M != nullptr; M = M->getNext()) {
     // Count this node.
     if (unsigned(M->getKind()) >= OpcodeFreq.size())
       OpcodeFreq.resize(M->getKind()+1);
@@ -769,6 +790,7 @@ void MatcherTableEmitter::EmitHistogram(const Matcher *M,
     case Matcher::SwitchType: OS << "OPC_SwitchType"; break;
     case Matcher::CheckChildType: OS << "OPC_CheckChildType"; break;
     case Matcher::CheckInteger: OS << "OPC_CheckInteger"; break;
+    case Matcher::CheckChildInteger: OS << "OPC_CheckChildInteger"; break;
     case Matcher::CheckCondCode: OS << "OPC_CheckCondCode"; break;
     case Matcher::CheckValueType: OS << "OPC_CheckValueType"; break;
     case Matcher::CheckComplexPat: OS << "OPC_CheckComplexPat"; break;
@@ -809,7 +831,7 @@ void llvm::EmitMatcherTable(const Matcher *TheMatcher,
   OS << "  // this.\n";
   OS << "  #define TARGET_VAL(X) X & 255, unsigned(X) >> 8\n";
   OS << "  static const unsigned char MatcherTable[] = {\n";
-  unsigned TotalSize = MatcherEmitter.EmitMatcherList(TheMatcher, 5, 0, OS);
+  unsigned TotalSize = MatcherEmitter.EmitMatcherList(TheMatcher, 6, 0, OS);
   OS << "    0\n  }; // Total Array size is " << (TotalSize+1) << " bytes\n\n";
 
   MatcherEmitter.EmitHistogram(TheMatcher, OS);