Reapply "SLPVectorizer: Ignore users that are insertelements we can reschedule them"
[oota-llvm.git] / lib / Transforms / Vectorize / SLPVectorizer.cpp
index f2e629ff6399ec7998f8f84aea71783ccf35ac9c..58ec5867efb3e3f45ebf950c5e939b28a092a7a6 100644 (file)
@@ -41,6 +41,7 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils/VectorUtils.h"
 #include <algorithm>
 #include <map>
 
@@ -365,13 +366,13 @@ public:
   /// A negative number means that this is profitable.
   int getTreeCost();
 
-  /// Construct a vectorizable tree that starts at \p Roots and is possibly
-  /// used by a reduction of \p RdxOps.
-  void buildTree(ArrayRef<Value *> Roots, ValueSet *RdxOps = 0);
+  /// Construct a vectorizable tree that starts at \p Roots, ignoring users for
+  /// the purpose of scheduling and extraction in the \p UserIgnoreLst.
+  void buildTree(ArrayRef<Value *> Roots,
+                 ArrayRef<Value *> UserIgnoreLst = None);
 
   /// Clear the internal data structures that are created by 'buildTree'.
   void deleteTree() {
-    RdxOps = 0;
     VectorizableTree.clear();
     ScalarToTreeEntry.clear();
     MustGather.clear();
@@ -527,8 +528,8 @@ private:
   /// Numbers instructions in different blocks.
   DenseMap<BasicBlock *, BlockNumbering> BlocksNumbers;
 
-  /// Reduction operators.
-  ValueSet *RdxOps;
+  /// List of users to ignore during scheduling and that don't need extracting.
+  ArrayRef<Value *> UserIgnoreList;
 
   // Analysis and block reference.
   Function *F;
@@ -542,9 +543,10 @@ private:
   IRBuilder<> Builder;
 };
 
-void BoUpSLP::buildTree(ArrayRef<Value *> Roots, ValueSet *Rdx) {
+void BoUpSLP::buildTree(ArrayRef<Value *> Roots,
+                        ArrayRef<Value *> UserIgnoreLst) {
   deleteTree();
-  RdxOps = Rdx;
+  UserIgnoreList = UserIgnoreLst;
   if (!getSameType(Roots))
     return;
   buildTree_rec(Roots, 0);
@@ -576,8 +578,9 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, ValueSet *Rdx) {
         if (!UserInst)
           continue;
 
-        // Ignore uses that are part of the reduction.
-        if (Rdx && std::find(Rdx->begin(), Rdx->end(), UserInst) != Rdx->end())
+        // Ignore users in the user ignore list.
+        if (std::find(UserIgnoreList.begin(), UserIgnoreList.end(), UserInst) !=
+            UserIgnoreList.end())
           continue;
 
         DEBUG(dbgs() << "SLP: Need to extract:" << *U << " from lane " <<
@@ -708,8 +711,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
         continue;
       }
 
-      // This user is part of the reduction.
-      if (RdxOps && RdxOps->count(UI))
+      // Ignore users in the user ignore list.
+      if (std::find(UserIgnoreList.begin(), UserIgnoreList.end(), UI) !=
+          UserIgnoreList.end())
         continue;
 
       // Make sure that we can schedule this unknown user.
@@ -946,6 +950,41 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) {
       buildTree_rec(Operands, Depth + 1);
       return;
     }
+    case Instruction::Call: {
+      // Check if the calls are all to the same vectorizable intrinsic.
+      IntrinsicInst *II = dyn_cast<IntrinsicInst>(VL[0]);
+      Intrinsic::ID ID = II ? II->getIntrinsicID() : Intrinsic::not_intrinsic;
+
+      if (!isTriviallyVectorizable(ID)) {
+        newTreeEntry(VL, false);
+        DEBUG(dbgs() << "SLP: Non-vectorizable call.\n");
+        return;
+      }
+
+      Function *Int = II->getCalledFunction();
+
+      for (unsigned i = 1, e = VL.size(); i != e; ++i) {
+        IntrinsicInst *II2 = dyn_cast<IntrinsicInst>(VL[i]);
+        if (!II2 || II2->getCalledFunction() != Int) {
+          newTreeEntry(VL, false);
+          DEBUG(dbgs() << "SLP: mismatched calls:" << *II << "!=" << *VL[i]
+                       << "\n");
+          return;
+        }
+      }
+
+      newTreeEntry(VL, true);
+      for (unsigned i = 0, e = II->getNumArgOperands(); i != e; ++i) {
+        ValueList Operands;
+        // Prepare the operand vector.
+        for (unsigned j = 0; j < VL.size(); ++j) {
+          IntrinsicInst *II2 = dyn_cast<IntrinsicInst>(VL[j]);
+          Operands.push_back(II2->getArgOperand(i));
+        }
+        buildTree_rec(Operands, Depth + 1);
+      }
+      return;
+    }
     default:
       newTreeEntry(VL, false);
       DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n");
@@ -979,8 +1018,17 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
       return 0;
     }
     case Instruction::ExtractElement: {
-      if (CanReuseExtract(VL))
-        return 0;
+      if (CanReuseExtract(VL)) {
+        int DeadCost = 0;
+        for (unsigned i = 0, e = VL.size(); i < e; ++i) {
+          ExtractElementInst *E = cast<ExtractElementInst>(VL[i]);
+          if (E->hasOneUse())
+            // Take credit for instruction that will become dead.
+            DeadCost +=
+                TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, i);
+        }
+        return -DeadCost;
+      }
       return getGatherCost(VecTy);
     }
     case Instruction::ZExt:
