enhance the new isel to use SelectNodeTo for most patterns,
[oota-llvm.git] / include / llvm / CodeGen / DAGISelHeader.h
index 67b4155cf4bbf92bf87d8d49c672e833b48f5481..c4b6a7181257b09e8f4435f8bdc2ec7fd7e508ce 100644 (file)
@@ -217,6 +217,51 @@ GetVBR(unsigned Val, const unsigned char *MatcherTable, unsigned &Idx) {
   return Val;
 }
 
+/// UpdateChainsAndFlags - When a match is complete, this method updates uses of
+/// interior flag and chain results to use the new flag and chain results.
+void UpdateChainsAndFlags(SDNode *NodeToMatch, SDValue InputChain,
+                          const SmallVectorImpl<SDNode*> &ChainNodesMatched,
+                          SDValue InputFlag,
+                          const SmallVectorImpl<SDNode*>&FlagResultNodesMatched,
+                          bool isSelectNodeTo) {
+  // Now that all the normal results are replaced, we replace the chain and
+  // flag results if present.
+  if (!ChainNodesMatched.empty()) {
+    assert(InputChain.getNode() != 0 &&
+           "Matched input chains but didn't produce a chain");
+    // Loop over all of the nodes we matched that produced a chain result.
+    // Replace all the chain results with the final chain we ended up with.
+    for (unsigned i = 0, e = ChainNodesMatched.size(); i != e; ++i) {
+      SDNode *ChainNode = ChainNodesMatched[i];
+      
+      // Don't replace the results of the root node if we're doing a
+      // SelectNodeTo.
+      if (ChainNode == NodeToMatch && isSelectNodeTo)
+        continue;
+      
+      SDValue ChainVal = SDValue(ChainNode, ChainNode->getNumValues()-1);
+      if (ChainVal.getValueType() == MVT::Flag)
+        ChainVal = ChainVal.getValue(ChainVal->getNumValues()-2);
+      assert(ChainVal.getValueType() == MVT::Other && "Not a chain?");
+      ReplaceUses(ChainVal, InputChain);
+    }
+  }
+  
+  // If the result produces a flag, update any flag results in the matched
+  // pattern with the flag result.
+  if (InputFlag.getNode() != 0) {
+    // Handle any interior nodes explicitly marked.
+    for (unsigned i = 0, e = FlagResultNodesMatched.size(); i != e; ++i) {
+      SDNode *FRN = FlagResultNodesMatched[i];
+      assert(FRN->getValueType(FRN->getNumValues()-1) == MVT::Flag &&
+             "Doesn't have a flag result");
+      ReplaceUses(SDValue(FRN, FRN->getNumValues()-1), InputFlag);
+    }
+  }
+  
+  DEBUG(errs() << "ISEL: Match complete!\n");
+}
+
 
 enum BuiltinOpcodes {
   OPC_Scope,
@@ -252,6 +297,7 @@ enum BuiltinOpcodes {
   OPC_EmitCopyToReg,
   OPC_EmitNodeXForm,
   OPC_EmitNode,
+  OPC_SelectNodeTo,
   OPC_MarkFlagResults,
   OPC_CompleteMatch
 };
@@ -741,7 +787,8 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
       continue;
     }
         
