Optimization for Gather/Scatter with uniform base
authorElena Demikhovsky <elena.demikhovsky@intel.com>
Wed, 2 Sep 2015 08:39:13 +0000 (08:39 +0000)
committerElena Demikhovsky <elena.demikhovsky@intel.com>
Wed, 2 Sep 2015 08:39:13 +0000 (08:39 +0000)
Vector 'getelementptr' with scalar base is an opportunity for gather/scatter intrinsic to generate a better sequence.
While looking for uniform base, we want to use the scalar base pointer of GEP, if exists.

Differential Revision: http://reviews.llvm.org/D11121

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@246622 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
test/CodeGen/X86/masked_gather_scatter.ll

index 50f8c16..997fa1a 100644 (file)
@@ -3142,51 +3142,63 @@ void SelectionDAGBuilder::visitMaskedStore(const CallInst &I) {
   setValue(&I, StoreNode);
 }
 
-// Gather/scatter receive a vector of pointers.
-// This vector of pointers may be represented as a base pointer + vector of
-// indices, it depends on GEP and instruction preceding GEP
-// that calculates indices
+// Get a uniform base for the Gather/Scatter intrinsic.
+// The first argument of the Gather/Scatter intrinsic is a vector of pointers.
+// We try to represent it as a base pointer + vector of indices.
+// Usually, the vector of pointers comes from a 'getelementptr' instruction.
+// The first operand of the GEP may be a single pointer or a vector of pointers
+// Example:
+//   %gep.ptr = getelementptr i32, <8 x i32*> %vptr, <8 x i32> %ind
+//  or
+//   %gep.ptr = getelementptr i32, i32* %ptr,        <8 x i32> %ind
+// %res = call <8 x i32> @llvm.masked.gather.v8i32(<8 x i32*> %gep.ptr, ..
+//
+// When the first GEP operand is a single pointer - it is the uniform base we
+// are looking for. If first operand of the GEP is a splat vector - we
+// extract the spalt value and use it as a uniform base.
+// In all other cases the function returns 'false'.
+//
 static bool getUniformBase(Value *& Ptr, SDValue& Base, SDValue& Index,
                            SelectionDAGBuilder* SDB) {
 
-  assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
+  SelectionDAG& DAG = SDB->DAG;
+  LLVMContext &Context = *DAG.getContext();
+
+  assert(Ptr->getType()->isVectorTy() && "Uexpected pointer type");
   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
   if (!GEP || GEP->getNumOperands() > 2)
     return false;
-  Value *GEPPtrs = GEP->getPointerOperand();
-  if (!(Ptr = getSplatValue(GEPPtrs)))
-    return false;
 
-  SelectionDAG& DAG = SDB->DAG;
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  // Check is the Ptr is inside current basic block
-  // If not, look for the shuffle instruction
-  if (SDB->findValue(Ptr))
-    Base = SDB->getValue(Ptr);
-  else if (SDB->findValue(GEPPtrs)) {
-    SDValue GEPPtrsVal = SDB->getValue(GEPPtrs);
-    SDLoc sdl = GEPPtrsVal;
-    EVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout());
-    Base = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl,
-                       GEPPtrsVal.getValueType().getScalarType(), GEPPtrsVal,
-                       DAG.getConstant(0, sdl, IdxVT));
-    SDB->setValue(Ptr, Base);
-  }
-  else
+  Value *GEPPtr = GEP->getPointerOperand();
+  if (!GEPPtr->getType()->isVectorTy())
+    Ptr = GEPPtr;
+  else if (!(Ptr = getSplatValue(GEPPtr)))
     return false;
 
   Value *IndexVal = GEP->getOperand(1);