@@ -1085,6 +1133,30 @@ int BoUpSLP::getEntryCost(TreeEntry *E) {
       int VecStCost = TTI->getMemoryOpCost(Instruction::Store, VecTy, 1, 0);
       return VecStCost - ScalarStCost;
     }
+    case Instruction::Call: {
+      CallInst *CI = cast<CallInst>(VL0);
+      IntrinsicInst *II = cast<IntrinsicInst>(CI);
+      Intrinsic::ID ID = II->getIntrinsicID();
+
+      // Calculate the cost of the scalar and vector calls.
+      SmallVector<Type*, 4> ScalarTys, VecTys;
+      for (unsigned op = 0, opc = II->getNumArgOperands(); op!= opc; ++op) {
+        ScalarTys.push_back(CI->getArgOperand(op)->getType());
+        VecTys.push_back(VectorType::get(CI->getArgOperand(op)->getType(),
+                                         VecTy->getNumElements()));
+      }
+
+      int ScalarCallCost = VecTy->getNumElements() *
+          TTI->getIntrinsicInstrCost(ID, ScalarTy, ScalarTys);
+
+      int VecCallCost = TTI->getIntrinsicInstrCost(ID, VecTy, VecTys);
+
+      DEBUG(dbgs() << "SLP: Call cost "<< VecCallCost - ScalarCallCost
+            << " (" << VecCallCost  << "-" <<  ScalarCallCost << ")"
+            << " for " << *II << "\n");
+
+      return VecCallCost - ScalarCallCost;
+    }
     default:
       llvm_unreachable("Unknown instruction");
   }
@@ -1572,6 +1644,32 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
       E->VectorizedValue = S;
       return propagateMetadata(S, E->Scalars);
     }
+    case Instruction::Call: {
+      CallInst *CI = cast<CallInst>(VL0);
+
+      setInsertPointAfterBundle(E->Scalars);
+      std::vector<Value *> OpVecs;
+      for (int j = 0, e = CI->getNumArgOperands(); j < e; ++j) {
+        ValueList OpVL;
+        for (int i = 0, e = E->Scalars.size(); i < e; ++i) {
+          CallInst *CEI = cast<CallInst>(E->Scalars[i]);
+          OpVL.push_back(CEI->getArgOperand(j));
+        }
+
+        Value *OpVec = vectorizeTree(OpVL);
+        DEBUG(dbgs() << "SLP: OpVec[" << j << "]: " << *OpVec << "\n");
+        OpVecs.push_back(OpVec);
+      }
+
+      Module *M = F->getParent();
+      IntrinsicInst *II = cast<IntrinsicInst>(CI);
+      Intrinsic::ID ID = II->getIntrinsicID();
+      Type *Tys[] = { VectorType::get(CI->getType(), E->Scalars.size()) };
+      Function *CF = Intrinsic::getDeclaration(M, ID, Tys);
+      Value *V = Builder.CreateCall(CF, OpVecs);
+      E->VectorizedValue = V;
+      return V;
+    }
     default:
     llvm_unreachable("unknown inst");
   }
