Masked gather and scatter: Added code for SelectionDAG.
[oota-llvm.git] / lib / CodeGen / SelectionDAG / SelectionDAGBuilder.cpp
index 71b221a02a17bcc87f2663fac8e57c18d8165ae2..cc97fbe2604b526f7587ca8e4f9e47c56aa02db4 100644 (file)
@@ -1059,6 +1059,12 @@ SDValue SelectionDAGBuilder::getValue(const Value *V) {
   return Val;
 }
 
+// Return true if SDValue exists for the given Value
+bool SelectionDAGBuilder::findValue(const Value *V) const {
+  return (NodeMap.find(V) != NodeMap.end()) ||
+    (FuncInfo.ValueMap.find(V) != FuncInfo.ValueMap.end());
+}
+
 /// getNonRegisterValue - Return an SDValue for the given Value, but
 /// don't look in FuncInfo.ValueMap for a virtual register.
 SDValue SelectionDAGBuilder::getNonRegisterValue(const Value *V) {
@@ -3026,6 +3032,92 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I) {
   setValue(&I, StoreNode);
 }
 
+// Gather/scatter receive a vector of pointers.
+// This vector of pointers may be represented as a base pointer + vector of 
+// indices, it depends on GEP and instruction preceeding GEP
+// that calculates indices
+static bool getUniformBase(Value *& Ptr, SDValue& Base, SDValue& Index,
+                           SelectionDAGBuilder* SDB) {
+
+  assert (Ptr->getType()->isVectorTy() && "Uexpected pointer type");
+  GetElementPtrInst *Gep = dyn_cast<GetElementPtrInst>(Ptr);
+  if (!Gep || Gep->getNumOperands() > 2)
+    return false;
+  ShuffleVectorInst *ShuffleInst = 
+    dyn_cast<ShuffleVectorInst>(Gep->getPointerOperand());
+  if (!ShuffleInst || !ShuffleInst->getMask()->isNullValue() ||
+      cast<Instruction>(ShuffleInst->getOperand(0))->getOpcode() !=
+      Instruction::InsertElement)
+    return false;
+
+  Ptr = cast<InsertElementInst>(ShuffleInst->getOperand(0))->getOperand(1);
+
+  SelectionDAG& DAG = SDB->DAG;
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  // Check is the Ptr is inside current basic block
+  // If not, look for the shuffle instruction
+  if (SDB->findValue(Ptr))
+    Base = SDB->getValue(Ptr);
+  else if (SDB->findValue(ShuffleInst)) {
+    SDValue ShuffleNode = SDB->getValue(ShuffleInst);
+    Base = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(ShuffleNode),
+                       ShuffleNode.getValueType().getScalarType(), ShuffleNode,
+                       DAG.getConstant(0, TLI.getVectorIdxTy()));
+    SDB->setValue(Ptr, Base);
+  }
+  else
+    return false;
+
+  Value *IndexVal = Gep->getOperand(1);
+  if (SDB->findValue(IndexVal)) {
+    Index = SDB->getValue(IndexVal);
+
+    if (SExtInst* Sext = dyn_cast<SExtInst>(IndexVal)) {
+      IndexVal = Sext->getOperand(0);
+      if (SDB->findValue(IndexVal))
+        Index = SDB->getValue(IndexVal);
+    }
+    return true;
+  }
+  return false;
+}
+
+void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
+  SDLoc sdl = getCurSDLoc();
+
+  // llvm.masked.scatter.*(Src0, Ptrs, alignemt, Mask)
+  Value  *Ptr = I.getArgOperand(1);
+  SDValue Src0 = getValue(I.getArgOperand(0));
+  SDValue Mask = getValue(I.getArgOperand(3));
+  EVT VT = Src0.getValueType();
+  unsigned Alignment = (cast<ConstantInt>(I.getArgOperand(2)))->getZExtValue();
+  if (!Alignment)
+    Alignment = DAG.getEVTAlignment(VT);
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+
+  AAMDNodes AAInfo;
+  I.getAAMetadata(AAInfo);
+
+  SDValue Base;
+  SDValue Index;
+  Value *BasePtr = Ptr;
+  bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
+
+  Value *MemOpBasePtr = UniformBase ? BasePtr : NULL;
+  MachineMemOperand *MMO = DAG.getMachineFunction().
+    getMachineMemOperand(MachinePointerInfo(MemOpBasePtr),
+                         MachineMemOperand::MOStore,  VT.getStoreSize(),
+                         Alignment, AAInfo);
+  if (!UniformBase) {
+    Base = DAG.getTargetConstant(0, TLI.getPointerTy());
+    Index = getValue(Ptr);
+  }
+  SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index };
+  SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl, Ops, MMO);
+  DAG.setRoot(Scatter);
+  setValue(&I, Scatter);
+}
+
 void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) {
   SDLoc sdl = getCurSDLoc();
 
@@ -3067,6 +3159,60 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) {
   setValue(&I, Load);
 }
 