-  if (SDB->findValue(IndexVal)) {
-    Index = SDB->getValue(IndexVal);
 
-    if (SExtInst* Sext = dyn_cast<SExtInst>(IndexVal)) {
+  // The operands of the GEP may be defined in another basic block.
+  // In this case we'll not find nodes for the operands.
+  if (!SDB->findValue(Ptr) || !SDB->findValue(IndexVal))
+    return false;
+
+  Base = SDB->getValue(Ptr);
+  Index = SDB->getValue(IndexVal);
+
+  // Suppress sign extension.
+  if (SExtInst* Sext = dyn_cast<SExtInst>(IndexVal)) {
+    if (SDB->findValue(Sext->getOperand(0))) {
       IndexVal = Sext->getOperand(0);
-      if (SDB->findValue(IndexVal))
-        Index = SDB->getValue(IndexVal);
+      Index = SDB->getValue(IndexVal);
     }
-    return true;
   }
-  return false;
+  if (!Index.getValueType().isVector()) {
+    unsigned GEPWidth = GEP->getType()->getVectorNumElements();
+    EVT VT = EVT::getVectorVT(Context, Index.getValueType(), GEPWidth);
+    SmallVector<SDValue, 16> Ops(GEPWidth, Index);
+    Index = DAG.getNode(ISD::BUILD_VECTOR, SDLoc(Index), VT, Ops);
+  }
+  return true;
 }
 
 void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) {
index de16e5d..11d87bb 100644 (file)
@@ -140,3 +140,108 @@ define <16 x i32> @test8(<16 x i32*> %ptr.random, <16 x i32> %ind, i16 %mask) {
   %res = add <16 x i32> %gt1, %gt2
   ret <16 x i32> %res
 }
+
+%struct.RT = type { i8, [10 x [20 x i32]], i8 }
+%struct.ST = type { i32, double, %struct.RT }
+
+; Masked gather for agregate types
+; Test9 and Test10 should give the same result (scalar and vector indices in GEP)
+
+; KNL-LABEL: test9
+; KNL: vpbroadcastq    %rdi, %zmm
+; KNL: vpmovsxdq
+; KNL: vpbroadcastq
+; KNL: vpmuludq
+; KNL: vpaddq
+; KNL: vpaddq
+; KNL: vpaddq
+; KNL: vpaddq
+; KNL: vpgatherqd      (,%zmm
+
+define <8 x i32> @test9(%struct.ST* %base, <8 x i64> %ind1, <8 x i32>%ind5) {
+entry:
+  %broadcast.splatinsert = insertelement <8 x %struct.ST*> undef, %struct.ST* %base, i32 0
+  %broadcast.splat = shufflevector <8 x %struct.ST*> %broadcast.splatinsert, <8 x %struct.ST*> undef, <8 x i32> zeroinitializer
+
+  %arrayidx = getelementptr  %struct.ST, <8 x %struct.ST*> %broadcast.splat, <8 x i64> %ind1, <8 x i32> <i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2, i32 2>, <8 x i32><i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>, <8 x i32> %ind5, <8 x i64> <i64 13, i64 13, i64 13, i64 13, i64 13, i64 13, i64 13, i64 13>
+  %res = call <8 x i32 >  @llvm.masked.gather.v8i32(<8 x i32*>%arrayidx, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
+  ret <8 x i32> %res
+}
+
+; KNL-LABEL: test10
+; KNL: vpbroadcastq    %rdi, %zmm
+; KNL: vpmovsxdq
+; KNL: vpbroadcastq
+; KNL: vpmuludq
+; KNL: vpaddq
+; KNL: vpaddq
+; KNL: vpaddq
+; KNL: vpaddq
+; KNL: vpgatherqd      (,%zmm
+define <8 x i32> @test10(%struct.ST* %base, <8 x i64> %i1, <8 x i32>%ind5) {
+entry:
+  %broadcast.splatinsert = insertelement <8 x %struct.ST*> undef, %struct.ST* %base, i32 0
+  %broadcast.splat = shufflevector <8 x %struct.ST*> %broadcast.splatinsert, <8 x %struct.ST*> undef, <8 x i32> zeroinitializer
+
+  %arrayidx = getelementptr  %struct.ST, <8 x %struct.ST*> %broadcast.splat, <8 x i64> %i1, i32 2, i32 1, <8 x i32> %ind5, i64 13
+  %res = call <8 x i32 >  @llvm.masked.gather.v8i32(<8 x i32*>%arrayidx, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x i32> undef)
+  ret <8 x i32> %res
+}
+
+; Splat index in GEP, requires broadcast
+; KNL-LABEL: test11
+; KNL: vpbroadcastd    %esi, %zmm
+; KNL: vgatherdps      (%rdi,%zmm
+define <16 x float> @test11(float* %base, i32 %ind) {
+
+  %broadcast.splatinsert = insertelement <16 x float*> undef, float* %base, i32 0
+  %broadcast.splat = shufflevector <16 x float*> %broadcast.splatinsert, <16 x float*> undef, <16 x i32> zeroinitializer
+
+  %gep.random = getelementptr float, <16 x float*> %broadcast.splat, i32 %ind
+
+  %res = call <16 x float> @llvm.masked.gather.v16f32(<16 x float*> %gep.random, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> undef)
+  ret <16 x float>%res
+}
+
+; We are checking the uniform base here. It is taken directly from input to vgatherdps
+; KNL-LABEL: test12
+; KNL: kxnorw  %k1, %k1, %k1
+; KNL: vgatherdps      (%rdi,%zmm
+define <16 x float> @test12(float* %base, <16 x i32> %ind) {
+
+  %sext_ind = sext <16 x i32> %ind to <16 x i64>
+  %gep.random = getelementptr float, float *%base, <16 x i64> %sext_ind
+
+  %res = call <16 x float> @llvm.masked.gather.v16f32(<16 x float*> %gep.random, i32 4, <16 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <16 x float> undef)
+  ret <16 x float>%res
+}
+
+; The same as the previous, but the mask is undefined
+; KNL-LABEL: test13
+; KNL-NOT: kxnorw
+; KNL: vgatherdps      (%rdi,%zmm
+define <16 x float> @test13(float* %base, <16 x i32> %ind) {
+
+  %sext_ind = sext <16 x i32> %ind to <16 x i64>
+  %gep.random = getelementptr float, float *%base, <16 x i64> %sext_ind
+
+  %res = call <16 x float> @llvm.masked.gather.v16f32(<16 x float*> %gep.random, i32 4, <16 x i1> undef, <16 x float> undef)
+  ret <16 x float>%res
+}
+
+; The base pointer is not splat, can't find unform base
+; KNL-LABEL: test14
+; KNL: vgatherqps      (,%zmm0)
+; KNL: vgatherqps      (,%zmm0)
+define <16 x float> @test14(float* %base, i32 %ind, <16 x float*> %vec) {
+
+  %broadcast.splatinsert = insertelement <16 x float*> %vec, float* %base, i32 1
+  %broadcast.splat = shufflevector <16 x float*> %broadcast.splatinsert, <16 x float*> undef, <16 x i32> zeroinitializer
+
+  %gep.random = getelementptr float, <16 x float*> %broadcast.splat, i32 %ind
+
+  %res = call <16 x float> @llvm.masked.gather.v16f32(<16 x float*> %gep.random, i32 4, <16 x i1> undef, <16 x float> undef)
+  ret <16 x float>%res
+}
+
+