Optimization for Gather/Scatter with uniform base
[oota-llvm.git] / lib / CodeGen / SelectionDAG / SelectionDAGBuilder.cpp
index 50f8c16309bc41792cba8e448f95c5471c8f7351..997fa1ae061841c792d4deb0f5502c8f71a2b70a 100644 (file)
@@ -3142,51 +3142,63 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I) {
   setValue(&I, StoreNode);
 }
 
-// Gather/scatter receive a vector of pointers.
-// This vector of pointers may be represented as a base pointer + vector of
-// indices, it depends on GEP and instruction preceding GEP
-// that calculates indices
+// Get a uniform base for the Gather/Scatter intrinsic.
+// The first argument of the Gather/Scatter intrinsic is a vector of pointers.
+// We try to represent it as a base pointer + vector of indices.
+// Usually, the vector of pointers comes from a 'getelementptr' instruction.
+// The first operand of the GEP may be a single pointer or a vector of pointers
+// Example:
+//   %gep.ptr = getelementptr i32, <8 x i32*> %vptr, <8 x i32> %ind
+//  or
+//   %gep.ptr = getelementptr i32, i32* %ptr,        <8 x i32> %ind
+// %res = call <8 x i32> @llvm.masked.gather.v8i32(<8 x i32*> %gep.ptr, ..
+//
+// When the first GEP operand is a single pointer - it is the uniform base we
+// are looking for. If first operand of the GEP is a splat vector - we
+// extract the spalt value and use it as a uniform base.
+// In all other cases the function returns 'false'.
+//
 static bool getUniformBase(Value *& Ptr, SDValue& Base, SDValue& Index,
                            SelectionDAGBuilder* SDB) {
 
-  assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
+  SelectionDAG& DAG = SDB->DAG;
+  LLVMContext &Context = *DAG.getContext();
+
+  assert(Ptr->getType()->isVectorTy() && "Uexpected pointer type");
   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
   if (!GEP || GEP->getNumOperands() > 2)
     return false;
-  Value *GEPPtrs = GEP->getPointerOperand();
-  if (!(Ptr = getSplatValue(GEPPtrs)))
-    return false;
 
-  SelectionDAG& DAG = SDB->DAG;
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  // Check is the Ptr is inside current basic block
-  // If not, look for the shuffle instruction
-  if (SDB->findValue(Ptr))
-    Base = SDB->getValue(Ptr);
-  else if (SDB->findValue(GEPPtrs)) {
-    SDValue GEPPtrsVal = SDB->getValue(GEPPtrs);
-    SDLoc sdl = GEPPtrsVal;
-    EVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout());
-    Base = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl,
-                       GEPPtrsVal.getValueType().getScalarType(), GEPPtrsVal,
-                       DAG.getConstant(0, sdl, IdxVT));
-    SDB->setValue(Ptr, Base);
-  }
-  else
+  Value *GEPPtr = GEP->getPointerOperand();
+  if (!GEPPtr->getType()->isVectorTy())
+    Ptr = GEPPtr;
+  else if (!(Ptr = getSplatValue(GEPPtr)))
     return false;
 
   Value *IndexVal = GEP->getOperand(1);
-  if (SDB->findValue(IndexVal)) {
-    Index = SDB->getValue(IndexVal);
 
-    if (SExtInst* Sext = dyn_cast<SExtInst>(IndexVal)) {
+  // The operands of the GEP may be defined in another basic block.
+  // In this case we'll not find nodes for the operands.
+  if (!SDB->findValue(Ptr) || !SDB->findValue(IndexVal))
+    return false;
+
+  Base = SDB->getValue(Ptr);
+  Index = SDB->getValue(IndexVal);
+
+  // Suppress sign extension.
+  if (SExtInst* Sext = dyn_cast<SExtInst>(IndexVal)) {
+    if (SDB->findValue(Sext->getOperand(0))) {
       IndexVal = Sext->getOperand(0);
-      if (SDB->findValue(IndexVal))
-        Index = SDB->getValue(IndexVal);
+      Index = SDB->getValue(IndexVal);
     }
-    return true;
   }
-  return false;
+  if (!Index.getValueType().isVector()) {
+    unsigned GEPWidth = GEP->getType()->getVectorNumElements();
+    EVT VT = EVT::getVectorVT(Context, Index.getValueType(), GEPWidth);
+    SmallVector<SDValue, 16> Ops(GEPWidth, Index);
+    Index = DAG.getNode(ISD::BUILD_VECTOR, SDLoc(Index), VT, Ops);
+  }
+  return true;
 }
 
 void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {