The new isel was not properly handling patterns that covered
[oota-llvm.git] / include / llvm / CodeGen / DAGISelHeader.h
index 99263b0773df894ee7eca805c443081db617ff22..7a6c1962f74bc59dc0d15f221d9509f13f594acc 100644 (file)
@@ -200,8 +200,26 @@ GetInt8(const unsigned char *MatcherTable, unsigned &Idx) {
   return Val;
 }
 
+/// GetVBR - decode a vbr encoding whose top bit is set.
+ALWAYS_INLINE static unsigned
+GetVBR(unsigned Val, const unsigned char *MatcherTable, unsigned &Idx) {
+  assert(Val >= 128 && "Not a VBR");
+  Val &= 127;  // Remove first vbr bit.
+  
+  unsigned Shift = 7;
+  unsigned NextBits;
+  do {
+    NextBits = GetInt1(MatcherTable, Idx);
+    Val |= (NextBits&127) << Shift;
+    Shift += 7;
+  } while (NextBits & 128);
+  
+  return Val;
+}
+
+
 enum BuiltinOpcodes {
-  OPC_Push,
+  OPC_Push, OPC_Push2,
   OPC_RecordNode,
   OPC_RecordMemRef,
   OPC_CaptureFlagInput,
@@ -211,6 +229,7 @@ enum BuiltinOpcodes {
   OPC_CheckPatternPredicate,
   OPC_CheckPredicate,
   OPC_CheckOpcode,
+  OPC_CheckMultiOpcode,
   OPC_CheckType,
   OPC_CheckInteger1, OPC_CheckInteger2, OPC_CheckInteger4, OPC_CheckInteger8,
   OPC_CheckCondCode,
@@ -228,6 +247,7 @@ enum BuiltinOpcodes {
   OPC_EmitCopyToReg,
   OPC_EmitNodeXForm,
   OPC_EmitNode,
+  OPC_MarkFlagResults,
   OPC_CompleteMatch
 };
 
@@ -271,7 +291,7 @@ struct MatchScope {
   SDValue InputChain, InputFlag;
 
   /// HasChainNodesMatched - True if the ChainNodesMatched list is non-empty.
-  bool HasChainNodesMatched;
+  bool HasChainNodesMatched, HasFlagResultNodesMatched;
 };
 
 SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
@@ -335,6 +355,11 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
   // which ones they are.  The result is captured into this list so that we can
   // update the chain results when the pattern is complete.
   SmallVector<SDNode*, 3> ChainNodesMatched;
+  SmallVector<SDNode*, 3> FlagResultNodesMatched;
+  
+  DEBUG(errs() << "ISEL: Starting pattern match on root node: ";
+        NodeToMatch->dump(CurDAG);
+        errs() << '\n');
   
   // Interpreter starts at opcode #0.
   unsigned MatcherIndex = 0;
@@ -351,6 +376,21 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
       NewEntry.InputChain = InputChain;
       NewEntry.InputFlag = InputFlag;
       NewEntry.HasChainNodesMatched = !ChainNodesMatched.empty();
+      NewEntry.HasFlagResultNodesMatched = !FlagResultNodesMatched.empty();
+      MatchScopes.push_back(NewEntry);
+      continue;
+    }
+    case OPC_Push2: {
+      unsigned NumToSkip = GetInt2(MatcherTable, MatcherIndex);
+      MatchScope NewEntry;
+      NewEntry.FailIndex = MatcherIndex+NumToSkip;
+      NewEntry.NodeStackSize = NodeStack.size();
+      NewEntry.NumRecordedNodes = RecordedNodes.size();
+      NewEntry.NumMatchedMemRefs = MatchedMemRefs.size();
+      NewEntry.InputChain = InputChain;
+      NewEntry.InputFlag = InputFlag;
+      NewEntry.HasChainNodesMatched = !ChainNodesMatched.empty();
+      NewEntry.HasFlagResultNodesMatched = !FlagResultNodesMatched.empty();
       MatchScopes.push_back(NewEntry);
       continue;
     }
@@ -406,6 +446,16 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
     case OPC_CheckOpcode:
       if (N->getOpcode() != MatcherTable[MatcherIndex++]) break;
       continue;
+        
+    case OPC_CheckMultiOpcode: {
+      unsigned NumOps = MatcherTable[MatcherIndex++];
+      bool OpcodeEquals = false;
+      for (unsigned i = 0; i != NumOps; ++i)
+        OpcodeEquals |= N->getOpcode() == MatcherTable[MatcherIndex++];
+      if (!OpcodeEquals) break;
+      continue;
+    }
+        
     case OPC_CheckType: {
       MVT::SimpleValueType VT =
         (MVT::SimpleValueType)MatcherTable[MatcherIndex++];
@@ -572,25 +622,46 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
       // the old nodes.
       unsigned NumChains = MatcherTable[MatcherIndex++];
       assert(NumChains != 0 && "Can't TF zero chains");
+
+      assert(ChainNodesMatched.empty() &&
+             "Should only have one EmitMergeInputChains per match");
+
+      // Handle the first chain.
+      unsigned RecNo = MatcherTable[MatcherIndex++];
+      assert(RecNo < RecordedNodes.size() && "Invalid CheckSame");
+      ChainNodesMatched.push_back(RecordedNodes[RecNo].getNode());
+      
+      // If the chained node is not the root, we can't fold it if it has
+      // multiple uses.
+      // FIXME: What if other value results of the node have uses not matched by
+      // this pattern?
+      if (ChainNodesMatched.back() != NodeToMatch &&
+          !RecordedNodes[RecNo].hasOneUse()) {
+        ChainNodesMatched.clear();
+        break;
+      }
       
       // The common case here is that we have exactly one chain, which is really
       // cheap to handle, just do it.
       if (NumChains == 1) {
-        unsigned RecNo = MatcherTable[MatcherIndex++];
-        assert(RecNo < RecordedNodes.size() && "Invalid CheckSame");
-        ChainNodesMatched.push_back(RecordedNodes[RecNo].getNode());
         InputChain = RecordedNodes[RecNo].getOperand(0);
         assert(InputChain.getValueType() == MVT::Other && "Not a chain");
         continue;
       }
       
       // Read all of the chained nodes.
-      assert(ChainNodesMatched.empty() &&
-             "Should only have one EmitMergeInputChains per match");
-      for (unsigned i = 0; i != NumChains; ++i) {
-        unsigned RecNo = MatcherTable[MatcherIndex++];
+      for (unsigned i = 1; i != NumChains; ++i) {
+        RecNo = MatcherTable[MatcherIndex++];
         assert(RecNo < RecordedNodes.size() && "Invalid CheckSame");
         ChainNodesMatched.push_back(RecordedNodes[RecNo].getNode());
+        
+        // FIXME: What if other value results of the node have uses not matched by
+        // this pattern?
+        if (ChainNodesMatched.back() != NodeToMatch &&
+            !RecordedNodes[RecNo].hasOneUse()) {
+          ChainNodesMatched.clear();
+          break;
+        }
       }
 
       // Walk all the chained nodes, adding the input chains if they are not in
@@ -663,7 +734,10 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
       SmallVector<SDValue, 8> Ops;
       for (unsigned i = 0; i != NumOps; ++i) {
         unsigned RecNo = MatcherTable[MatcherIndex++];
-        assert(RecNo < RecordedNodes.size() && "Invalid CheckSame");
+        if (RecNo & 128)
+          RecNo = GetVBR(RecNo, MatcherTable, MatcherIndex);
+        
+        assert(RecNo < RecordedNodes.size() && "Invalid EmitNode");
         Ops.push_back(RecordedNodes[RecNo]);
       }
       
@@ -695,7 +769,11 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
                                                   NodeToMatch->getDebugLoc(),
                                                   VTList,
                                                   Ops.data(), Ops.size());
-      RecordedNodes.push_back(SDValue(Res, 0));
+      // Add all the non-flag/non-chain results to the RecordedNodes list.
+      for (unsigned i = 0, e = VTs.size(); i != e; ++i) {
+        if (VTs[i] == MVT::Other || VTs[i] == MVT::Flag) break;
+        RecordedNodes.push_back(SDValue(Res, i));
+      }
       
       // If the node had chain/flag results, update our notion of the current
       // chain and flag.
@@ -718,6 +796,23 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
         std::copy(MatchedMemRefs.begin(), MatchedMemRefs.end(), MemRefs);
         Res->setMemRefs(MemRefs, MemRefs + MatchedMemRefs.size());
       }
+      
+      DEBUG(errs() << "  Created node: "; Res->dump(CurDAG); errs() << "\n");
+      continue;
+    }
+        
+    case OPC_MarkFlagResults: {
+      unsigned NumNodes = MatcherTable[MatcherIndex++];
+      
+      // Read and remember all the flag-result nodes.
+      for (unsigned i = 0; i != NumNodes; ++i) {
+        unsigned RecNo = MatcherTable[MatcherIndex++];
+        if (RecNo & 128)
+          RecNo = GetVBR(RecNo, MatcherTable, MatcherIndex);
+
+        assert(RecNo < RecordedNodes.size() && "Invalid CheckSame");
+        FlagResultNodesMatched.push_back(RecordedNodes[RecNo].getNode());
+      }
       continue;
     }
       
