Initial support for multi-result patterns:
authorEvan Cheng <evan.cheng@apple.com>
Wed, 12 Sep 2007 23:30:14 +0000 (23:30 +0000)
committerEvan Cheng <evan.cheng@apple.com>
Wed, 12 Sep 2007 23:30:14 +0000 (23:30 +0000)
1.
[(set GR32:$dst, (add GR32:$src1, GR32:$src2)),
 (modify EFLAGS)]
This indicates the source pattern expects the instruction would produce 2 values. The first is the result of the addition. The second is an implicit definition in register EFLAGS.
2.
def : Pat<(parallel (addc GR32:$src1, GR32:$src2), (modify EFLAGS)), ()>
Similar to #1 except this is used for def : Pat patterns.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@41897 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/TargetSelectionDAG.td
utils/TableGen/DAGISelEmitter.cpp
utils/TableGen/DAGISelEmitter.h

index 0698e1ea359e1054c3fbf6121b4b72200d8b4a73..194f55f9674508c674bf746be5dff62b62d4e4e2 100644 (file)
@@ -197,6 +197,8 @@ class SDNode<string opcode, SDTypeProfile typeprof,
 }
 
 def set;
+def modify;
+def parallel;
 def node;
 def srcvalue;
 
index 575b701801cede0077cc1a1a3c0594c2865aa878..9aa424fe3e0d66665dfbb51c065c7930e1f77b47 100644 (file)
@@ -691,6 +691,13 @@ bool TreePatternNode::ApplyTypeConstraints(TreePattern &TP, bool NotRegisters) {
       MadeChange |= UpdateNodeType(MVT::isVoid, TP);
     }
     return MadeChange;
+  } else if (getOperator()->getName() == "modify" ||
+             getOperator()->getName() == "parallel") {
+    bool MadeChange = false;
+    for (unsigned i = 0; i < getNumChildren(); ++i)
+      MadeChange = getChild(i)->ApplyTypeConstraints(TP, NotRegisters);
+    MadeChange |= UpdateNodeType(MVT::isVoid, TP);
+    return MadeChange;
   } else if (getOperator() == ISE.get_intrinsic_void_sdnode() ||
              getOperator() == ISE.get_intrinsic_w_chain_sdnode() ||
              getOperator() == ISE.get_intrinsic_wo_chain_sdnode()) {
@@ -968,7 +975,9 @@ TreePatternNode *TreePattern::ParseTreePattern(DagInit *Dag) {
       !Operator->isSubClassOf("Instruction") && 
       !Operator->isSubClassOf("SDNodeXForm") &&
       !Operator->isSubClassOf("Intrinsic") &&
-      Operator->getName() != "set")
+      Operator->getName() != "set" &&
+      Operator->getName() != "modify" &&
+      Operator->getName() != "parallel")
     error("Unrecognized node '" + Operator->getName() + "'!");
   
   //  Check to see if this is something that is illegal in an input pattern.
