Use SmallVectorImpl::iterator/const_iterator instead of SmallVector to avoid specifyi...
[oota-llvm.git] / lib / Transforms / Vectorize / SLPVectorizer.cpp
1 //===- SLPVectorizer.cpp - A bottom up SLP Vectorizer ---------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This pass implements the Bottom Up SLP vectorizer. It detects consecutive
10 // stores that can be put together into vector-stores. Next, it attempts to
11 // construct vectorizable tree using the use-def chains. If a profitable tree
12 // was found, the SLP vectorizer performs vectorization on the tree.
13 //
14 // The pass is inspired by the work described in the paper:
15 //  "Loop-Aware SLP in GCC" by Ira Rosen, Dorit Nuzman, Ayal Zaks.
16 //
17 //===----------------------------------------------------------------------===//
18 #define SV_NAME "slp-vectorizer"
19 #define DEBUG_TYPE "SLP"
20
21 #include "llvm/Transforms/Vectorize.h"
22 #include "llvm/ADT/MapVector.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/Analysis/AliasAnalysis.h"
26 #include "llvm/Analysis/ScalarEvolution.h"
27 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
28 #include "llvm/Analysis/AliasAnalysis.h"
29 #include "llvm/Analysis/TargetTransformInfo.h"
30 #include "llvm/Analysis/Verifier.h"
31 #include "llvm/Analysis/LoopInfo.h"
32 #include "llvm/IR/DataLayout.h"
33 #include "llvm/IR/Instructions.h"
34 #include "llvm/IR/IntrinsicInst.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/Module.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/CommandLine.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/raw_ostream.h"
43 #include <algorithm>
44 #include <map>
45
46 using namespace llvm;
47
48 static cl::opt<int>
49     SLPCostThreshold("slp-threshold", cl::init(0), cl::Hidden,
50                      cl::desc("Only vectorize if you gain more than this "
51                               "number "));
52 namespace {
53
54 static const unsigned MinVecRegSize = 128;
55
56 static const unsigned RecursionMaxDepth = 12;
57
58 /// RAII pattern to save the insertion point of the IR builder.
59 class BuilderLocGuard {
60 public:
61   BuilderLocGuard(IRBuilder<> &B) : Builder(B), Loc(B.GetInsertPoint()) {}
62   ~BuilderLocGuard() { Builder.SetInsertPoint(Loc); }
63
64 private:
65   // Prevent copying.
66   BuilderLocGuard(const BuilderLocGuard &);
67   BuilderLocGuard &operator=(const BuilderLocGuard &);
68   IRBuilder<> &Builder;
69   BasicBlock::iterator Loc;
70 };
71
72 /// A helper class for numbering instructions in multible blocks.
73 /// Numbers starts at zero for each basic block.
74 struct BlockNumbering {
75
76   BlockNumbering(BasicBlock *Bb) : BB(Bb), Valid(false) {}
77
78   BlockNumbering() : BB(0), Valid(false) {}
79
80   void numberInstructions() {
81     unsigned Loc = 0;
82     InstrIdx.clear();
83     InstrVec.clear();
84     // Number the instructions in the block.
85     for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
86       InstrIdx[it] = Loc++;
87       InstrVec.push_back(it);
88       assert(InstrVec[InstrIdx[it]] == it && "Invalid allocation");
89     }
90     Valid = true;
91   }
92
93   int getIndex(Instruction *I) {
94     if (!Valid)
95       numberInstructions();
96     assert(InstrIdx.count(I) && "Unknown instruction");
97     return InstrIdx[I];
98   }
99
100   Instruction *getInstruction(unsigned loc) {
101     if (!Valid)
102       numberInstructions();
103     assert(InstrVec.size() > loc && "Invalid Index");
104     return InstrVec[loc];
105   }
106
107   void forget() { Valid = false; }
108
109 private:
110   /// The block we are numbering.
111   BasicBlock *BB;
112   /// Is the block numbered.
113   bool Valid;
114   /// Maps instructions to numbers and back.
115   SmallDenseMap<Instruction *, int> InstrIdx;
116   /// Maps integers to Instructions.
117   std::vector<Instruction *> InstrVec;
118 };
119
120 class FuncSLP {
121   typedef SmallVector<Value *, 8> ValueList;
122   typedef SmallVector<Instruction *, 16> InstrList;
123   typedef SmallPtrSet<Value *, 16> ValueSet;
124   typedef SmallVector<StoreInst *, 8> StoreList;
125
126 public:
127   static const int MAX_COST = INT_MIN;
128
129   FuncSLP(Function *Func, ScalarEvolution *Se, DataLayout *Dl,
130           TargetTransformInfo *Tti, AliasAnalysis *Aa, LoopInfo *Li, 
131           DominatorTree *Dt) :
132     F(Func), SE(Se), DL(Dl), TTI(Tti), AA(Aa), LI(Li), DT(Dt),
133     Builder(Se->getContext()) {
134     for (Function::iterator it = F->begin(), e = F->end(); it != e; ++it) {
135       BasicBlock *BB = it;
136       BlocksNumbers[BB] = BlockNumbering(BB);
137     }
138   }
139
140   /// \brief Take the pointer operand from the Load/Store instruction.
141   /// \returns NULL if this is not a valid Load/Store instruction.
142   static Value *getPointerOperand(Value *I);
143
144   /// \brief Take the address space operand from the Load/Store instruction.
145   /// \returns -1 if this is not a valid Load/Store instruction.
146   static unsigned getAddressSpaceOperand(Value *I);
147
148   /// \returns true if the memory operations A and B are consecutive.
149   bool isConsecutiveAccess(Value *A, Value *B);
150
151   /// \brief Vectorize the tree that starts with the elements in \p VL.
152   /// \returns the vectorized value.
153   Value *vectorizeTree(ArrayRef<Value *> VL);
154
155   /// \returns the vectorization cost of the subtree that starts at \p VL.
156   /// A negative number means that this is profitable.
157   int getTreeCost(ArrayRef<Value *> VL);
158
159   /// \returns the scalarization cost for this list of values. Assuming that
160   /// this subtree gets vectorized, we may need to extract the values from the
161   /// roots. This method calculates the cost of extracting the values.
162   int getGatherCost(ArrayRef<Value *> VL);
163
164   /// \brief Attempts to order and vectorize a sequence of stores. This
165   /// function does a quadratic scan of the given stores.
166   /// \returns true if the basic block was modified.
167   bool vectorizeStores(ArrayRef<StoreInst *> Stores, int costThreshold);
168
169   /// \brief Vectorize a group of scalars into a vector tree.
170   /// \returns the vectorized value.
171   Value *vectorizeArith(ArrayRef<Value *> Operands);
172
173   /// \brief This method contains the recursive part of getTreeCost.
174   int getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth);
175
176   /// \brief This recursive method looks for vectorization hazards such as
177   /// values that are used by multiple users and checks that values are used
178   /// by only one vector lane. It updates the variables LaneMap, MultiUserVals.
179   void getTreeUses_rec(ArrayRef<Value *> VL, unsigned Depth);
180
181   /// \brief This method contains the recursive part of vectorizeTree.
182   Value *vectorizeTree_rec(ArrayRef<Value *> VL);
183
184   ///  \brief Vectorize a sorted sequence of stores.
185   bool vectorizeStoreChain(ArrayRef<Value *> Chain, int CostThreshold);
186
187   /// \returns the scalarization cost for this type. Scalarization in this
188   /// context means the creation of vectors from a group of scalars.
189   int getGatherCost(Type *Ty);
190
191   /// \returns the AA location that is being access by the instruction.
192   AliasAnalysis::Location getLocation(Instruction *I);
193
194   /// \brief Checks if it is possible to sink an instruction from
195   /// \p Src to \p Dst.
196   /// \returns the pointer to the barrier instruction if we can't sink.
197   Value *getSinkBarrier(Instruction *Src, Instruction *Dst);
198
199   /// \returns the index of the last instrucion in the BB from \p VL.
200   int getLastIndex(ArrayRef<Value *> VL);
201
202   /// \returns the Instrucion in the bundle \p VL.
203   Instruction *getLastInstruction(ArrayRef<Value *> VL);
204
205   /// \returns the Instruction at index \p Index which is in Block \p BB.
206   Instruction *getInstructionForIndex(unsigned Index, BasicBlock *BB);
207
208   /// \returns the index of the first User of \p VL.
209   int getFirstUserIndex(ArrayRef<Value *> VL);
210
211   /// \returns a vector from a collection of scalars in \p VL.
212   Value *Gather(ArrayRef<Value *> VL, VectorType *Ty);
213
214   /// \brief Perform LICM and CSE on the newly generated gather sequences.
215   void optimizeGatherSequence();
216
217   bool needToGatherAny(ArrayRef<Value *> VL) {
218     for (int i = 0, e = VL.size(); i < e; ++i)
219       if (MustGather.count(VL[i]))
220         return true;
221     return false;
222   }
223
224   void forgetNumbering() {
225     for (Function::iterator it = F->begin(), e = F->end(); it != e; ++it)
226       BlocksNumbers[it].forget();
227   }
228
229   /// -- Vectorization State --
230
231   /// Maps values in the tree to the vector lanes that uses them. This map must
232   /// be reset between runs of getCost.
233   std::map<Value *, int> LaneMap;
234   /// A list of instructions to ignore while sinking
235   /// memory instructions. This map must be reset between runs of getCost.
236   ValueSet MemBarrierIgnoreList;
237
238   /// Maps between the first scalar to the vector. This map must be reset
239   /// between runs.
240   DenseMap<Value *, Value *> VectorizedValues;
241
242   /// Contains values that must be gathered because they are used
243   /// by multiple lanes, or by users outside the tree.
244   /// NOTICE: The vectorization methods also use this set.
245   ValueSet MustGather;
246
247   /// Contains PHINodes that are being processed. We use this data structure
248   /// to stop cycles in the graph.
249   ValueSet VisitedPHIs;
250
251   /// Contains a list of values that are used outside the current tree, the
252   /// first element in the bundle and the insertion point for extracts. This
253   /// set must be reset between runs.
254   struct UseInfo{
255     UseInfo(Instruction *VL0, int I) :
256       Leader(VL0), LastIndex(I) {}
257     UseInfo() : Leader(0), LastIndex(0) {}
258     /// The first element in the bundle.
259     Instruction *Leader;
260     /// The insertion index.
261     int LastIndex;
262   };
263   MapVector<Instruction*, UseInfo> MultiUserVals;
264   SetVector<Instruction*> ExtractedLane;
265
266   /// Holds all of the instructions that we gathered.
267   SetVector<Instruction *> GatherSeq;
268
269   /// Numbers instructions in different blocks.
270   std::map<BasicBlock *, BlockNumbering> BlocksNumbers;
271
272   // Analysis and block reference.
273   Function *F;
274   ScalarEvolution *SE;
275   DataLayout *DL;
276   TargetTransformInfo *TTI;
277   AliasAnalysis *AA;
278   LoopInfo *LI;
279   DominatorTree *DT;
280   /// Instruction builder to construct the vectorized tree.
281   IRBuilder<> Builder;
282 };
283
284 int FuncSLP::getGatherCost(Type *Ty) {
285   int Cost = 0;
286   for (unsigned i = 0, e = cast<VectorType>(Ty)->getNumElements(); i < e; ++i)
287     Cost += TTI->getVectorInstrCost(Instruction::InsertElement, Ty, i);
288   return Cost;
289 }
290
291 int FuncSLP::getGatherCost(ArrayRef<Value *> VL) {
292   // Find the type of the operands in VL.
293   Type *ScalarTy = VL[0]->getType();
294   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
295     ScalarTy = SI->getValueOperand()->getType();
296   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
297   // Find the cost of inserting/extracting values from the vector.
298   return getGatherCost(VecTy);
299 }
300
301 AliasAnalysis::Location FuncSLP::getLocation(Instruction *I) {
302   if (StoreInst *SI = dyn_cast<StoreInst>(I))
303     return AA->getLocation(SI);
304   if (LoadInst *LI = dyn_cast<LoadInst>(I))
305     return AA->getLocation(LI);
306   return AliasAnalysis::Location();
307 }
308
309 Value *FuncSLP::getPointerOperand(Value *I) {
310   if (LoadInst *LI = dyn_cast<LoadInst>(I))
311     return LI->getPointerOperand();
312   if (StoreInst *SI = dyn_cast<StoreInst>(I))
313     return SI->getPointerOperand();
314   return 0;
315 }
316
317 unsigned FuncSLP::getAddressSpaceOperand(Value *I) {
318   if (LoadInst *L = dyn_cast<LoadInst>(I))
319     return L->getPointerAddressSpace();
320   if (StoreInst *S = dyn_cast<StoreInst>(I))
321     return S->getPointerAddressSpace();
322   return -1;
323 }
324
325 bool FuncSLP::isConsecutiveAccess(Value *A, Value *B) {
326   Value *PtrA = getPointerOperand(A);
327   Value *PtrB = getPointerOperand(B);
328   unsigned ASA = getAddressSpaceOperand(A);
329   unsigned ASB = getAddressSpaceOperand(B);
330
331   // Check that the address spaces match and that the pointers are valid.
332   if (!PtrA || !PtrB || (ASA != ASB))
333     return false;
334
335   // Check that A and B are of the same type.
336   if (PtrA->getType() != PtrB->getType())
337     return false;
338
339   // Calculate the distance.
340   const SCEV *PtrSCEVA = SE->getSCEV(PtrA);
341   const SCEV *PtrSCEVB = SE->getSCEV(PtrB);
342   const SCEV *OffsetSCEV = SE->getMinusSCEV(PtrSCEVA, PtrSCEVB);
343   const SCEVConstant *ConstOffSCEV = dyn_cast<SCEVConstant>(OffsetSCEV);
344
345   // Non constant distance.
346   if (!ConstOffSCEV)
347     return false;
348
349   int64_t Offset = ConstOffSCEV->getValue()->getSExtValue();
350   Type *Ty = cast<PointerType>(PtrA->getType())->getElementType();
351   // The Instructions are connsecutive if the size of the first load/store is
352   // the same as the offset.
353   int64_t Sz = DL->getTypeStoreSize(Ty);
354   return ((-Offset) == Sz);
355 }
356
357 Value *FuncSLP::getSinkBarrier(Instruction *Src, Instruction *Dst) {
358   assert(Src->getParent() == Dst->getParent() && "Not the same BB");
359   BasicBlock::iterator I = Src, E = Dst;
360   /// Scan all of the instruction from SRC to DST and check if
361   /// the source may alias.
362   for (++I; I != E; ++I) {
363     // Ignore store instructions that are marked as 'ignore'.
364     if (MemBarrierIgnoreList.count(I))
365       continue;
366     if (Src->mayWriteToMemory()) /* Write */ {
367       if (!I->mayReadOrWriteMemory())
368         continue;
369     } else /* Read */ {
370       if (!I->mayWriteToMemory())
371         continue;
372     }
373     AliasAnalysis::Location A = getLocation(&*I);
374     AliasAnalysis::Location B = getLocation(Src);
375
376     if (!A.Ptr || !B.Ptr || AA->alias(A, B))
377       return I;
378   }
379   return 0;
380 }
381
382 static BasicBlock *getSameBlock(ArrayRef<Value *> VL) {
383   BasicBlock *BB = 0;
384   for (int i = 0, e = VL.size(); i < e; i++) {
385     Instruction *I = dyn_cast<Instruction>(VL[i]);
386     if (!I)
387       return 0;
388
389     if (!BB) {
390       BB = I->getParent();
391       continue;
392     }
393
394     if (BB != I->getParent())
395       return 0;
396   }
397   return BB;
398 }
399
400 static bool allConstant(ArrayRef<Value *> VL) {
401   for (unsigned i = 0, e = VL.size(); i < e; ++i)
402     if (!isa<Constant>(VL[i]))
403       return false;
404   return true;
405 }
406
407 static bool isSplat(ArrayRef<Value *> VL) {
408   for (unsigned i = 1, e = VL.size(); i < e; ++i)
409     if (VL[i] != VL[0])
410       return false;
411   return true;
412 }
413
414 static unsigned getSameOpcode(ArrayRef<Value *> VL) {
415   unsigned Opcode = 0;
416   for (int i = 0, e = VL.size(); i < e; i++) {
417     if (Instruction *I = dyn_cast<Instruction>(VL[i])) {
418       if (!Opcode) {
419         Opcode = I->getOpcode();
420         continue;
421       }
422       if (Opcode != I->getOpcode())
423         return 0;
424     }
425   }
426   return Opcode;
427 }
428
429 static bool CanReuseExtract(ArrayRef<Value *> VL, unsigned VF,
430                             VectorType *VecTy) {
431   assert(Instruction::ExtractElement == getSameOpcode(VL) && "Invalid opcode");
432   // Check if all of the extracts come from the same vector and from the
433   // correct offset.
434   Value *VL0 = VL[0];
435   ExtractElementInst *E0 = cast<ExtractElementInst>(VL0);
436   Value *Vec = E0->getOperand(0);
437
438   // We have to extract from the same vector type.
439   if (Vec->getType() != VecTy)
440     return false;
441
442   // Check that all of the indices extract from the correct offset.
443   ConstantInt *CI = dyn_cast<ConstantInt>(E0->getOperand(1));
444   if (!CI || CI->getZExtValue())
445     return false;
446
447   for (unsigned i = 1, e = VF; i < e; ++i) {
448     ExtractElementInst *E = cast<ExtractElementInst>(VL[i]);
449     ConstantInt *CI = dyn_cast<ConstantInt>(E->getOperand(1));
450
451     if (!CI || CI->getZExtValue() != i || E->getOperand(0) != Vec)
452       return false;
453   }
454
455   return true;
456 }
457
458 void FuncSLP::getTreeUses_rec(ArrayRef<Value *> VL, unsigned Depth) {
459   if (Depth == RecursionMaxDepth)
460     return MustGather.insert(VL.begin(), VL.end());
461
462   // Don't handle vectors.
463   if (VL[0]->getType()->isVectorTy())
464     return;
465
466   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
467     if (SI->getValueOperand()->getType()->isVectorTy())
468       return;
469
470   // If all of the operands are identical or constant we have a simple solution.
471   if (allConstant(VL) || isSplat(VL) || !getSameBlock(VL))
472     return MustGather.insert(VL.begin(), VL.end());
473
474   // Stop the scan at unknown IR.
475   Instruction *VL0 = dyn_cast<Instruction>(VL[0]);
476   assert(VL0 && "Invalid instruction");
477
478   // Mark instructions with multiple users.
479   int LastIndex = getLastIndex(VL);
480   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
481     if (PHINode *PN = dyn_cast<PHINode>(VL[i])) {
482       unsigned NumUses = 0;
483       // Check that PHINodes have only one external (non-self) use.
484       for (Value::use_iterator U = VL[i]->use_begin(), UE = VL[i]->use_end();
485            U != UE; ++U) {
486         // Don't count self uses.
487         if (*U == PN)
488           continue;
489         NumUses++;
490       }
491       if (NumUses > 1) {
492         DEBUG(dbgs() << "SLP: Adding PHI to MultiUserVals "
493               "because it has " << NumUses << " users:" << *PN << " \n");
494         UseInfo UI(VL0, 0);
495         MultiUserVals[PN] = UI;
496       }
497       continue;
498     }
499
500     Instruction *I = dyn_cast<Instruction>(VL[i]);
501     // Remember to check if all of the users of this instruction are vectorized
502     // within our tree. At depth zero we have no local users, only external
503     // users that we don't care about.
504     if (Depth && I && I->getNumUses() > 1) {
505       DEBUG(dbgs() << "SLP: Adding to MultiUserVals "
506             "because it has " << I->getNumUses() << " users:" << *I << " \n");
507       UseInfo UI(VL0, LastIndex);
508       MultiUserVals[I] = UI;
509     }
510   }
511
512   // Check that the instruction is only used within one lane.
513   for (int i = 0, e = VL.size(); i < e; ++i) {
514     if (LaneMap.count(VL[i]) && LaneMap[VL[i]] != i) {
515       DEBUG(dbgs() << "SLP: Value used by multiple lanes:" << *VL[i] << "\n");
516       return MustGather.insert(VL.begin(), VL.end());
517     }
518     // Make this instruction as 'seen' and remember the lane.
519     LaneMap[VL[i]] = i;
520   }
521
522   unsigned Opcode = getSameOpcode(VL);
523   if (!Opcode)
524     return MustGather.insert(VL.begin(), VL.end());
525
526   switch (Opcode) {
527   case Instruction::PHI: {
528     PHINode *PH = dyn_cast<PHINode>(VL0);
529
530     // Stop self cycles.
531     if (VisitedPHIs.count(PH))
532         return;
533
534     VisitedPHIs.insert(PH);
535     for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {
536       ValueList Operands;
537       // Prepare the operand vector.
538       for (unsigned j = 0; j < VL.size(); ++j)
539         Operands.push_back(cast<PHINode>(VL[j])->getIncomingValue(i));
540
541       getTreeUses_rec(Operands, Depth + 1);
542     }
543     return;
544   }
545   case Instruction::ExtractElement: {
546     VectorType *VecTy = VectorType::get(VL[0]->getType(), VL.size());
547     // No need to follow ExtractElements that are going to be optimized away.
548     if (CanReuseExtract(VL, VL.size(), VecTy))
549       return;
550     // Fall through.
551   }
552   case Instruction::Load:
553     return;
554   case Instruction::ZExt:
555   case Instruction::SExt:
556   case Instruction::FPToUI:
557   case Instruction::FPToSI:
558   case Instruction::FPExt:
559   case Instruction::PtrToInt:
560   case Instruction::IntToPtr:
561   case Instruction::SIToFP:
562   case Instruction::UIToFP:
563   case Instruction::Trunc:
564   case Instruction::FPTrunc:
565   case Instruction::BitCast:
566   case Instruction::Select:
567   case Instruction::ICmp:
568   case Instruction::FCmp:
569   case Instruction::Add:
570   case Instruction::FAdd:
571   case Instruction::Sub:
572   case Instruction::FSub:
573   case Instruction::Mul:
574   case Instruction::FMul:
575   case Instruction::UDiv:
576   case Instruction::SDiv:
577   case Instruction::FDiv:
578   case Instruction::URem:
579   case Instruction::SRem:
580   case Instruction::FRem:
581   case Instruction::Shl:
582   case Instruction::LShr:
583   case Instruction::AShr:
584   case Instruction::And:
585   case Instruction::Or:
586   case Instruction::Xor: {
587     for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
588       ValueList Operands;
589       // Prepare the operand vector.
590       for (unsigned j = 0; j < VL.size(); ++j)
591         Operands.push_back(cast<Instruction>(VL[j])->getOperand(i));
592
593       getTreeUses_rec(Operands, Depth + 1);
594     }
595     return;
596   }
597   case Instruction::Store: {
598     ValueList Operands;
599     for (unsigned j = 0; j < VL.size(); ++j)
600       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
601     getTreeUses_rec(Operands, Depth + 1);
602     return;
603   }
604   default:
605     return MustGather.insert(VL.begin(), VL.end());
606   }
607 }
608
609 int FuncSLP::getLastIndex(ArrayRef<Value *> VL) {
610   BasicBlock *BB = cast<Instruction>(VL[0])->getParent();
611   assert(BB == getSameBlock(VL) && BlocksNumbers.count(BB) && "Invalid block");
612   BlockNumbering &BN = BlocksNumbers[BB];
613
614   int MaxIdx = BN.getIndex(BB->getFirstNonPHI());
615   for (unsigned i = 0, e = VL.size(); i < e; ++i)
616     MaxIdx = std::max(MaxIdx, BN.getIndex(cast<Instruction>(VL[i])));
617   return MaxIdx;
618 }
619
620 Instruction *FuncSLP::getLastInstruction(ArrayRef<Value *> VL) {
621   BasicBlock *BB = cast<Instruction>(VL[0])->getParent();
622   assert(BB == getSameBlock(VL) && BlocksNumbers.count(BB) && "Invalid block");
623   BlockNumbering &BN = BlocksNumbers[BB];
624
625   int MaxIdx = BN.getIndex(cast<Instruction>(VL[0]));
626   for (unsigned i = 1, e = VL.size(); i < e; ++i)
627     MaxIdx = std::max(MaxIdx, BN.getIndex(cast<Instruction>(VL[i])));
628   return BN.getInstruction(MaxIdx);
629 }
630
631 Instruction *FuncSLP::getInstructionForIndex(unsigned Index, BasicBlock *BB) {
632   BlockNumbering &BN = BlocksNumbers[BB];
633   return BN.getInstruction(Index);
634 }
635
636 int FuncSLP::getFirstUserIndex(ArrayRef<Value *> VL) {
637   BasicBlock *BB = getSameBlock(VL);
638   assert(BB && "All instructions must come from the same block");
639   BlockNumbering &BN = BlocksNumbers[BB];
640
641   // Find the first user of the values.
642   int FirstUser = BN.getIndex(BB->getTerminator());
643   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
644     for (Value::use_iterator U = VL[i]->use_begin(), UE = VL[i]->use_end();
645          U != UE; ++U) {
646       Instruction *Instr = dyn_cast<Instruction>(*U);
647
648       if (!Instr || Instr->getParent() != BB)
649         continue;
650
651       FirstUser = std::min(FirstUser, BN.getIndex(Instr));
652     }
653   }
654   return FirstUser;
655 }
656
657 int FuncSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
658   Type *ScalarTy = VL[0]->getType();
659
660   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
661     ScalarTy = SI->getValueOperand()->getType();
662
663   /// Don't mess with vectors.
664   if (ScalarTy->isVectorTy())
665     return FuncSLP::MAX_COST;
666
667   if (allConstant(VL))
668     return 0;
669
670   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
671
672   if (isSplat(VL))
673     return TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, 0);
674
675   int GatherCost = getGatherCost(VecTy);
676   if (Depth == RecursionMaxDepth || needToGatherAny(VL))
677     return GatherCost;
678
679   BasicBlock *BB = getSameBlock(VL);
680   unsigned Opcode = getSameOpcode(VL);
681   assert(Opcode && BB && "Invalid Instruction Value");
682
683   // Check if it is safe to sink the loads or the stores.
684   if (Opcode == Instruction::Load || Opcode == Instruction::Store) {
685     int MaxIdx = getLastIndex(VL);
686     Instruction *Last = getInstructionForIndex(MaxIdx, BB);
687
688     for (unsigned i = 0, e = VL.size(); i < e; ++i) {
689       if (VL[i] == Last)
690         continue;
691       Value *Barrier = getSinkBarrier(cast<Instruction>(VL[i]), Last);
692       if (Barrier) {
693         DEBUG(dbgs() << "SLP: Can't sink " << *VL[i] << "\n down to " << *Last
694                      << "\n because of " << *Barrier << "\n");
695         return MAX_COST;
696       }
697     }
698   }
699
700   // Calculate the extract cost.
701   unsigned ExternalUserExtractCost = 0;
702   for (unsigned i = 0, e = VL.size(); i < e; ++i)
703     if (ExtractedLane.count(cast<Instruction>(VL[i])))
704       ExternalUserExtractCost +=
705         TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, i);
706
707   Instruction *VL0 = cast<Instruction>(VL[0]);
708   switch (Opcode) {
709   case Instruction::PHI: {
710     PHINode *PH = dyn_cast<PHINode>(VL0);
711
712     // Stop self cycles.
713     if (VisitedPHIs.count(PH))
714         return 0;
715
716     VisitedPHIs.insert(PH);
717     int TotalCost = 0;
718     // Calculate the cost of all of the operands.
719     for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {      
720       ValueList Operands;
721       // Prepare the operand vector.
722       for (unsigned j = 0; j < VL.size(); ++j)
723         Operands.push_back(cast<PHINode>(VL[j])->getIncomingValue(i));
724
725       int Cost = getTreeCost_rec(Operands, Depth + 1);
726       if (Cost == MAX_COST)
727         return MAX_COST;
728       TotalCost += TotalCost;
729     }
730
731     if (TotalCost > GatherCost) {
732       MustGather.insert(VL.begin(), VL.end());
733       return GatherCost;
734     }
735
736     return TotalCost + ExternalUserExtractCost;
737   }
738   case Instruction::ExtractElement: {
739     if (CanReuseExtract(VL, VL.size(), VecTy))
740       return 0;
741     return getGatherCost(VecTy);
742   }
743   case Instruction::ZExt:
744   case Instruction::SExt:
745   case Instruction::FPToUI:
746   case Instruction::FPToSI:
747   case Instruction::FPExt:
748   case Instruction::PtrToInt:
749   case Instruction::IntToPtr:
750   case Instruction::SIToFP:
751   case Instruction::UIToFP:
752   case Instruction::Trunc:
753   case Instruction::FPTrunc:
754   case Instruction::BitCast: {
755     ValueList Operands;
756     Type *SrcTy = VL0->getOperand(0)->getType();
757     // Prepare the operand vector.
758     for (unsigned j = 0; j < VL.size(); ++j) {
759       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
760       // Check that the casted type is the same for all users.
761       if (cast<Instruction>(VL[j])->getOperand(0)->getType() != SrcTy)
762         return getGatherCost(VecTy);
763     }
764
765     int Cost = getTreeCost_rec(Operands, Depth + 1);
766     if (Cost == MAX_COST)
767       return MAX_COST;
768
769     // Calculate the cost of this instruction.
770     int ScalarCost = VL.size() * TTI->getCastInstrCost(VL0->getOpcode(),
771                                                        VL0->getType(), SrcTy);
772
773     VectorType *SrcVecTy = VectorType::get(SrcTy, VL.size());
774     int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy);
775     Cost += (VecCost - ScalarCost);
776
777     if (Cost > GatherCost) {
778       MustGather.insert(VL.begin(), VL.end());
779       return GatherCost;
780     }
781
782     return Cost + ExternalUserExtractCost;
783   }
784   case Instruction::FCmp:
785   case Instruction::ICmp: {
786     // Check that all of the compares have the same predicate.
787     CmpInst::Predicate P0 = dyn_cast<CmpInst>(VL0)->getPredicate();
788     for (unsigned i = 1, e = VL.size(); i < e; ++i) {
789       CmpInst *Cmp = cast<CmpInst>(VL[i]);
790       if (Cmp->getPredicate() != P0)
791         return getGatherCost(VecTy);
792     }
793     // Fall through.
794   }
795   case Instruction::Select:
796   case Instruction::Add:
797   case Instruction::FAdd:
798   case Instruction::Sub:
799   case Instruction::FSub:
800   case Instruction::Mul:
801   case Instruction::FMul:
802   case Instruction::UDiv:
803   case Instruction::SDiv:
804   case Instruction::FDiv:
805   case Instruction::URem:
806   case Instruction::SRem:
807   case Instruction::FRem:
808   case Instruction::Shl:
809   case Instruction::LShr:
810   case Instruction::AShr:
811   case Instruction::And:
812   case Instruction::Or:
813   case Instruction::Xor: {
814     int TotalCost = 0;
815     // Calculate the cost of all of the operands.
816     for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
817       ValueList Operands;
818       // Prepare the operand vector.
819       for (unsigned j = 0; j < VL.size(); ++j)
820         Operands.push_back(cast<Instruction>(VL[j])->getOperand(i));
821
822       int Cost = getTreeCost_rec(Operands, Depth + 1);
823       if (Cost == MAX_COST)
824         return MAX_COST;
825       TotalCost += Cost;
826     }
827
828     // Calculate the cost of this instruction.
829     int ScalarCost = 0;
830     int VecCost = 0;
831     if (Opcode == Instruction::FCmp || Opcode == Instruction::ICmp ||
832         Opcode == Instruction::Select) {
833       VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size());
834       ScalarCost =
835           VecTy->getNumElements() *
836           TTI->getCmpSelInstrCost(Opcode, ScalarTy, Builder.getInt1Ty());
837       VecCost = TTI->getCmpSelInstrCost(Opcode, VecTy, MaskTy);
838     } else {
839       ScalarCost = VecTy->getNumElements() *
840                    TTI->getArithmeticInstrCost(Opcode, ScalarTy);
841       VecCost = TTI->getArithmeticInstrCost(Opcode, VecTy);
842     }
843     TotalCost += (VecCost - ScalarCost);
844
845     if (TotalCost > GatherCost) {
846       MustGather.insert(VL.begin(), VL.end());
847       return GatherCost;
848     }
849
850     return TotalCost + ExternalUserExtractCost;
851   }
852   case Instruction::Load: {
853     // If we are scalarize the loads, add the cost of forming the vector.
854     for (unsigned i = 0, e = VL.size() - 1; i < e; ++i)
855       if (!isConsecutiveAccess(VL[i], VL[i + 1]))
856         return getGatherCost(VecTy);
857
858     // Cost of wide load - cost of scalar loads.
859     int ScalarLdCost = VecTy->getNumElements() *
860                        TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0);
861     int VecLdCost = TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0);
862     int TotalCost = VecLdCost - ScalarLdCost;
863
864     if (TotalCost > GatherCost) {
865       MustGather.insert(VL.begin(), VL.end());
866       return GatherCost;
867     }
868
869     return TotalCost + ExternalUserExtractCost;
870   }
871   case Instruction::Store: {
872     // We know that we can merge the stores. Calculate the cost.
873     int ScalarStCost = VecTy->getNumElements() *
874                        TTI->getMemoryOpCost(Instruction::Store, ScalarTy, 1, 0);
875     int VecStCost = TTI->getMemoryOpCost(Instruction::Store, ScalarTy, 1, 0);
876     int StoreCost = VecStCost - ScalarStCost;
877
878     ValueList Operands;
879     for (unsigned j = 0; j < VL.size(); ++j) {
880       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
881       MemBarrierIgnoreList.insert(VL[j]);
882     }
883
884     int Cost = getTreeCost_rec(Operands, Depth + 1);
885     if (Cost == MAX_COST)
886       return MAX_COST;
887
888     int TotalCost = StoreCost + Cost;
889     return TotalCost + ExternalUserExtractCost;
890   }
891   default:
892     // Unable to vectorize unknown instructions.
893     return getGatherCost(VecTy);
894   }
895 }
896
897 int FuncSLP::getTreeCost(ArrayRef<Value *> VL) {
898   // Get rid of the list of stores that were removed, and from the
899   // lists of instructions with multiple users.
900   MemBarrierIgnoreList.clear();
901   LaneMap.clear();
902   MultiUserVals.clear();
903   ExtractedLane.clear();
904   MustGather.clear();
905   VisitedPHIs.clear();
906
907   if (!getSameBlock(VL))
908     return MAX_COST;
909
910   // Find the location of the last root.
911   int LastRootIndex = getLastIndex(VL);
912   int FirstUserIndex = getFirstUserIndex(VL);
913
914   // Don't vectorize if there are users of the tree roots inside the tree
915   // itself.
916   if (LastRootIndex > FirstUserIndex)
917     return MAX_COST;
918
919   // Scan the tree and find which value is used by which lane, and which values
920   // must be scalarized.
921   getTreeUses_rec(VL, 0);
922
923   // Check that instructions with multiple users can be vectorized. Mark
924   // unsafe instructions.
925   for (MapVector<Instruction *, UseInfo>::iterator UI = MultiUserVals.begin(),
926        e = MultiUserVals.end(); UI != e; ++UI) {
927     Instruction *Scalar = UI->first;
928
929     if (MustGather.count(Scalar))
930       continue;
931
932     assert(LaneMap.count(Scalar) && "Unknown scalar");
933     int ScalarLane = LaneMap[Scalar];
934
935     bool ExternalUse = false;
936     // Check that all of the users of this instr are within the tree.
937     for (Value::use_iterator Usr = Scalar->use_begin(),
938          UE = Scalar->use_end(); Usr != UE; ++Usr) {
939       // If this user is within the tree, make sure it is from the same lane.
940       // Notice that we have both in-tree and out-of-tree users.
941       if (LaneMap.count(*Usr)) {
942         if (LaneMap[*Usr] != ScalarLane) {
943           DEBUG(dbgs() << "SLP: Adding to MustExtract "
944                 "because of an out-of-lane usage.\n");
945           MustGather.insert(Scalar);
946           break;
947         }
948         continue;
949       }
950
951       // We have an out-of-tree user. Check if we can place an 'extract'.
952       Instruction *User = cast<Instruction>(*Usr);
953       // We care about the order only if the user is in the same block.
954       if (User->getParent() == Scalar->getParent()) {
955         int LastLoc = UI->second.LastIndex;
956         BlockNumbering &BN = BlocksNumbers[User->getParent()];
957         int UserIdx = BN.getIndex(User);
958         if (UserIdx <= LastLoc) {
959           DEBUG(dbgs() << "SLP: Adding to MustExtract because of an external "
960                 "user that we can't schedule.\n");
961           MustGather.insert(Scalar);
962           break;
963         }
964       }
965       // We have an external user.
966       ExternalUse = true;
967     }
968
969     if (ExternalUse) {
970       // Items that are left in MultiUserVals are to be extracted.
971       // ExtractLane is used for the lookup.
972       ExtractedLane.insert(Scalar);
973     }
974
975   }
976
977   // Now calculate the cost of vectorizing the tree.
978   return getTreeCost_rec(VL, 0);
979 }
980 bool FuncSLP::vectorizeStoreChain(ArrayRef<Value *> Chain, int CostThreshold) {
981   unsigned ChainLen = Chain.size();
982   DEBUG(dbgs() << "SLP: Analyzing a store chain of length " << ChainLen
983                << "\n");
984   Type *StoreTy = cast<StoreInst>(Chain[0])->getValueOperand()->getType();
985   unsigned Sz = DL->getTypeSizeInBits(StoreTy);
986   unsigned VF = MinVecRegSize / Sz;
987
988   if (!isPowerOf2_32(Sz) || VF < 2)
989     return false;
990
991   bool Changed = false;
992   // Look for profitable vectorizable trees at all offsets, starting at zero.
993   for (unsigned i = 0, e = ChainLen; i < e; ++i) {
994     if (i + VF > e)
995       break;
996     DEBUG(dbgs() << "SLP: Analyzing " << VF << " stores at offset " << i
997                  << "\n");
998     ArrayRef<Value *> Operands = Chain.slice(i, VF);
999
1000     int Cost = getTreeCost(Operands);
1001     if (Cost == FuncSLP::MAX_COST)
1002       continue;
1003     DEBUG(dbgs() << "SLP: Found cost=" << Cost << " for VF=" << VF << "\n");
1004     if (Cost < CostThreshold) {
1005       DEBUG(dbgs() << "SLP: Decided to vectorize cost=" << Cost << "\n");
1006       vectorizeTree(Operands);
1007
1008       // Remove the scalar stores.
1009       for (int j = 0, e = VF; j < e; ++j)
1010         cast<Instruction>(Operands[j])->eraseFromParent();
1011
1012       // Move to the next bundle.
1013       i += VF - 1;
1014       Changed = true;
1015     }
1016   }
1017
1018   if (Changed || ChainLen > VF)
1019     return Changed;
1020
1021   // Handle short chains. This helps us catch types such as <3 x float> that
1022   // are smaller than vector size.
1023   int Cost = getTreeCost(Chain);
1024   if (Cost == FuncSLP::MAX_COST)
1025     return false;
1026   if (Cost < CostThreshold) {
1027     DEBUG(dbgs() << "SLP: Found store chain cost = " << Cost
1028                  << " for size = " << ChainLen << "\n");
1029     vectorizeTree(Chain);
1030
1031     // Remove all of the scalar stores.
1032     for (int i = 0, e = Chain.size(); i < e; ++i)
1033       cast<Instruction>(Chain[i])->eraseFromParent();
1034
1035     return true;
1036   }
1037
1038   return false;
1039 }
1040
1041 bool FuncSLP::vectorizeStores(ArrayRef<StoreInst *> Stores, int costThreshold) {
1042   SetVector<Value *> Heads, Tails;
1043   SmallDenseMap<Value *, Value *> ConsecutiveChain;
1044
1045   // We may run into multiple chains that merge into a single chain. We mark the
1046   // stores that we vectorized so that we don't visit the same store twice.
1047   ValueSet VectorizedStores;
1048   bool Changed = false;
1049
1050   // Do a quadratic search on all of the given stores and find
1051   // all of the pairs of loads that follow each other.
1052   for (unsigned i = 0, e = Stores.size(); i < e; ++i)
1053     for (unsigned j = 0; j < e; ++j) {
1054       if (i == j)
1055         continue;
1056
1057       if (isConsecutiveAccess(Stores[i], Stores[j])) {
1058         Tails.insert(Stores[j]);
1059         Heads.insert(Stores[i]);
1060         ConsecutiveChain[Stores[i]] = Stores[j];
1061       }
1062     }
1063
1064   // For stores that start but don't end a link in the chain:
1065   for (SetVector<Value *>::iterator it = Heads.begin(), e = Heads.end();
1066        it != e; ++it) {
1067     if (Tails.count(*it))
1068       continue;
1069
1070     // We found a store instr that starts a chain. Now follow the chain and try
1071     // to vectorize it.
1072     ValueList Operands;
1073     Value *I = *it;
1074     // Collect the chain into a list.
1075     while (Tails.count(I) || Heads.count(I)) {
1076       if (VectorizedStores.count(I))
1077         break;
1078       Operands.push_back(I);
1079       // Move to the next value in the chain.
1080       I = ConsecutiveChain[I];
1081     }
1082
1083     bool Vectorized = vectorizeStoreChain(Operands, costThreshold);
1084
1085     // Mark the vectorized stores so that we don't vectorize them again.
1086     if (Vectorized)
1087       VectorizedStores.insert(Operands.begin(), Operands.end());
1088     Changed |= Vectorized;
1089   }
1090
1091   return Changed;
1092 }
1093
1094 Value *FuncSLP::Gather(ArrayRef<Value *> VL, VectorType *Ty) {
1095   Value *Vec = UndefValue::get(Ty);
1096   // Generate the 'InsertElement' instruction.
1097   for (unsigned i = 0; i < Ty->getNumElements(); ++i) {
1098     Vec = Builder.CreateInsertElement(Vec, VL[i], Builder.getInt32(i));
1099     if (Instruction *I = dyn_cast<Instruction>(Vec))
1100       GatherSeq.insert(I);
1101   }
1102
1103   return Vec;
1104 }
1105
1106 Value *FuncSLP::vectorizeTree_rec(ArrayRef<Value *> VL) {
1107   BuilderLocGuard Guard(Builder);
1108
1109   Type *ScalarTy = VL[0]->getType();
1110   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
1111     ScalarTy = SI->getValueOperand()->getType();
1112   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
1113
1114   if (needToGatherAny(VL))
1115     return Gather(VL, VecTy);
1116
1117   if (VectorizedValues.count(VL[0])) {
1118     DEBUG(dbgs() << "SLP: Diamond merged at depth.\n");
1119     return VectorizedValues[VL[0]];
1120   }
1121
1122   Instruction *VL0 = cast<Instruction>(VL[0]);
1123   unsigned Opcode = VL0->getOpcode();
1124   assert(Opcode == getSameOpcode(VL) && "Invalid opcode");
1125
1126   switch (Opcode) {
1127   case Instruction::PHI: {
1128     PHINode *PH = dyn_cast<PHINode>(VL0);
1129     Builder.SetInsertPoint(PH->getParent()->getFirstInsertionPt());
1130     PHINode *NewPhi = Builder.CreatePHI(VecTy, PH->getNumIncomingValues());
1131     VectorizedValues[VL0] = NewPhi;
1132
1133     for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {
1134       ValueList Operands;
1135       BasicBlock *IBB = PH->getIncomingBlock(i);
1136
1137       // Prepare the operand vector.
1138       for (unsigned j = 0; j < VL.size(); ++j)
1139         Operands.push_back(cast<PHINode>(VL[j])->getIncomingValueForBlock(IBB));
1140
1141       Builder.SetInsertPoint(IBB->getTerminator());
1142       Value *Vec = vectorizeTree_rec(Operands);
1143       NewPhi->addIncoming(Vec, IBB);
1144     }
1145
1146     assert(NewPhi->getNumIncomingValues() == PH->getNumIncomingValues() &&
1147            "Invalid number of incoming values");
1148     return NewPhi;
1149   }
1150
1151   case Instruction::ExtractElement: {
1152     if (CanReuseExtract(VL, VL.size(), VecTy))
1153       return VL0->getOperand(0);
1154     return Gather(VL, VecTy);
1155   }
1156   case Instruction::ZExt:
1157   case Instruction::SExt:
1158   case Instruction::FPToUI:
1159   case Instruction::FPToSI:
1160   case Instruction::FPExt:
1161   case Instruction::PtrToInt:
1162   case Instruction::IntToPtr:
1163   case Instruction::SIToFP:
1164   case Instruction::UIToFP:
1165   case Instruction::Trunc:
1166   case Instruction::FPTrunc:
1167   case Instruction::BitCast: {
1168     ValueList INVL;
1169     for (int i = 0, e = VL.size(); i < e; ++i)
1170       INVL.push_back(cast<Instruction>(VL[i])->getOperand(0));
1171
1172     Builder.SetInsertPoint(getLastInstruction(VL));
1173     Value *InVec = vectorizeTree_rec(INVL);
1174     CastInst *CI = dyn_cast<CastInst>(VL0);
1175     Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy);
1176     VectorizedValues[VL0] = V;
1177     return V;
1178   }
1179   case Instruction::FCmp:
1180   case Instruction::ICmp: {
1181     // Check that all of the compares have the same predicate.
1182     CmpInst::Predicate P0 = dyn_cast<CmpInst>(VL0)->getPredicate();
1183     for (unsigned i = 1, e = VL.size(); i < e; ++i) {
1184       CmpInst *Cmp = cast<CmpInst>(VL[i]);
1185       if (Cmp->getPredicate() != P0)
1186         return Gather(VL, VecTy);
1187     }
1188
1189     ValueList LHSV, RHSV;
1190     for (int i = 0, e = VL.size(); i < e; ++i) {
1191       LHSV.push_back(cast<Instruction>(VL[i])->getOperand(0));
1192       RHSV.push_back(cast<Instruction>(VL[i])->getOperand(1));
1193     }
1194
1195     Builder.SetInsertPoint(getLastInstruction(VL));
1196     Value *L = vectorizeTree_rec(LHSV);
1197     Value *R = vectorizeTree_rec(RHSV);
1198     Value *V;
1199
1200     if (Opcode == Instruction::FCmp)
1201       V = Builder.CreateFCmp(P0, L, R);
1202     else
1203       V = Builder.CreateICmp(P0, L, R);
1204
1205     VectorizedValues[VL0] = V;
1206     return V;
1207   }
1208   case Instruction::Select: {
1209     ValueList TrueVec, FalseVec, CondVec;
1210     for (int i = 0, e = VL.size(); i < e; ++i) {
1211       CondVec.push_back(cast<Instruction>(VL[i])->getOperand(0));
1212       TrueVec.push_back(cast<Instruction>(VL[i])->getOperand(1));
1213       FalseVec.push_back(cast<Instruction>(VL[i])->getOperand(2));
1214     }
1215
1216     Builder.SetInsertPoint(getLastInstruction(VL));
1217     Value *True = vectorizeTree_rec(TrueVec);
1218     Value *False = vectorizeTree_rec(FalseVec);
1219     Value *Cond = vectorizeTree_rec(CondVec);
1220     Value *V = Builder.CreateSelect(Cond, True, False);
1221     VectorizedValues[VL0] = V;
1222     return V;
1223   }
1224   case Instruction::Add:
1225   case Instruction::FAdd:
1226   case Instruction::Sub:
1227   case Instruction::FSub:
1228   case Instruction::Mul:
1229   case Instruction::FMul:
1230   case Instruction::UDiv:
1231   case Instruction::SDiv:
1232   case Instruction::FDiv:
1233   case Instruction::URem:
1234   case Instruction::SRem:
1235   case Instruction::FRem:
1236   case Instruction::Shl:
1237   case Instruction::LShr:
1238   case Instruction::AShr:
1239   case Instruction::And:
1240   case Instruction::Or:
1241   case Instruction::Xor: {
1242     ValueList LHSVL, RHSVL;
1243     for (int i = 0, e = VL.size(); i < e; ++i) {
1244       LHSVL.push_back(cast<Instruction>(VL[i])->getOperand(0));
1245       RHSVL.push_back(cast<Instruction>(VL[i])->getOperand(1));
1246     }
1247
1248     Builder.SetInsertPoint(getLastInstruction(VL));
1249     Value *LHS = vectorizeTree_rec(LHSVL);
1250     Value *RHS = vectorizeTree_rec(RHSVL);
1251
1252     if (LHS == RHS) {
1253       assert((VL0->getOperand(0) == VL0->getOperand(1)) && "Invalid order");
1254     }
1255
1256     BinaryOperator *BinOp = cast<BinaryOperator>(VL0);
1257     Value *V = Builder.CreateBinOp(BinOp->getOpcode(), LHS, RHS);
1258     VectorizedValues[VL0] = V;
1259     return V;
1260   }
1261   case Instruction::Load: {
1262     // Check if all of the loads are consecutive.
1263     for (unsigned i = 1, e = VL.size(); i < e; ++i)
1264       if (!isConsecutiveAccess(VL[i - 1], VL[i]))
1265         return Gather(VL, VecTy);
1266
1267     // Loads are inserted at the head of the tree because we don't want to
1268     // sink them all the way down past store instructions.
1269     Builder.SetInsertPoint(getLastInstruction(VL));
1270     LoadInst *LI = cast<LoadInst>(VL0);
1271     Value *VecPtr =
1272         Builder.CreateBitCast(LI->getPointerOperand(), VecTy->getPointerTo());
1273     unsigned Alignment = LI->getAlignment();
1274     LI = Builder.CreateLoad(VecPtr);
1275     LI->setAlignment(Alignment);
1276
1277     VectorizedValues[VL0] = LI;
1278     return LI;
1279   }
1280   case Instruction::Store: {
1281     StoreInst *SI = cast<StoreInst>(VL0);
1282     unsigned Alignment = SI->getAlignment();
1283
1284     ValueList ValueOp;
1285     for (int i = 0, e = VL.size(); i < e; ++i)
1286       ValueOp.push_back(cast<StoreInst>(VL[i])->getValueOperand());
1287
1288     Value *VecValue = vectorizeTree_rec(ValueOp);
1289
1290     Builder.SetInsertPoint(getLastInstruction(VL));
1291     Value *VecPtr =
1292         Builder.CreateBitCast(SI->getPointerOperand(), VecTy->getPointerTo());
1293     Builder.CreateStore(VecValue, VecPtr)->setAlignment(Alignment);
1294     return 0;
1295   }
1296   default:
1297     return Gather(VL, VecTy);
1298   }
1299 }
1300
1301 Value *FuncSLP::vectorizeTree(ArrayRef<Value *> VL) {
1302   Builder.SetInsertPoint(getLastInstruction(VL));
1303   Value *V = vectorizeTree_rec(VL);
1304
1305   DEBUG(dbgs() << "SLP: Placing 'extracts'\n");
1306   for (SetVector<Instruction*>::iterator it = ExtractedLane.begin(), e =
1307        ExtractedLane.end(); it != e; ++it) {
1308     Instruction *Scalar = *it;
1309     DEBUG(dbgs() << "SLP: Looking at " << *Scalar);
1310
1311     if (!Scalar)
1312       continue;
1313
1314     Instruction *Loc = 0;
1315
1316     assert(MultiUserVals.count(Scalar) && "Can't find the lane to extract");
1317     Instruction *Leader = MultiUserVals[Scalar].Leader;
1318
1319     // This value is gathered so we don't need to extract from anywhere.
1320     if (!VectorizedValues.count(Leader))
1321       continue;
1322
1323     Value *Vec = VectorizedValues[Leader];
1324     if (PHINode *PN = dyn_cast<PHINode>(Vec)) {
1325       Loc = PN->getParent()->getFirstInsertionPt();
1326     } else {
1327       Instruction *I = cast<Instruction>(Vec);
1328       BasicBlock::iterator L = *I;
1329       Loc = ++L;
1330     }
1331
1332     Builder.SetInsertPoint(Loc);
1333     assert(LaneMap.count(Scalar) && "Can't find the extracted lane.");
1334     int Lane = LaneMap[Scalar];
1335     Value *Idx = Builder.getInt32(Lane);
1336     Value *Extract = Builder.CreateExtractElement(Vec, Idx);
1337
1338     bool Replaced = false;;
1339     for (Value::use_iterator U = Scalar->use_begin(), UE = Scalar->use_end();
1340          U != UE; ++U) {
1341       Instruction *UI = cast<Instruction>(*U);
1342       // No need to replace instructions that are inside our lane map.
1343       if (LaneMap.count(UI))
1344         continue;
1345
1346       UI->replaceUsesOfWith(Scalar ,Extract);
1347       Replaced = true;
1348     }
1349     assert(Replaced && "Must replace at least one outside user");
1350     (void)Replaced;
1351   }
1352
1353   // We moved some instructions around. We have to number them again
1354   // before we can do any analysis.
1355   forgetNumbering();
1356
1357   // Clear the state.
1358   MustGather.clear();
1359   VisitedPHIs.clear();
1360   VectorizedValues.clear();
1361   MemBarrierIgnoreList.clear();
1362   return V;
1363 }
1364
1365 Value *FuncSLP::vectorizeArith(ArrayRef<Value *> Operands) {
1366   Instruction *LastInst = getLastInstruction(Operands);
1367   Value *Vec = vectorizeTree(Operands);
1368   // After vectorizing the operands we need to generate extractelement
1369   // instructions and replace all of the uses of the scalar values with
1370   // the values that we extracted from the vectorized tree.
1371   Builder.SetInsertPoint(LastInst);
1372   for (unsigned i = 0, e = Operands.size(); i != e; ++i) {
1373     Value *S = Builder.CreateExtractElement(Vec, Builder.getInt32(i));
1374     Operands[i]->replaceAllUsesWith(S);
1375   }
1376
1377   forgetNumbering();
1378   return Vec;
1379 }
1380
1381 void FuncSLP::optimizeGatherSequence() {
1382   // LICM InsertElementInst sequences.
1383   for (SetVector<Instruction *>::iterator it = GatherSeq.begin(),
1384        e = GatherSeq.end(); it != e; ++it) {
1385     InsertElementInst *Insert = dyn_cast<InsertElementInst>(*it);
1386
1387     if (!Insert)
1388       continue;
1389
1390     // Check if this block is inside a loop.
1391     Loop *L = LI->getLoopFor(Insert->getParent());
1392     if (!L)
1393       continue;
1394
1395     // Check if it has a preheader.
1396     BasicBlock *PreHeader = L->getLoopPreheader();
1397     if (!PreHeader)
1398       continue;
1399
1400     // If the vector or the element that we insert into it are
1401     // instructions that are defined in this basic block then we can't
1402     // hoist this instruction.
1403     Instruction *CurrVec = dyn_cast<Instruction>(Insert->getOperand(0));
1404     Instruction *NewElem = dyn_cast<Instruction>(Insert->getOperand(1));
1405     if (CurrVec && L->contains(CurrVec))
1406       continue;
1407     if (NewElem && L->contains(NewElem))
1408       continue;
1409
1410     // We can hoist this instruction. Move it to the pre-header.
1411     Insert->moveBefore(PreHeader->getTerminator());
1412   }
1413
1414   // Perform O(N^2) search over the gather sequences and merge identical
1415   // instructions. TODO: We can further optimize this scan if we split the
1416   // instructions into different buckets based on the insert lane.
1417   SmallPtrSet<Instruction*, 16> Visited;
1418   SmallVector<Instruction*, 16> ToRemove;
1419   ReversePostOrderTraversal<Function*> RPOT(F);
1420   for (ReversePostOrderTraversal<Function*>::rpo_iterator I = RPOT.begin(),
1421        E = RPOT.end(); I != E; ++I) {
1422     BasicBlock *BB = *I;
1423     // For all instructions in the function:
1424     for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
1425       InsertElementInst *Insert = dyn_cast<InsertElementInst>(it);
1426       if (!Insert || !GatherSeq.count(Insert))
1427         continue;
1428
1429       // Check if we can replace this instruction with any of the
1430       // visited instructions.
1431       for (SmallPtrSet<Instruction*, 16>::iterator v = Visited.begin(),
1432            ve = Visited.end(); v != ve; ++v) {
1433         if (Insert->isIdenticalTo(*v) &&
1434             DT->dominates((*v)->getParent(), Insert->getParent())) {
1435           Insert->replaceAllUsesWith(*v);
1436           ToRemove.push_back(Insert);
1437           Insert = 0;
1438           break;
1439         }
1440       }
1441       if (Insert)
1442         Visited.insert(Insert);
1443     }
1444   }
1445
1446   // Erase all of the instructions that we RAUWed.
1447   for (SmallVectorImpl<Instruction *>::iterator v = ToRemove.begin(),
1448        ve = ToRemove.end(); v != ve; ++v) {
1449     assert((*v)->getNumUses() == 0 && "Can't remove instructions with uses");
1450     (*v)->eraseFromParent();
1451   }
1452
1453   forgetNumbering();
1454 }
1455
1456 /// The SLPVectorizer Pass.
1457 struct SLPVectorizer : public FunctionPass {
1458   typedef SmallVector<StoreInst *, 8> StoreList;
1459   typedef MapVector<Value *, StoreList> StoreListMap;
1460
1461   /// Pass identification, replacement for typeid
1462   static char ID;
1463
1464   explicit SLPVectorizer() : FunctionPass(ID) {
1465     initializeSLPVectorizerPass(*PassRegistry::getPassRegistry());
1466   }
1467
1468   ScalarEvolution *SE;
1469   DataLayout *DL;
1470   TargetTransformInfo *TTI;
1471   AliasAnalysis *AA;
1472   LoopInfo *LI;
1473   DominatorTree *DT;
1474
1475   virtual bool runOnFunction(Function &F) {
1476     SE = &getAnalysis<ScalarEvolution>();
1477     DL = getAnalysisIfAvailable<DataLayout>();
1478     TTI = &getAnalysis<TargetTransformInfo>();
1479     AA = &getAnalysis<AliasAnalysis>();
1480     LI = &getAnalysis<LoopInfo>();
1481     DT = &getAnalysis<DominatorTree>();
1482
1483     StoreRefs.clear();
1484     bool Changed = false;
1485
1486     // Must have DataLayout. We can't require it because some tests run w/o
1487     // triple.
1488     if (!DL)
1489       return false;
1490
1491     DEBUG(dbgs() << "SLP: Analyzing blocks in " << F.getName() << ".\n");
1492
1493     // Use the bollom up slp vectorizer to construct chains that start with
1494     // he store instructions.
1495     FuncSLP R(&F, SE, DL, TTI, AA, LI, DT);
1496
1497     // Scan the blocks in the function in post order.
1498     for (po_iterator<BasicBlock*> it = po_begin(&F.getEntryBlock()),
1499          e = po_end(&F.getEntryBlock()); it != e; ++it) {
1500       BasicBlock *BB = *it;
1501
1502       // Vectorize trees that end at reductions.
1503       Changed |= vectorizeChainsInBlock(BB, R);
1504
1505       // Vectorize trees that end at stores.
1506       if (unsigned count = collectStores(BB, R)) {
1507         (void)count;
1508         DEBUG(dbgs() << "SLP: Found " << count << " stores to vectorize.\n");
1509         Changed |= vectorizeStoreChains(R);
1510       }
1511     }
1512
1513     if (Changed) {
1514       R.optimizeGatherSequence();
1515       DEBUG(dbgs() << "SLP: vectorized \"" << F.getName() << "\"\n");
1516       DEBUG(verifyFunction(F));
1517     }
1518     return Changed;
1519   }
1520
1521   virtual void getAnalysisUsage(AnalysisUsage &AU) const {
1522     FunctionPass::getAnalysisUsage(AU);
1523     AU.addRequired<ScalarEvolution>();
1524     AU.addRequired<AliasAnalysis>();
1525     AU.addRequired<TargetTransformInfo>();
1526     AU.addRequired<LoopInfo>();
1527     AU.addRequired<DominatorTree>();
1528     AU.addPreserved<LoopInfo>();
1529     AU.addPreserved<DominatorTree>();
1530     AU.setPreservesCFG();
1531   }
1532
1533 private:
1534
1535   /// \brief Collect memory references and sort them according to their base
1536   /// object. We sort the stores to their base objects to reduce the cost of the
1537   /// quadratic search on the stores. TODO: We can further reduce this cost
1538   /// if we flush the chain creation every time we run into a memory barrier.
1539   unsigned collectStores(BasicBlock *BB, FuncSLP &R);
1540
1541   /// \brief Try to vectorize a chain that starts at two arithmetic instrs.
1542   bool tryToVectorizePair(Value *A, Value *B, FuncSLP &R);
1543
1544   /// \brief Try to vectorize a list of operands. If \p NeedExtracts is true
1545   /// then we calculate the cost of extracting the scalars from the vector.
1546   /// \returns true if a value was vectorized.
1547   bool tryToVectorizeList(ArrayRef<Value *> VL, FuncSLP &R, bool NeedExtracts);
1548
1549   /// \brief Try to vectorize a chain that may start at the operands of \V;
1550   bool tryToVectorize(BinaryOperator *V, FuncSLP &R);
1551
1552   /// \brief Vectorize the stores that were collected in StoreRefs.
1553   bool vectorizeStoreChains(FuncSLP &R);
1554
1555   /// \brief Scan the basic block and look for patterns that are likely to start
1556   /// a vectorization chain.
1557   bool vectorizeChainsInBlock(BasicBlock *BB, FuncSLP &R);
1558
1559 private:
1560   StoreListMap StoreRefs;
1561 };
1562
1563 unsigned SLPVectorizer::collectStores(BasicBlock *BB, FuncSLP &R) {
1564   unsigned count = 0;
1565   StoreRefs.clear();
1566   for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
1567     StoreInst *SI = dyn_cast<StoreInst>(it);
1568     if (!SI)
1569       continue;
1570
1571     // Check that the pointer points to scalars.
1572     Type *Ty = SI->getValueOperand()->getType();
1573     if (Ty->isAggregateType() || Ty->isVectorTy())
1574       return 0;
1575
1576     // Find the base of the GEP.
1577     Value *Ptr = SI->getPointerOperand();
1578     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr))
1579       Ptr = GEP->getPointerOperand();
1580
1581     // Save the store locations.
1582     StoreRefs[Ptr].push_back(SI);
1583     count++;
1584   }
1585   return count;
1586 }
1587
1588 bool SLPVectorizer::tryToVectorizePair(Value *A, Value *B, FuncSLP &R) {
1589   if (!A || !B)
1590     return false;
1591   Value *VL[] = { A, B };
1592   return tryToVectorizeList(VL, R, true);
1593 }
1594
1595 bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, FuncSLP &R,
1596                                        bool NeedExtracts) {
1597   if (VL.size() < 2)
1598     return false;
1599
1600   DEBUG(dbgs() << "SLP: Vectorizing a list of length = " << VL.size() << ".\n");
1601
1602   // Check that all of the parts are scalar instructions of the same type.
1603   Instruction *I0 = dyn_cast<Instruction>(VL[0]);
1604   if (!I0)
1605     return 0;
1606
1607   unsigned Opcode0 = I0->getOpcode();
1608
1609   for (int i = 0, e = VL.size(); i < e; ++i) {
1610     Type *Ty = VL[i]->getType();
1611     if (Ty->isAggregateType() || Ty->isVectorTy())
1612       return 0;
1613     Instruction *Inst = dyn_cast<Instruction>(VL[i]);
1614     if (!Inst || Inst->getOpcode() != Opcode0)
1615       return 0;
1616   }
1617
1618   int Cost = R.getTreeCost(VL);
1619   if (Cost == FuncSLP::MAX_COST)
1620     return false;
1621
1622   int ExtrCost = NeedExtracts ? R.getGatherCost(VL) : 0;
1623   DEBUG(dbgs() << "SLP: Cost of pair:" << Cost
1624                << " Cost of extract:" << ExtrCost << ".\n");
1625   if ((Cost + ExtrCost) >= -SLPCostThreshold)
1626     return false;
1627   DEBUG(dbgs() << "SLP: Vectorizing pair.\n");
1628   R.vectorizeArith(VL);
1629   return true;
1630 }
1631
1632 bool SLPVectorizer::tryToVectorize(BinaryOperator *V, FuncSLP &R) {
1633   if (!V)
1634     return false;
1635
1636   // Try to vectorize V.
1637   if (tryToVectorizePair(V->getOperand(0), V->getOperand(1), R))
1638     return true;
1639
1640   BinaryOperator *A = dyn_cast<BinaryOperator>(V->getOperand(0));
1641   BinaryOperator *B = dyn_cast<BinaryOperator>(V->getOperand(1));
1642   // Try to skip B.
1643   if (B && B->hasOneUse()) {
1644     BinaryOperator *B0 = dyn_cast<BinaryOperator>(B->getOperand(0));
1645     BinaryOperator *B1 = dyn_cast<BinaryOperator>(B->getOperand(1));
1646     if (tryToVectorizePair(A, B0, R)) {
1647       B->moveBefore(V);
1648       return true;
1649     }
1650     if (tryToVectorizePair(A, B1, R)) {
1651       B->moveBefore(V);
1652       return true;
1653     }
1654   }
1655
1656   // Try to skip A.
1657   if (A && A->hasOneUse()) {
1658     BinaryOperator *A0 = dyn_cast<BinaryOperator>(A->getOperand(0));
1659     BinaryOperator *A1 = dyn_cast<BinaryOperator>(A->getOperand(1));
1660     if (tryToVectorizePair(A0, B, R)) {
1661       A->moveBefore(V);
1662       return true;
1663     }
1664     if (tryToVectorizePair(A1, B, R)) {
1665       A->moveBefore(V);
1666       return true;
1667     }
1668   }
1669   return 0;
1670 }
1671
1672 bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, FuncSLP &R) {
1673   bool Changed = false;
1674   for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
1675     if (isa<DbgInfoIntrinsic>(it))
1676       continue;
1677
1678     // Try to vectorize reductions that use PHINodes.
1679     if (PHINode *P = dyn_cast<PHINode>(it)) {
1680       // Check that the PHI is a reduction PHI.
1681       if (P->getNumIncomingValues() != 2)
1682         return Changed;
1683       Value *Rdx =
1684           (P->getIncomingBlock(0) == BB
1685                ? (P->getIncomingValue(0))
1686                : (P->getIncomingBlock(1) == BB ? P->getIncomingValue(1) : 0));
1687       // Check if this is a Binary Operator.
1688       BinaryOperator *BI = dyn_cast_or_null<BinaryOperator>(Rdx);
1689       if (!BI)
1690         continue;
1691
1692       Value *Inst = BI->getOperand(0);
1693       if (Inst == P)
1694         Inst = BI->getOperand(1);
1695
1696       Changed |= tryToVectorize(dyn_cast<BinaryOperator>(Inst), R);
1697       continue;
1698     }
1699
1700     // Try to vectorize trees that start at compare instructions.
1701     if (CmpInst *CI = dyn_cast<CmpInst>(it)) {
1702       if (tryToVectorizePair(CI->getOperand(0), CI->getOperand(1), R)) {
1703         Changed |= true;
1704         continue;
1705       }
1706       for (int i = 0; i < 2; ++i)
1707         if (BinaryOperator *BI = dyn_cast<BinaryOperator>(CI->getOperand(i)))
1708           Changed |=
1709               tryToVectorizePair(BI->getOperand(0), BI->getOperand(1), R);
1710       continue;
1711     }
1712   }
1713
1714   // Scan the PHINodes in our successors in search for pairing hints.
1715   for (succ_iterator it = succ_begin(BB), e = succ_end(BB); it != e; ++it) {
1716     BasicBlock *Succ = *it;
1717     SmallVector<Value *, 4> Incoming;
1718
1719     // Collect the incoming values from the PHIs.
1720     for (BasicBlock::iterator instr = Succ->begin(), ie = Succ->end();
1721          instr != ie; ++instr) {
1722       PHINode *P = dyn_cast<PHINode>(instr);
1723
1724       if (!P)
1725         break;
1726
1727       Value *V = P->getIncomingValueForBlock(BB);
1728       if (Instruction *I = dyn_cast<Instruction>(V))
1729         if (I->getParent() == BB)
1730           Incoming.push_back(I);
1731     }
1732
1733     if (Incoming.size() > 1)
1734       Changed |= tryToVectorizeList(Incoming, R, true);
1735   }
1736
1737   return Changed;
1738 }
1739
1740 bool SLPVectorizer::vectorizeStoreChains(FuncSLP &R) {
1741   bool Changed = false;
1742   // Attempt to sort and vectorize each of the store-groups.
1743   for (StoreListMap::iterator it = StoreRefs.begin(), e = StoreRefs.end();
1744        it != e; ++it) {
1745     if (it->second.size() < 2)
1746       continue;
1747
1748     DEBUG(dbgs() << "SLP: Analyzing a store chain of length "
1749                  << it->second.size() << ".\n");
1750
1751     Changed |= R.vectorizeStores(it->second, -SLPCostThreshold);
1752   }
1753   return Changed;
1754 }
1755
1756 } // end anonymous namespace
1757
1758 char SLPVectorizer::ID = 0;
1759 static const char lv_name[] = "SLP Vectorizer";
1760 INITIALIZE_PASS_BEGIN(SLPVectorizer, SV_NAME, lv_name, false, false)
1761 INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
1762 INITIALIZE_AG_DEPENDENCY(TargetTransformInfo)
1763 INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
1764 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
1765 INITIALIZE_PASS_END(SLPVectorizer, SV_NAME, lv_name, false, false)
1766
1767 namespace llvm {
1768 Pass *createSLPVectorizerPass() { return new SLPVectorizer(); }
1769 }