SLPVectorizer: support slp-vectorization of PHINodes between basic blocks
[oota-llvm.git] / lib / Transforms / Vectorize / SLPVectorizer.cpp
1 //===- SLPVectorizer.cpp - A bottom up SLP Vectorizer ---------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This pass implements the Bottom Up SLP vectorizer. It detects consecutive
10 // stores that can be put together into vector-stores. Next, it attempts to
11 // construct vectorizable tree using the use-def chains. If a profitable tree
12 // was found, the SLP vectorizer performs vectorization on the tree.
13 //
14 // The pass is inspired by the work described in the paper:
15 //  "Loop-Aware SLP in GCC" by Ira Rosen, Dorit Nuzman, Ayal Zaks.
16 //
17 //===----------------------------------------------------------------------===//
18 #define SV_NAME "slp-vectorizer"
19 #define DEBUG_TYPE "SLP"
20
21 #include "llvm/Transforms/Vectorize.h"
22 #include "llvm/ADT/MapVector.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/Analysis/AliasAnalysis.h"
26 #include "llvm/Analysis/ScalarEvolution.h"
27 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
28 #include "llvm/Analysis/AliasAnalysis.h"
29 #include "llvm/Analysis/TargetTransformInfo.h"
30 #include "llvm/Analysis/Verifier.h"
31 #include "llvm/Analysis/LoopInfo.h"
32 #include "llvm/IR/DataLayout.h"
33 #include "llvm/IR/Instructions.h"
34 #include "llvm/IR/IntrinsicInst.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/Module.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/CommandLine.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/raw_ostream.h"
43 #include <algorithm>
44 #include <map>
45
46 using namespace llvm;
47
48 static cl::opt<int>
49     SLPCostThreshold("slp-threshold", cl::init(0), cl::Hidden,
50                      cl::desc("Only vectorize trees if the gain is above this "
51                               "number. (gain = -cost of vectorization)"));
52 namespace {
53
54 static const unsigned MinVecRegSize = 128;
55
56 static const unsigned RecursionMaxDepth = 12;
57
58 /// RAII pattern to save the insertion point of the IR builder.
59 class BuilderLocGuard {
60 public:
61   BuilderLocGuard(IRBuilder<> &B) : Builder(B), Loc(B.GetInsertPoint()) {}
62   ~BuilderLocGuard() { Builder.SetInsertPoint(Loc); }
63
64 private:
65   // Prevent copying.
66   BuilderLocGuard(const BuilderLocGuard &);
67   BuilderLocGuard &operator=(const BuilderLocGuard &);
68   IRBuilder<> &Builder;
69   BasicBlock::iterator Loc;
70 };
71
72 /// A helper class for numbering instructions in multible blocks.
73 /// Numbers starts at zero for each basic block.
74 struct BlockNumbering {
75
76   BlockNumbering(BasicBlock *Bb) : BB(Bb), Valid(false) {}
77
78   BlockNumbering() : BB(0), Valid(false) {}
79
80   void numberInstructions() {
81     unsigned Loc = 0;
82     InstrIdx.clear();
83     InstrVec.clear();
84     // Number the instructions in the block.
85     for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
86       InstrIdx[it] = Loc++;
87       InstrVec.push_back(it);
88       assert(InstrVec[InstrIdx[it]] == it && "Invalid allocation");
89     }
90     Valid = true;
91   }
92
93   int getIndex(Instruction *I) {
94     if (!Valid)
95       numberInstructions();
96     assert(InstrIdx.count(I) && "Unknown instruction");
97     return InstrIdx[I];
98   }
99
100   Instruction *getInstruction(unsigned loc) {
101     if (!Valid)
102       numberInstructions();
103     assert(InstrVec.size() > loc && "Invalid Index");
104     return InstrVec[loc];
105   }
106
107   void forget() { Valid = false; }
108
109 private:
110   /// The block we are numbering.
111   BasicBlock *BB;
112   /// Is the block numbered.
113   bool Valid;
114   /// Maps instructions to numbers and back.
115   SmallDenseMap<Instruction *, int> InstrIdx;
116   /// Maps integers to Instructions.
117   std::vector<Instruction *> InstrVec;
118 };
119
120 class FuncSLP {
121   typedef SmallVector<Value *, 8> ValueList;
122   typedef SmallVector<Instruction *, 16> InstrList;
123   typedef SmallPtrSet<Value *, 16> ValueSet;
124   typedef SmallVector<StoreInst *, 8> StoreList;
125
126 public:
127   static const int MAX_COST = INT_MIN;
128
129   FuncSLP(Function *Func, ScalarEvolution *Se, DataLayout *Dl,
130           TargetTransformInfo *Tti, AliasAnalysis *Aa, LoopInfo *Li, 
131           DominatorTree *Dt) :
132     F(Func), SE(Se), DL(Dl), TTI(Tti), AA(Aa), LI(Li), DT(Dt),
133     Builder(Se->getContext()) {
134     for (Function::iterator it = F->begin(), e = F->end(); it != e; ++it) {
135       BasicBlock *BB = it;
136       BlocksNumbers[BB] = BlockNumbering(BB);
137     }
138   }
139
140   /// \brief Take the pointer operand from the Load/Store instruction.
141   /// \returns NULL if this is not a valid Load/Store instruction.
142   static Value *getPointerOperand(Value *I);
143
144   /// \brief Take the address space operand from the Load/Store instruction.
145   /// \returns -1 if this is not a valid Load/Store instruction.
146   static unsigned getAddressSpaceOperand(Value *I);
147
148   /// \returns true if the memory operations A and B are consecutive.
149   bool isConsecutiveAccess(Value *A, Value *B);
150
151   /// \brief Vectorize the tree that starts with the elements in \p VL.
152   /// \returns the vectorized value.
153   Value *vectorizeTree(ArrayRef<Value *> VL);
154
155   /// \returns the vectorization cost of the subtree that starts at \p VL.
156   /// A negative number means that this is profitable.
157   int getTreeCost(ArrayRef<Value *> VL);
158
159   /// \returns the scalarization cost for this list of values. Assuming that
160   /// this subtree gets vectorized, we may need to extract the values from the
161   /// roots. This method calculates the cost of extracting the values.
162   int getGatherCost(ArrayRef<Value *> VL);
163
164   /// \brief Attempts to order and vectorize a sequence of stores. This
165   /// function does a quadratic scan of the given stores.
166   /// \returns true if the basic block was modified.
167   bool vectorizeStores(ArrayRef<StoreInst *> Stores, int costThreshold);
168
169   /// \brief Vectorize a group of scalars into a vector tree.
170   /// \returns the vectorized value.
171   Value *vectorizeArith(ArrayRef<Value *> Operands);
172
173   /// \brief This method contains the recursive part of getTreeCost.
174   int getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth);
175
176   /// \brief This recursive method looks for vectorization hazards such as
177   /// values that are used by multiple users and checks that values are used
178   /// by only one vector lane. It updates the variables LaneMap, MultiUserVals.
179   void getTreeUses_rec(ArrayRef<Value *> VL, unsigned Depth);
180
181   /// \brief This method contains the recursive part of vectorizeTree.
182   Value *vectorizeTree_rec(ArrayRef<Value *> VL);
183
184   ///  \brief Vectorize a sorted sequence of stores.
185   bool vectorizeStoreChain(ArrayRef<Value *> Chain, int CostThreshold);
186
187   /// \returns the scalarization cost for this type. Scalarization in this
188   /// context means the creation of vectors from a group of scalars.
189   int getGatherCost(Type *Ty);
190
191   /// \returns the AA location that is being access by the instruction.
192   AliasAnalysis::Location getLocation(Instruction *I);
193
194   /// \brief Checks if it is possible to sink an instruction from
195   /// \p Src to \p Dst.
196   /// \returns the pointer to the barrier instruction if we can't sink.
197   Value *getSinkBarrier(Instruction *Src, Instruction *Dst);
198
199   /// \returns the index of the last instrucion in the BB from \p VL.
200   int getLastIndex(ArrayRef<Value *> VL);
201
202   /// \returns the Instrucion in the bundle \p VL.
203   Instruction *getLastInstruction(ArrayRef<Value *> VL);
204
205   /// \returns the Instruction at index \p Index which is in Block \p BB.
206   Instruction *getInstructionForIndex(unsigned Index, BasicBlock *BB);
207
208   /// \returns the index of the first User of \p VL.
209   int getFirstUserIndex(ArrayRef<Value *> VL);
210
211   /// \returns a vector from a collection of scalars in \p VL.
212   Value *Gather(ArrayRef<Value *> VL, VectorType *Ty);
213
214   /// \brief Perform LICM and CSE on the newly generated gather sequences.
215   void optimizeGatherSequence();
216
217   bool needToGatherAny(ArrayRef<Value *> VL) {
218     for (int i = 0, e = VL.size(); i < e; ++i)
219       if (MustGather.count(VL[i]))
220         return true;
221     return false;
222   }
223
224   /// -- Vectorization State --
225
226   /// Maps values in the tree to the vector lanes that uses them. This map must
227   /// be reset between runs of getCost.
228   std::map<Value *, int> LaneMap;
229   /// A list of instructions to ignore while sinking
230   /// memory instructions. This map must be reset between runs of getCost.
231   ValueSet MemBarrierIgnoreList;
232
233   /// Maps between the first scalar to the vector. This map must be reset
234   /// between runs.
235   DenseMap<Value *, Value *> VectorizedValues;
236
237   /// Contains values that must be gathered because they are used
238   /// by multiple lanes, or by users outside the tree.
239   /// NOTICE: The vectorization methods also use this set.
240   ValueSet MustGather;
241
242   /// Contains PHINodes that are being processed. We use this data structure
243   /// to stop cycles in the graph.
244   ValueSet VisitedPHIs;
245
246   /// Contains a list of values that are used outside the current tree. This
247   /// set must be reset between runs.
248   SetVector<Value *> MultiUserVals;
249
250   /// Holds all of the instructions that we gathered.
251   SetVector<Instruction *> GatherSeq;
252
253   /// Numbers instructions in different blocks.
254   std::map<BasicBlock *, BlockNumbering> BlocksNumbers;
255
256   // Analysis and block reference.
257   Function *F;
258   ScalarEvolution *SE;
259   DataLayout *DL;
260   TargetTransformInfo *TTI;
261   AliasAnalysis *AA;
262   LoopInfo *LI;
263   DominatorTree *DT;
264   /// Instruction builder to construct the vectorized tree.
265   IRBuilder<> Builder;
266 };
267
268 int FuncSLP::getGatherCost(Type *Ty) {
269   int Cost = 0;
270   for (unsigned i = 0, e = cast<VectorType>(Ty)->getNumElements(); i < e; ++i)
271     Cost += TTI->getVectorInstrCost(Instruction::InsertElement, Ty, i);
272   return Cost;
273 }
274
275 int FuncSLP::getGatherCost(ArrayRef<Value *> VL) {
276   // Find the type of the operands in VL.
277   Type *ScalarTy = VL[0]->getType();
278   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
279     ScalarTy = SI->getValueOperand()->getType();
280   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
281   // Find the cost of inserting/extracting values from the vector.
282   return getGatherCost(VecTy);
283 }
284
285 AliasAnalysis::Location FuncSLP::getLocation(Instruction *I) {
286   if (StoreInst *SI = dyn_cast<StoreInst>(I))
287     return AA->getLocation(SI);
288   if (LoadInst *LI = dyn_cast<LoadInst>(I))
289     return AA->getLocation(LI);
290   return AliasAnalysis::Location();
291 }
292
293 Value *FuncSLP::getPointerOperand(Value *I) {
294   if (LoadInst *LI = dyn_cast<LoadInst>(I))
295     return LI->getPointerOperand();
296   if (StoreInst *SI = dyn_cast<StoreInst>(I))
297     return SI->getPointerOperand();
298   return 0;
299 }
300
301 unsigned FuncSLP::getAddressSpaceOperand(Value *I) {
302   if (LoadInst *L = dyn_cast<LoadInst>(I))
303     return L->getPointerAddressSpace();
304   if (StoreInst *S = dyn_cast<StoreInst>(I))
305     return S->getPointerAddressSpace();
306   return -1;
307 }
308
309 bool FuncSLP::isConsecutiveAccess(Value *A, Value *B) {
310   Value *PtrA = getPointerOperand(A);
311   Value *PtrB = getPointerOperand(B);
312   unsigned ASA = getAddressSpaceOperand(A);
313   unsigned ASB = getAddressSpaceOperand(B);
314
315   // Check that the address spaces match and that the pointers are valid.
316   if (!PtrA || !PtrB || (ASA != ASB))
317     return false;
318
319   // Check that A and B are of the same type.
320   if (PtrA->getType() != PtrB->getType())
321     return false;
322
323   // Calculate the distance.
324   const SCEV *PtrSCEVA = SE->getSCEV(PtrA);
325   const SCEV *PtrSCEVB = SE->getSCEV(PtrB);
326   const SCEV *OffsetSCEV = SE->getMinusSCEV(PtrSCEVA, PtrSCEVB);
327   const SCEVConstant *ConstOffSCEV = dyn_cast<SCEVConstant>(OffsetSCEV);
328
329   // Non constant distance.
330   if (!ConstOffSCEV)
331     return false;
332
333   int64_t Offset = ConstOffSCEV->getValue()->getSExtValue();
334   Type *Ty = cast<PointerType>(PtrA->getType())->getElementType();
335   // The Instructions are connsecutive if the size of the first load/store is
336   // the same as the offset.
337   int64_t Sz = DL->getTypeStoreSize(Ty);
338   return ((-Offset) == Sz);
339 }
340
341 Value *FuncSLP::getSinkBarrier(Instruction *Src, Instruction *Dst) {
342   assert(Src->getParent() == Dst->getParent() && "Not the same BB");
343   BasicBlock::iterator I = Src, E = Dst;
344   /// Scan all of the instruction from SRC to DST and check if
345   /// the source may alias.
346   for (++I; I != E; ++I) {
347     // Ignore store instructions that are marked as 'ignore'.
348     if (MemBarrierIgnoreList.count(I))
349       continue;
350     if (Src->mayWriteToMemory()) /* Write */ {
351       if (!I->mayReadOrWriteMemory())
352         continue;
353     } else /* Read */ {
354       if (!I->mayWriteToMemory())
355         continue;
356     }
357     AliasAnalysis::Location A = getLocation(&*I);
358     AliasAnalysis::Location B = getLocation(Src);
359
360     if (!A.Ptr || !B.Ptr || AA->alias(A, B))
361       return I;
362   }
363   return 0;
364 }
365
366 static BasicBlock *getSameBlock(ArrayRef<Value *> VL) {
367   BasicBlock *BB = 0;
368   for (int i = 0, e = VL.size(); i < e; i++) {
369     Instruction *I = dyn_cast<Instruction>(VL[i]);
370     if (!I)
371       return 0;
372
373     if (!BB) {
374       BB = I->getParent();
375       continue;
376     }
377
378     if (BB != I->getParent())
379       return 0;
380   }
381   return BB;
382 }
383
384 static bool allConstant(ArrayRef<Value *> VL) {
385   for (unsigned i = 0, e = VL.size(); i < e; ++i)
386     if (!isa<Constant>(VL[i]))
387       return false;
388   return true;
389 }
390
391 static bool isSplat(ArrayRef<Value *> VL) {
392   for (unsigned i = 1, e = VL.size(); i < e; ++i)
393     if (VL[i] != VL[0])
394       return false;
395   return true;
396 }
397
398 static unsigned getSameOpcode(ArrayRef<Value *> VL) {
399   unsigned Opcode = 0;
400   for (int i = 0, e = VL.size(); i < e; i++) {
401     if (Instruction *I = dyn_cast<Instruction>(VL[i])) {
402       if (!Opcode) {
403         Opcode = I->getOpcode();
404         continue;
405       }
406       if (Opcode != I->getOpcode())
407         return 0;
408     }
409   }
410   return Opcode;
411 }
412
413 static bool CanReuseExtract(ArrayRef<Value *> VL, unsigned VF,
414                             VectorType *VecTy) {
415   assert(Instruction::ExtractElement == getSameOpcode(VL) && "Invalid opcode");
416   // Check if all of the extracts come from the same vector and from the
417   // correct offset.
418   Value *VL0 = VL[0];
419   ExtractElementInst *E0 = cast<ExtractElementInst>(VL0);
420   Value *Vec = E0->getOperand(0);
421
422   // We have to extract from the same vector type.
423   if (Vec->getType() != VecTy)
424     return false;
425
426   // Check that all of the indices extract from the correct offset.
427   ConstantInt *CI = dyn_cast<ConstantInt>(E0->getOperand(1));
428   if (!CI || CI->getZExtValue())
429     return false;
430
431   for (unsigned i = 1, e = VF; i < e; ++i) {
432     ExtractElementInst *E = cast<ExtractElementInst>(VL[i]);
433     ConstantInt *CI = dyn_cast<ConstantInt>(E->getOperand(1));
434
435     if (!CI || CI->getZExtValue() != i || E->getOperand(0) != Vec)
436       return false;
437   }
438
439   return true;
440 }
441
442 void FuncSLP::getTreeUses_rec(ArrayRef<Value *> VL, unsigned Depth) {
443   if (Depth == RecursionMaxDepth)
444     return MustGather.insert(VL.begin(), VL.end());
445
446   // Don't handle vectors.
447   if (VL[0]->getType()->isVectorTy())
448     return;
449
450   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
451     if (SI->getValueOperand()->getType()->isVectorTy())
452       return;
453
454   // If all of the operands are identical or constant we have a simple solution.
455   if (allConstant(VL) || isSplat(VL) || !getSameBlock(VL))
456     return MustGather.insert(VL.begin(), VL.end());
457
458   // Stop the scan at unknown IR.
459   Instruction *VL0 = dyn_cast<Instruction>(VL[0]);
460   assert(VL0 && "Invalid instruction");
461
462   // Mark instructions with multiple users.
463   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
464     if (PHINode *PN = dyn_cast<PHINode>(VL[i])) {
465       unsigned NumUses = 0;
466       // Check that PHINodes have only one external (non-self) use.
467       for (Value::use_iterator U = VL[i]->use_begin(), UE = VL[i]->use_end();
468            U != UE; ++U) {
469         // Don't count self uses.
470         if (*U == PN)
471           continue;
472         NumUses++;
473       }
474       if (NumUses > 1) {
475         DEBUG(dbgs() << "SLP: Adding PHI to MultiUserVals "
476               "because it has " << NumUses << " users:" << *PN << " \n");
477         MultiUserVals.insert(PN);
478       }
479       continue;
480     }
481
482     Instruction *I = dyn_cast<Instruction>(VL[i]);
483     // Remember to check if all of the users of this instruction are vectorized
484     // within our tree. At depth zero we have no local users, only external
485     // users that we don't care about.
486     if (Depth && I && I->getNumUses() > 1) {
487       DEBUG(dbgs() << "SLP: Adding to MultiUserVals "
488             "because it has " << I->getNumUses() << " users:" << *I << " \n");
489       MultiUserVals.insert(I);
490     }
491   }
492
493   // Check that the instruction is only used within one lane.
494   for (int i = 0, e = VL.size(); i < e; ++i) {
495     if (LaneMap.count(VL[i]) && LaneMap[VL[i]] != i) {
496       DEBUG(dbgs() << "SLP: Value used by multiple lanes:" << *VL[i] << "\n");
497       return MustGather.insert(VL.begin(), VL.end());
498     }
499     // Make this instruction as 'seen' and remember the lane.
500     LaneMap[VL[i]] = i;
501   }
502
503   unsigned Opcode = getSameOpcode(VL);
504   if (!Opcode)
505     return MustGather.insert(VL.begin(), VL.end());
506
507   switch (Opcode) {
508   case Instruction::PHI: {
509     PHINode *PH = dyn_cast<PHINode>(VL0);
510
511     // Stop self cycles.
512     if (VisitedPHIs.count(PH))
513         return;
514
515     VisitedPHIs.insert(PH);
516     for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {
517       ValueList Operands;
518       // Prepare the operand vector.
519       for (unsigned j = 0; j < VL.size(); ++j)
520         Operands.push_back(cast<PHINode>(VL[j])->getIncomingValue(i));
521
522       getTreeUses_rec(Operands, Depth + 1);
523     }
524     return;
525   }
526   case Instruction::ExtractElement: {
527     VectorType *VecTy = VectorType::get(VL[0]->getType(), VL.size());
528     // No need to follow ExtractElements that are going to be optimized away.
529     if (CanReuseExtract(VL, VL.size(), VecTy))
530       return;
531     // Fall through.
532   }
533   case Instruction::Load:
534     return;
535   case Instruction::ZExt:
536   case Instruction::SExt:
537   case Instruction::FPToUI:
538   case Instruction::FPToSI:
539   case Instruction::FPExt:
540   case Instruction::PtrToInt:
541   case Instruction::IntToPtr:
542   case Instruction::SIToFP:
543   case Instruction::UIToFP:
544   case Instruction::Trunc:
545   case Instruction::FPTrunc:
546   case Instruction::BitCast:
547   case Instruction::Select:
548   case Instruction::ICmp:
549   case Instruction::FCmp:
550   case Instruction::Add:
551   case Instruction::FAdd:
552   case Instruction::Sub:
553   case Instruction::FSub:
554   case Instruction::Mul:
555   case Instruction::FMul:
556   case Instruction::UDiv:
557   case Instruction::SDiv:
558   case Instruction::FDiv:
559   case Instruction::URem:
560   case Instruction::SRem:
561   case Instruction::FRem:
562   case Instruction::Shl:
563   case Instruction::LShr:
564   case Instruction::AShr:
565   case Instruction::And:
566   case Instruction::Or:
567   case Instruction::Xor: {
568     for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
569       ValueList Operands;
570       // Prepare the operand vector.
571       for (unsigned j = 0; j < VL.size(); ++j)
572         Operands.push_back(cast<Instruction>(VL[j])->getOperand(i));
573
574       getTreeUses_rec(Operands, Depth + 1);
575     }
576     return;
577   }
578   case Instruction::Store: {
579     ValueList Operands;
580     for (unsigned j = 0; j < VL.size(); ++j)
581       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
582     getTreeUses_rec(Operands, Depth + 1);
583     return;
584   }
585   default:
586     return MustGather.insert(VL.begin(), VL.end());
587   }
588 }
589
590 int FuncSLP::getLastIndex(ArrayRef<Value *> VL) {
591   BasicBlock *BB = cast<Instruction>(VL[0])->getParent();
592   assert(BB == getSameBlock(VL) && BlocksNumbers.count(BB) && "Invalid block");
593   BlockNumbering &BN = BlocksNumbers[BB];
594
595   int MaxIdx = BN.getIndex(BB->getFirstNonPHI());
596   for (unsigned i = 0, e = VL.size(); i < e; ++i)
597     MaxIdx = std::max(MaxIdx, BN.getIndex(cast<Instruction>(VL[i])));
598   return MaxIdx;
599 }
600
601 Instruction *FuncSLP::getLastInstruction(ArrayRef<Value *> VL) {
602   BasicBlock *BB = cast<Instruction>(VL[0])->getParent();
603   assert(BB == getSameBlock(VL) && BlocksNumbers.count(BB) && "Invalid block");
604   BlockNumbering &BN = BlocksNumbers[BB];
605
606   int MaxIdx = BN.getIndex(cast<Instruction>(VL[0]));
607   for (unsigned i = 1, e = VL.size(); i < e; ++i)
608     MaxIdx = std::max(MaxIdx, BN.getIndex(cast<Instruction>(VL[i])));
609   return BN.getInstruction(MaxIdx);
610 }
611
612 Instruction *FuncSLP::getInstructionForIndex(unsigned Index, BasicBlock *BB) {
613   BlockNumbering &BN = BlocksNumbers[BB];
614   return BN.getInstruction(Index);
615 }
616
617 int FuncSLP::getFirstUserIndex(ArrayRef<Value *> VL) {
618   BasicBlock *BB = getSameBlock(VL);
619   assert(BB && "All instructions must come from the same block");
620   BlockNumbering &BN = BlocksNumbers[BB];
621
622   // Find the first user of the values.
623   int FirstUser = BN.getIndex(BB->getTerminator());
624   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
625     for (Value::use_iterator U = VL[i]->use_begin(), UE = VL[i]->use_end();
626          U != UE; ++U) {
627       Instruction *Instr = dyn_cast<Instruction>(*U);
628
629       if (!Instr || Instr->getParent() != BB)
630         continue;
631
632       FirstUser = std::min(FirstUser, BN.getIndex(Instr));
633     }
634   }
635   return FirstUser;
636 }
637
638 int FuncSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
639   Type *ScalarTy = VL[0]->getType();
640
641   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
642     ScalarTy = SI->getValueOperand()->getType();
643
644   /// Don't mess with vectors.
645   if (ScalarTy->isVectorTy())
646     return FuncSLP::MAX_COST;
647
648   if (allConstant(VL))
649     return 0;
650
651   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
652
653   if (isSplat(VL))
654     return TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, 0);
655
656   int GatherCost = getGatherCost(VecTy);
657   if (Depth == RecursionMaxDepth || needToGatherAny(VL))
658     return GatherCost;
659
660   BasicBlock *BB = getSameBlock(VL);
661   unsigned Opcode = getSameOpcode(VL);
662   assert(Opcode && BB && "Invalid Instruction Value");
663
664   // Check if it is safe to sink the loads or the stores.
665   if (Opcode == Instruction::Load || Opcode == Instruction::Store) {
666     int MaxIdx = getLastIndex(VL);
667     Instruction *Last = getInstructionForIndex(MaxIdx, BB);
668
669     for (unsigned i = 0, e = VL.size(); i < e; ++i) {
670       if (VL[i] == Last)
671         continue;
672       Value *Barrier = getSinkBarrier(cast<Instruction>(VL[i]), Last);
673       if (Barrier) {
674         DEBUG(dbgs() << "SLP: Can't sink " << *VL[i] << "\n down to " << *Last
675                      << "\n because of " << *Barrier << "\n");
676         return MAX_COST;
677       }
678     }
679   }
680
681   Instruction *VL0 = cast<Instruction>(VL[0]);
682   switch (Opcode) {
683   case Instruction::PHI: {
684     PHINode *PH = dyn_cast<PHINode>(VL0);
685
686     // Stop self cycles.
687     if (VisitedPHIs.count(PH))
688         return 0;
689
690     VisitedPHIs.insert(PH);
691     int TotalCost = 0;
692     // Calculate the cost of all of the operands.
693     for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {      
694       ValueList Operands;
695       // Prepare the operand vector.
696       for (unsigned j = 0; j < VL.size(); ++j)
697         Operands.push_back(cast<PHINode>(VL[j])->getIncomingValue(i));
698
699       int Cost = getTreeCost_rec(Operands, Depth + 1);
700       if (Cost == MAX_COST)
701         return MAX_COST;
702       TotalCost += TotalCost;
703     }
704
705     if (TotalCost > GatherCost) {
706       MustGather.insert(VL.begin(), VL.end());
707       return GatherCost;
708     }
709
710     return TotalCost;
711   }
712   case Instruction::ExtractElement: {
713     if (CanReuseExtract(VL, VL.size(), VecTy))
714       return 0;
715     return getGatherCost(VecTy);
716   }
717   case Instruction::ZExt:
718   case Instruction::SExt:
719   case Instruction::FPToUI:
720   case Instruction::FPToSI:
721   case Instruction::FPExt:
722   case Instruction::PtrToInt:
723   case Instruction::IntToPtr:
724   case Instruction::SIToFP:
725   case Instruction::UIToFP:
726   case Instruction::Trunc:
727   case Instruction::FPTrunc:
728   case Instruction::BitCast: {
729     ValueList Operands;
730     Type *SrcTy = VL0->getOperand(0)->getType();
731     // Prepare the operand vector.
732     for (unsigned j = 0; j < VL.size(); ++j) {
733       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
734       // Check that the casted type is the same for all users.
735       if (cast<Instruction>(VL[j])->getOperand(0)->getType() != SrcTy)
736         return getGatherCost(VecTy);
737     }
738
739     int Cost = getTreeCost_rec(Operands, Depth + 1);
740     if (Cost == FuncSLP::MAX_COST)
741       return Cost;
742
743     // Calculate the cost of this instruction.
744     int ScalarCost = VL.size() * TTI->getCastInstrCost(VL0->getOpcode(),
745                                                        VL0->getType(), SrcTy);
746
747     VectorType *SrcVecTy = VectorType::get(SrcTy, VL.size());
748     int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy);
749     Cost += (VecCost - ScalarCost);
750
751     if (Cost > GatherCost) {
752       MustGather.insert(VL.begin(), VL.end());
753       return GatherCost;
754     }
755
756     return Cost;
757   }
758   case Instruction::FCmp:
759   case Instruction::ICmp: {
760     // Check that all of the compares have the same predicate.
761     CmpInst::Predicate P0 = dyn_cast<CmpInst>(VL0)->getPredicate();
762     for (unsigned i = 1, e = VL.size(); i < e; ++i) {
763       CmpInst *Cmp = cast<CmpInst>(VL[i]);
764       if (Cmp->getPredicate() != P0)
765         return getGatherCost(VecTy);
766     }
767     // Fall through.
768   }
769   case Instruction::Select:
770   case Instruction::Add:
771   case Instruction::FAdd:
772   case Instruction::Sub:
773   case Instruction::FSub:
774   case Instruction::Mul:
775   case Instruction::FMul:
776   case Instruction::UDiv:
777   case Instruction::SDiv:
778   case Instruction::FDiv:
779   case Instruction::URem:
780   case Instruction::SRem:
781   case Instruction::FRem:
782   case Instruction::Shl:
783   case Instruction::LShr:
784   case Instruction::AShr:
785   case Instruction::And:
786   case Instruction::Or:
787   case Instruction::Xor: {
788     int TotalCost = 0;
789     // Calculate the cost of all of the operands.
790     for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
791       ValueList Operands;
792       // Prepare the operand vector.
793       for (unsigned j = 0; j < VL.size(); ++j)
794         Operands.push_back(cast<Instruction>(VL[j])->getOperand(i));
795
796       int Cost = getTreeCost_rec(Operands, Depth + 1);
797       if (Cost == MAX_COST)
798         return MAX_COST;
799       TotalCost += Cost;
800     }
801
802     // Calculate the cost of this instruction.
803     int ScalarCost = 0;
804     int VecCost = 0;
805     if (Opcode == Instruction::FCmp || Opcode == Instruction::ICmp ||
806         Opcode == Instruction::Select) {
807       VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size());
808       ScalarCost =
809           VecTy->getNumElements() *
810           TTI->getCmpSelInstrCost(Opcode, ScalarTy, Builder.getInt1Ty());
811       VecCost = TTI->getCmpSelInstrCost(Opcode, VecTy, MaskTy);
812     } else {
813       ScalarCost = VecTy->getNumElements() *
814                    TTI->getArithmeticInstrCost(Opcode, ScalarTy);
815       VecCost = TTI->getArithmeticInstrCost(Opcode, VecTy);
816     }
817     TotalCost += (VecCost - ScalarCost);
818
819     if (TotalCost > GatherCost) {
820       MustGather.insert(VL.begin(), VL.end());
821       return GatherCost;
822     }
823
824     return TotalCost;
825   }
826   case Instruction::Load: {
827     // If we are scalarize the loads, add the cost of forming the vector.
828     for (unsigned i = 0, e = VL.size() - 1; i < e; ++i)
829       if (!isConsecutiveAccess(VL[i], VL[i + 1]))
830         return getGatherCost(VecTy);
831
832     // Cost of wide load - cost of scalar loads.
833     int ScalarLdCost = VecTy->getNumElements() *
834                        TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0);
835     int VecLdCost = TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0);
836     int TotalCost = VecLdCost - ScalarLdCost;
837
838     if (TotalCost > GatherCost) {
839       MustGather.insert(VL.begin(), VL.end());
840       return GatherCost;
841     }
842
843     return TotalCost;
844   }
845   case Instruction::Store: {
846     // We know that we can merge the stores. Calculate the cost.
847     int ScalarStCost = VecTy->getNumElements() *
848                        TTI->getMemoryOpCost(Instruction::Store, ScalarTy, 1, 0);
849     int VecStCost = TTI->getMemoryOpCost(Instruction::Store, ScalarTy, 1, 0);
850     int StoreCost = VecStCost - ScalarStCost;
851
852     ValueList Operands;
853     for (unsigned j = 0; j < VL.size(); ++j) {
854       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
855       MemBarrierIgnoreList.insert(VL[j]);
856     }
857
858     int Cost = getTreeCost_rec(Operands, Depth + 1);
859     if (Cost == MAX_COST)
860       return MAX_COST;
861
862     int TotalCost = StoreCost + Cost;
863     return TotalCost;
864   }
865   default:
866     // Unable to vectorize unknown instructions.
867     return getGatherCost(VecTy);
868   }
869 }
870
871 int FuncSLP::getTreeCost(ArrayRef<Value *> VL) {
872   // Get rid of the list of stores that were removed, and from the
873   // lists of instructions with multiple users.
874   MemBarrierIgnoreList.clear();
875   LaneMap.clear();
876   MultiUserVals.clear();
877   MustGather.clear();
878   VisitedPHIs.clear();
879
880   if (!getSameBlock(VL))
881     return MAX_COST;
882
883   // Find the location of the last root.
884   int LastRootIndex = getLastIndex(VL);
885   int FirstUserIndex = getFirstUserIndex(VL);
886
887   // Don't vectorize if there are users of the tree roots inside the tree
888   // itself.
889   if (LastRootIndex > FirstUserIndex)
890     return MAX_COST;
891
892   // Scan the tree and find which value is used by which lane, and which values
893   // must be scalarized.
894   getTreeUses_rec(VL, 0);
895
896   // Check that instructions with multiple users can be vectorized. Mark unsafe
897   // instructions.
898   for (SetVector<Value *>::iterator it = MultiUserVals.begin(),
899                                     e = MultiUserVals.end();
900        it != e; ++it) {
901     // Check that all of the users of this instr are within the tree.
902     for (Value::use_iterator I = (*it)->use_begin(), E = (*it)->use_end();
903          I != E; ++I) {
904       if (LaneMap.find(*I) == LaneMap.end()) {
905         DEBUG(dbgs() << "SLP: Adding to MustExtract "
906                         "because of an out of tree usage.\n");
907         MustGather.insert(*it);
908         continue;
909       }
910     }
911   }
912
913   // Now calculate the cost of vectorizing the tree.
914   return getTreeCost_rec(VL, 0);
915 }
916 bool FuncSLP::vectorizeStoreChain(ArrayRef<Value *> Chain, int CostThreshold) {
917   unsigned ChainLen = Chain.size();
918   DEBUG(dbgs() << "SLP: Analyzing a store chain of length " << ChainLen
919                << "\n");
920   Type *StoreTy = cast<StoreInst>(Chain[0])->getValueOperand()->getType();
921   unsigned Sz = DL->getTypeSizeInBits(StoreTy);
922   unsigned VF = MinVecRegSize / Sz;
923
924   if (!isPowerOf2_32(Sz) || VF < 2)
925     return false;
926
927   bool Changed = false;
928   // Look for profitable vectorizable trees at all offsets, starting at zero.
929   for (unsigned i = 0, e = ChainLen; i < e; ++i) {
930     if (i + VF > e)
931       break;
932     DEBUG(dbgs() << "SLP: Analyzing " << VF << " stores at offset " << i
933                  << "\n");
934     ArrayRef<Value *> Operands = Chain.slice(i, VF);
935
936     int Cost = getTreeCost(Operands);
937     if (Cost == FuncSLP::MAX_COST)
938       continue;
939     DEBUG(dbgs() << "SLP: Found cost=" << Cost << " for VF=" << VF << "\n");
940     if (Cost < CostThreshold) {
941       DEBUG(dbgs() << "SLP: Decided to vectorize cost=" << Cost << "\n");
942       vectorizeTree(Operands);
943
944       // Remove the scalar stores.
945       for (int j = 0, e = VF; j < e; ++j)
946         cast<Instruction>(Operands[j])->eraseFromParent();
947
948       // Move to the next bundle.
949       i += VF - 1;
950       Changed = true;
951     }
952   }
953
954   if (Changed || ChainLen > VF)
955     return Changed;
956
957   // Handle short chains. This helps us catch types such as <3 x float> that
958   // are smaller than vector size.
959   int Cost = getTreeCost(Chain);
960   if (Cost == FuncSLP::MAX_COST)
961     return false;
962   if (Cost < CostThreshold) {
963     DEBUG(dbgs() << "SLP: Found store chain cost = " << Cost
964                  << " for size = " << ChainLen << "\n");
965     vectorizeTree(Chain);
966
967     // Remove all of the scalar stores.
968     for (int i = 0, e = Chain.size(); i < e; ++i)
969       cast<Instruction>(Chain[i])->eraseFromParent();
970
971     return true;
972   }
973
974   return false;
975 }
976
977 bool FuncSLP::vectorizeStores(ArrayRef<StoreInst *> Stores, int costThreshold) {
978   SetVector<Value *> Heads, Tails;
979   SmallDenseMap<Value *, Value *> ConsecutiveChain;
980
981   // We may run into multiple chains that merge into a single chain. We mark the
982   // stores that we vectorized so that we don't visit the same store twice.
983   ValueSet VectorizedStores;
984   bool Changed = false;
985
986   // Do a quadratic search on all of the given stores and find
987   // all of the pairs of loads that follow each other.
988   for (unsigned i = 0, e = Stores.size(); i < e; ++i)
989     for (unsigned j = 0; j < e; ++j) {
990       if (i == j)
991         continue;
992
993       if (isConsecutiveAccess(Stores[i], Stores[j])) {
994         Tails.insert(Stores[j]);
995         Heads.insert(Stores[i]);
996         ConsecutiveChain[Stores[i]] = Stores[j];
997       }
998     }
999
1000   // For stores that start but don't end a link in the chain:
1001   for (SetVector<Value *>::iterator it = Heads.begin(), e = Heads.end();
1002        it != e; ++it) {
1003     if (Tails.count(*it))
1004       continue;
1005
1006     // We found a store instr that starts a chain. Now follow the chain and try
1007     // to vectorize it.
1008     ValueList Operands;
1009     Value *I = *it;
1010     // Collect the chain into a list.
1011     while (Tails.count(I) || Heads.count(I)) {
1012       if (VectorizedStores.count(I))
1013         break;
1014       Operands.push_back(I);
1015       // Move to the next value in the chain.
1016       I = ConsecutiveChain[I];
1017     }
1018
1019     bool Vectorized = vectorizeStoreChain(Operands, costThreshold);
1020
1021     // Mark the vectorized stores so that we don't vectorize them again.
1022     if (Vectorized)
1023       VectorizedStores.insert(Operands.begin(), Operands.end());
1024     Changed |= Vectorized;
1025   }
1026
1027   return Changed;
1028 }
1029
1030 Value *FuncSLP::Gather(ArrayRef<Value *> VL, VectorType *Ty) {
1031   Value *Vec = UndefValue::get(Ty);
1032   // Generate the 'InsertElement' instruction.
1033   for (unsigned i = 0; i < Ty->getNumElements(); ++i) {
1034     Vec = Builder.CreateInsertElement(Vec, VL[i], Builder.getInt32(i));
1035     if (Instruction *I = dyn_cast<Instruction>(Vec))
1036       GatherSeq.insert(I);
1037   }
1038
1039   return Vec;
1040 }
1041
1042 Value *FuncSLP::vectorizeTree_rec(ArrayRef<Value *> VL) {
1043   BuilderLocGuard Guard(Builder);
1044
1045   Type *ScalarTy = VL[0]->getType();
1046   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
1047     ScalarTy = SI->getValueOperand()->getType();
1048   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
1049
1050   if (needToGatherAny(VL))
1051     return Gather(VL, VecTy);
1052
1053   if (VectorizedValues.count(VL[0])) {
1054     DEBUG(dbgs() << "SLP: Diamond merged at depth.\n");
1055     return VectorizedValues[VL[0]];
1056   }
1057
1058   Instruction *VL0 = cast<Instruction>(VL[0]);
1059   unsigned Opcode = VL0->getOpcode();
1060   assert(Opcode == getSameOpcode(VL) && "Invalid opcode");
1061
1062   switch (Opcode) {
1063   case Instruction::PHI: {
1064     PHINode *PH = dyn_cast<PHINode>(VL0);
1065     Builder.SetInsertPoint(PH->getParent()->getFirstInsertionPt());
1066     PHINode *NewPhi = Builder.CreatePHI(VecTy, PH->getNumIncomingValues());
1067     VectorizedValues[VL0] = NewPhi;
1068
1069     for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {
1070       ValueList Operands;
1071       BasicBlock *IBB = PH->getIncomingBlock(i);
1072
1073       // Prepare the operand vector.
1074       for (unsigned j = 0; j < VL.size(); ++j)
1075         Operands.push_back(cast<PHINode>(VL[j])->getIncomingValueForBlock(IBB));
1076
1077       Builder.SetInsertPoint(IBB->getTerminator());
1078       Value *Vec = vectorizeTree_rec(Operands);
1079       NewPhi->addIncoming(Vec, IBB);
1080     }
1081
1082     assert(NewPhi->getNumIncomingValues() == PH->getNumIncomingValues() &&
1083            "Invalid number of incoming values");
1084     return NewPhi;
1085   }
1086
1087   case Instruction::ExtractElement: {
1088     if (CanReuseExtract(VL, VL.size(), VecTy))
1089       return VL0->getOperand(0);
1090     return Gather(VL, VecTy);
1091   }
1092   case Instruction::ZExt:
1093   case Instruction::SExt:
1094   case Instruction::FPToUI:
1095   case Instruction::FPToSI:
1096   case Instruction::FPExt:
1097   case Instruction::PtrToInt:
1098   case Instruction::IntToPtr:
1099   case Instruction::SIToFP:
1100   case Instruction::UIToFP:
1101   case Instruction::Trunc:
1102   case Instruction::FPTrunc:
1103   case Instruction::BitCast: {
1104     ValueList INVL;
1105     for (int i = 0, e = VL.size(); i < e; ++i)
1106       INVL.push_back(cast<Instruction>(VL[i])->getOperand(0));
1107
1108     Builder.SetInsertPoint(getLastInstruction(VL));
1109     Value *InVec = vectorizeTree_rec(INVL);
1110     CastInst *CI = dyn_cast<CastInst>(VL0);
1111     Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy);
1112     VectorizedValues[VL0] = V;
1113     return V;
1114   }
1115   case Instruction::FCmp:
1116   case Instruction::ICmp: {
1117     // Check that all of the compares have the same predicate.
1118     CmpInst::Predicate P0 = dyn_cast<CmpInst>(VL0)->getPredicate();
1119     for (unsigned i = 1, e = VL.size(); i < e; ++i) {
1120       CmpInst *Cmp = cast<CmpInst>(VL[i]);
1121       if (Cmp->getPredicate() != P0)
1122         return Gather(VL, VecTy);
1123     }
1124
1125     ValueList LHSV, RHSV;
1126     for (int i = 0, e = VL.size(); i < e; ++i) {
1127       LHSV.push_back(cast<Instruction>(VL[i])->getOperand(0));
1128       RHSV.push_back(cast<Instruction>(VL[i])->getOperand(1));
1129     }
1130
1131     Builder.SetInsertPoint(getLastInstruction(VL));
1132     Value *L = vectorizeTree_rec(LHSV);
1133     Value *R = vectorizeTree_rec(RHSV);
1134     Value *V;
1135
1136     if (Opcode == Instruction::FCmp)
1137       V = Builder.CreateFCmp(P0, L, R);
1138     else
1139       V = Builder.CreateICmp(P0, L, R);
1140
1141     VectorizedValues[VL0] = V;
1142     return V;
1143   }
1144   case Instruction::Select: {
1145     ValueList TrueVec, FalseVec, CondVec;
1146     for (int i = 0, e = VL.size(); i < e; ++i) {
1147       CondVec.push_back(cast<Instruction>(VL[i])->getOperand(0));
1148       TrueVec.push_back(cast<Instruction>(VL[i])->getOperand(1));
1149       FalseVec.push_back(cast<Instruction>(VL[i])->getOperand(2));
1150     }
1151
1152     Builder.SetInsertPoint(getLastInstruction(VL));
1153     Value *True = vectorizeTree_rec(TrueVec);
1154     Value *False = vectorizeTree_rec(FalseVec);
1155     Value *Cond = vectorizeTree_rec(CondVec);
1156     Value *V = Builder.CreateSelect(Cond, True, False);
1157     VectorizedValues[VL0] = V;
1158     return V;
1159   }
1160   case Instruction::Add:
1161   case Instruction::FAdd:
1162   case Instruction::Sub:
1163   case Instruction::FSub:
1164   case Instruction::Mul:
1165   case Instruction::FMul:
1166   case Instruction::UDiv:
1167   case Instruction::SDiv:
1168   case Instruction::FDiv:
1169   case Instruction::URem:
1170   case Instruction::SRem:
1171   case Instruction::FRem:
1172   case Instruction::Shl:
1173   case Instruction::LShr:
1174   case Instruction::AShr:
1175   case Instruction::And:
1176   case Instruction::Or:
1177   case Instruction::Xor: {
1178     ValueList LHSVL, RHSVL;
1179     for (int i = 0, e = VL.size(); i < e; ++i) {
1180       LHSVL.push_back(cast<Instruction>(VL[i])->getOperand(0));
1181       RHSVL.push_back(cast<Instruction>(VL[i])->getOperand(1));
1182     }
1183
1184     Builder.SetInsertPoint(getLastInstruction(VL));
1185     Value *LHS = vectorizeTree_rec(LHSVL);
1186     Value *RHS = vectorizeTree_rec(RHSVL);
1187
1188     if (LHS == RHS) {
1189       assert((VL0->getOperand(0) == VL0->getOperand(1)) && "Invalid order");
1190     }
1191
1192     BinaryOperator *BinOp = cast<BinaryOperator>(VL0);
1193     Value *V = Builder.CreateBinOp(BinOp->getOpcode(), LHS, RHS);
1194     VectorizedValues[VL0] = V;
1195     return V;
1196   }
1197   case Instruction::Load: {
1198     // Check if all of the loads are consecutive.
1199     for (unsigned i = 1, e = VL.size(); i < e; ++i)
1200       if (!isConsecutiveAccess(VL[i - 1], VL[i]))
1201         return Gather(VL, VecTy);
1202
1203     // Loads are inserted at the head of the tree because we don't want to
1204     // sink them all the way down past store instructions.
1205     Builder.SetInsertPoint(getLastInstruction(VL));
1206     LoadInst *LI = cast<LoadInst>(VL0);
1207     Value *VecPtr =
1208         Builder.CreateBitCast(LI->getPointerOperand(), VecTy->getPointerTo());
1209     unsigned Alignment = LI->getAlignment();
1210     LI = Builder.CreateLoad(VecPtr);
1211     LI->setAlignment(Alignment);
1212
1213     VectorizedValues[VL0] = LI;
1214     return LI;
1215   }
1216   case Instruction::Store: {
1217     StoreInst *SI = cast<StoreInst>(VL0);
1218     unsigned Alignment = SI->getAlignment();
1219
1220     ValueList ValueOp;
1221     for (int i = 0, e = VL.size(); i < e; ++i)
1222       ValueOp.push_back(cast<StoreInst>(VL[i])->getValueOperand());
1223
1224     Value *VecValue = vectorizeTree_rec(ValueOp);
1225
1226     Builder.SetInsertPoint(getLastInstruction(VL));
1227     Value *VecPtr =
1228         Builder.CreateBitCast(SI->getPointerOperand(), VecTy->getPointerTo());
1229     Builder.CreateStore(VecValue, VecPtr)->setAlignment(Alignment);
1230     return 0;
1231   }
1232   default:
1233     return Gather(VL, VecTy);
1234   }
1235 }
1236
1237 Value *FuncSLP::vectorizeTree(ArrayRef<Value *> VL) {
1238   Builder.SetInsertPoint(getLastInstruction(VL));
1239   Value *V = vectorizeTree_rec(VL);
1240
1241   // We moved some instructions around. We have to number them again
1242   // before we can do any analysis.
1243   for (Function::iterator it = F->begin(), e = F->end(); it != e; ++it)
1244     BlocksNumbers[it].forget();
1245   // Clear the state.
1246   MustGather.clear();
1247   VisitedPHIs.clear();
1248   VectorizedValues.clear();
1249   MemBarrierIgnoreList.clear();
1250   return V;
1251 }
1252
1253 Value *FuncSLP::vectorizeArith(ArrayRef<Value *> Operands) {
1254   Value *Vec = vectorizeTree(Operands);
1255   // After vectorizing the operands we need to generate extractelement
1256   // instructions and replace all of the uses of the scalar values with
1257   // the values that we extracted from the vectorized tree.
1258   for (unsigned i = 0, e = Operands.size(); i != e; ++i) {
1259     Value *S = Builder.CreateExtractElement(Vec, Builder.getInt32(i));
1260     Operands[i]->replaceAllUsesWith(S);
1261   }
1262
1263   return Vec;
1264 }
1265
1266 void FuncSLP::optimizeGatherSequence() {
1267   // LICM InsertElementInst sequences.
1268   for (SetVector<Instruction *>::iterator it = GatherSeq.begin(),
1269        e = GatherSeq.end(); it != e; ++it) {
1270     InsertElementInst *Insert = dyn_cast<InsertElementInst>(*it);
1271
1272     if (!Insert)
1273       continue;
1274
1275     // Check if this block is inside a loop.
1276     Loop *L = LI->getLoopFor(Insert->getParent());
1277     if (!L)
1278       continue;
1279
1280     // Check if it has a preheader.
1281     BasicBlock *PreHeader = L->getLoopPreheader();
1282     if (!PreHeader)
1283       return;
1284
1285     // If the vector or the element that we insert into it are
1286     // instructions that are defined in this basic block then we can't
1287     // hoist this instruction.
1288     Instruction *CurrVec = dyn_cast<Instruction>(Insert->getOperand(0));
1289     Instruction *NewElem = dyn_cast<Instruction>(Insert->getOperand(1));
1290     if (CurrVec && L->contains(CurrVec))
1291       continue;
1292     if (NewElem && L->contains(NewElem))
1293       continue;
1294
1295     // We can hoist this instruction. Move it to the pre-header.
1296     Insert->moveBefore(PreHeader->getTerminator());
1297   }
1298
1299   // Perform O(N^2) search over the gather sequences and merge identical
1300   // instructions. TODO: We can further optimize this scan if we split the
1301   // instructions into different buckets based on the insert lane.
1302   SmallPtrSet<Instruction*, 16> Visited;
1303   ReversePostOrderTraversal<Function*> RPOT(F);
1304   for (ReversePostOrderTraversal<Function*>::rpo_iterator I = RPOT.begin(),
1305        E = RPOT.end(); I != E; ++I) {
1306     BasicBlock *BB = *I;
1307     // For all instructions in the function:
1308     for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
1309       InsertElementInst *Insert = dyn_cast<InsertElementInst>(it);
1310       if (!Insert || !GatherSeq.count(Insert))
1311         continue;
1312
1313      // Check if we can replace this instruction with any of the
1314      // visited instructions.
1315       for (SmallPtrSet<Instruction*, 16>::iterator v = Visited.begin(),
1316            ve = Visited.end(); v != ve; ++v) {
1317         if (Insert->isIdenticalTo(*v) &&
1318           DT->dominates((*v)->getParent(), Insert->getParent())) {
1319           Insert->replaceAllUsesWith(*v);
1320           break;
1321         }
1322       }
1323       Visited.insert(Insert);
1324     }
1325   }
1326 }
1327
1328 /// The SLPVectorizer Pass.
1329 struct SLPVectorizer : public FunctionPass {
1330   typedef SmallVector<StoreInst *, 8> StoreList;
1331   typedef MapVector<Value *, StoreList> StoreListMap;
1332
1333   /// Pass identification, replacement for typeid
1334   static char ID;
1335
1336   explicit SLPVectorizer() : FunctionPass(ID) {
1337     initializeSLPVectorizerPass(*PassRegistry::getPassRegistry());
1338   }
1339
1340   ScalarEvolution *SE;
1341   DataLayout *DL;
1342   TargetTransformInfo *TTI;
1343   AliasAnalysis *AA;
1344   LoopInfo *LI;
1345   DominatorTree *DT;
1346
1347   virtual bool runOnFunction(Function &F) {
1348     SE = &getAnalysis<ScalarEvolution>();
1349     DL = getAnalysisIfAvailable<DataLayout>();
1350     TTI = &getAnalysis<TargetTransformInfo>();
1351     AA = &getAnalysis<AliasAnalysis>();
1352     LI = &getAnalysis<LoopInfo>();
1353     DT = &getAnalysis<DominatorTree>();
1354
1355     StoreRefs.clear();
1356     bool Changed = false;
1357
1358     // Must have DataLayout. We can't require it because some tests run w/o
1359     // triple.
1360     if (!DL)
1361       return false;
1362
1363     DEBUG(dbgs() << "SLP: Analyzing blocks in " << F.getName() << ".\n");
1364
1365     // Use the bollom up slp vectorizer to construct chains that start with
1366     // he store instructions.
1367     FuncSLP R(&F, SE, DL, TTI, AA, LI, DT);
1368
1369     for (Function::iterator it = F.begin(), e = F.end(); it != e; ++it) {
1370       BasicBlock *BB = it;
1371
1372       // Vectorize trees that end at reductions.
1373       Changed |= vectorizeChainsInBlock(BB, R);
1374
1375       // Vectorize trees that end at stores.
1376       if (unsigned count = collectStores(BB, R)) {
1377         (void)count;
1378         DEBUG(dbgs() << "SLP: Found " << count << " stores to vectorize.\n");
1379         Changed |= vectorizeStoreChains(R);
1380       }
1381     }
1382
1383     if (Changed) {
1384       R.optimizeGatherSequence();
1385       DEBUG(dbgs() << "SLP: vectorized \"" << F.getName() << "\"\n");
1386       DEBUG(verifyFunction(F));
1387     }
1388     return Changed;
1389   }
1390
1391   virtual void getAnalysisUsage(AnalysisUsage &AU) const {
1392     FunctionPass::getAnalysisUsage(AU);
1393     AU.addRequired<ScalarEvolution>();
1394     AU.addRequired<AliasAnalysis>();
1395     AU.addRequired<TargetTransformInfo>();
1396     AU.addRequired<LoopInfo>();
1397     AU.addRequired<DominatorTree>();
1398   }
1399
1400 private:
1401
1402   /// \brief Collect memory references and sort them according to their base
1403   /// object. We sort the stores to their base objects to reduce the cost of the
1404   /// quadratic search on the stores. TODO: We can further reduce this cost
1405   /// if we flush the chain creation every time we run into a memory barrier.
1406   unsigned collectStores(BasicBlock *BB, FuncSLP &R);
1407
1408   /// \brief Try to vectorize a chain that starts at two arithmetic instrs.
1409   bool tryToVectorizePair(Value *A, Value *B, FuncSLP &R);
1410
1411   /// \brief Try to vectorize a list of operands. If \p NeedExtracts is true
1412   /// then we calculate the cost of extracting the scalars from the vector.
1413   /// \returns true if a value was vectorized.
1414   bool tryToVectorizeList(ArrayRef<Value *> VL, FuncSLP &R, bool NeedExtracts);
1415
1416   /// \brief Try to vectorize a chain that may start at the operands of \V;
1417   bool tryToVectorize(BinaryOperator *V, FuncSLP &R);
1418
1419   /// \brief Vectorize the stores that were collected in StoreRefs.
1420   bool vectorizeStoreChains(FuncSLP &R);
1421
1422   /// \brief Scan the basic block and look for patterns that are likely to start
1423   /// a vectorization chain.
1424   bool vectorizeChainsInBlock(BasicBlock *BB, FuncSLP &R);
1425
1426 private:
1427   StoreListMap StoreRefs;
1428 };
1429
1430 unsigned SLPVectorizer::collectStores(BasicBlock *BB, FuncSLP &R) {
1431   unsigned count = 0;
1432   StoreRefs.clear();
1433   for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
1434     StoreInst *SI = dyn_cast<StoreInst>(it);
1435     if (!SI)
1436       continue;
1437
1438     // Check that the pointer points to scalars.
1439     Type *Ty = SI->getValueOperand()->getType();
1440     if (Ty->isAggregateType() || Ty->isVectorTy())
1441       return 0;
1442
1443     // Find the base of the GEP.
1444     Value *Ptr = SI->getPointerOperand();
1445     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr))
1446       Ptr = GEP->getPointerOperand();
1447
1448     // Save the store locations.
1449     StoreRefs[Ptr].push_back(SI);
1450     count++;
1451   }
1452   return count;
1453 }
1454
1455 bool SLPVectorizer::tryToVectorizePair(Value *A, Value *B, FuncSLP &R) {
1456   if (!A || !B)
1457     return false;
1458   Value *VL[] = { A, B };
1459   return tryToVectorizeList(VL, R, true);
1460 }
1461
1462 bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, FuncSLP &R,
1463                                        bool NeedExtracts) {
1464   if (VL.size() < 2)
1465     return false;
1466
1467   DEBUG(dbgs() << "SLP: Vectorizing a list of length = " << VL.size() << ".\n");
1468
1469   // Check that all of the parts are scalar instructions of the same type.
1470   Instruction *I0 = dyn_cast<Instruction>(VL[0]);
1471   if (!I0)
1472     return 0;
1473
1474   unsigned Opcode0 = I0->getOpcode();
1475
1476   for (int i = 0, e = VL.size(); i < e; ++i) {
1477     Type *Ty = VL[i]->getType();
1478     if (Ty->isAggregateType() || Ty->isVectorTy())
1479       return 0;
1480     Instruction *Inst = dyn_cast<Instruction>(VL[i]);
1481     if (!Inst || Inst->getOpcode() != Opcode0)
1482       return 0;
1483   }
1484
1485   int Cost = R.getTreeCost(VL);
1486   if (Cost == FuncSLP::MAX_COST)
1487     return false;
1488
1489   int ExtrCost = NeedExtracts ? R.getGatherCost(VL) : 0;
1490   DEBUG(dbgs() << "SLP: Cost of pair:" << Cost
1491                << " Cost of extract:" << ExtrCost << ".\n");
1492   if ((Cost + ExtrCost) >= -SLPCostThreshold)
1493     return false;
1494   DEBUG(dbgs() << "SLP: Vectorizing pair.\n");
1495   R.vectorizeArith(VL);
1496   return true;
1497 }
1498
1499 bool SLPVectorizer::tryToVectorize(BinaryOperator *V, FuncSLP &R) {
1500   if (!V)
1501     return false;
1502
1503   // Try to vectorize V.
1504   if (tryToVectorizePair(V->getOperand(0), V->getOperand(1), R))
1505     return true;
1506
1507   BinaryOperator *A = dyn_cast<BinaryOperator>(V->getOperand(0));
1508   BinaryOperator *B = dyn_cast<BinaryOperator>(V->getOperand(1));
1509   // Try to skip B.
1510   if (B && B->hasOneUse()) {
1511     BinaryOperator *B0 = dyn_cast<BinaryOperator>(B->getOperand(0));
1512     BinaryOperator *B1 = dyn_cast<BinaryOperator>(B->getOperand(1));
1513     if (tryToVectorizePair(A, B0, R)) {
1514       B->moveBefore(V);
1515       return true;
1516     }
1517     if (tryToVectorizePair(A, B1, R)) {
1518       B->moveBefore(V);
1519       return true;
1520     }
1521   }
1522
1523   // Try to skip A.
1524   if (A && A->hasOneUse()) {
1525     BinaryOperator *A0 = dyn_cast<BinaryOperator>(A->getOperand(0));
1526     BinaryOperator *A1 = dyn_cast<BinaryOperator>(A->getOperand(1));
1527     if (tryToVectorizePair(A0, B, R)) {
1528       A->moveBefore(V);
1529       return true;
1530     }
1531     if (tryToVectorizePair(A1, B, R)) {
1532       A->moveBefore(V);
1533       return true;
1534     }
1535   }
1536   return 0;
1537 }
1538
1539 bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, FuncSLP &R) {
1540   bool Changed = false;
1541   for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
1542     if (isa<DbgInfoIntrinsic>(it))
1543       continue;
1544
1545     // Try to vectorize reductions that use PHINodes.
1546     if (PHINode *P = dyn_cast<PHINode>(it)) {
1547       // Check that the PHI is a reduction PHI.
1548       if (P->getNumIncomingValues() != 2)
1549         return Changed;
1550       Value *Rdx =
1551           (P->getIncomingBlock(0) == BB
1552                ? (P->getIncomingValue(0))
1553                : (P->getIncomingBlock(1) == BB ? P->getIncomingValue(1) : 0));
1554       // Check if this is a Binary Operator.
1555       BinaryOperator *BI = dyn_cast_or_null<BinaryOperator>(Rdx);
1556       if (!BI)
1557         continue;
1558
1559       Value *Inst = BI->getOperand(0);
1560       if (Inst == P)
1561         Inst = BI->getOperand(1);
1562
1563       Changed |= tryToVectorize(dyn_cast<BinaryOperator>(Inst), R);
1564       continue;
1565     }
1566
1567     // Try to vectorize trees that start at compare instructions.
1568     if (CmpInst *CI = dyn_cast<CmpInst>(it)) {
1569       if (tryToVectorizePair(CI->getOperand(0), CI->getOperand(1), R)) {
1570         Changed |= true;
1571         continue;
1572       }
1573       for (int i = 0; i < 2; ++i)
1574         if (BinaryOperator *BI = dyn_cast<BinaryOperator>(CI->getOperand(i)))
1575           Changed |=
1576               tryToVectorizePair(BI->getOperand(0), BI->getOperand(1), R);
1577       continue;
1578     }
1579   }
1580
1581   // Scan the PHINodes in our successors in search for pairing hints.
1582   for (succ_iterator it = succ_begin(BB), e = succ_end(BB); it != e; ++it) {
1583     BasicBlock *Succ = *it;
1584     SmallVector<Value *, 4> Incoming;
1585
1586     // Collect the incoming values from the PHIs.
1587     for (BasicBlock::iterator instr = Succ->begin(), ie = Succ->end();
1588          instr != ie; ++instr) {
1589       PHINode *P = dyn_cast<PHINode>(instr);
1590
1591       if (!P)
1592         break;
1593
1594       Value *V = P->getIncomingValueForBlock(BB);
1595       if (Instruction *I = dyn_cast<Instruction>(V))
1596         if (I->getParent() == BB)
1597           Incoming.push_back(I);
1598     }
1599
1600     if (Incoming.size() > 1)
1601       Changed |= tryToVectorizeList(Incoming, R, true);
1602   }
1603
1604   return Changed;
1605 }
1606
1607 bool SLPVectorizer::vectorizeStoreChains(FuncSLP &R) {
1608   bool Changed = false;
1609   // Attempt to sort and vectorize each of the store-groups.
1610   for (StoreListMap::iterator it = StoreRefs.begin(), e = StoreRefs.end();
1611        it != e; ++it) {
1612     if (it->second.size() < 2)
1613       continue;
1614
1615     DEBUG(dbgs() << "SLP: Analyzing a store chain of length "
1616                  << it->second.size() << ".\n");
1617
1618     Changed |= R.vectorizeStores(it->second, -SLPCostThreshold);
1619   }
1620   return Changed;
1621 }
1622
1623 } // end anonymous namespace
1624
1625 char SLPVectorizer::ID = 0;
1626 static const char lv_name[] = "SLP Vectorizer";
1627 INITIALIZE_PASS_BEGIN(SLPVectorizer, SV_NAME, lv_name, false, false)
1628 INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
1629 INITIALIZE_AG_DEPENDENCY(TargetTransformInfo)
1630 INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
1631 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
1632 INITIALIZE_PASS_END(SLPVectorizer, SV_NAME, lv_name, false, false)
1633
1634 namespace llvm {
1635 Pass *createSLPVectorizerPass() { return new SLPVectorizer(); }
1636 }