[C++] Use 'nullptr'.
[oota-llvm.git] / lib / Transforms / Vectorize / SLPVectorizer.cpp
index f6b5b12274296e09affcbcf28944e65eae4e8e84..21a727bfb3b0dd2a4a4ec78e79a061d01c68c8de 100644 (file)
@@ -15,9 +15,6 @@
 //  "Loop-Aware SLP in GCC" by Ira Rosen, Dorit Nuzman, Ayal Zaks.
 //
 //===----------------------------------------------------------------------===//
-#define SV_NAME "slp-vectorizer"
-#define DEBUG_TYPE "SLP"
-
 #include "llvm/Transforms/Vectorize.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils/VectorUtils.h"
 #include <algorithm>
 #include <map>
 
 using namespace llvm;
 
+#define SV_NAME "slp-vectorizer"
+#define DEBUG_TYPE "SLP"
+
 static cl::opt<int>
     SLPCostThreshold("slp-threshold", cl::init(0), cl::Hidden,
                      cl::desc("Only vectorize if you gain more than this "
@@ -72,7 +73,7 @@ struct BlockNumbering {
 
   BlockNumbering(BasicBlock *Bb) : BB(Bb), Valid(false) {}
 
-  BlockNumbering() : BB(0), Valid(false) {}
+  BlockNumbering() : BB(nullptr), Valid(false) {}
 
   void numberInstructions() {
     unsigned Loc = 0;
@@ -120,15 +121,15 @@ private:
 static BasicBlock *getSameBlock(ArrayRef<Value *> VL) {
   Instruction *I0 = dyn_cast<Instruction>(VL[0]);
   if (!I0)
-    return 0;
+    return nullptr;
   BasicBlock *BB = I0->getParent();
   for (int i = 1, e = VL.size(); i < e; i++) {
     Instruction *I = dyn_cast<Instruction>(VL[i]);
     if (!I)
-      return 0;
+      return nullptr;
 
     if (BB != I->getParent())
-      return 0;
+      return nullptr;
   }
   return BB;
 }
@@ -180,7 +181,7 @@ static Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL) {
 
       switch (Kind) {
       default:
-        MD = 0; // Remove unknown metadata
+        MD = nullptr; // Remove unknown metadata
         break;
       case LLVMContext::MD_tbaa:
         MD = MDNode::getMostGenericTBAA(MD, IMD);
@@ -201,7 +202,7 @@ static Type* getSameType(ArrayRef<Value *> VL) {
   Type *Ty = VL[0]->getType();
   for (int i = 1, e = VL.size(); i < e; i++)
     if (VL[i]->getType() != Ty)
-      return 0;
+      return nullptr;
 
   return Ty;
 }
@@ -365,13 +366,13 @@ public:
   /// A negative number means that this is profitable.
   int getTreeCost();
 
-  /// Construct a vectorizable tree that starts at \p Roots and is possibly
-  /// used by a reduction of \p RdxOps.
-  void buildTree(ArrayRef<Value *> Roots, ValueSet *RdxOps = 0);
+  /// Construct a vectorizable tree that starts at \p Roots, ignoring users for
+  /// the purpose of scheduling and extraction in the \p UserIgnoreLst.
+  void buildTree(ArrayRef<Value *> Roots,
+                 ArrayRef<Value *> UserIgnoreLst = None);
 
   /// Clear the internal data structures that are created by 'buildTree'.
   void deleteTree() {
-    RdxOps = 0;
     VectorizableTree.clear();
     ScalarToTreeEntry.clear();
     MustGather.clear();
@@ -446,7 +447,7 @@ private:
   bool isFullyVectorizableTinyTree();
 
   struct TreeEntry {
-    TreeEntry() : Scalars(), VectorizedValue(0), LastScalarIndex(0),
+    TreeEntry() : Scalars(), VectorizedValue(nullptr), LastScalarIndex(0),
     NeedToGather(0) {}
 
     /// \returns true if the scalars in VL are equal to this entry.
@@ -527,8 +528,8 @@ private:
   /// Numbers instructions in different blocks.
   DenseMap<BasicBlock *, BlockNumbering> BlocksNumbers;
 
-  /// Reduction operators.
-  ValueSet *RdxOps;
+  /// List of users to ignore during scheduling and that don't need extracting.
+  ArrayRef<Value *> UserIgnoreList;
 
   // Analysis and block reference.
   Function *F;
@@ -542,9 +543,10 @@ private:
   IRBuilder<> Builder;
 };
 
-void BoUpSLP::buildTree(ArrayRef<Value *> Roots, ValueSet *Rdx) {
+void BoUpSLP::buildTree(ArrayRef<Value *> Roots,
+                        ArrayRef<Value *> UserIgnoreLst) {
   deleteTree();
-  RdxOps = Rdx;
+  UserIgnoreList = UserIgnoreLst;
   if (!getSameType(Roots))
     return;
   buildTree_rec(Roots, 0);
@@ -576,8 +578,9 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, ValueSet *Rdx) {
         if (!UserInst)
           continue;
 
-        // Ignore uses that are part of the reduction.
-        if (Rdx && std::find(Rdx->begin(), Rdx->end(), UserInst) != Rdx->end())
+        // Ignore users in the user ignore list.
+        if (std::find(UserIgnoreList.begin(), UserIgnoreList.end(), UserInst) !=
+            UserIgnoreList.end())
           continue;
 
         DEBUG(dbgs() << "SLP: Need to extract:" << *U << " from lane " <<
@@ -708,8 +711,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
         continue;
       }
 
-      // This user is part of the reduction.
-      if (RdxOps && RdxOps->count(UI))
+      // Ignore users in the user ignore list.
+      if (std::find(UserIgnoreList.begin(), UserIgnoreList.end(), UI) !=
+          UserIgnoreList.end())
         continue;
 
       // Make sure that we can schedule this unknown user.
@@ -949,17 +953,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
     case Instruction::Call: {
       // Check if the calls are all to the same vectorizable intrinsic.
       IntrinsicInst *II = dyn_cast<IntrinsicInst>(VL[0]);
-      if (II==NULL) {
+      Intrinsic::ID ID = II ? II->getIntrinsicID() : Intrinsic::not_intrinsic;
+
+      if (!isTriviallyVectorizable(ID)) {
         newTreeEntry(VL, false);
         DEBUG(dbgs() << "SLP: Non-vectorizable call.\n");
         return;
       }
 
-      Intrinsic::ID ID = II->getIntrinsicID();
+      Function *Int = II->getCalledFunction();
 
       for (unsigned i = 1, e = VL.size(); i != e; ++i) {
         IntrinsicInst *II2 = dyn_cast<IntrinsicInst>(VL[i]);
-        if (!II2 || II2->getIntrinsicID() != ID) {
+        if (!II2 || II2->getCalledFunction() != Int) {
           newTreeEntry(VL, false);
           DEBUG(dbgs() << "SLP: mismatched calls:" << *II << "!=" << *VL[i]
                        << "\n");
@@ -1090,7 +1096,7 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
         // If instead not all operands are constants, then set the operand kind
         // to OK_AnyValue. If all operands are constants but not the same,
         // then set the operand kind to OK_NonUniformConstantValue.
-        ConstantInt *CInt = NULL;
+        ConstantInt *CInt = nullptr;
         for (unsigned i = 0; i < VL.size(); ++i) {
           const Instruction *I = cast<Instruction>(VL[i]);
           if (!isa<ConstantInt>(I->getOperand(1))) {
@@ -1244,7 +1250,7 @@ Value *BoUpSLP::getPointerOperand(Value *I) {
     return LI->getPointerOperand();
   if (StoreInst *SI = dyn_cast<StoreInst>(I))
     return SI->getPointerOperand();
-  return 0;
+  return nullptr;
 }
 
 unsigned BoUpSLP::getAddressSpaceOperand(Value *I) {
@@ -1318,7 +1324,7 @@ Value *BoUpSLP::getSinkBarrier(Instruction *Src, Instruction *Dst) {
     if (!A.Ptr || !B.Ptr || AA->alias(A, B))
       return I;
   }
-  return 0;
+  return nullptr;
 }
 
 int BoUpSLP::getLastIndex(ArrayRef<Value *> VL) {
@@ -1394,7 +1400,7 @@ Value *BoUpSLP::alreadyVectorized(ArrayRef<Value *> VL) const {
     if (En->isSame(VL) && En->VectorizedValue)
       return En->VectorizedValue;
   }
-  return 0;
+  return nullptr;
 }
 
 Value *BoUpSLP::vectorizeTree(ArrayRef<Value *> VL) {
@@ -1667,7 +1673,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
     default:
     llvm_unreachable("unknown inst");
   }
-  return 0;
+  return nullptr;
 }
 
 Value *BoUpSLP::vectorizeTree() {
@@ -1746,8 +1752,9 @@ Value *BoUpSLP::vectorizeTree() {
           DEBUG(dbgs() << "SLP: \tvalidating user:" << *U << ".\n");
 
           assert((ScalarToTreeEntry.count(U) ||
-                  // It is legal to replace the reduction users by undef.
-                  (RdxOps && RdxOps->count(U))) &&
+                  // It is legal to replace users in the ignorelist by undef.
+                  (std::find(UserIgnoreList.begin(), UserIgnoreList.end(), U) !=
+                   UserIgnoreList.end())) &&
                  "Replacing out-of-tree value with undef");
         }
 #endif
@@ -1835,7 +1842,7 @@ void BoUpSLP::optimizeGatherSequence() {
             DT->dominates((*v)->getParent(), In->getParent())) {
           In->replaceAllUsesWith(*v);
           In->eraseFromParent();
-          In = 0;
+          In = nullptr;
           break;
         }
       }
@@ -1874,7 +1881,7 @@ struct SLPVectorizer : public FunctionPass {
 
     SE = &getAnalysis<ScalarEvolution>();
     DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
-    DL = DLP ? &DLP->getDataLayout() : 0;
+    DL = DLP ? &DLP->getDataLayout() : nullptr;
     TTI = &getAnalysis<TargetTransformInfo>();
     AA = &getAnalysis<AliasAnalysis>();
     LI = &getAnalysis<LoopInfo>();
@@ -1951,8 +1958,11 @@ private:
   bool tryToVectorizePair(Value *A, Value *B, BoUpSLP &R);
 
   /// \brief Try to vectorize a list of operands.
+  /// \@param BuildVector A list of users to ignore for the purpose of
+  ///                     scheduling and that don't need extracting.
   /// \returns true if a value was vectorized.
-  bool tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R);
+  bool tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
+                          ArrayRef<Value *> BuildVector = None);
 
   /// \brief Try to vectorize a chain that may start at the operands of \V;
   bool tryToVectorize(BinaryOperator *V, BoUpSLP &R);
@@ -2125,7 +2135,8 @@ bool SLPVectorizer::tryToVectorizePair(Value *A, Value *B, BoUpSLP &R) {
   return tryToVectorizeList(VL, R);
 }
 
-bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R) {
+bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
+                                       ArrayRef<Value *> BuildVector) {
   if (VL.size() < 2)
     return false;
 
@@ -2153,7 +2164,7 @@ bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R) {
 
   bool Changed = false;
 
-  // Keep track of values that were delete by vectorizing in the loop below.
+  // Keep track of values that were deleted by vectorizing in the loop below.
   SmallVector<WeakVH, 8> TrackValues(VL.begin(), VL.end());
 
   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
@@ -2175,13 +2186,33 @@ bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R) {
                  << "\n");
     ArrayRef<Value *> Ops = VL.slice(i, OpsWidth);
 
-    R.buildTree(Ops);
+    ArrayRef<Value *> BuildVectorSlice;
+    if (!BuildVector.empty())
+      BuildVectorSlice = BuildVector.slice(i, OpsWidth);
+
+    R.buildTree(Ops, BuildVectorSlice);
     int Cost = R.getTreeCost();
 
     if (Cost < -SLPCostThreshold) {
       DEBUG(dbgs() << "SLP: Vectorizing list at cost:" << Cost << ".\n");
-      R.vectorizeTree();
-
+      Value *VectorizedRoot = R.vectorizeTree();
+
+      // Reconstruct the build vector by extracting the vectorized root. This
+      // way we handle the case where some elements of the vector are undefined.
+      //  (return (inserelt <4 xi32> (insertelt undef (opd0) 0) (opd1) 2))
+      if (!BuildVectorSlice.empty()) {
+        Instruction *InsertAfter = cast<Instruction>(VectorizedRoot);
+        for (auto &V : BuildVectorSlice) {
+          InsertElementInst *IE = cast<InsertElementInst>(V);
+          IRBuilder<> Builder(++BasicBlock::iterator(InsertAfter));
+          Instruction *Extract = cast<Instruction>(
+              Builder.CreateExtractElement(VectorizedRoot, IE->getOperand(2)));
+          IE->setOperand(1, Extract);
+          IE->removeFromParent();
+          IE->insertAfter(Extract);
+          InsertAfter = IE;
+        }
+      }
       // Move to the next bundle.
       i += VF - 1;
       Changed = true;
@@ -2290,7 +2321,7 @@ static Value *createRdxShuffleMask(unsigned VecLen, unsigned NumEltsToRdx,
 ///   *p =
 ///
 class HorizontalReduction {
-  SmallPtrSet<Value *, 16> ReductionOps;
+  SmallVector<Value *, 16> ReductionOps;
   SmallVector<Value *, 32> ReducedVals;
 
   BinaryOperator *ReductionRoot;
@@ -2308,7 +2339,7 @@ class HorizontalReduction {
 
 public:
   HorizontalReduction()
-    : ReductionRoot(0), ReductionPHI(0), ReductionOpcode(0),
+    : ReductionRoot(nullptr), ReductionPHI(nullptr), ReductionOpcode(0),
     ReducedValueOpcode(0), ReduxWidth(0), IsPairwiseReduction(false) {}
 
   /// \brief Try to find a reduction tree.
@@ -2323,10 +2354,10 @@ public:
     // In such a case start looking for a tree rooted in the first '+'.
     if (Phi) {
       if (B->getOperand(0) == Phi) {
-        Phi = 0;
+        Phi = nullptr;
         B = dyn_cast<BinaryOperator>(B->getOperand(1));
       } else if (B->getOperand(1) == Phi) {
-        Phi = 0;
+        Phi = nullptr;
         B = dyn_cast<BinaryOperator>(B->getOperand(0));
       }
     }
@@ -2384,7 +2415,7 @@ public:
           // We need to be able to reassociate the adds.
           if (!TreeN->isAssociative())
             return false;
-          ReductionOps.insert(TreeN);
+          ReductionOps.push_back(TreeN);
         }
         // Retract.
         Stack.pop_back();
@@ -2412,7 +2443,7 @@ public:
     if (NumReducedVals < ReduxWidth)
       return false;
 
-    Value *VectorizedTree = 0;
+    Value *VectorizedTree = nullptr;
     IRBuilder<> Builder(ReductionRoot);
     FastMathFlags Unsafe;
     Unsafe.setUnsafeAlgebra();
@@ -2421,7 +2452,7 @@ public:
 
     for (; i < NumReducedVals - ReduxWidth + 1; i += ReduxWidth) {
       ArrayRef<Value *> ValsToReduce(&ReducedVals[i], ReduxWidth);
-      V.buildTree(ValsToReduce, &ReductionOps);
+      V.buildTree(ValsToReduce, ReductionOps);
 
       // Estimate cost.
       int Cost = V.getTreeCost() + getReductionCost(TTI, ReducedVals[i]);
@@ -2455,13 +2486,13 @@ public:
       }
       // Update users.
       if (ReductionPHI) {
-        assert(ReductionRoot != NULL && "Need a reduction operation");
+        assert(ReductionRoot && "Need a reduction operation");
         ReductionRoot->setOperand(0, VectorizedTree);
         ReductionRoot->setOperand(1, ReductionPHI);
       } else
         ReductionRoot->replaceAllUsesWith(VectorizedTree);
     }
-    return VectorizedTree != 0;
+    return VectorizedTree != nullptr;
   }
 
 private:
@@ -2540,13 +2571,16 @@ private:
 ///
 /// Returns true if it matches
 ///
-static bool findBuildVector(InsertElementInst *IE,
-                            SmallVectorImpl<Value *> &Ops) {
-  if (!isa<UndefValue>(IE->getOperand(0)))
+static bool findBuildVector(InsertElementInst *FirstInsertElem,
+                            SmallVectorImpl<Value *> &BuildVector,
+                            SmallVectorImpl<Value *> &BuildVectorOpds) {
+  if (!isa<UndefValue>(FirstInsertElem->getOperand(0)))
     return false;
 
+  InsertElementInst *IE = FirstInsertElem;
   while (true) {
-    Ops.push_back(IE->getOperand(1));
+    BuildVector.push_back(IE);
+    BuildVectorOpds.push_back(IE->getOperand(1));
 
     if (IE->use_empty())
       return false;
@@ -2641,7 +2675,8 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
       Value *Rdx =
           (P->getIncomingBlock(0) == BB
                ? (P->getIncomingValue(0))
-               : (P->getIncomingBlock(1) == BB ? P->getIncomingValue(1) : 0));
+               : (P->getIncomingBlock(1) == BB ? P->getIncomingValue(1)
+                                               : nullptr));
       // Check if this is a Binary Operator.
       BinaryOperator *BI = dyn_cast_or_null<BinaryOperator>(Rdx);
       if (!BI)
@@ -2680,7 +2715,7 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
         if (BinaryOperator *BinOp =
                 dyn_cast<BinaryOperator>(SI->getValueOperand())) {
           HorizontalReduction HorRdx;
-          if (((HorRdx.matchAssociativeReduction(0, BinOp, DL) &&
+          if (((HorRdx.matchAssociativeReduction(nullptr, BinOp, DL) &&
                 HorRdx.tryToReduce(R, TTI)) ||
                tryToVectorize(BinOp, R))) {
             Changed = true;
@@ -2716,12 +2751,16 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
     }
 
     // Try to vectorize trees that start at insertelement instructions.
-    if (InsertElementInst *IE = dyn_cast<InsertElementInst>(it)) {
-      SmallVector<Value *, 8> Ops;
-      if (!findBuildVector(IE, Ops))
+    if (InsertElementInst *FirstInsertElem = dyn_cast<InsertElementInst>(it)) {
+      SmallVector<Value *, 16> BuildVector;
+      SmallVector<Value *, 16> BuildVectorOpds;
+      if (!findBuildVector(FirstInsertElem, BuildVector, BuildVectorOpds))
         continue;
 
-      if (tryToVectorizeList(Ops, R)) {
+      // Vectorize starting with the build vector operands ignoring the
+      // BuildVector instructions for the purpose of scheduling and user
+      // extraction.
+      if (tryToVectorizeList(BuildVectorOpds, R, BuildVector)) {
         Changed = true;
         it = BB->begin();
         e = BB->end();