@@ -1376,6 +1385,18 @@ FindPatternInputsAndOutputs(TreePattern *I, TreePatternNode *Pat,
     if (!isUse && Pat->getTransformFn())
       I->error("Cannot specify a transform function for a non-input value!");
     return;
+  } else if (Pat->getOperator()->getName() == "modify") {
+    for (unsigned i = 0, e = Pat->getNumChildren(); i != e; ++i) {
+      TreePatternNode *Dest = Pat->getChild(i);
+      if (!Dest->isLeaf())
+        I->error("modify value should be a register!");
+    
+      DefInit *Val = dynamic_cast<DefInit*>(Dest->getLeafValue());
+      if (!Val || !Val->getDef()->isSubClassOf("Register"))
+        I->error("modify value should be a register!");
+      InstImpResults.push_back(Val->getDef());
+    }
+    return;
   } else if (Pat->getOperator()->getName() != "set") {
     // If this is not a set, verify that the children nodes are not void typed,
     // and recurse.
@@ -1424,7 +1445,6 @@ FindPatternInputsAndOutputs(TreePattern *I, TreePatternNode *Pat,
       InstResults[Dest->getName()] = Dest;
     } else if (Val->getDef()->isSubClassOf("Register")) {
       InstImpResults.push_back(Val->getDef());
-      ;
     } else {
       I->error("set destination should be a register!");
     }
@@ -1621,6 +1641,8 @@ void DAGISelEmitter::ParseInstructions() {
       ResultPattern->setTypes(Res0Node->getExtTypes());
 
     // Create and insert the instruction.
+    // FIXME: InstImpResults and InstImpInputs should not be part of
+    // DAGInstruction.
     DAGInstruction TheInst(I, Results, Operands, InstImpResults, InstImpInputs);
     Instructions.insert(std::make_pair(I->getRecord(), TheInst));
 
@@ -1643,10 +1665,8 @@ void DAGISelEmitter::ParseInstructions() {
     TreePattern *I = TheInst.getPattern();
     if (I == 0) continue;  // No pattern.
 
-    if (I->getNumTrees() != 1) {
-      cerr << "CANNOT HANDLE: " << I->getRecord()->getName() << " yet!";
-      continue;
-    }
+    // FIXME: Assume only the first tree is the pattern. The others are clobber
+    // nodes.
     TreePatternNode *Pattern = I->getTree(0);
     TreePatternNode *SrcPattern;
     if (Pattern->getOperator()->getName() == "set") {
@@ -1664,7 +1684,7 @@ void DAGISelEmitter::ParseInstructions() {
     TreePatternNode *DstPattern = TheInst.getResultPattern();
     PatternsToMatch.
       push_back(PatternToMatch(Instr->getValueAsListInit("Predicates"),
-                               SrcPattern, DstPattern,
+                               SrcPattern, DstPattern, TheInst.getImpResults(),
                                Instr->getValueAsInt("AddedComplexity")));
   }
 }
@@ -1674,7 +1694,18 @@ void DAGISelEmitter::ParsePatterns() {
 
   for (unsigned i = 0, e = Patterns.size(); i != e; ++i) {
     DagInit *Tree = Patterns[i]->getValueAsDag("PatternToMatch");
-    TreePattern *Pattern = new TreePattern(Patterns[i], Tree, true, *this);
+    DefInit *OpDef = dynamic_cast<DefInit*>(Tree->getOperator());
+    Record *Operator = OpDef->getDef();
+    TreePattern *Pattern;
+    if (Operator->getName() != "parallel")
+      Pattern = new TreePattern(Patterns[i], Tree, true, *this);
+    else {
+      std::vector<Init*> Values;
+      for (unsigned j = 0, ee = Tree->getNumArgs(); j != ee; ++j)
+        Values.push_back(Tree->getArg(j));
+      ListInit *LI = new ListInit(Values);
+      Pattern = new TreePattern(Patterns[i], LI, true, *this);
+    }
 
     // Inline pattern fragments into it.
     Pattern->InlinePatternFragments();
@@ -1707,10 +1738,10 @@ void DAGISelEmitter::ParsePatterns() {
       // resolve cases where the input type is known to be a pointer type (which
       // is considered resolved), but the result knows it needs to be 32- or
       // 64-bits.  Infer the other way for good measure.
-      IterateInference = Pattern->getOnlyTree()->
-        UpdateNodeType(Result->getOnlyTree()->getExtTypes(), *Result);
-      IterateInference |= Result->getOnlyTree()->
-        UpdateNodeType(Pattern->getOnlyTree()->getExtTypes(), *Result);
+      IterateInference = Pattern->getTree(0)->
+        UpdateNodeType(Result->getTree(0)->getExtTypes(), *Result);
+      IterateInference |= Result->getTree(0)->
+        UpdateNodeType(Pattern->getTree(0)->getExtTypes(), *Result);
     } while (IterateInference);
 
     // Verify that we inferred enough types that we can do something with the
@@ -1721,19 +1752,18 @@ void DAGISelEmitter::ParsePatterns() {
       Result->error("Could not infer all types in pattern result!");
     
     // Validate that the input pattern is correct.
-    {
-      std::map<std::string, TreePatternNode*> InstInputs;
-      std::map<std::string, TreePatternNode*> InstResults;
-      std::vector<Record*> InstImpInputs;
-      std::vector<Record*> InstImpResults;
-      FindPatternInputsAndOutputs(Pattern, Pattern->getOnlyTree(),
+    std::map<std::string, TreePatternNode*> InstInputs;
+    std::map<std::string, TreePatternNode*> InstResults;
+    std::vector<Record*> InstImpInputs;
+    std::vector<Record*> InstImpResults;
+    for (unsigned j = 0, ee = Pattern->getNumTrees(); j != ee; ++j)
+      FindPatternInputsAndOutputs(Pattern, Pattern->getTree(j),
                                   InstInputs, InstResults,
                                   InstImpInputs, InstImpResults);
-    }
 
     // Promote the xform function to be an explicit node if set.
-    std::vector<TreePatternNode*> ResultNodeOperands;
     TreePatternNode *DstPattern = Result->getOnlyTree();
+    std::vector<TreePatternNode*> ResultNodeOperands;
     for (unsigned ii = 0, ee = DstPattern->getNumChildren(); ii != ee; ++ii) {
       TreePatternNode *OpNode = DstPattern->getChild(ii);
       if (Record *Xform = OpNode->getTransformFn()) {
@@ -1753,13 +1783,13 @@ void DAGISelEmitter::ParsePatterns() {
     Temp.InferAllTypes();
 
     std::string Reason;
-    if (!Pattern->getOnlyTree()->canPatternMatch(Reason, *this))
+    if (!Pattern->getTree(0)->canPatternMatch(Reason, *this))
       Pattern->error("Pattern can never match: " + Reason);
     
     PatternsToMatch.
       push_back(PatternToMatch(Patterns[i]->getValueAsListInit("Predicates"),
-                               Pattern->getOnlyTree(),
-                               Temp.getOnlyTree(),
+                               Pattern->getTree(0),
+                               Temp.getOnlyTree(), InstImpResults,
                                Patterns[i]->getValueAsInt("AddedComplexity")));
   }
 }
@@ -2017,6 +2047,7 @@ void DAGISelEmitter::GenerateVariants() {
       PatternsToMatch.
         push_back(PatternToMatch(PatternsToMatch[i].getPredicates(),
                                  Variant, PatternsToMatch[i].getDstPattern(),
+                                 PatternsToMatch[i].getDstRegs(),
                                  PatternsToMatch[i].getAddedComplexity()));
     }
 
@@ -2617,7 +2648,8 @@ public:
   /// EmitResultCode - Emit the action for a pattern.  Now that it has matched
   /// we actually have to build a DAG!
   std::vector<std::string>
-  EmitResultCode(TreePatternNode *N, bool RetSelected,
+  EmitResultCode(TreePatternNode *N, std::vector<Record*> DstRegs,
+                 bool RetSelected,
                  bool InFlagDecled, bool ResNodeDecled,
                  bool LikeLeaf = false, bool isRoot = false) {
     // List of arguments of getTargetNode() or SelectNodeTo().
@@ -2758,15 +2790,17 @@ public:
       CodeGenInstruction &II = CGT.getInstruction(Op->getName());
       const DAGInstruction &Inst = ISE.getInstruction(Op);
       TreePattern *InstPat = Inst.getPattern();
+      // FIXME: Assume actual pattern comes before "modify".
       TreePatternNode *InstPatNode =
-        isRoot ? (InstPat ? InstPat->getOnlyTree() : Pattern)
-               : (InstPat ? InstPat->getOnlyTree() : NULL);
+        isRoot ? (InstPat ? InstPat->getTree(0) : Pattern)
+               : (InstPat ? InstPat->getTree(0) : NULL);
       if (InstPatNode && InstPatNode->getOperator()->getName() == "set") {
         InstPatNode = InstPatNode->getChild(InstPatNode->getNumChildren()-1);
       }
       bool HasVarOps     = isRoot && II.hasVariableNumberOfOperands;
+      // FIXME: fix how we deal with physical register operands.
       bool HasImpInputs  = isRoot && Inst.getNumImpOperands() > 0;
-      bool HasImpResults = isRoot && Inst.getNumImpResults() > 0;
+      bool HasImpResults = isRoot && DstRegs.size() > 0;
       bool NodeHasOptInFlag = isRoot &&
         PatternHasProperty(Pattern, SDNPOptInFlag, ISE);
       bool NodeHasInFlag  = isRoot &&
@@ -2778,6 +2812,7 @@ public:
       bool InputHasChain = isRoot &&
         NodeHasProperty(Pattern, SDNPHasChain, ISE);
       unsigned NumResults = Inst.getNumResults();    
+      unsigned NumDstRegs = HasImpResults ? DstRegs.size() : 0;
 
       if (NodeHasOptInFlag) {
         emitCode("bool HasInFlag = "
@@ -2787,11 +2822,11 @@ public:
         emitCode("SmallVector<SDOperand, 8> Ops" + utostr(OpcNo) + ";");
 
       // How many results is this pattern expected to produce?
-      unsigned PatResults = 0;
+      unsigned NumPatResults = 0;
       for (unsigned i = 0, e = Pattern->getExtTypes().size(); i != e; i++) {
         MVT::ValueType VT = Pattern->getTypeNum(i);
         if (VT != MVT::isVoid && VT != MVT::Flag)
-          PatResults++;
+          NumPatResults++;
       }
 
       if (OrigChains.size() > 0) {
@@ -2832,7 +2867,7 @@ public:
         if ((!OperandNode->isSubClassOf("PredicateOperand") &&
              !OperandNode->isSubClassOf("OptionalDefOperand")) ||
             ISE.getDefaultOperand(OperandNode).DefaultOps.empty()) {
-          Ops = EmitResultCode(N->getChild(ChildNo), RetSelected, 
+          Ops = EmitResultCode(N->getChild(ChildNo), DstRegs, RetSelected, 
                                InFlagDecled, ResNodeDecled);
           AllOps.insert(AllOps.end(), Ops.begin(), Ops.end());
           ++ChildNo;
@@ -2842,7 +2877,7 @@ public:
           const DAGDefaultOperand &DefaultOp =
             ISE.getDefaultOperand(II.OperandList[InstOpNo].Rec);
           for (unsigned i = 0, e = DefaultOp.DefaultOps.size(); i != e; ++i) {
-            Ops = EmitResultCode(DefaultOp.DefaultOps[i], RetSelected, 
+            Ops = EmitResultCode(DefaultOp.DefaultOps[i], DstRegs, RetSelected, 
                                  InFlagDecled, ResNodeDecled);
             AllOps.insert(AllOps.end(), Ops.begin(), Ops.end());
             NumEAInputs += Ops.size();
@@ -2888,7 +2923,7 @@ public:
             Code2 = NodeName + " = ";
         }
 
-        Code = "CurDAG->getTargetNode(Opc" + utostr(OpcNo);
+        Code += "CurDAG->getTargetNode(Opc" + utostr(OpcNo);
         unsigned OpsNo = OpcNo;
         emitOpcode(II.Namespace + "::" + II.TheDef->getName());
 
@@ -2900,14 +2935,11 @@ public:
         }
         // Add types for implicit results in physical registers, scheduler will
         // care of adding copyfromreg nodes.
-        if (HasImpResults) {
-          for (unsigned i = 0, e = Inst.getNumImpResults(); i < e; i++) {
-            Record *RR = Inst.getImpResult(i);
-            if (RR->isSubClassOf("Register")) {
-              MVT::ValueType RVT = getRegisterValueType(RR, CGT);
-              Code += ", " + getEnumName(RVT);
-              ++NumResults;
-            }
+        for (unsigned i = 0; i < NumDstRegs; i++) {
+          Record *RR = DstRegs[i];
+          if (RR->isSubClassOf("Register")) {
+            MVT::ValueType RVT = getRegisterValueType(RR, CGT);
+            Code += ", " + getEnumName(RVT);
           }
         }
         if (NodeHasChain)
@@ -2961,7 +2993,7 @@ public:
           Code += ", &Ops" + utostr(OpsNo) + "[0], Ops" + utostr(OpsNo) +
             ".size()";
         } else if (NodeHasInFlag || NodeHasOptInFlag || HasImpInputs)
-            AllOps.push_back("InFlag");
+          AllOps.push_back("InFlag");
 
         unsigned NumOps = AllOps.size();
         if (NumOps) {
@@ -2993,10 +3025,10 @@ public:
           // Remember which op produces the chain.
           if (!isRoot)
             emitCode(ChainName + " = SDOperand(" + NodeName +
-                     ".Val, " + utostr(PatResults) + ");");
+                     ".Val, " + utostr(NumResults+NumDstRegs) + ");");
           else
             emitCode(ChainName + " = SDOperand(" + NodeName +
-                     ", " + utostr(PatResults) + ");");
+                     ", " + utostr(NumResults+NumDstRegs) + ");");
 
         if (!isRoot) {
           NodeOps.push_back("Tmp" + utostr(ResNo));
@@ -3007,11 +3039,11 @@ public:
         if (NodeHasOutFlag) {
           if (!InFlagDecled) {
             emitCode("SDOperand InFlag(ResNode, " + 
-                     utostr(NumResults + (unsigned)NodeHasChain) + ");");
+                     utostr(NumResults+NumDstRegs+(unsigned)NodeHasChain) + ");");
             InFlagDecled = true;
           } else
             emitCode("InFlag = SDOperand(ResNode, " + 
-                     utostr(NumResults + (unsigned)NodeHasChain) + ");");
+                     utostr(NumResults+NumDstRegs+(unsigned)NodeHasChain) + ");");
         }
 
         if (FoldedChains.size() > 0) {
@@ -3020,23 +3052,23 @@ public:
             emitCode("ReplaceUses(SDOperand(" +
                      FoldedChains[j].first + ".Val, " + 
                      utostr(FoldedChains[j].second) + "), SDOperand(ResNode, " +
-                     utostr(NumResults) + "));");
+                     utostr(NumResults+NumDstRegs) + "));");
           NeedReplace = true;
         }
 
         if (NodeHasOutFlag) {
           emitCode("ReplaceUses(SDOperand(N.Val, " +
-                   utostr(PatResults + (unsigned)InputHasChain) +"), InFlag);");
+                   utostr(NumPatResults + (unsigned)InputHasChain) +"), InFlag);");
           NeedReplace = true;
         }
 
         if (NeedReplace) {
-          for (unsigned i = 0; i < NumResults; i++)
+          for (unsigned i = 0; i < NumPatResults; i++)
             emitCode("ReplaceUses(SDOperand(N.Val, " +
                      utostr(i) + "), SDOperand(ResNode, " + utostr(i) + "));");
           if (InputHasChain)
             emitCode("ReplaceUses(SDOperand(N.Val, " + 
-                     utostr(PatResults) + "), SDOperand(" + ChainName + ".Val, "
+                     utostr(NumPatResults) + "), SDOperand(" + ChainName + ".Val, "
                      + ChainName + ".ResNo" + "));");
         } else
           RetSelected = true;
@@ -3047,12 +3079,12 @@ public:
         } else if (InputHasChain && !NodeHasChain) {
           // One of the inner node produces a chain.
           if (NodeHasOutFlag)
-           emitCode("ReplaceUses(SDOperand(N.Val, " + utostr(PatResults+1) +
+           emitCode("ReplaceUses(SDOperand(N.Val, " + utostr(NumPatResults+1) +
                     "), SDOperand(ResNode, N.ResNo-1));");
-         for (unsigned i = 0; i < PatResults; ++i)
+         for (unsigned i = 0; i < NumPatResults; ++i)
            emitCode("ReplaceUses(SDOperand(N.Val, " + utostr(i) +
                     "), SDOperand(ResNode, " + utostr(i) + "));");
-         emitCode("ReplaceUses(SDOperand(N.Val, " + utostr(PatResults) +
+         emitCode("ReplaceUses(SDOperand(N.Val, " + utostr(NumPatResults) +
                   "), " + ChainName + ");");
          RetSelected = false;
         }
@@ -3101,7 +3133,7 @@ public:
       // PatLeaf node - the operand may or may not be a leaf node. But it should
       // behave like one.
       std::vector<std::string> Ops =
-        EmitResultCode(N->getChild(0), RetSelected, InFlagDecled,
+        EmitResultCode(N->getChild(0), DstRegs, RetSelected, InFlagDecled,
                        ResNodeDecled, true);
       unsigned ResNo = TmpNo++;
       emitCode("SDOperand Tmp" + utostr(ResNo) + " = Transform_" + Op->getName()
@@ -3267,7 +3299,7 @@ void DAGISelEmitter::GenerateCodeForPattern(PatternToMatch &Pattern,
     // otherwise we are done.
   } while (Emitter.InsertOneTypeCheck(Pat, Pattern.getSrcPattern(), "N", true));
 
-  Emitter.EmitResultCode(Pattern.getDstPattern(),
+  Emitter.EmitResultCode(Pattern.getDstPattern(), Pattern.getDstRegs(),
                          false, false, false, false, true);
   delete Pat;
 }
@@ -3924,8 +3956,15 @@ OS << "  unsigned NumKilled = ISelKilled.size();\n";
   OS << "  setSelected(F.Val->getNodeId());\n";
   OS << "  RemoveKilled();\n";
   OS << "}\n";
-  OS << "inline void ReplaceUses(SDNode *F, SDNode *T) {\n";
-  OS << "  CurDAG->ReplaceAllUsesWith(F, T, &ISelKilled);\n";
+  OS << "void ReplaceUses(SDNode *F, SDNode *T) DISABLE_INLINE {\n";
+  OS << "  unsigned NumVals = F->getNumValues();\n";
+  OS << "  if (NumVals < T->getNumValues()) {\n";
+  OS << "    for (unsigned i = 0; i < NumVals; ++i)\n";
+  OS << "      CurDAG->ReplaceAllUsesOfValueWith(SDOperand(F, i), "
+     << "SDOperand(T, i), ISelKilled);\n";
+  OS << "  } else {\n";
+  OS << "    CurDAG->ReplaceAllUsesWith(F, T, &ISelKilled);\n";
+  OS << "  }\n";
   OS << "  setSelected(F->getNodeId());\n";
   OS << "  RemoveKilled();\n";
   OS << "}\n\n";
index 7511c4ea6a39ab67dc43ba77051f506b24505c75..a3b9010c93fa4efdc5f3f42b60df6569bcc45baf 100644 (file)
@@ -371,6 +371,7 @@ namespace llvm {
     unsigned getNumOperands() const { return Operands.size(); }
     unsigned getNumImpResults() const { return ImpResults.size(); }
     unsigned getNumImpOperands() const { return ImpOperands.size(); }
+    const std::vector<Record*>& getImpResults() const { return ImpResults; }
     
     void setResultPattern(TreePatternNode *R) { ResultPattern = R; }
     
@@ -402,18 +403,21 @@ namespace llvm {
 struct PatternToMatch {
   PatternToMatch(ListInit *preds,
                  TreePatternNode *src, TreePatternNode *dst,
+                 const std::vector<Record*> &dstregs,
                  unsigned complexity):
-    Predicates(preds), SrcPattern(src), DstPattern(dst),
+    Predicates(preds), SrcPattern(src), DstPattern(dst), Dstregs(dstregs),
     AddedComplexity(complexity) {};
 
   ListInit        *Predicates;  // Top level predicate conditions to match.
   TreePatternNode *SrcPattern;  // Source pattern to match.
   TreePatternNode *DstPattern;  // Resulting pattern.
+  std::vector<Record*> Dstregs; // Physical register defs being matched.
   unsigned         AddedComplexity; // Add to matching pattern complexity.
 
   ListInit        *getPredicates() const { return Predicates; }
   TreePatternNode *getSrcPattern() const { return SrcPattern; }
   TreePatternNode *getDstPattern() const { return DstPattern; }
+  const std::vector<Record*> &getDstRegs() const { return Dstregs; }
   unsigned         getAddedComplexity() const { return AddedComplexity; }
 };