+void SelectionDAGBuilder::visitMaskedGather(const CallInst &I) {
+  SDLoc sdl = getCurSDLoc();
+
+  // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
+  Value  *Ptr = I.getArgOperand(0);
+  SDValue Src0 = getValue(I.getArgOperand(3));
+  SDValue Mask = getValue(I.getArgOperand(2));
+
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  EVT VT = TLI.getValueType(I.getType());
+  unsigned Alignment = (cast<ConstantInt>(I.getArgOperand(1)))->getZExtValue();
+  if (!Alignment)
+    Alignment = DAG.getEVTAlignment(VT);
+
+  AAMDNodes AAInfo;
+  I.getAAMetadata(AAInfo);
+  const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range);
+
+  SDValue Root = DAG.getRoot();
+  SDValue Base;
+  SDValue Index;
+  Value *BasePtr = Ptr;
+  bool UniformBase = getUniformBase(BasePtr, Base, Index, this);
+  bool ConstantMemory = false;
+  if (UniformBase && AA->pointsToConstantMemory(
+      AliasAnalysis::Location(BasePtr,
+                                   AA->getTypeStoreSize(I.getType()),
+                              AAInfo))) {
+    // Do not serialize (non-volatile) loads of constant memory with anything.
+    Root = DAG.getEntryNode();
+    ConstantMemory = true;
+  }
+
+  MachineMemOperand *MMO =
+    DAG.getMachineFunction().
+    getMachineMemOperand(MachinePointerInfo(UniformBase ? BasePtr : NULL),
+                          MachineMemOperand::MOLoad,  VT.getStoreSize(),
+                          Alignment, AAInfo, Ranges);
+
+  if (!UniformBase) {
+    Base = DAG.getTargetConstant(0, TLI.getPointerTy());
+    Index = getValue(Ptr);
+  }
+
+  SDValue Ops[] = { Root, Src0, Mask, Base, Index };
+  SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl,
+                                       Ops, MMO);
+
+  SDValue OutChain = Gather.getValue(1);
+  if (!ConstantMemory)
+    PendingLoads.push_back(OutChain);
+  setValue(&I, Gather);
+}
+
 void SelectionDAGBuilder::visitAtomicCmpXchg(const AtomicCmpXchgInst &I) {
   SDLoc dl = getCurSDLoc();
   AtomicOrdering SuccessOrder = I.getSuccessOrdering();
@@ -4216,9 +4362,13 @@ SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I, unsigned Intrinsic) {
     return nullptr;
   }
 
+  case Intrinsic::masked_gather:
+    visitMaskedGather(I);
   case Intrinsic::masked_load:
     visitMaskedLoad(I);
     return nullptr;
+  case Intrinsic::masked_scatter:
+    visitMaskedScatter(I);
   case Intrinsic::masked_store:
     visitMaskedStore(I);
     return nullptr;