-    case OPC_EmitNode: {
+    case OPC_EmitNode:
+    case OPC_SelectNodeTo: {
       uint16_t TargetOpc = GetInt2(MatcherTable, MatcherIndex);
       unsigned EmitNodeInfo = MatcherTable[MatcherIndex++];
       // Get the result VT list.
@@ -794,14 +841,54 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
         Ops.push_back(InputFlag);
       
       // Create the node.
-      MachineSDNode *Res = CurDAG->getMachineNode(TargetOpc,
-                                                  NodeToMatch->getDebugLoc(),
-                                                  VTList,
-                                                  Ops.data(), Ops.size());
-      // Add all the non-flag/non-chain results to the RecordedNodes list.
-      for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
-        if (VTs[i] == MVT::Other || VTs[i] == MVT::Flag) break;
-        RecordedNodes.push_back(SDValue(Res, i));
+      SDNode *Res = 0;
+      if (Opcode == OPC_SelectNodeTo) {
+        // It is possible we're using SelectNodeTo to replace a node with no
+        // normal results with one that has a normal result (or we could be
+        // adding a chain) and the input could have flags and chains as well.
+        // In this case we need to shifting the operands down.
+        // FIXME: This is a horrible hack and broken in obscure cases, no worse
+        // than the old isel though.  We should sink this into SelectNodeTo.
+        int OldFlagResultNo = -1, OldChainResultNo = -1;
+        
+        unsigned NTMNumResults = NodeToMatch->getNumValues();
+        if (NodeToMatch->getValueType(NTMNumResults-1) == MVT::Flag) {
+          OldFlagResultNo = NTMNumResults-1;
+          if (NTMNumResults != 1 &&
+              NodeToMatch->getValueType(NTMNumResults-2) == MVT::Other)
+            OldChainResultNo = NTMNumResults-2;
+        } else if (NodeToMatch->getValueType(NTMNumResults-1) == MVT::Other)
+          OldChainResultNo = NTMNumResults-1;
+        
+        Res = CurDAG->SelectNodeTo(NodeToMatch, TargetOpc, VTList,
+                                   Ops.data(), Ops.size());
+        
+        // FIXME: Whether the selected node has a flag result should come from
+        // flags on the node.
+        unsigned ResNumResults = Res->getNumValues();
+        if (Res->getValueType(ResNumResults-1) == MVT::Flag) {
+          // Move the flag if needed.
+          if (OldFlagResultNo != -1 &&
+              (unsigned)OldFlagResultNo != ResNumResults-1)
+            ReplaceUses(SDValue(Res, OldFlagResultNo), 
+                        SDValue(Res, ResNumResults-1));
+          --ResNumResults;
+        }
+
+        // Move the chain reference if needed.
+        if ((EmitNodeInfo & OPFL_Chain) && OldChainResultNo != -1 &&
+            (unsigned)OldChainResultNo != ResNumResults-1)
+          ReplaceUses(SDValue(Res, OldChainResultNo), 
+                      SDValue(Res, ResNumResults-1));
+      } else {
+        Res = CurDAG->getMachineNode(TargetOpc, NodeToMatch->getDebugLoc(),
+                                     VTList, Ops.data(), Ops.size());
+      
+        // Add all the non-flag/non-chain results to the RecordedNodes list.
+        for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
+          if (VTs[i] == MVT::Other || VTs[i] == MVT::Flag) break;
+          RecordedNodes.push_back(SDValue(Res, i));
+        }
       }
       
       // If the node had chain/flag results, update our notion of the current
@@ -823,10 +910,22 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
         MachineSDNode::mmo_iterator MemRefs =
           MF->allocateMemRefsArray(MatchedMemRefs.size());
         std::copy(MatchedMemRefs.begin(), MatchedMemRefs.end(), MemRefs);
-        Res->setMemRefs(MemRefs, MemRefs + MatchedMemRefs.size());
+        cast<MachineSDNode>(Res)
+          ->setMemRefs(MemRefs, MemRefs + MatchedMemRefs.size());
+      }
+      
+      DEBUG(errs() << "  "
+                   << (Opcode == OPC_SelectNodeTo ? "Selected" : "Created")
+                   << " node: "; Res->dump(CurDAG); errs() << "\n");
+      
+      // If this was a SelectNodeTo then we're completely done!
+      if (Opcode == OPC_SelectNodeTo) {
+        // Update chain and flag uses.
+        UpdateChainsAndFlags(NodeToMatch, InputChain, ChainNodesMatched,
+                             InputFlag, FlagResultNodesMatched, true);
+        return Res;
       }
       
-      DEBUG(errs() << "  Created node: "; Res->dump(CurDAG); errs() << "\n");
       continue;
     }
         
@@ -875,47 +974,19 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
                "invalid replacement");
         ReplaceUses(SDValue(NodeToMatch, i), Res);
       }
-      
-      // Now that all the normal results are replaced, we replace the chain and
-      // flag results if present.
-      if (!ChainNodesMatched.empty()) {
-        assert(InputChain.getNode() != 0 &&
-               "Matched input chains but didn't produce a chain");
-        // Loop over all of the nodes we matched that produced a chain result.
-        // Replace all the chain results with the final chain we ended up with.
-        for (unsigned i = 0, e = ChainNodesMatched.size(); i != e; ++i) {
-          SDNode *ChainNode = ChainNodesMatched[i];
-          SDValue ChainVal = SDValue(ChainNode, ChainNode->getNumValues()-1);
-          if (ChainVal.getValueType() == MVT::Flag)
-            ChainVal = ChainVal.getValue(ChainVal->getNumValues()-2);
-          assert(ChainVal.getValueType() == MVT::Other && "Not a chain?");
-          ReplaceUses(ChainVal, InputChain);
-        }
-      }
 
-      // If the result produces a flag, update any flag results in the matched
-      // pattern with the flag result.
-      if (InputFlag.getNode() != 0) {
-        // Handle the root node:
-        if (NodeToMatch->getValueType(NodeToMatch->getNumValues()-1) ==
-              MVT::Flag)
-          ReplaceUses(SDValue(NodeToMatch, NodeToMatch->getNumValues()-1),
-                      InputFlag);
-        
-        // Handle any interior nodes explicitly marked.
-        for (unsigned i = 0, e = FlagResultNodesMatched.size(); i != e; ++i) {
-          SDNode *FRN = FlagResultNodesMatched[i];
-          assert(FRN->getValueType(FRN->getNumValues()-1) == MVT::Flag &&
-                 "Doesn't have a flag result");
-          ReplaceUses(SDValue(FRN, FRN->getNumValues()-1), InputFlag);
-        }
-      }
+      // If the root node defines a flag, add it to the flag nodes to update
+      // list.
+      if (NodeToMatch->getValueType(NodeToMatch->getNumValues()-1) == MVT::Flag)
+        FlagResultNodesMatched.push_back(NodeToMatch);
+      
+      // Update chain and flag uses.
+      UpdateChainsAndFlags(NodeToMatch, InputChain, ChainNodesMatched,
+                           InputFlag, FlagResultNodesMatched, false);
       
       assert(NodeToMatch->use_empty() &&
              "Didn't replace all uses of the node?");
       
-      DEBUG(errs() << "ISEL: Match complete!\n");
-      
       // FIXME: We just return here, which interacts correctly with SelectRoot
       // above.  We should fix this to not return an SDNode* anymore.
       return 0;