From 8dfda019a02e2b8a3f3f574acac129bbfc1fdf45 Mon Sep 17 00:00:00 2001 From: Elena Demikhovsky Date: Tue, 28 Apr 2015 07:57:37 +0000 Subject: [PATCH] Masked gather and scatter: Added code for SelectionDAG. All other patches, including tests will follow. http://reviews.llvm.org/D7665 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@235970 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/CodeGen/ISDOpcodes.h | 9 +- include/llvm/CodeGen/SelectionDAG.h | 4 + include/llvm/CodeGen/SelectionDAGNodes.h | 80 +++++++++- lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 49 ++++++ .../SelectionDAG/SelectionDAGBuilder.cpp | 150 ++++++++++++++++++ .../SelectionDAG/SelectionDAGBuilder.h | 4 + 6 files changed, 294 insertions(+), 2 deletions(-) diff --git a/include/llvm/CodeGen/ISDOpcodes.h b/include/llvm/CodeGen/ISDOpcodes.h index 2d1c8cd6fdd..9b7b44bc05a 100644 --- a/include/llvm/CodeGen/ISDOpcodes.h +++ b/include/llvm/CodeGen/ISDOpcodes.h @@ -687,9 +687,16 @@ namespace ISD { ATOMIC_LOAD_UMIN, ATOMIC_LOAD_UMAX, - // Masked load and store + // Masked load and store - consecutive vector load and store operations + // with additional mask operand that prevents memory accesses to the + // masked-off lanes. MLOAD, MSTORE, + // Masked gather and scatter - load and store operations for a vector of + // random addresses with additional mask operand that prevents memory + // accesses to the masked-off lanes. + MGATHER, MSCATTER, + /// This corresponds to the llvm.lifetime.* intrinsics. The first operand /// is the chain and the second operand is the alloca pointer. LIFETIME_START, LIFETIME_END, diff --git a/include/llvm/CodeGen/SelectionDAG.h b/include/llvm/CodeGen/SelectionDAG.h index 582febda085..874842b3f4a 100644 --- a/include/llvm/CodeGen/SelectionDAG.h +++ b/include/llvm/CodeGen/SelectionDAG.h @@ -856,6 +856,10 @@ public: SDValue getMaskedStore(SDValue Chain, SDLoc dl, SDValue Val, SDValue Ptr, SDValue Mask, EVT MemVT, MachineMemOperand *MMO, bool IsTrunc); + SDValue getMaskedGather(SDVTList VTs, EVT VT, SDLoc dl, + ArrayRef Ops, MachineMemOperand *MMO); + SDValue getMaskedScatter(SDVTList VTs, EVT VT, SDLoc dl, + ArrayRef Ops, MachineMemOperand *MMO); /// Construct a node to track a Value* through the backend. SDValue getSrcValue(const Value *v); diff --git a/include/llvm/CodeGen/SelectionDAGNodes.h b/include/llvm/CodeGen/SelectionDAGNodes.h index 66f060c6b5a..148eaa92df4 100644 --- a/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1151,6 +1151,8 @@ public: N->getOpcode() == ISD::ATOMIC_STORE || N->getOpcode() == ISD::MLOAD || N->getOpcode() == ISD::MSTORE || + N->getOpcode() == ISD::MGATHER || + N->getOpcode() == ISD::MSCATTER || N->isMemIntrinsic() || N->isTargetMemoryOpcode(); } @@ -1987,6 +1989,82 @@ public: } }; +/// This is a base class is used to represent +/// MGATHER and MSCATTER nodes +/// +class MaskedGatherScatterSDNode : public MemSDNode { + // Operands + SDUse Ops[5]; +public: + friend class SelectionDAG; + MaskedGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, DebugLoc dl, + ArrayRef Operands, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) { + assert(Operands.size() == 5 && "Incompatible number of operands"); + InitOperands(Ops, Operands.data(), Operands.size()); + } + + // In the both nodes address is Op1, mask is Op2: + // MaskedGatherSDNode (Chain, src0, mask, base, index), src0 is a passthru value + // MaskedScatterSDNode (Chain, value, mask, base, index) + // Mask is a vector of i1 elements + const SDValue &getBasePtr() const { return getOperand(3); } + const SDValue &getIndex() const { return getOperand(4); } + const SDValue &getMask() const { return getOperand(2); } + const SDValue &getValue() const { return getOperand(1); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::MGATHER || + N->getOpcode() == ISD::MSCATTER; + } +}; + +/// This class is used to represent an MGATHER node +/// +class MaskedGatherSDNode : public MaskedGatherScatterSDNode { +public: + friend class SelectionDAG; + MaskedGatherSDNode(unsigned Order, DebugLoc dl, ArrayRef Operands, + SDVTList VTs, EVT MemVT, MachineMemOperand *MMO) + : MaskedGatherScatterSDNode(ISD::MGATHER, Order, dl, Operands, VTs, MemVT, + MMO) { + assert(getValue().getValueType() == getValueType(0) && + "Incompatible type of the PathThru value in MaskedGatherSDNode"); + assert(getMask().getValueType().getVectorNumElements() == + getValueType(0).getVectorNumElements() && + "Vector width mismatch between mask and data"); + assert(getMask().getValueType().getScalarType() == MVT::i1 && + "Vector width mismatch between mask and data"); + } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::MGATHER; + } +}; + +/// This class is used to represent an MSCATTER node +/// +class MaskedScatterSDNode : public MaskedGatherScatterSDNode { + +public: + friend class SelectionDAG; + MaskedScatterSDNode(unsigned Order, DebugLoc dl,ArrayRef Operands, + SDVTList VTs, EVT MemVT, MachineMemOperand *MMO) + : MaskedGatherScatterSDNode(ISD::MSCATTER, Order, dl, Operands, VTs, MemVT, + MMO) { + assert(getMask().getValueType().getVectorNumElements() == + getValue().getValueType().getVectorNumElements() && + "Vector width mismatch between mask and data"); + assert(getMask().getValueType().getScalarType() == MVT::i1 && + "Vector width mismatch between mask and data"); + } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::MSCATTER; + } +}; + /// An SDNode that represents everything that will be needed /// to construct a MachineInstr. These nodes are created during the /// instruction selection proper phase. @@ -2078,7 +2156,7 @@ template <> struct GraphTraits { }; /// The largest SDNode class. -typedef AtomicSDNode LargestSDNode; +typedef MaskedGatherScatterSDNode LargestSDNode; /// The SDNode class with the greatest alignment requirement. typedef GlobalAddressSDNode MostAlignedSDNode; diff --git a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 5b766981545..39001764717 100644 --- a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -5097,6 +5097,55 @@ SDValue SelectionDAG::getMaskedStore(SDValue Chain, SDLoc dl, SDValue Val, return SDValue(N, 0); } +SDValue +SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, SDLoc dl, + ArrayRef Ops, + MachineMemOperand *MMO) { + + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(encodeMemSDNodeFlags(ISD::NON_EXTLOAD, ISD::UNINDEXED, + MMO->isVolatile(), + MMO->isNonTemporal(), + MMO->isInvariant())); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = CSEMap.FindNodeOrInsertPos(ID, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + MaskedGatherSDNode *N = + new (NodeAllocator) MaskedGatherSDNode(dl.getIROrder(), dl.getDebugLoc(), + Ops, VTs, VT, MMO); + CSEMap.InsertNode(N, IP); + InsertNode(N); + return SDValue(N, 0); +} + +SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, SDLoc dl, + ArrayRef Ops, + MachineMemOperand *MMO) { + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(encodeMemSDNodeFlags(false, ISD::UNINDEXED, MMO->isVolatile(), + MMO->isNonTemporal(), + MMO->isInvariant())); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = CSEMap.FindNodeOrInsertPos(ID, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + SDNode *N = + new (NodeAllocator) MaskedScatterSDNode(dl.getIROrder(), dl.getDebugLoc(), + Ops, VTs, VT, MMO); + CSEMap.InsertNode(N, IP); + InsertNode(N); + return SDValue(N, 0); +} + SDValue SelectionDAG::getVAArg(EVT VT, SDLoc dl, SDValue Chain, SDValue Ptr, SDValue SV, diff --git a/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 71b221a02a1..cc97fbe2604 100644 --- a/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -1059,6 +1059,12 @@ SDValue SelectionDAGBuilder::getValue(const Value *V) { return Val; } +// Return true if SDValue exists for the given Value +bool SelectionDAGBuilder::findValue(const Value *V) const { + return (NodeMap.find(V) != NodeMap.end()) || + (FuncInfo.ValueMap.find(V) != FuncInfo.ValueMap.end()); +} + /// getNonRegisterValue - Return an SDValue for the given Value, but /// don't look in FuncInfo.ValueMap for a virtual register. SDValue SelectionDAGBuilder::getNonRegisterValue(const Value *V) { @@ -3026,6 +3032,92 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I) { setValue(&I, StoreNode); } +// Gather/scatter receive a vector of pointers. +// This vector of pointers may be represented as a base pointer + vector of +// indices, it depends on GEP and instruction preceeding GEP +// that calculates indices +static bool getUniformBase(Value *& Ptr, SDValue& Base, SDValue& Index, + SelectionDAGBuilder* SDB) { + + assert (Ptr->getType()->isVectorTy() && "Uexpected pointer type"); + GetElementPtrInst *Gep = dyn_cast(Ptr); + if (!Gep || Gep->getNumOperands() > 2) + return false; + ShuffleVectorInst *ShuffleInst = + dyn_cast(Gep->getPointerOperand()); + if (!ShuffleInst || !ShuffleInst->getMask()->isNullValue() || + cast(ShuffleInst->getOperand(0))->getOpcode() != + Instruction::InsertElement) + return false; + + Ptr = cast(ShuffleInst->getOperand(0))->getOperand(1); + + SelectionDAG& DAG = SDB->DAG; + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + // Check is the Ptr is inside current basic block + // If not, look for the shuffle instruction + if (SDB->findValue(Ptr)) + Base = SDB->getValue(Ptr); + else if (SDB->findValue(ShuffleInst)) { + SDValue ShuffleNode = SDB->getValue(ShuffleInst); + Base = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(ShuffleNode), + ShuffleNode.getValueType().getScalarType(), ShuffleNode, + DAG.getConstant(0, TLI.getVectorIdxTy())); + SDB->setValue(Ptr, Base); + } + else + return false; + + Value *IndexVal = Gep->getOperand(1); + if (SDB->findValue(IndexVal)) { + Index = SDB->getValue(IndexVal); + + if (SExtInst* Sext = dyn_cast(IndexVal)) { + IndexVal = Sext->getOperand(0); + if (SDB->findValue(IndexVal)) + Index = SDB->getValue(IndexVal); + } + return true; + } + return false; +} + +void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + // llvm.masked.scatter.*(Src0, Ptrs, alignemt, Mask) + Value *Ptr = I.getArgOperand(1); + SDValue Src0 = getValue(I.getArgOperand(0)); + SDValue Mask = getValue(I.getArgOperand(3)); + EVT VT = Src0.getValueType(); + unsigned Alignment = (cast(I.getArgOperand(2)))->getZExtValue(); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + + SDValue Base; + SDValue Index; + Value *BasePtr = Ptr; + bool UniformBase = getUniformBase(BasePtr, Base, Index, this); + + Value *MemOpBasePtr = UniformBase ? BasePtr : NULL; + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(MemOpBasePtr), + MachineMemOperand::MOStore, VT.getStoreSize(), + Alignment, AAInfo); + if (!UniformBase) { + Base = DAG.getTargetConstant(0, TLI.getPointerTy()); + Index = getValue(Ptr); + } + SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index }; + SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl, Ops, MMO); + DAG.setRoot(Scatter); + setValue(&I, Scatter); +} + void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) { SDLoc sdl = getCurSDLoc(); @@ -3067,6 +3159,60 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) { setValue(&I, Load); } +void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) + Value *Ptr = I.getArgOperand(0); + SDValue Src0 = getValue(I.getArgOperand(3)); + SDValue Mask = getValue(I.getArgOperand(2)); + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + EVT VT = TLI.getValueType(I.getType()); + unsigned Alignment = (cast(I.getArgOperand(1)))->getZExtValue(); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); + + SDValue Root = DAG.getRoot(); + SDValue Base; + SDValue Index; + Value *BasePtr = Ptr; + bool UniformBase = getUniformBase(BasePtr, Base, Index, this); + bool ConstantMemory = false; + if (UniformBase && AA->pointsToConstantMemory( + AliasAnalysis::Location(BasePtr, + AA->getTypeStoreSize(I.getType()), + AAInfo))) { + // Do not serialize (non-volatile) loads of constant memory with anything. + Root = DAG.getEntryNode(); + ConstantMemory = true; + } + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(UniformBase ? BasePtr : NULL), + MachineMemOperand::MOLoad, VT.getStoreSize(), + Alignment, AAInfo, Ranges); + + if (!UniformBase) { + Base = DAG.getTargetConstant(0, TLI.getPointerTy()); + Index = getValue(Ptr); + } + + SDValue Ops[] = { Root, Src0, Mask, Base, Index }; + SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl, + Ops, MMO); + + SDValue OutChain = Gather.getValue(1); + if (!ConstantMemory) + PendingLoads.push_back(OutChain); + setValue(&I, Gather); +} + void SelectionDAGBuilder::visitAtomicCmpXchg(const AtomicCmpXchgInst &I) { SDLoc dl = getCurSDLoc(); AtomicOrdering SuccessOrder = I.getSuccessOrdering(); @@ -4216,9 +4362,13 @@ SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, unsigned Intrinsic) { return nullptr; } + case Intrinsic::masked_gather: + visitMaskedGather(I); case Intrinsic::masked_load: visitMaskedLoad(I); return nullptr; + case Intrinsic::masked_scatter: + visitMaskedScatter(I); case Intrinsic::masked_store: visitMaskedStore(I); return nullptr; diff --git a/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h index ba6d9743d74..3e232e78c57 100644 --- a/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h +++ b/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h @@ -667,6 +667,8 @@ public: // generate the debug data structures now that we've seen its definition. void resolveDanglingDebugInfo(const Value *V, SDValue Val); SDValue getValue(const Value *V); + bool findValue(const Value *V) const; + SDValue getNonRegisterValue(const Value *V); SDValue getValueImpl(const Value *V); @@ -814,6 +816,8 @@ private: void visitStore(const StoreInst &I); void visitMaskedLoad(const CallInst &I); void visitMaskedStore(const CallInst &I); + void visitMaskedGather(const CallInst &I); + void visitMaskedScatter(const CallInst &I); void visitAtomicCmpXchg(const AtomicCmpXchgInst &I); void visitAtomicRMW(const AtomicRMWInst &I); void visitFence(const FenceInst &I); -- 2.34.1