+/// This base class is used to represent MLOAD and MSTORE nodes
+class MaskedLoadStoreSDNode : public MemSDNode {
+ // Operands
+ SDUse Ops[4];
+public:
+ friend class SelectionDAG;
+ MaskedLoadStoreSDNode(ISD::NodeType NodeTy, unsigned Order, DebugLoc dl,
+ SDValue *Operands, unsigned numOperands,
+ SDVTList VTs, EVT MemVT, MachineMemOperand *MMO)
+ : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {
+ InitOperands(Ops, Operands, numOperands);
+ }
+
+ // In the both nodes address is Op1, mask is Op2:
+ // MaskedLoadSDNode (Chain, ptr, mask, src0), src0 is a passthru value
+ // MaskedStoreSDNode (Chain, ptr, mask, data)
+ // Mask is a vector of i1 elements
+ const SDValue &getBasePtr() const { return getOperand(1); }
+ const SDValue &getMask() const { return getOperand(2); }
+
+ static bool classof(const SDNode *N) {
+ return N->getOpcode() == ISD::MLOAD ||
+ N->getOpcode() == ISD::MSTORE;
+ }
+};
+
+/// This class is used to represent an MLOAD node
+class MaskedLoadSDNode : public MaskedLoadStoreSDNode {
+public:
+ friend class SelectionDAG;
+ MaskedLoadSDNode(unsigned Order, DebugLoc dl, SDValue *Operands,
+ unsigned numOperands, SDVTList VTs, ISD::LoadExtType ETy,
+ EVT MemVT, MachineMemOperand *MMO)
+ : MaskedLoadStoreSDNode(ISD::MLOAD, Order, dl, Operands, numOperands,
+ VTs, MemVT, MMO) {
+ SubclassData |= (unsigned short)ETy;
+ }
+
+ ISD::LoadExtType getExtensionType() const {
+ return ISD::LoadExtType(SubclassData & 3);
+ }
+ const SDValue &getSrc0() const { return getOperand(3); }
+ static bool classof(const SDNode *N) {
+ return N->getOpcode() == ISD::MLOAD;
+ }
+};
+
+/// This class is used to represent an MSTORE node
+class MaskedStoreSDNode : public MaskedLoadStoreSDNode {
+
+public:
+ friend class SelectionDAG;
+ MaskedStoreSDNode(unsigned Order, DebugLoc dl, SDValue *Operands,
+ unsigned numOperands, SDVTList VTs, bool isTrunc, EVT MemVT,
+ MachineMemOperand *MMO)
+ : MaskedLoadStoreSDNode(ISD::MSTORE, Order, dl, Operands, numOperands,
+ VTs, MemVT, MMO) {
+ SubclassData |= (unsigned short)isTrunc;
+ }
+ /// Return true if the op does a truncation before store.
+ /// For integers this is the same as doing a TRUNCATE and storing the result.
+ /// For floats, it is the same as doing an FP_ROUND and storing the result.
+ bool isTruncatingStore() const { return SubclassData & 1; }
+
+ const SDValue &getValue() const { return getOperand(3); }
+
+ static bool classof(const SDNode *N) {
+ return N->getOpcode() == ISD::MSTORE;
+ }
+};
+
+/// This is a base class used to represent
+/// MGATHER and MSCATTER nodes
+///
+class MaskedGatherScatterSDNode : public MemSDNode {
+ // Operands
+ SDUse Ops[5];
+public:
+ friend class SelectionDAG;
+ MaskedGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, DebugLoc dl,
+ ArrayRef<SDValue> Operands, SDVTList VTs, EVT MemVT,
+ MachineMemOperand *MMO)
+ : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {
+ assert(Operands.size() == 5 && "Incompatible number of operands");
+ InitOperands(Ops, Operands.data(), Operands.size());
+ }
+
+ // In the both nodes address is Op1, mask is Op2:
+ // MaskedGatherSDNode (Chain, src0, mask, base, index), src0 is a passthru value
+ // MaskedScatterSDNode (Chain, value, mask, base, index)
+ // Mask is a vector of i1 elements
+ const SDValue &getBasePtr() const { return getOperand(3); }
+ const SDValue &getIndex() const { return getOperand(4); }
+ const SDValue &getMask() const { return getOperand(2); }
+ const SDValue &getValue() const { return getOperand(1); }
+
+ static bool classof(const SDNode *N) {
+ return N->getOpcode() == ISD::MGATHER ||
+ N->getOpcode() == ISD::MSCATTER;
+ }
+};
+
+/// This class is used to represent an MGATHER node
+///
+class MaskedGatherSDNode : public MaskedGatherScatterSDNode {
+public:
+ friend class SelectionDAG;
+ MaskedGatherSDNode(unsigned Order, DebugLoc dl, ArrayRef<SDValue> Operands,
+ SDVTList VTs, EVT MemVT, MachineMemOperand *MMO)
+ : MaskedGatherScatterSDNode(ISD::MGATHER, Order, dl, Operands, VTs, MemVT,
+ MMO) {
+ assert(getValue().getValueType() == getValueType(0) &&
+ "Incompatible type of the PathThru value in MaskedGatherSDNode");
+ assert(getMask().getValueType().getVectorNumElements() ==
+ getValueType(0).getVectorNumElements() &&
+ "Vector width mismatch between mask and data");
+ assert(getMask().getValueType().getScalarType() == MVT::i1 &&
+ "Vector width mismatch between mask and data");
+ }
+
+ static bool classof(const SDNode *N) {
+ return N->getOpcode() == ISD::MGATHER;
+ }
+};
+
+/// This class is used to represent an MSCATTER node
+///
+class MaskedScatterSDNode : public MaskedGatherScatterSDNode {
+
+public:
+ friend class SelectionDAG;
+ MaskedScatterSDNode(unsigned Order, DebugLoc dl,ArrayRef<SDValue> Operands,
+ SDVTList VTs, EVT MemVT, MachineMemOperand *MMO)
+ : MaskedGatherScatterSDNode(ISD::MSCATTER, Order, dl, Operands, VTs, MemVT,
+ MMO) {
+ assert(getMask().getValueType().getVectorNumElements() ==
+ getValue().getValueType().getVectorNumElements() &&
+ "Vector width mismatch between mask and data");
+ assert(getMask().getValueType().getScalarType() == MVT::i1 &&
+ "Vector width mismatch between mask and data");
+ }
+
+ static bool classof(const SDNode *N) {
+ return N->getOpcode() == ISD::MSCATTER;
+ }
+};
+
+/// An SDNode that represents everything that will be needed