@@ -729,8 +824,20 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
 
       for (unsigned i = 0; i != NumResults; ++i) {
         unsigned ResSlot = MatcherTable[MatcherIndex++];
+        if (ResSlot & 128)
+          ResSlot = GetVBR(ResSlot, MatcherTable, MatcherIndex);
+        
         assert(ResSlot < RecordedNodes.size() && "Invalid CheckSame");
         SDValue Res = RecordedNodes[ResSlot];
+        
+        // FIXME2: Eliminate this horrible hack by fixing the 'Gen' program
+        // after (parallel) on input patterns are removed.  This would also
+        // allow us to stop encoding #results in OPC_CompleteMatch's table
+        // entry.
+        if (NodeToMatch->getNumValues() <= i ||
+            NodeToMatch->getValueType(i) == MVT::Other ||
+            NodeToMatch->getValueType(i) == MVT::Flag)
+          break;
         assert((NodeToMatch->getValueType(i) == Res.getValueType() ||
                 NodeToMatch->getValueType(i) == MVT::iPTR ||
                 Res.getValueType() == MVT::iPTR ||
@@ -756,12 +863,29 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
           ReplaceUses(ChainVal, InputChain);
         }
       }
-      // If the root node produces a flag, make sure to replace its flag
-      // result with the resultant flag.
-      if (NodeToMatch->getValueType(NodeToMatch->getNumValues()-1) ==
-            MVT::Flag)
-        ReplaceUses(SDValue(NodeToMatch, NodeToMatch->getNumValues()-1),
-                    InputFlag);
+
+      // If the result produces a flag, update any flag results in the matched
+      // pattern with the flag result.
+      if (InputFlag.getNode() != 0) {
+        // Handle the root node:
+        if (NodeToMatch->getValueType(NodeToMatch->getNumValues()-1) ==
+              MVT::Flag)
+          ReplaceUses(SDValue(NodeToMatch, NodeToMatch->getNumValues()-1),
+                      InputFlag);
+        
+        // Handle any interior nodes explicitly marked.
+        for (unsigned i = 0, e = FlagResultNodesMatched.size(); i != e; ++i) {
+          SDNode *FRN = FlagResultNodesMatched[i];
+          assert(FRN->getValueType(FRN->getNumValues()-1) == MVT::Flag &&
+                 "Doesn't have a flag result");
+          ReplaceUses(SDValue(FRN, FRN->getNumValues()-1), InputFlag);
+        }
+      }
+      
+      assert(NodeToMatch->use_empty() &&
+             "Didn't replace all uses of the node?");
+      
+      DEBUG(errs() << "ISEL: Match complete!\n");
       
       // FIXME: We just return here, which interacts correctly with SelectRoot
       // above.  We should fix this to not return an SDNode* anymore.
@@ -780,6 +904,9 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
     RecordedNodes.resize(LastScope.NumRecordedNodes);
     NodeStack.resize(LastScope.NodeStackSize);
     N = NodeStack.back();
+
+    DEBUG(errs() << "  Match failed at index " << MatcherIndex
+                 << " continuing at " << LastScope.FailIndex << "\n");
     
     if (LastScope.NumMatchedMemRefs != MatchedMemRefs.size())
       MatchedMemRefs.resize(LastScope.NumMatchedMemRefs);
@@ -789,7 +916,9 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
     InputFlag = LastScope.InputFlag;
     if (!LastScope.HasChainNodesMatched)
       ChainNodesMatched.clear();
-    
+    if (!LastScope.HasFlagResultNodesMatched)
+      FlagResultNodesMatched.clear();
+
     MatchScopes.pop_back();
   }
 }