SLP Vectorizer: Fix a bug in the code that does CSE on the generated gather sequences.
[oota-llvm.git] / lib / Transforms / Vectorize / SLPVectorizer.cpp
1 //===- SLPVectorizer.cpp - A bottom up SLP Vectorizer ---------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This pass implements the Bottom Up SLP vectorizer. It detects consecutive
10 // stores that can be put together into vector-stores. Next, it attempts to
11 // construct vectorizable tree using the use-def chains. If a profitable tree
12 // was found, the SLP vectorizer performs vectorization on the tree.
13 //
14 // The pass is inspired by the work described in the paper:
15 //  "Loop-Aware SLP in GCC" by Ira Rosen, Dorit Nuzman, Ayal Zaks.
16 //
17 //===----------------------------------------------------------------------===//
18 #define SV_NAME "slp-vectorizer"
19 #define DEBUG_TYPE "SLP"
20
21 #include "llvm/Transforms/Vectorize.h"
22 #include "llvm/ADT/MapVector.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/Analysis/AliasAnalysis.h"
26 #include "llvm/Analysis/ScalarEvolution.h"
27 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
28 #include "llvm/Analysis/AliasAnalysis.h"
29 #include "llvm/Analysis/TargetTransformInfo.h"
30 #include "llvm/Analysis/Verifier.h"
31 #include "llvm/Analysis/LoopInfo.h"
32 #include "llvm/IR/DataLayout.h"
33 #include "llvm/IR/Instructions.h"
34 #include "llvm/IR/IntrinsicInst.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/Module.h"
37 #include "llvm/IR/Type.h"
38 #include "llvm/IR/Value.h"
39 #include "llvm/Pass.h"
40 #include "llvm/Support/CommandLine.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Support/raw_ostream.h"
43 #include <algorithm>
44 #include <map>
45
46 using namespace llvm;
47
48 static cl::opt<int>
49     SLPCostThreshold("slp-threshold", cl::init(0), cl::Hidden,
50                      cl::desc("Only vectorize trees if the gain is above this "
51                               "number. (gain = -cost of vectorization)"));
52 namespace {
53
54 static const unsigned MinVecRegSize = 128;
55
56 static const unsigned RecursionMaxDepth = 6;
57
58 /// RAII pattern to save the insertion point of the IR builder.
59 class BuilderLocGuard {
60 public:
61   BuilderLocGuard(IRBuilder<> &B) : Builder(B), Loc(B.GetInsertPoint()) {}
62   ~BuilderLocGuard() { Builder.SetInsertPoint(Loc); }
63
64 private:
65   // Prevent copying.
66   BuilderLocGuard(const BuilderLocGuard &);
67   BuilderLocGuard &operator=(const BuilderLocGuard &);
68   IRBuilder<> &Builder;
69   BasicBlock::iterator Loc;
70 };
71
72 /// A helper class for numbering instructions in multible blocks.
73 /// Numbers starts at zero for each basic block.
74 struct BlockNumbering {
75
76   BlockNumbering(BasicBlock *Bb) : BB(Bb), Valid(false) {}
77
78   BlockNumbering() : BB(0), Valid(false) {}
79
80   void numberInstructions() {
81     unsigned Loc = 0;
82     InstrIdx.clear();
83     InstrVec.clear();
84     // Number the instructions in the block.
85     for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
86       InstrIdx[it] = Loc++;
87       InstrVec.push_back(it);
88       assert(InstrVec[InstrIdx[it]] == it && "Invalid allocation");
89     }
90     Valid = true;
91   }
92
93   int getIndex(Instruction *I) {
94     if (!Valid)
95       numberInstructions();
96     assert(InstrIdx.count(I) && "Unknown instruction");
97     return InstrIdx[I];
98   }
99
100   Instruction *getInstruction(unsigned loc) {
101     if (!Valid)
102       numberInstructions();
103     assert(InstrVec.size() > loc && "Invalid Index");
104     return InstrVec[loc];
105   }
106
107   void forget() { Valid = false; }
108
109 private:
110   /// The block we are numbering.
111   BasicBlock *BB;
112   /// Is the block numbered.
113   bool Valid;
114   /// Maps instructions to numbers and back.
115   SmallDenseMap<Instruction *, int> InstrIdx;
116   /// Maps integers to Instructions.
117   std::vector<Instruction *> InstrVec;
118 };
119
120 class FuncSLP {
121   typedef SmallVector<Value *, 8> ValueList;
122   typedef SmallVector<Instruction *, 16> InstrList;
123   typedef SmallPtrSet<Value *, 16> ValueSet;
124   typedef SmallVector<StoreInst *, 8> StoreList;
125
126 public:
127   static const int MAX_COST = INT_MIN;
128
129   FuncSLP(Function *Func, ScalarEvolution *Se, DataLayout *Dl,
130           TargetTransformInfo *Tti, AliasAnalysis *Aa, LoopInfo *Li, 
131           DominatorTree *Dt) :
132     F(Func), SE(Se), DL(Dl), TTI(Tti), AA(Aa), LI(Li), DT(Dt),
133     Builder(Se->getContext()) {
134     for (Function::iterator it = F->begin(), e = F->end(); it != e; ++it) {
135       BasicBlock *BB = it;
136       BlocksNumbers[BB] = BlockNumbering(BB);
137     }
138   }
139
140   /// \brief Take the pointer operand from the Load/Store instruction.
141   /// \returns NULL if this is not a valid Load/Store instruction.
142   static Value *getPointerOperand(Value *I);
143
144   /// \brief Take the address space operand from the Load/Store instruction.
145   /// \returns -1 if this is not a valid Load/Store instruction.
146   static unsigned getAddressSpaceOperand(Value *I);
147
148   /// \returns true if the memory operations A and B are consecutive.
149   bool isConsecutiveAccess(Value *A, Value *B);
150
151   /// \brief Vectorize the tree that starts with the elements in \p VL.
152   /// \returns the vectorized value.
153   Value *vectorizeTree(ArrayRef<Value *> VL);
154
155   /// \returns the vectorization cost of the subtree that starts at \p VL.
156   /// A negative number means that this is profitable.
157   int getTreeCost(ArrayRef<Value *> VL);
158
159   /// \returns the scalarization cost for this list of values. Assuming that
160   /// this subtree gets vectorized, we may need to extract the values from the
161   /// roots. This method calculates the cost of extracting the values.
162   int getGatherCost(ArrayRef<Value *> VL);
163
164   /// \brief Attempts to order and vectorize a sequence of stores. This
165   /// function does a quadratic scan of the given stores.
166   /// \returns true if the basic block was modified.
167   bool vectorizeStores(ArrayRef<StoreInst *> Stores, int costThreshold);
168
169   /// \brief Vectorize a group of scalars into a vector tree.
170   /// \returns the vectorized value.
171   Value *vectorizeArith(ArrayRef<Value *> Operands);
172
173   /// \brief This method contains the recursive part of getTreeCost.
174   int getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth);
175
176   /// \brief This recursive method looks for vectorization hazards such as
177   /// values that are used by multiple users and checks that values are used
178   /// by only one vector lane. It updates the variables LaneMap, MultiUserVals.
179   void getTreeUses_rec(ArrayRef<Value *> VL, unsigned Depth);
180
181   /// \brief This method contains the recursive part of vectorizeTree.
182   Value *vectorizeTree_rec(ArrayRef<Value *> VL);
183
184   ///  \brief Vectorize a sorted sequence of stores.
185   bool vectorizeStoreChain(ArrayRef<Value *> Chain, int CostThreshold);
186
187   /// \returns the scalarization cost for this type. Scalarization in this
188   /// context means the creation of vectors from a group of scalars.
189   int getGatherCost(Type *Ty);
190
191   /// \returns the AA location that is being access by the instruction.
192   AliasAnalysis::Location getLocation(Instruction *I);
193
194   /// \brief Checks if it is possible to sink an instruction from
195   /// \p Src to \p Dst.
196   /// \returns the pointer to the barrier instruction if we can't sink.
197   Value *getSinkBarrier(Instruction *Src, Instruction *Dst);
198
199   /// \returns the index of the last instrucion in the BB from \p VL.
200   int getLastIndex(ArrayRef<Value *> VL);
201
202   /// \returns the Instrucion in the bundle \p VL.
203   Instruction *getLastInstruction(ArrayRef<Value *> VL);
204
205   /// \returns the Instruction at index \p Index which is in Block \p BB.
206   Instruction *getInstructionForIndex(unsigned Index, BasicBlock *BB);
207
208   /// \returns the index of the first User of \p VL.
209   int getFirstUserIndex(ArrayRef<Value *> VL);
210
211   /// \returns a vector from a collection of scalars in \p VL.
212   Value *Gather(ArrayRef<Value *> VL, VectorType *Ty);
213
214   /// \brief Perform LICM and CSE on the newly generated gather sequences.
215   void optimizeGatherSequence();
216
217   bool needToGatherAny(ArrayRef<Value *> VL) {
218     for (int i = 0, e = VL.size(); i < e; ++i)
219       if (MustGather.count(VL[i]))
220         return true;
221     return false;
222   }
223
224   /// -- Vectorization State --
225
226   /// Maps values in the tree to the vector lanes that uses them. This map must
227   /// be reset between runs of getCost.
228   std::map<Value *, int> LaneMap;
229   /// A list of instructions to ignore while sinking
230   /// memory instructions. This map must be reset between runs of getCost.
231   ValueSet MemBarrierIgnoreList;
232
233   /// Maps between the first scalar to the vector. This map must be reset
234   /// between runs.
235   DenseMap<Value *, Value *> VectorizedValues;
236
237   /// Contains values that must be gathered because they are used
238   /// by multiple lanes, or by users outside the tree.
239   /// NOTICE: The vectorization methods also use this set.
240   ValueSet MustGather;
241
242   /// Contains a list of values that are used outside the current tree. This
243   /// set must be reset between runs.
244   SetVector<Value *> MultiUserVals;
245
246   /// Holds all of the instructions that we gathered.
247   SetVector<Instruction *> GatherSeq;
248
249   /// Numbers instructions in different blocks.
250   std::map<BasicBlock *, BlockNumbering> BlocksNumbers;
251
252   // Analysis and block reference.
253   Function *F;
254   ScalarEvolution *SE;
255   DataLayout *DL;
256   TargetTransformInfo *TTI;
257   AliasAnalysis *AA;
258   LoopInfo *LI;
259   DominatorTree *DT;
260   /// Instruction builder to construct the vectorized tree.
261   IRBuilder<> Builder;
262 };
263
264 int FuncSLP::getGatherCost(Type *Ty) {
265   int Cost = 0;
266   for (unsigned i = 0, e = cast<VectorType>(Ty)->getNumElements(); i < e; ++i)
267     Cost += TTI->getVectorInstrCost(Instruction::InsertElement, Ty, i);
268   return Cost;
269 }
270
271 int FuncSLP::getGatherCost(ArrayRef<Value *> VL) {
272   // Find the type of the operands in VL.
273   Type *ScalarTy = VL[0]->getType();
274   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
275     ScalarTy = SI->getValueOperand()->getType();
276   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
277   // Find the cost of inserting/extracting values from the vector.
278   return getGatherCost(VecTy);
279 }
280
281 AliasAnalysis::Location FuncSLP::getLocation(Instruction *I) {
282   if (StoreInst *SI = dyn_cast<StoreInst>(I))
283     return AA->getLocation(SI);
284   if (LoadInst *LI = dyn_cast<LoadInst>(I))
285     return AA->getLocation(LI);
286   return AliasAnalysis::Location();
287 }
288
289 Value *FuncSLP::getPointerOperand(Value *I) {
290   if (LoadInst *LI = dyn_cast<LoadInst>(I))
291     return LI->getPointerOperand();
292   if (StoreInst *SI = dyn_cast<StoreInst>(I))
293     return SI->getPointerOperand();
294   return 0;
295 }
296
297 unsigned FuncSLP::getAddressSpaceOperand(Value *I) {
298   if (LoadInst *L = dyn_cast<LoadInst>(I))
299     return L->getPointerAddressSpace();
300   if (StoreInst *S = dyn_cast<StoreInst>(I))
301     return S->getPointerAddressSpace();
302   return -1;
303 }
304
305 bool FuncSLP::isConsecutiveAccess(Value *A, Value *B) {
306   Value *PtrA = getPointerOperand(A);
307   Value *PtrB = getPointerOperand(B);
308   unsigned ASA = getAddressSpaceOperand(A);
309   unsigned ASB = getAddressSpaceOperand(B);
310
311   // Check that the address spaces match and that the pointers are valid.
312   if (!PtrA || !PtrB || (ASA != ASB))
313     return false;
314
315   // Check that A and B are of the same type.
316   if (PtrA->getType() != PtrB->getType())
317     return false;
318
319   // Calculate the distance.
320   const SCEV *PtrSCEVA = SE->getSCEV(PtrA);
321   const SCEV *PtrSCEVB = SE->getSCEV(PtrB);
322   const SCEV *OffsetSCEV = SE->getMinusSCEV(PtrSCEVA, PtrSCEVB);
323   const SCEVConstant *ConstOffSCEV = dyn_cast<SCEVConstant>(OffsetSCEV);
324
325   // Non constant distance.
326   if (!ConstOffSCEV)
327     return false;
328
329   int64_t Offset = ConstOffSCEV->getValue()->getSExtValue();
330   Type *Ty = cast<PointerType>(PtrA->getType())->getElementType();
331   // The Instructions are connsecutive if the size of the first load/store is
332   // the same as the offset.
333   int64_t Sz = DL->getTypeStoreSize(Ty);
334   return ((-Offset) == Sz);
335 }
336
337 Value *FuncSLP::getSinkBarrier(Instruction *Src, Instruction *Dst) {
338   assert(Src->getParent() == Dst->getParent() && "Not the same BB");
339   BasicBlock::iterator I = Src, E = Dst;
340   /// Scan all of the instruction from SRC to DST and check if
341   /// the source may alias.
342   for (++I; I != E; ++I) {
343     // Ignore store instructions that are marked as 'ignore'.
344     if (MemBarrierIgnoreList.count(I))
345       continue;
346     if (Src->mayWriteToMemory()) /* Write */ {
347       if (!I->mayReadOrWriteMemory())
348         continue;
349     } else /* Read */ {
350       if (!I->mayWriteToMemory())
351         continue;
352     }
353     AliasAnalysis::Location A = getLocation(&*I);
354     AliasAnalysis::Location B = getLocation(Src);
355
356     if (!A.Ptr || !B.Ptr || AA->alias(A, B))
357       return I;
358   }
359   return 0;
360 }
361
362 static BasicBlock *getSameBlock(ArrayRef<Value *> VL) {
363   BasicBlock *BB = 0;
364   for (int i = 0, e = VL.size(); i < e; i++) {
365     Instruction *I = dyn_cast<Instruction>(VL[i]);
366     if (!I)
367       return 0;
368
369     if (!BB) {
370       BB = I->getParent();
371       continue;
372     }
373
374     if (BB != I->getParent())
375       return 0;
376   }
377   return BB;
378 }
379
380 static bool allConstant(ArrayRef<Value *> VL) {
381   for (unsigned i = 0, e = VL.size(); i < e; ++i)
382     if (!isa<Constant>(VL[i]))
383       return false;
384   return true;
385 }
386
387 static bool isSplat(ArrayRef<Value *> VL) {
388   for (unsigned i = 1, e = VL.size(); i < e; ++i)
389     if (VL[i] != VL[0])
390       return false;
391   return true;
392 }
393
394 static unsigned getSameOpcode(ArrayRef<Value *> VL) {
395   unsigned Opcode = 0;
396   for (int i = 0, e = VL.size(); i < e; i++) {
397     if (Instruction *I = dyn_cast<Instruction>(VL[i])) {
398       if (!Opcode) {
399         Opcode = I->getOpcode();
400         continue;
401       }
402       if (Opcode != I->getOpcode())
403         return 0;
404     }
405   }
406   return Opcode;
407 }
408
409 static bool CanReuseExtract(ArrayRef<Value *> VL, unsigned VF,
410                             VectorType *VecTy) {
411   assert(Instruction::ExtractElement == getSameOpcode(VL) && "Invalid opcode");
412   // Check if all of the extracts come from the same vector and from the
413   // correct offset.
414   Value *VL0 = VL[0];
415   ExtractElementInst *E0 = cast<ExtractElementInst>(VL0);
416   Value *Vec = E0->getOperand(0);
417
418   // We have to extract from the same vector type.
419   if (Vec->getType() != VecTy)
420     return false;
421
422   // Check that all of the indices extract from the correct offset.
423   ConstantInt *CI = dyn_cast<ConstantInt>(E0->getOperand(1));
424   if (!CI || CI->getZExtValue())
425     return false;
426
427   for (unsigned i = 1, e = VF; i < e; ++i) {
428     ExtractElementInst *E = cast<ExtractElementInst>(VL[i]);
429     ConstantInt *CI = dyn_cast<ConstantInt>(E->getOperand(1));
430
431     if (!CI || CI->getZExtValue() != i || E->getOperand(0) != Vec)
432       return false;
433   }
434
435   return true;
436 }
437
438 void FuncSLP::getTreeUses_rec(ArrayRef<Value *> VL, unsigned Depth) {
439   if (Depth == RecursionMaxDepth)
440     return MustGather.insert(VL.begin(), VL.end());
441
442   // Don't handle vectors.
443   if (VL[0]->getType()->isVectorTy())
444     return;
445
446   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
447     if (SI->getValueOperand()->getType()->isVectorTy())
448       return;
449
450   // If all of the operands are identical or constant we have a simple solution.
451   if (allConstant(VL) || isSplat(VL) || !getSameBlock(VL))
452     return MustGather.insert(VL.begin(), VL.end());
453
454   // Stop the scan at unknown IR.
455   Instruction *VL0 = dyn_cast<Instruction>(VL[0]);
456   assert(VL0 && "Invalid instruction");
457
458   // Mark instructions with multiple users.
459   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
460     Instruction *I = dyn_cast<Instruction>(VL[i]);
461     // Remember to check if all of the users of this instruction are vectorized
462     // within our tree. At depth zero we have no local users, only external
463     // users that we don't care about.
464     if (Depth && I && I->getNumUses() > 1) {
465       DEBUG(dbgs() << "SLP: Adding to MultiUserVals "
466                       "because it has multiple users:" << *I << " \n");
467       MultiUserVals.insert(I);
468     }
469   }
470
471   // Check that the instruction is only used within one lane.
472   for (int i = 0, e = VL.size(); i < e; ++i) {
473     if (LaneMap.count(VL[i]) && LaneMap[VL[i]] != i) {
474       DEBUG(dbgs() << "SLP: Value used by multiple lanes:" << *VL[i] << "\n");
475       return MustGather.insert(VL.begin(), VL.end());
476     }
477     // Make this instruction as 'seen' and remember the lane.
478     LaneMap[VL[i]] = i;
479   }
480
481   unsigned Opcode = getSameOpcode(VL);
482   if (!Opcode)
483     return MustGather.insert(VL.begin(), VL.end());
484
485   switch (Opcode) {
486   case Instruction::ExtractElement: {
487     VectorType *VecTy = VectorType::get(VL[0]->getType(), VL.size());
488     // No need to follow ExtractElements that are going to be optimized away.
489     if (CanReuseExtract(VL, VL.size(), VecTy))
490       return;
491     // Fall through.
492   }
493   case Instruction::Load:
494     return;
495   case Instruction::ZExt:
496   case Instruction::SExt:
497   case Instruction::FPToUI:
498   case Instruction::FPToSI:
499   case Instruction::FPExt:
500   case Instruction::PtrToInt:
501   case Instruction::IntToPtr:
502   case Instruction::SIToFP:
503   case Instruction::UIToFP:
504   case Instruction::Trunc:
505   case Instruction::FPTrunc:
506   case Instruction::BitCast:
507   case Instruction::Select:
508   case Instruction::ICmp:
509   case Instruction::FCmp:
510   case Instruction::Add:
511   case Instruction::FAdd:
512   case Instruction::Sub:
513   case Instruction::FSub:
514   case Instruction::Mul:
515   case Instruction::FMul:
516   case Instruction::UDiv:
517   case Instruction::SDiv:
518   case Instruction::FDiv:
519   case Instruction::URem:
520   case Instruction::SRem:
521   case Instruction::FRem:
522   case Instruction::Shl:
523   case Instruction::LShr:
524   case Instruction::AShr:
525   case Instruction::And:
526   case Instruction::Or:
527   case Instruction::Xor: {
528     for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
529       ValueList Operands;
530       // Prepare the operand vector.
531       for (unsigned j = 0; j < VL.size(); ++j)
532         Operands.push_back(cast<Instruction>(VL[j])->getOperand(i));
533
534       getTreeUses_rec(Operands, Depth + 1);
535     }
536     return;
537   }
538   case Instruction::Store: {
539     ValueList Operands;
540     for (unsigned j = 0; j < VL.size(); ++j)
541       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
542     getTreeUses_rec(Operands, Depth + 1);
543     return;
544   }
545   default:
546     return MustGather.insert(VL.begin(), VL.end());
547   }
548 }
549
550 int FuncSLP::getLastIndex(ArrayRef<Value *> VL) {
551   BasicBlock *BB = cast<Instruction>(VL[0])->getParent();
552   assert(BB == getSameBlock(VL) && BlocksNumbers.count(BB) && "Invalid block");
553   BlockNumbering &BN = BlocksNumbers[BB];
554
555   int MaxIdx = BN.getIndex(BB->getFirstNonPHI());
556   for (unsigned i = 0, e = VL.size(); i < e; ++i)
557     MaxIdx = std::max(MaxIdx, BN.getIndex(cast<Instruction>(VL[i])));
558   return MaxIdx;
559 }
560
561 Instruction *FuncSLP::getLastInstruction(ArrayRef<Value *> VL) {
562   BasicBlock *BB = cast<Instruction>(VL[0])->getParent();
563   assert(BB == getSameBlock(VL) && BlocksNumbers.count(BB) && "Invalid block");
564   BlockNumbering &BN = BlocksNumbers[BB];
565
566   int MaxIdx = BN.getIndex(cast<Instruction>(VL[0]));
567   for (unsigned i = 1, e = VL.size(); i < e; ++i)
568     MaxIdx = std::max(MaxIdx, BN.getIndex(cast<Instruction>(VL[i])));
569   return BN.getInstruction(MaxIdx);
570 }
571
572 Instruction *FuncSLP::getInstructionForIndex(unsigned Index, BasicBlock *BB) {
573   BlockNumbering &BN = BlocksNumbers[BB];
574   return BN.getInstruction(Index);
575 }
576
577 int FuncSLP::getFirstUserIndex(ArrayRef<Value *> VL) {
578   BasicBlock *BB = getSameBlock(VL);
579   assert(BB && "All instructions must come from the same block");
580   BlockNumbering &BN = BlocksNumbers[BB];
581
582   // Find the first user of the values.
583   int FirstUser = BN.getIndex(BB->getTerminator());
584   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
585     for (Value::use_iterator U = VL[i]->use_begin(), UE = VL[i]->use_end();
586          U != UE; ++U) {
587       Instruction *Instr = dyn_cast<Instruction>(*U);
588
589       if (!Instr || Instr->getParent() != BB)
590         continue;
591
592       FirstUser = std::min(FirstUser, BN.getIndex(Instr));
593     }
594   }
595   return FirstUser;
596 }
597
598 int FuncSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
599   Type *ScalarTy = VL[0]->getType();
600
601   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
602     ScalarTy = SI->getValueOperand()->getType();
603
604   /// Don't mess with vectors.
605   if (ScalarTy->isVectorTy())
606     return FuncSLP::MAX_COST;
607
608   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
609
610   if (allConstant(VL))
611     return 0;
612
613   if (isSplat(VL))
614     return TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, 0);
615
616   if (Depth == RecursionMaxDepth || needToGatherAny(VL))
617     return getGatherCost(VecTy);
618
619   BasicBlock *BB = getSameBlock(VL);
620   unsigned Opcode = getSameOpcode(VL);
621   assert(Opcode && BB && "Invalid Instruction Value");
622
623   // Check if it is safe to sink the loads or the stores.
624   if (Opcode == Instruction::Load || Opcode == Instruction::Store) {
625     int MaxIdx = getLastIndex(VL);
626     Instruction *Last = getInstructionForIndex(MaxIdx, BB);
627
628     for (unsigned i = 0, e = VL.size(); i < e; ++i) {
629       if (VL[i] == Last)
630         continue;
631       Value *Barrier = getSinkBarrier(cast<Instruction>(VL[i]), Last);
632       if (Barrier) {
633         DEBUG(dbgs() << "SLP: Can't sink " << *VL[i] << "\n down to " << *Last
634                      << "\n because of " << *Barrier << "\n");
635         return MAX_COST;
636       }
637     }
638   }
639
640   Instruction *VL0 = cast<Instruction>(VL[0]);
641   switch (Opcode) {
642   case Instruction::ExtractElement: {
643     if (CanReuseExtract(VL, VL.size(), VecTy))
644       return 0;
645     return getGatherCost(VecTy);
646   }
647   case Instruction::ZExt:
648   case Instruction::SExt:
649   case Instruction::FPToUI:
650   case Instruction::FPToSI:
651   case Instruction::FPExt:
652   case Instruction::PtrToInt:
653   case Instruction::IntToPtr:
654   case Instruction::SIToFP:
655   case Instruction::UIToFP:
656   case Instruction::Trunc:
657   case Instruction::FPTrunc:
658   case Instruction::BitCast: {
659     ValueList Operands;
660     Type *SrcTy = VL0->getOperand(0)->getType();
661     // Prepare the operand vector.
662     for (unsigned j = 0; j < VL.size(); ++j) {
663       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
664       // Check that the casted type is the same for all users.
665       if (cast<Instruction>(VL[j])->getOperand(0)->getType() != SrcTy)
666         return getGatherCost(VecTy);
667     }
668
669     int Cost = getTreeCost_rec(Operands, Depth + 1);
670     if (Cost == FuncSLP::MAX_COST)
671       return Cost;
672
673     // Calculate the cost of this instruction.
674     int ScalarCost = VL.size() * TTI->getCastInstrCost(VL0->getOpcode(),
675                                                        VL0->getType(), SrcTy);
676
677     VectorType *SrcVecTy = VectorType::get(SrcTy, VL.size());
678     int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy);
679     Cost += (VecCost - ScalarCost);
680     return Cost;
681   }
682   case Instruction::FCmp:
683   case Instruction::ICmp: {
684     // Check that all of the compares have the same predicate.
685     CmpInst::Predicate P0 = dyn_cast<CmpInst>(VL0)->getPredicate();
686     for (unsigned i = 1, e = VL.size(); i < e; ++i) {
687       CmpInst *Cmp = cast<CmpInst>(VL[i]);
688       if (Cmp->getPredicate() != P0)
689         return getGatherCost(VecTy);
690     }
691     // Fall through.
692   }
693   case Instruction::Select:
694   case Instruction::Add:
695   case Instruction::FAdd:
696   case Instruction::Sub:
697   case Instruction::FSub:
698   case Instruction::Mul:
699   case Instruction::FMul:
700   case Instruction::UDiv:
701   case Instruction::SDiv:
702   case Instruction::FDiv:
703   case Instruction::URem:
704   case Instruction::SRem:
705   case Instruction::FRem:
706   case Instruction::Shl:
707   case Instruction::LShr:
708   case Instruction::AShr:
709   case Instruction::And:
710   case Instruction::Or:
711   case Instruction::Xor: {
712     int TotalCost = 0;
713     // Calculate the cost of all of the operands.
714     for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
715       ValueList Operands;
716       // Prepare the operand vector.
717       for (unsigned j = 0; j < VL.size(); ++j)
718         Operands.push_back(cast<Instruction>(VL[j])->getOperand(i));
719
720       int Cost = getTreeCost_rec(Operands, Depth + 1);
721       if (Cost == MAX_COST)
722         return MAX_COST;
723       TotalCost += TotalCost;
724     }
725
726     // Calculate the cost of this instruction.
727     int ScalarCost = 0;
728     int VecCost = 0;
729     if (Opcode == Instruction::FCmp || Opcode == Instruction::ICmp ||
730         Opcode == Instruction::Select) {
731       VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size());
732       ScalarCost =
733           VecTy->getNumElements() *
734           TTI->getCmpSelInstrCost(Opcode, ScalarTy, Builder.getInt1Ty());
735       VecCost = TTI->getCmpSelInstrCost(Opcode, VecTy, MaskTy);
736     } else {
737       ScalarCost = VecTy->getNumElements() *
738                    TTI->getArithmeticInstrCost(Opcode, ScalarTy);
739       VecCost = TTI->getArithmeticInstrCost(Opcode, VecTy);
740     }
741     TotalCost += (VecCost - ScalarCost);
742     return TotalCost;
743   }
744   case Instruction::Load: {
745     // If we are scalarize the loads, add the cost of forming the vector.
746     for (unsigned i = 0, e = VL.size() - 1; i < e; ++i)
747       if (!isConsecutiveAccess(VL[i], VL[i + 1]))
748         return getGatherCost(VecTy);
749
750     // Cost of wide load - cost of scalar loads.
751     int ScalarLdCost = VecTy->getNumElements() *
752                        TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0);
753     int VecLdCost = TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0);
754     return VecLdCost - ScalarLdCost;
755   }
756   case Instruction::Store: {
757     // We know that we can merge the stores. Calculate the cost.
758     int ScalarStCost = VecTy->getNumElements() *
759                        TTI->getMemoryOpCost(Instruction::Store, ScalarTy, 1, 0);
760     int VecStCost = TTI->getMemoryOpCost(Instruction::Store, ScalarTy, 1, 0);
761     int StoreCost = VecStCost - ScalarStCost;
762
763     ValueList Operands;
764     for (unsigned j = 0; j < VL.size(); ++j) {
765       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
766       MemBarrierIgnoreList.insert(VL[j]);
767     }
768
769     int Cost = getTreeCost_rec(Operands, Depth + 1);
770     if (Cost == MAX_COST)
771       return MAX_COST;
772
773     int TotalCost = StoreCost + Cost;
774     return TotalCost;
775   }
776   default:
777     // Unable to vectorize unknown instructions.
778     return getGatherCost(VecTy);
779   }
780 }
781
782 int FuncSLP::getTreeCost(ArrayRef<Value *> VL) {
783   // Get rid of the list of stores that were removed, and from the
784   // lists of instructions with multiple users.
785   MemBarrierIgnoreList.clear();
786   LaneMap.clear();
787   MultiUserVals.clear();
788   MustGather.clear();
789
790   if (!getSameBlock(VL))
791     return MAX_COST;
792
793   // Find the location of the last root.
794   int LastRootIndex = getLastIndex(VL);
795   int FirstUserIndex = getFirstUserIndex(VL);
796
797   // Don't vectorize if there are users of the tree roots inside the tree
798   // itself.
799   if (LastRootIndex > FirstUserIndex)
800     return MAX_COST;
801
802   // Scan the tree and find which value is used by which lane, and which values
803   // must be scalarized.
804   getTreeUses_rec(VL, 0);
805
806   // Check that instructions with multiple users can be vectorized. Mark unsafe
807   // instructions.
808   for (SetVector<Value *>::iterator it = MultiUserVals.begin(),
809                                     e = MultiUserVals.end();
810        it != e; ++it) {
811     // Check that all of the users of this instr are within the tree.
812     for (Value::use_iterator I = (*it)->use_begin(), E = (*it)->use_end();
813          I != E; ++I) {
814       if (LaneMap.find(*I) == LaneMap.end()) {
815         DEBUG(dbgs() << "SLP: Adding to MustExtract "
816                         "because of an out of tree usage.\n");
817         MustGather.insert(*it);
818         continue;
819       }
820     }
821   }
822
823   // Now calculate the cost of vectorizing the tree.
824   return getTreeCost_rec(VL, 0);
825 }
826 bool FuncSLP::vectorizeStoreChain(ArrayRef<Value *> Chain, int CostThreshold) {
827   unsigned ChainLen = Chain.size();
828   DEBUG(dbgs() << "SLP: Analyzing a store chain of length " << ChainLen
829                << "\n");
830   Type *StoreTy = cast<StoreInst>(Chain[0])->getValueOperand()->getType();
831   unsigned Sz = DL->getTypeSizeInBits(StoreTy);
832   unsigned VF = MinVecRegSize / Sz;
833
834   if (!isPowerOf2_32(Sz) || VF < 2)
835     return false;
836
837   bool Changed = false;
838   // Look for profitable vectorizable trees at all offsets, starting at zero.
839   for (unsigned i = 0, e = ChainLen; i < e; ++i) {
840     if (i + VF > e)
841       break;
842     DEBUG(dbgs() << "SLP: Analyzing " << VF << " stores at offset " << i
843                  << "\n");
844     ArrayRef<Value *> Operands = Chain.slice(i, VF);
845
846     int Cost = getTreeCost(Operands);
847     if (Cost == FuncSLP::MAX_COST)
848       continue;
849     DEBUG(dbgs() << "SLP: Found cost=" << Cost << " for VF=" << VF << "\n");
850     if (Cost < CostThreshold) {
851       DEBUG(dbgs() << "SLP: Decided to vectorize cost=" << Cost << "\n");
852       vectorizeTree(Operands);
853
854       // Remove the scalar stores.
855       for (int i = 0, e = VF; i < e; ++i)
856         cast<Instruction>(Operands[i])->eraseFromParent();
857
858       // Move to the next bundle.
859       i += VF - 1;
860       Changed = true;
861     }
862   }
863
864   if (Changed || ChainLen > VF)
865     return Changed;
866
867   // Handle short chains. This helps us catch types such as <3 x float> that
868   // are smaller than vector size.
869   int Cost = getTreeCost(Chain);
870   if (Cost == FuncSLP::MAX_COST)
871     return false;
872   if (Cost < CostThreshold) {
873     DEBUG(dbgs() << "SLP: Found store chain cost = " << Cost
874                  << " for size = " << ChainLen << "\n");
875     vectorizeTree(Chain);
876
877     // Remove all of the scalar stores.
878     for (int i = 0, e = Chain.size(); i < e; ++i)
879       cast<Instruction>(Chain[i])->eraseFromParent();
880
881     return true;
882   }
883
884   return false;
885 }
886
887 bool FuncSLP::vectorizeStores(ArrayRef<StoreInst *> Stores, int costThreshold) {
888   SetVector<Value *> Heads, Tails;
889   SmallDenseMap<Value *, Value *> ConsecutiveChain;
890
891   // We may run into multiple chains that merge into a single chain. We mark the
892   // stores that we vectorized so that we don't visit the same store twice.
893   ValueSet VectorizedStores;
894   bool Changed = false;
895
896   // Do a quadratic search on all of the given stores and find
897   // all of the pairs of loads that follow each other.
898   for (unsigned i = 0, e = Stores.size(); i < e; ++i)
899     for (unsigned j = 0; j < e; ++j) {
900       if (i == j)
901         continue;
902
903       if (isConsecutiveAccess(Stores[i], Stores[j])) {
904         Tails.insert(Stores[j]);
905         Heads.insert(Stores[i]);
906         ConsecutiveChain[Stores[i]] = Stores[j];
907       }
908     }
909
910   // For stores that start but don't end a link in the chain:
911   for (SetVector<Value *>::iterator it = Heads.begin(), e = Heads.end();
912        it != e; ++it) {
913     if (Tails.count(*it))
914       continue;
915
916     // We found a store instr that starts a chain. Now follow the chain and try
917     // to vectorize it.
918     ValueList Operands;
919     Value *I = *it;
920     // Collect the chain into a list.
921     while (Tails.count(I) || Heads.count(I)) {
922       if (VectorizedStores.count(I))
923         break;
924       Operands.push_back(I);
925       // Move to the next value in the chain.
926       I = ConsecutiveChain[I];
927     }
928
929     bool Vectorized = vectorizeStoreChain(Operands, costThreshold);
930
931     // Mark the vectorized stores so that we don't vectorize them again.
932     if (Vectorized)
933       VectorizedStores.insert(Operands.begin(), Operands.end());
934     Changed |= Vectorized;
935   }
936
937   return Changed;
938 }
939
940 Value *FuncSLP::Gather(ArrayRef<Value *> VL, VectorType *Ty) {
941   Value *Vec = UndefValue::get(Ty);
942   // Generate the 'InsertElement' instruction.
943   for (unsigned i = 0; i < Ty->getNumElements(); ++i) {
944     Vec = Builder.CreateInsertElement(Vec, VL[i], Builder.getInt32(i));
945     if (Instruction *I = dyn_cast<Instruction>(Vec))
946       GatherSeq.insert(I);
947   }
948
949   return Vec;
950 }
951
952 Value *FuncSLP::vectorizeTree_rec(ArrayRef<Value *> VL) {
953   BuilderLocGuard Guard(Builder);
954
955   Type *ScalarTy = VL[0]->getType();
956   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
957     ScalarTy = SI->getValueOperand()->getType();
958   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
959
960   if (needToGatherAny(VL))
961     return Gather(VL, VecTy);
962
963   if (VectorizedValues.count(VL[0])) {
964     DEBUG(dbgs() << "SLP: Diamond merged at depth.\n");
965     return VectorizedValues[VL[0]];
966   }
967
968   Instruction *VL0 = cast<Instruction>(VL[0]);
969   unsigned Opcode = VL0->getOpcode();
970   assert(Opcode == getSameOpcode(VL) && "Invalid opcode");
971
972   switch (Opcode) {
973   case Instruction::ExtractElement: {
974     if (CanReuseExtract(VL, VL.size(), VecTy))
975       return VL0->getOperand(0);
976     return Gather(VL, VecTy);
977   }
978   case Instruction::ZExt:
979   case Instruction::SExt:
980   case Instruction::FPToUI:
981   case Instruction::FPToSI:
982   case Instruction::FPExt:
983   case Instruction::PtrToInt:
984   case Instruction::IntToPtr:
985   case Instruction::SIToFP:
986   case Instruction::UIToFP:
987   case Instruction::Trunc:
988   case Instruction::FPTrunc:
989   case Instruction::BitCast: {
990     ValueList INVL;
991     for (int i = 0, e = VL.size(); i < e; ++i)
992       INVL.push_back(cast<Instruction>(VL[i])->getOperand(0));
993
994     Builder.SetInsertPoint(getLastInstruction(VL));
995     Value *InVec = vectorizeTree_rec(INVL);
996     CastInst *CI = dyn_cast<CastInst>(VL0);
997     Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy);
998     VectorizedValues[VL0] = V;
999     return V;
1000   }
1001   case Instruction::FCmp:
1002   case Instruction::ICmp: {
1003     // Check that all of the compares have the same predicate.
1004     CmpInst::Predicate P0 = dyn_cast<CmpInst>(VL0)->getPredicate();
1005     for (unsigned i = 1, e = VL.size(); i < e; ++i) {
1006       CmpInst *Cmp = cast<CmpInst>(VL[i]);
1007       if (Cmp->getPredicate() != P0)
1008         return Gather(VL, VecTy);
1009     }
1010
1011     ValueList LHSV, RHSV;
1012     for (int i = 0, e = VL.size(); i < e; ++i) {
1013       LHSV.push_back(cast<Instruction>(VL[i])->getOperand(0));
1014       RHSV.push_back(cast<Instruction>(VL[i])->getOperand(1));
1015     }
1016
1017     Builder.SetInsertPoint(getLastInstruction(VL));
1018     Value *L = vectorizeTree_rec(LHSV);
1019     Value *R = vectorizeTree_rec(RHSV);
1020     Value *V;
1021
1022     if (Opcode == Instruction::FCmp)
1023       V = Builder.CreateFCmp(P0, L, R);
1024     else
1025       V = Builder.CreateICmp(P0, L, R);
1026
1027     VectorizedValues[VL0] = V;
1028     return V;
1029   }
1030   case Instruction::Select: {
1031     ValueList TrueVec, FalseVec, CondVec;
1032     for (int i = 0, e = VL.size(); i < e; ++i) {
1033       CondVec.push_back(cast<Instruction>(VL[i])->getOperand(0));
1034       TrueVec.push_back(cast<Instruction>(VL[i])->getOperand(1));
1035       FalseVec.push_back(cast<Instruction>(VL[i])->getOperand(2));
1036     }
1037
1038     Builder.SetInsertPoint(getLastInstruction(VL));
1039     Value *True = vectorizeTree_rec(TrueVec);
1040     Value *False = vectorizeTree_rec(FalseVec);
1041     Value *Cond = vectorizeTree_rec(CondVec);
1042     Value *V = Builder.CreateSelect(Cond, True, False);
1043     VectorizedValues[VL0] = V;
1044     return V;
1045   }
1046   case Instruction::Add:
1047   case Instruction::FAdd:
1048   case Instruction::Sub:
1049   case Instruction::FSub:
1050   case Instruction::Mul:
1051   case Instruction::FMul:
1052   case Instruction::UDiv:
1053   case Instruction::SDiv:
1054   case Instruction::FDiv:
1055   case Instruction::URem:
1056   case Instruction::SRem:
1057   case Instruction::FRem:
1058   case Instruction::Shl:
1059   case Instruction::LShr:
1060   case Instruction::AShr:
1061   case Instruction::And:
1062   case Instruction::Or:
1063   case Instruction::Xor: {
1064     ValueList LHSVL, RHSVL;
1065     for (int i = 0, e = VL.size(); i < e; ++i) {
1066       LHSVL.push_back(cast<Instruction>(VL[i])->getOperand(0));
1067       RHSVL.push_back(cast<Instruction>(VL[i])->getOperand(1));
1068     }
1069
1070     Builder.SetInsertPoint(getLastInstruction(VL));
1071     Value *LHS = vectorizeTree_rec(LHSVL);
1072     Value *RHS = vectorizeTree_rec(RHSVL);
1073
1074     if (LHS == RHS) {
1075       assert((VL0->getOperand(0) == VL0->getOperand(1)) && "Invalid order");
1076     }
1077
1078     BinaryOperator *BinOp = cast<BinaryOperator>(VL0);
1079     Value *V = Builder.CreateBinOp(BinOp->getOpcode(), LHS, RHS);
1080     VectorizedValues[VL0] = V;
1081     return V;
1082   }
1083   case Instruction::Load: {
1084     // Check if all of the loads are consecutive.
1085     for (unsigned i = 1, e = VL.size(); i < e; ++i)
1086       if (!isConsecutiveAccess(VL[i - 1], VL[i]))
1087         return Gather(VL, VecTy);
1088
1089     // Loads are inserted at the head of the tree because we don't want to
1090     // sink them all the way down past store instructions.
1091     Builder.SetInsertPoint(getLastInstruction(VL));
1092     LoadInst *LI = cast<LoadInst>(VL0);
1093     Value *VecPtr =
1094         Builder.CreateBitCast(LI->getPointerOperand(), VecTy->getPointerTo());
1095     unsigned Alignment = LI->getAlignment();
1096     LI = Builder.CreateLoad(VecPtr);
1097     LI->setAlignment(Alignment);
1098
1099     VectorizedValues[VL0] = LI;
1100     return LI;
1101   }
1102   case Instruction::Store: {
1103     StoreInst *SI = cast<StoreInst>(VL0);
1104     unsigned Alignment = SI->getAlignment();
1105
1106     ValueList ValueOp;
1107     for (int i = 0, e = VL.size(); i < e; ++i)
1108       ValueOp.push_back(cast<StoreInst>(VL[i])->getValueOperand());
1109
1110     Value *VecValue = vectorizeTree_rec(ValueOp);
1111
1112     Builder.SetInsertPoint(getLastInstruction(VL));
1113     Value *VecPtr =
1114         Builder.CreateBitCast(SI->getPointerOperand(), VecTy->getPointerTo());
1115     Builder.CreateStore(VecValue, VecPtr)->setAlignment(Alignment);
1116     return 0;
1117   }
1118   default:
1119     return Gather(VL, VecTy);
1120   }
1121 }
1122
1123 Value *FuncSLP::vectorizeTree(ArrayRef<Value *> VL) {
1124   Builder.SetInsertPoint(getLastInstruction(VL));
1125   Value *V = vectorizeTree_rec(VL);
1126
1127   // We moved some instructions around. We have to number them again
1128   // before we can do any analysis.
1129   for (Function::iterator it = F->begin(), e = F->end(); it != e; ++it)
1130     BlocksNumbers[it].forget();
1131   // Clear the state.
1132   MustGather.clear();
1133   VectorizedValues.clear();
1134   MemBarrierIgnoreList.clear();
1135   return V;
1136 }
1137
1138 Value *FuncSLP::vectorizeArith(ArrayRef<Value *> Operands) {
1139   Value *Vec = vectorizeTree(Operands);
1140   // After vectorizing the operands we need to generate extractelement
1141   // instructions and replace all of the uses of the scalar values with
1142   // the values that we extracted from the vectorized tree.
1143   for (unsigned i = 0, e = Operands.size(); i != e; ++i) {
1144     Value *S = Builder.CreateExtractElement(Vec, Builder.getInt32(i));
1145     Operands[i]->replaceAllUsesWith(S);
1146   }
1147
1148   return Vec;
1149 }
1150
1151 void FuncSLP::optimizeGatherSequence() {
1152   // LICM InsertElementInst sequences.
1153   for (SetVector<Instruction *>::iterator it = GatherSeq.begin(),
1154        e = GatherSeq.end(); it != e; ++it) {
1155     InsertElementInst *Insert = dyn_cast<InsertElementInst>(*it);
1156
1157     if (!Insert)
1158       continue;
1159
1160     // Check if this block is inside a loop.
1161     Loop *L = LI->getLoopFor(Insert->getParent());
1162     if (!L)
1163       continue;
1164
1165     // Check if it has a preheader.
1166     BasicBlock *PreHeader = L->getLoopPreheader();
1167     if (!PreHeader)
1168       return;
1169
1170     // If the vector or the element that we insert into it are
1171     // instructions that are defined in this basic block then we can't
1172     // hoist this instruction.
1173     Instruction *CurrVec = dyn_cast<Instruction>(Insert->getOperand(0));
1174     Instruction *NewElem = dyn_cast<Instruction>(Insert->getOperand(1));
1175     if (CurrVec && L->contains(CurrVec))
1176       continue;
1177     if (NewElem && L->contains(NewElem))
1178       continue;
1179
1180     // We can hoist this instruction. Move it to the pre-header.
1181     Insert->moveBefore(PreHeader->getTerminator());
1182   }
1183
1184   // Perform O(N^2) search over the gather sequences and merge identical
1185   // instructions. TODO: We can further optimize this scan if we split the
1186   // instructions into different buckets based on the insert lane.
1187   SmallPtrSet<Instruction*, 16> Visited;
1188   ReversePostOrderTraversal<Function*> RPOT(F);
1189   for (ReversePostOrderTraversal<Function*>::rpo_iterator I = RPOT.begin(),
1190        E = RPOT.end(); I != E; ++I) {
1191     BasicBlock *BB = *I;
1192     // For all instructions in the function:
1193     for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
1194       InsertElementInst *Insert = dyn_cast<InsertElementInst>(it);
1195       if (!Insert || !GatherSeq.count(Insert))
1196         continue;
1197
1198      // Check if we can replace this instruction with any of the
1199      // visited instructions.
1200       for (SmallPtrSet<Instruction*, 16>::iterator v = Visited.begin(),
1201            ve = Visited.end(); v != ve; ++v) {
1202         if (Insert->isIdenticalTo(*v) &&
1203           DT->dominates((*v)->getParent(), Insert->getParent())) {
1204           Insert->replaceAllUsesWith(*v);
1205           break;
1206         }
1207       }
1208       Visited.insert(Insert);
1209     }
1210   }
1211 }
1212
1213 /// The SLPVectorizer Pass.
1214 struct SLPVectorizer : public FunctionPass {
1215   typedef SmallVector<StoreInst *, 8> StoreList;
1216   typedef MapVector<Value *, StoreList> StoreListMap;
1217
1218   /// Pass identification, replacement for typeid
1219   static char ID;
1220
1221   explicit SLPVectorizer() : FunctionPass(ID) {
1222     initializeSLPVectorizerPass(*PassRegistry::getPassRegistry());
1223   }
1224
1225   ScalarEvolution *SE;
1226   DataLayout *DL;
1227   TargetTransformInfo *TTI;
1228   AliasAnalysis *AA;
1229   LoopInfo *LI;
1230   DominatorTree *DT;
1231
1232   virtual bool runOnFunction(Function &F) {
1233     SE = &getAnalysis<ScalarEvolution>();
1234     DL = getAnalysisIfAvailable<DataLayout>();
1235     TTI = &getAnalysis<TargetTransformInfo>();
1236     AA = &getAnalysis<AliasAnalysis>();
1237     LI = &getAnalysis<LoopInfo>();
1238     DT = &getAnalysis<DominatorTree>();
1239
1240     StoreRefs.clear();
1241     bool Changed = false;
1242
1243     // Must have DataLayout. We can't require it because some tests run w/o
1244     // triple.
1245     if (!DL)
1246       return false;
1247
1248     DEBUG(dbgs() << "SLP: Analyzing blocks in " << F.getName() << ".\n");
1249
1250     // Use the bollom up slp vectorizer to construct chains that start with
1251     // he store instructions.
1252     FuncSLP R(&F, SE, DL, TTI, AA, LI, DT);
1253
1254     for (Function::iterator it = F.begin(), e = F.end(); it != e; ++it) {
1255       BasicBlock *BB = it;
1256
1257       // Vectorize trees that end at reductions.
1258       Changed |= vectorizeChainsInBlock(BB, R);
1259
1260       // Vectorize trees that end at stores.
1261       if (unsigned count = collectStores(BB, R)) {
1262         (void)count;
1263         DEBUG(dbgs() << "SLP: Found " << count << " stores to vectorize.\n");
1264         Changed |= vectorizeStoreChains(R);
1265       }
1266     }
1267
1268     if (Changed) {
1269       R.optimizeGatherSequence();
1270       DEBUG(dbgs() << "SLP: vectorized \"" << F.getName() << "\"\n");
1271       DEBUG(verifyFunction(F));
1272     }
1273     return Changed;
1274   }
1275
1276   virtual void getAnalysisUsage(AnalysisUsage &AU) const {
1277     FunctionPass::getAnalysisUsage(AU);
1278     AU.addRequired<ScalarEvolution>();
1279     AU.addRequired<AliasAnalysis>();
1280     AU.addRequired<TargetTransformInfo>();
1281     AU.addRequired<LoopInfo>();
1282     AU.addRequired<DominatorTree>();
1283   }
1284
1285 private:
1286
1287   /// \brief Collect memory references and sort them according to their base
1288   /// object. We sort the stores to their base objects to reduce the cost of the
1289   /// quadratic search on the stores. TODO: We can further reduce this cost
1290   /// if we flush the chain creation every time we run into a memory barrier.
1291   unsigned collectStores(BasicBlock *BB, FuncSLP &R);
1292
1293   /// \brief Try to vectorize a chain that starts at two arithmetic instrs.
1294   bool tryToVectorizePair(Value *A, Value *B, FuncSLP &R);
1295
1296   /// \brief Try to vectorize a list of operands. If \p NeedExtracts is true
1297   /// then we calculate the cost of extracting the scalars from the vector.
1298   /// \returns true if a value was vectorized.
1299   bool tryToVectorizeList(ArrayRef<Value *> VL, FuncSLP &R, bool NeedExtracts);
1300
1301   /// \brief Try to vectorize a chain that may start at the operands of \V;
1302   bool tryToVectorize(BinaryOperator *V, FuncSLP &R);
1303
1304   /// \brief Vectorize the stores that were collected in StoreRefs.
1305   bool vectorizeStoreChains(FuncSLP &R);
1306
1307   /// \brief Scan the basic block and look for patterns that are likely to start
1308   /// a vectorization chain.
1309   bool vectorizeChainsInBlock(BasicBlock *BB, FuncSLP &R);
1310
1311 private:
1312   StoreListMap StoreRefs;
1313 };
1314
1315 unsigned SLPVectorizer::collectStores(BasicBlock *BB, FuncSLP &R) {
1316   unsigned count = 0;
1317   StoreRefs.clear();
1318   for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
1319     StoreInst *SI = dyn_cast<StoreInst>(it);
1320     if (!SI)
1321       continue;
1322
1323     // Check that the pointer points to scalars.
1324     Type *Ty = SI->getValueOperand()->getType();
1325     if (Ty->isAggregateType() || Ty->isVectorTy())
1326       return 0;
1327
1328     // Find the base of the GEP.
1329     Value *Ptr = SI->getPointerOperand();
1330     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr))
1331       Ptr = GEP->getPointerOperand();
1332
1333     // Save the store locations.
1334     StoreRefs[Ptr].push_back(SI);
1335     count++;
1336   }
1337   return count;
1338 }
1339
1340 bool SLPVectorizer::tryToVectorizePair(Value *A, Value *B, FuncSLP &R) {
1341   if (!A || !B)
1342     return false;
1343   Value *VL[] = { A, B };
1344   return tryToVectorizeList(VL, R, true);
1345 }
1346
1347 bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, FuncSLP &R,
1348                                        bool NeedExtracts) {
1349   if (VL.size() < 2)
1350     return false;
1351
1352   DEBUG(dbgs() << "SLP: Vectorizing a list of length = " << VL.size() << ".\n");
1353
1354   // Check that all of the parts are scalar instructions of the same type.
1355   Instruction *I0 = dyn_cast<Instruction>(VL[0]);
1356   if (!I0)
1357     return 0;
1358
1359   unsigned Opcode0 = I0->getOpcode();
1360
1361   for (int i = 0, e = VL.size(); i < e; ++i) {
1362     Type *Ty = VL[i]->getType();
1363     if (Ty->isAggregateType() || Ty->isVectorTy())
1364       return 0;
1365     Instruction *Inst = dyn_cast<Instruction>(VL[i]);
1366     if (!Inst || Inst->getOpcode() != Opcode0)
1367       return 0;
1368   }
1369
1370   int Cost = R.getTreeCost(VL);
1371   if (Cost == FuncSLP::MAX_COST)
1372     return false;
1373
1374   int ExtrCost = NeedExtracts ? R.getGatherCost(VL) : 0;
1375   DEBUG(dbgs() << "SLP: Cost of pair:" << Cost
1376                << " Cost of extract:" << ExtrCost << ".\n");
1377   if ((Cost + ExtrCost) >= -SLPCostThreshold)
1378     return false;
1379   DEBUG(dbgs() << "SLP: Vectorizing pair.\n");
1380   R.vectorizeArith(VL);
1381   return true;
1382 }
1383
1384 bool SLPVectorizer::tryToVectorize(BinaryOperator *V, FuncSLP &R) {
1385   if (!V)
1386     return false;
1387
1388   // Try to vectorize V.
1389   if (tryToVectorizePair(V->getOperand(0), V->getOperand(1), R))
1390     return true;
1391
1392   BinaryOperator *A = dyn_cast<BinaryOperator>(V->getOperand(0));
1393   BinaryOperator *B = dyn_cast<BinaryOperator>(V->getOperand(1));
1394   // Try to skip B.
1395   if (B && B->hasOneUse()) {
1396     BinaryOperator *B0 = dyn_cast<BinaryOperator>(B->getOperand(0));
1397     BinaryOperator *B1 = dyn_cast<BinaryOperator>(B->getOperand(1));
1398     if (tryToVectorizePair(A, B0, R)) {
1399       B->moveBefore(V);
1400       return true;
1401     }
1402     if (tryToVectorizePair(A, B1, R)) {
1403       B->moveBefore(V);
1404       return true;
1405     }
1406   }
1407
1408   // Try to skip A.
1409   if (A && A->hasOneUse()) {
1410     BinaryOperator *A0 = dyn_cast<BinaryOperator>(A->getOperand(0));
1411     BinaryOperator *A1 = dyn_cast<BinaryOperator>(A->getOperand(1));
1412     if (tryToVectorizePair(A0, B, R)) {
1413       A->moveBefore(V);
1414       return true;
1415     }
1416     if (tryToVectorizePair(A1, B, R)) {
1417       A->moveBefore(V);
1418       return true;
1419     }
1420   }
1421   return 0;
1422 }
1423
1424 bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, FuncSLP &R) {
1425   bool Changed = false;
1426   for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
1427     if (isa<DbgInfoIntrinsic>(it))
1428       continue;
1429
1430     // Try to vectorize reductions that use PHINodes.
1431     if (PHINode *P = dyn_cast<PHINode>(it)) {
1432       // Check that the PHI is a reduction PHI.
1433       if (P->getNumIncomingValues() != 2)
1434         return Changed;
1435       Value *Rdx =
1436           (P->getIncomingBlock(0) == BB
1437                ? (P->getIncomingValue(0))
1438                : (P->getIncomingBlock(1) == BB ? P->getIncomingValue(1) : 0));
1439       // Check if this is a Binary Operator.
1440       BinaryOperator *BI = dyn_cast_or_null<BinaryOperator>(Rdx);
1441       if (!BI)
1442         continue;
1443
1444       Value *Inst = BI->getOperand(0);
1445       if (Inst == P)
1446         Inst = BI->getOperand(1);
1447
1448       Changed |= tryToVectorize(dyn_cast<BinaryOperator>(Inst), R);
1449       continue;
1450     }
1451
1452     // Try to vectorize trees that start at compare instructions.
1453     if (CmpInst *CI = dyn_cast<CmpInst>(it)) {
1454       if (tryToVectorizePair(CI->getOperand(0), CI->getOperand(1), R)) {
1455         Changed |= true;
1456         continue;
1457       }
1458       for (int i = 0; i < 2; ++i)
1459         if (BinaryOperator *BI = dyn_cast<BinaryOperator>(CI->getOperand(i)))
1460           Changed |=
1461               tryToVectorizePair(BI->getOperand(0), BI->getOperand(1), R);
1462       continue;
1463     }
1464   }
1465
1466   // Scan the PHINodes in our successors in search for pairing hints.
1467   for (succ_iterator it = succ_begin(BB), e = succ_end(BB); it != e; ++it) {
1468     BasicBlock *Succ = *it;
1469     SmallVector<Value *, 4> Incoming;
1470
1471     // Collect the incoming values from the PHIs.
1472     for (BasicBlock::iterator instr = Succ->begin(), ie = Succ->end();
1473          instr != ie; ++instr) {
1474       PHINode *P = dyn_cast<PHINode>(instr);
1475
1476       if (!P)
1477         break;
1478
1479       Value *V = P->getIncomingValueForBlock(BB);
1480       if (Instruction *I = dyn_cast<Instruction>(V))
1481         if (I->getParent() == BB)
1482           Incoming.push_back(I);
1483     }
1484
1485     if (Incoming.size() > 1)
1486       Changed |= tryToVectorizeList(Incoming, R, true);
1487   }
1488
1489   return Changed;
1490 }
1491
1492 bool SLPVectorizer::vectorizeStoreChains(FuncSLP &R) {
1493   bool Changed = false;
1494   // Attempt to sort and vectorize each of the store-groups.
1495   for (StoreListMap::iterator it = StoreRefs.begin(), e = StoreRefs.end();
1496        it != e; ++it) {
1497     if (it->second.size() < 2)
1498       continue;
1499
1500     DEBUG(dbgs() << "SLP: Analyzing a store chain of length "
1501                  << it->second.size() << ".\n");
1502
1503     Changed |= R.vectorizeStores(it->second, -SLPCostThreshold);
1504   }
1505   return Changed;
1506 }
1507
1508 } // end anonymous namespace
1509
1510 char SLPVectorizer::ID = 0;
1511 static const char lv_name[] = "SLP Vectorizer";
1512 INITIALIZE_PASS_BEGIN(SLPVectorizer, SV_NAME, lv_name, false, false)
1513 INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
1514 INITIALIZE_AG_DEPENDENCY(TargetTransformInfo)
1515 INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
1516 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
1517 INITIALIZE_PASS_END(SLPVectorizer, SV_NAME, lv_name, false, false)
1518
1519 namespace llvm {
1520 Pass *createSLPVectorizerPass() { return new SLPVectorizer(); }
1521 }