@@ -1607,12 +1705,7 @@ Value *BoUpSLP::vectorizeTree() {
     Value *Lane = Builder.getInt32(it->Lane);
     // Generate extracts for out-of-tree users.
     // Find the insertion point for the extractelement lane.
-    if (PHINode *PN = dyn_cast<PHINode>(Vec)) {
-      Builder.SetInsertPoint(PN->getParent()->getFirstInsertionPt());
-      Value *Ex = Builder.CreateExtractElement(Vec, Lane);
-      CSEBlocks.insert(PN->getParent());
-      User->replaceUsesOfWith(Scalar, Ex);
-    } else if (isa<Instruction>(Vec)){
+    if (isa<Instruction>(Vec)){
       if (PHINode *PH = dyn_cast<PHINode>(User)) {
         for (int i = 0, e = PH->getNumIncomingValues(); i != e; ++i) {
           if (PH->getIncomingValue(i) == Scalar) {
@@ -1654,14 +1747,17 @@ Value *BoUpSLP::vectorizeTree() {
 
       Type *Ty = Scalar->getType();
       if (!Ty->isVoidTy()) {
+#ifndef NDEBUG
         for (User *U : Scalar->users()) {
           DEBUG(dbgs() << "SLP: \tvalidating user:" << *U << ".\n");
 
           assert((ScalarToTreeEntry.count(U) ||
-                  // It is legal to replace the reduction users by undef.
-                  (RdxOps && RdxOps->count(U))) &&
+                  // It is legal to replace users in the ignorelist by undef.
+                  (std::find(UserIgnoreList.begin(), UserIgnoreList.end(), U) !=
+                   UserIgnoreList.end())) &&
                  "Replacing out-of-tree value with undef");
         }
+#endif
         Value *Undef = UndefValue::get(Ty);
         Scalar->replaceAllUsesWith(Undef);
       }
@@ -1862,8 +1958,11 @@ private:
   bool tryToVectorizePair(Value *A, Value *B, BoUpSLP &R);
 
   /// \brief Try to vectorize a list of operands.
+  /// \@param BuildVector A list of users to ignore for the purpose of
+  ///                     scheduling and that don't need extracting.
   /// \returns true if a value was vectorized.
-  bool tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R);
+  bool tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
+                          ArrayRef<Value *> BuildVector = None);
 
   /// \brief Try to vectorize a chain that may start at the operands of \V;
   bool tryToVectorize(BinaryOperator *V, BoUpSLP &R);
@@ -1911,7 +2010,7 @@ bool SLPVectorizer::vectorizeStoreChain(ArrayRef<Value *> Chain,
   if (!isPowerOf2_32(Sz) || VF < 2)
     return false;
 
-  // Keep track of values that were delete by vectorizing in the loop below.
+  // Keep track of values that were deleted by vectorizing in the loop below.
   SmallVector<WeakVH, 8> TrackValues(Chain.begin(), Chain.end());
 
   bool Changed = false;
@@ -2036,7 +2135,8 @@ bool SLPVectorizer::tryToVectorizePair(Value *A, Value *B, BoUpSLP &R) {
   return tryToVectorizeList(VL, R);
 }
 
-bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R) {
+bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
+                                       ArrayRef<Value *> BuildVector) {
   if (VL.size() < 2)
     return false;
 
@@ -2064,7 +2164,7 @@ bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R) {
 
   bool Changed = false;
 
-  // Keep track of values that were delete by vectorizing in the loop below.
+  // Keep track of values that were deleted by vectorizing in the loop below.
   SmallVector<WeakVH, 8> TrackValues(VL.begin(), VL.end());
 
   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
@@ -2086,13 +2186,33 @@ bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R) {
                  << "\n");
     ArrayRef<Value *> Ops = VL.slice(i, OpsWidth);
 
-    R.buildTree(Ops);
+    ArrayRef<Value *> BuildVectorSlice;
+    if (!BuildVector.empty())
+      BuildVectorSlice = BuildVector.slice(i, OpsWidth);
+
+    R.buildTree(Ops, BuildVectorSlice);
     int Cost = R.getTreeCost();
 
     if (Cost < -SLPCostThreshold) {
-      DEBUG(dbgs() << "SLP: Vectorizing pair at cost:" << Cost << ".\n");
-      R.vectorizeTree();
-
+      DEBUG(dbgs() << "SLP: Vectorizing list at cost:" << Cost << ".\n");
+      Value *VectorizedRoot = R.vectorizeTree();
+
+      // Reconstruct the build vector by extracting the vectorized root. This
+      // way we handle the case where some elements of the vector are undefined.
+      //  (return (inserelt <4 xi32> (insertelt undef (opd0) 0) (opd1) 2))
+      if (!BuildVectorSlice.empty()) {
+        Instruction *InsertAfter = cast<Instruction>(VectorizedRoot);
+        for (auto &V : BuildVectorSlice) {
+          InsertElementInst *IE = cast<InsertElementInst>(V);
+          IRBuilder<> Builder(++BasicBlock::iterator(InsertAfter));
+          Instruction *Extract = cast<Instruction>(
+              Builder.CreateExtractElement(VectorizedRoot, IE->getOperand(2)));
+          IE->setOperand(1, Extract);
+          IE->removeFromParent();
+          IE->insertAfter(Extract);
+          InsertAfter = IE;
+        }
+      }
       // Move to the next bundle.
       i += VF - 1;
       Changed = true;
@@ -2201,7 +2321,7 @@ static Value *createRdxShuffleMask(unsigned VecLen, unsigned NumEltsToRdx,
 ///   *p =
 ///
 class HorizontalReduction {
-  SmallPtrSet<Value *, 16> ReductionOps;
+  SmallVector<Value *, 16> ReductionOps;
   SmallVector<Value *, 32> ReducedVals;
 
   BinaryOperator *ReductionRoot;
@@ -2295,7 +2415,7 @@ public:
           // We need to be able to reassociate the adds.
           if (!TreeN->isAssociative())
             return false;
-          ReductionOps.insert(TreeN);
+          ReductionOps.push_back(TreeN);
         }
         // Retract.
         Stack.pop_back();
@@ -2332,7 +2452,7 @@ public:
 
     for (; i < NumReducedVals - ReduxWidth + 1; i += ReduxWidth) {
       ArrayRef<Value *> ValsToReduce(&ReducedVals[i], ReduxWidth);
-      V.buildTree(ValsToReduce, &ReductionOps);
+      V.buildTree(ValsToReduce, ReductionOps);
 
       // Estimate cost.
       int Cost = V.getTreeCost() + getReductionCost(TTI, ReducedVals[i]);
@@ -2451,13 +2571,16 @@ private:
 ///
 /// Returns true if it matches
 ///
-static bool findBuildVector(InsertElementInst *IE,
-                            SmallVectorImpl<Value *> &Ops) {
-  if (!isa<UndefValue>(IE->getOperand(0)))
+static bool findBuildVector(InsertElementInst *FirstInsertElem,
+                            SmallVectorImpl<Value *> &BuildVector,
+                            SmallVectorImpl<Value *> &BuildVectorOpds) {
+  if (!isa<UndefValue>(FirstInsertElem->getOperand(0)))
     return false;
 
+  InsertElementInst *IE = FirstInsertElem;
   while (true) {
-    Ops.push_back(IE->getOperand(1));
+    BuildVector.push_back(IE);
+    BuildVectorOpds.push_back(IE->getOperand(1));
 
     if (IE->use_empty())
       return false;
@@ -2627,12 +2750,16 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
     }
 
     // Try to vectorize trees that start at insertelement instructions.
-    if (InsertElementInst *IE = dyn_cast<InsertElementInst>(it)) {
-      SmallVector<Value *, 8> Ops;
-      if (!findBuildVector(IE, Ops))
+    if (InsertElementInst *FirstInsertElem = dyn_cast<InsertElementInst>(it)) {
+      SmallVector<Value *, 16> BuildVector;
+      SmallVector<Value *, 16> BuildVectorOpds;
+      if (!findBuildVector(FirstInsertElem, BuildVector, BuildVectorOpds))
         continue;
 
-      if (tryToVectorizeList(Ops, R)) {
+      // Vectorize starting with the build vector operands ignoring the
+      // BuildVector instructions for the purpose of scheduling and user
+      // extraction.
+      if (tryToVectorizeList(BuildVectorOpds, R, BuildVector)) {
         Changed = true;
         it = BB->begin();
         e = BB->end();