Masked gather and scatter: Added code for SelectionDAG.
[oota-llvm.git] / include / llvm / CodeGen / SelectionDAGNodes.h
index 66f060c6b5a81030c9d83103619e2f9285e4d906..148eaa92df4da936a69b676fd1d87d6379748cd8 100644 (file)
@@ -1151,6 +1151,8 @@ public:
            N->getOpcode() == ISD::ATOMIC_STORE        ||
            N->getOpcode() == ISD::MLOAD               ||
            N->getOpcode() == ISD::MSTORE              ||
+           N->getOpcode() == ISD::MGATHER             ||
+           N->getOpcode() == ISD::MSCATTER            ||
            N->isMemIntrinsic()                        ||
            N->isTargetMemoryOpcode();
   }
@@ -1987,6 +1989,82 @@ public:
   }
 };
 
+/// This is a base class is used to represent
+/// MGATHER and MSCATTER nodes
+///
+class MaskedGatherScatterSDNode : public MemSDNode {
+  // Operands
+  SDUse Ops[5];
+public:
+  friend class SelectionDAG;
+  MaskedGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, DebugLoc dl,
+                            ArrayRef<SDValue> Operands, SDVTList VTs, EVT MemVT,
+                            MachineMemOperand *MMO)
+    : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {
+    assert(Operands.size() == 5 && "Incompatible number of operands");
+    InitOperands(Ops, Operands.data(), Operands.size());
+  }
+
+  // In the both nodes address is Op1, mask is Op2:
+  // MaskedGatherSDNode  (Chain, src0, mask, base, index), src0 is a passthru value
+  // MaskedScatterSDNode (Chain, value, mask, base, index)
+  // Mask is a vector of i1 elements
+  const SDValue &getBasePtr() const { return getOperand(3); }
+  const SDValue &getIndex()   const { return getOperand(4); }
+  const SDValue &getMask()    const { return getOperand(2); }
+  const SDValue &getValue()   const { return getOperand(1); }
+
+  static bool classof(const SDNode *N) {
+    return N->getOpcode() == ISD::MGATHER ||
+           N->getOpcode() == ISD::MSCATTER;
+  }
+};
+
+/// This class is used to represent an MGATHER node
+///
+class MaskedGatherSDNode : public MaskedGatherScatterSDNode {
+public:
+  friend class SelectionDAG;
+  MaskedGatherSDNode(unsigned Order, DebugLoc dl, ArrayRef<SDValue> Operands, 
+                     SDVTList VTs, EVT MemVT, MachineMemOperand *MMO)
+    : MaskedGatherScatterSDNode(ISD::MGATHER, Order, dl, Operands, VTs, MemVT,
+                                MMO) {
+    assert(getValue().getValueType() == getValueType(0) &&
+           "Incompatible type of the PathThru value in MaskedGatherSDNode");
+    assert(getMask().getValueType().getVectorNumElements() == 
+           getValueType(0).getVectorNumElements() && 
+           "Vector width mismatch between mask and data");
+    assert(getMask().getValueType().getScalarType() == MVT::i1 && 
+           "Vector width mismatch between mask and data");
+  }
+
+  static bool classof(const SDNode *N) {
+    return N->getOpcode() == ISD::MGATHER;
+  }
+};
+
+/// This class is used to represent an MSCATTER node
+///
+class MaskedScatterSDNode : public MaskedGatherScatterSDNode {
+
+public:
+  friend class SelectionDAG;
+  MaskedScatterSDNode(unsigned Order, DebugLoc dl,ArrayRef<SDValue> Operands,
+                      SDVTList VTs, EVT MemVT, MachineMemOperand *MMO)
+    : MaskedGatherScatterSDNode(ISD::MSCATTER, Order, dl, Operands, VTs, MemVT,
+                                MMO) {
+    assert(getMask().getValueType().getVectorNumElements() == 
+           getValue().getValueType().getVectorNumElements() && 
+           "Vector width mismatch between mask and data");
+    assert(getMask().getValueType().getScalarType() == MVT::i1 && 
+           "Vector width mismatch between mask and data");
+  }
+
+  static bool classof(const SDNode *N) {
+    return N->getOpcode() == ISD::MSCATTER;
+  }
+};
+
 /// An SDNode that represents everything that will be needed
 /// to construct a MachineInstr. These nodes are created during the
 /// instruction selection proper phase.
@@ -2078,7 +2156,7 @@ template <> struct GraphTraits<SDNode*> {
 };
 
 /// The largest SDNode class.
-typedef AtomicSDNode LargestSDNode;
+typedef MaskedGatherScatterSDNode LargestSDNode;
 
 /// The SDNode class with the greatest alignment requirement.
 typedef GlobalAddressSDNode MostAlignedSDNode;