speculatively teach OPC_CheckValueType and OPC_EmitNode to handle
[oota-llvm.git] / include / llvm / CodeGen / DAGISelHeader.h
index 9acd406792784ca9b62d916a7bc948d1922e9bb7..f04fe34c553637e80686f8bd26d98bddd16da135 100644 (file)
@@ -406,19 +406,30 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
     case OPC_CheckOpcode:
       if (N->getOpcode() != MatcherTable[MatcherIndex++]) break;
       continue;
-    case OPC_CheckType:
-      if (N.getValueType() !=
-          (MVT::SimpleValueType)MatcherTable[MatcherIndex++]) break;
+    case OPC_CheckType: {
+      MVT::SimpleValueType VT =
+        (MVT::SimpleValueType)MatcherTable[MatcherIndex++];
+      if (N.getValueType() != VT) {
+        // Handle the case when VT is iPTR.
+        if (VT != MVT::iPTR || N.getValueType() != TLI.getPointerTy())
+          break;
+      }
       continue;
+    }
     case OPC_CheckCondCode:
       if (cast<CondCodeSDNode>(N)->get() !=
           (ISD::CondCode)MatcherTable[MatcherIndex++]) break;
       continue;
-    case OPC_CheckValueType:
-      if (cast<VTSDNode>(N)->getVT() !=
-          (MVT::SimpleValueType)MatcherTable[MatcherIndex++]) break;
+    case OPC_CheckValueType: {
+      MVT::SimpleValueType VT =
+        (MVT::SimpleValueType)MatcherTable[MatcherIndex++];
+      if (cast<VTSDNode>(N)->getVT() != VT) {
+        // Handle the case when VT is iPTR.
+        if (VT != MVT::iPTR || cast<VTSDNode>(N)->getVT() != TLI.getPointerTy())
+          break;
+      }
       continue;
-
+    }
     case OPC_CheckInteger1:
       if (CheckInteger(N, GetInt1(MatcherTable, MatcherIndex))) break;
       continue;
@@ -525,11 +536,10 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
     }
         
     case OPC_EmitRegister: {
-      unsigned RegNo = MatcherTable[MatcherIndex++];
       MVT::SimpleValueType VT =
-      (MVT::SimpleValueType)MatcherTable[MatcherIndex++];
-      SDValue Reg = CurDAG->getRegister(RegNo, VT);
-      RecordedNodes.push_back(N);
+        (MVT::SimpleValueType)MatcherTable[MatcherIndex++];
+      unsigned RegNo = MatcherTable[MatcherIndex++];
+      RecordedNodes.push_back(CurDAG->getRegister(RegNo, VT));
       continue;
     }
         
@@ -638,8 +648,14 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
       unsigned NumVTs = MatcherTable[MatcherIndex++];
       assert(NumVTs != 0 && "Invalid node result");
       SmallVector<EVT, 4> VTs;
-      for (unsigned i = 0; i != NumVTs; ++i)
-        VTs.push_back((MVT::SimpleValueType)MatcherTable[MatcherIndex++]);
+      for (unsigned i = 0; i != NumVTs; ++i) {
+        MVT::SimpleValueType VT =
+          (MVT::SimpleValueType)MatcherTable[MatcherIndex++];
+        if (VT == MVT::iPTR) VT = TLI.getPointerTy().SimpleTy;
+        VTs.push_back(VT);
+      }
+      
+      // FIXME: Use faster version for the common 'one VT' case?
       SDVTList VTList = CurDAG->getVTList(VTs.data(), VTs.size());
 
       // Get the operand list.
@@ -671,7 +687,7 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
       // If this has chain/flag inputs, add them.
       if (EmitNodeInfo & OPFL_Chain)
         Ops.push_back(InputChain);
-      if (EmitNodeInfo & OPFL_Flag)
+      if ((EmitNodeInfo & OPFL_Flag) && InputFlag.getNode() != 0)
         Ops.push_back(InputFlag);
       
       // Create the node.
@@ -715,7 +731,11 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
         unsigned ResSlot = MatcherTable[MatcherIndex++];
         assert(ResSlot < RecordedNodes.size() && "Invalid CheckSame");
         SDValue Res = RecordedNodes[ResSlot];
-        assert(NodeToMatch->getValueType(i) == Res.getValueType() &&
+        assert((NodeToMatch->getValueType(i) == Res.getValueType() ||
+                NodeToMatch->getValueType(i) == MVT::iPTR ||
+                Res.getValueType() == MVT::iPTR ||
+                NodeToMatch->getValueType(i).getSizeInBits() ==
+                    Res.getValueType().getSizeInBits()) &&
                "invalid replacement");
         ReplaceUses(SDValue(NodeToMatch, i), Res);
       }
@@ -759,6 +779,7 @@ SDNode *SelectCodeCommon(SDNode *NodeToMatch, const unsigned char *MatcherTable,
     const MatchScope &LastScope = MatchScopes.back();
     RecordedNodes.resize(LastScope.NumRecordedNodes);
     NodeStack.resize(LastScope.NodeStackSize);
+    N = NodeStack.back();
     
     if (LastScope.NumMatchedMemRefs != MatchedMemRefs.size())
       MatchedMemRefs.resize(LastScope.NumMatchedMemRefs);