The new isel was not properly handling patterns that covered
[oota-llvm.git] / include / llvm / CodeGen / DAGISelHeader.h
index b4cc0d72308644722f218f62dad89cc2790dc952..7a6c1962f74bc59dc0d15f221d9509f13f594acc 100644 (file)
@@ -247,6 +247,7 @@ enum BuiltinOpcodes {
   OPC_EmitCopyToReg,
   OPC_EmitNodeXForm,
   OPC_EmitNode,
+  OPC_MarkFlagResults,
   OPC_CompleteMatch
 };
 
@@ -290,7 +291,7 @@ struct MatchScope {
   SDValue InputChain, InputFlag;
 
   /// HasChainNodesMatched - True if the ChainNodesMatched list is non-empty.
-  bool HasChainNodesMatched;
+  bool HasChainNodesMatched, HasFlagResultNodesMatched;
 };
 
 SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
@@ -354,6 +355,7 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
   // which ones they are.  The result is captured into this list so that we can
   // update the chain results when the pattern is complete.
   SmallVector<SDNode*, 3> ChainNodesMatched;
+  SmallVector<SDNode*, 3> FlagResultNodesMatched;
   
   DEBUG(errs() << "ISEL: Starting pattern match on root node: ";
         NodeToMatch->dump(CurDAG);
@@ -374,6 +376,7 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
       NewEntry.InputChain = InputChain;
       NewEntry.InputFlag = InputFlag;
       NewEntry.HasChainNodesMatched = !ChainNodesMatched.empty();
+      NewEntry.HasFlagResultNodesMatched = !FlagResultNodesMatched.empty();
       MatchScopes.push_back(NewEntry);
       continue;
     }
@@ -387,6 +390,7 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
       NewEntry.InputChain = InputChain;
       NewEntry.InputFlag = InputFlag;
       NewEntry.HasChainNodesMatched = !ChainNodesMatched.empty();
+      NewEntry.HasFlagResultNodesMatched = !FlagResultNodesMatched.empty();
       MatchScopes.push_back(NewEntry);
       continue;
     }
@@ -796,6 +800,21 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
       DEBUG(errs() << "  Created node: "; Res->dump(CurDAG); errs() << "\n");
       continue;
     }
+        
+    case OPC_MarkFlagResults: {
+      unsigned NumNodes = MatcherTable[MatcherIndex++];
+      
+      // Read and remember all the flag-result nodes.
+      for (unsigned i = 0; i != NumNodes; ++i) {
+        unsigned RecNo = MatcherTable[MatcherIndex++];
+        if (RecNo & 128)
+          RecNo = GetVBR(RecNo, MatcherTable, MatcherIndex);
+
+        assert(RecNo < RecordedNodes.size() && "Invalid CheckSame");
+        FlagResultNodesMatched.push_back(RecordedNodes[RecNo].getNode());
+      }
+      continue;
+    }
       
     case OPC_CompleteMatch: {
       // The match has been completed, and any new nodes (if any) have been
@@ -844,12 +863,24 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
           ReplaceUses(ChainVal, InputChain);
         }
       }
-      // If the root node produces a flag, make sure to replace its flag
-      // result with the resultant flag.
-      if (NodeToMatch->getValueType(NodeToMatch->getNumValues()-1) ==
-            MVT::Flag)
-        ReplaceUses(SDValue(NodeToMatch, NodeToMatch->getNumValues()-1),
-                    InputFlag);
+
+      // If the result produces a flag, update any flag results in the matched
+      // pattern with the flag result.
+      if (InputFlag.getNode() != 0) {
+        // Handle the root node:
+        if (NodeToMatch->getValueType(NodeToMatch->getNumValues()-1) ==
+              MVT::Flag)
+          ReplaceUses(SDValue(NodeToMatch, NodeToMatch->getNumValues()-1),
+                      InputFlag);
+        
+        // Handle any interior nodes explicitly marked.
+        for (unsigned i = 0, e = FlagResultNodesMatched.size(); i != e; ++i) {
+          SDNode *FRN = FlagResultNodesMatched[i];
+          assert(FRN->getValueType(FRN->getNumValues()-1) == MVT::Flag &&
+                 "Doesn't have a flag result");
+          ReplaceUses(SDValue(FRN, FRN->getNumValues()-1), InputFlag);
+        }
+      }
       
       assert(NodeToMatch->use_empty() &&
              "Didn't replace all uses of the node?");
@@ -885,7 +916,9 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
     InputFlag = LastScope.InputFlag;
     if (!LastScope.HasChainNodesMatched)
       ChainNodesMatched.clear();
-    
+    if (!LastScope.HasFlagResultNodesMatched)
+      FlagResultNodesMatched.clear();
+
     MatchScopes.pop_back();
   }
 }