Clang-format the SLP vectorizer. No functionality change.
[oota-llvm.git] / lib / Transforms / Vectorize / VecUtils.cpp
1 //===- VecUtils.cpp --- Vectorization Utilities ---------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 #define DEBUG_TYPE "SLP"
10
11 #include "VecUtils.h"
12 #include "llvm/ADT/DenseMap.h"
13 #include "llvm/ADT/SmallPtrSet.h"
14 #include "llvm/ADT/SmallSet.h"
15 #include "llvm/ADT/SmallVector.h"
16 #include "llvm/Analysis/AliasAnalysis.h"
17 #include "llvm/Analysis/ScalarEvolution.h"
18 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/Analysis/Verifier.h"
21 #include "llvm/Analysis/LoopInfo.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/DataLayout.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/Instructions.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/IR/Type.h"
28 #include "llvm/IR/Value.h"
29 #include "llvm/Pass.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "llvm/Target/TargetLibraryInfo.h"
34 #include "llvm/Transforms/Scalar.h"
35 #include "llvm/Transforms/Utils/Local.h"
36 #include <algorithm>
37 #include <map>
38
39 using namespace llvm;
40
41 static const unsigned MinVecRegSize = 128;
42
43 static const unsigned RecursionMaxDepth = 6;
44
45 namespace llvm {
46
47 BoUpSLP::BoUpSLP(BasicBlock *Bb, ScalarEvolution *S, DataLayout *Dl,
48                  TargetTransformInfo *Tti, AliasAnalysis *Aa, Loop *Lp)
49     : Builder(S->getContext()), BB(Bb), SE(S), DL(Dl), TTI(Tti), AA(Aa), L(Lp) {
50   numberInstructions();
51 }
52
53 void BoUpSLP::numberInstructions() {
54   int Loc = 0;
55   InstrIdx.clear();
56   InstrVec.clear();
57   // Number the instructions in the block.
58   for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
59     InstrIdx[it] = Loc++;
60     InstrVec.push_back(it);
61     assert(InstrVec[InstrIdx[it]] == it && "Invalid allocation");
62   }
63 }
64
65 Value *BoUpSLP::getPointerOperand(Value *I) {
66   if (LoadInst *LI = dyn_cast<LoadInst>(I))
67     return LI->getPointerOperand();
68   if (StoreInst *SI = dyn_cast<StoreInst>(I))
69     return SI->getPointerOperand();
70   return 0;
71 }
72
73 unsigned BoUpSLP::getAddressSpaceOperand(Value *I) {
74   if (LoadInst *L = dyn_cast<LoadInst>(I))
75     return L->getPointerAddressSpace();
76   if (StoreInst *S = dyn_cast<StoreInst>(I))
77     return S->getPointerAddressSpace();
78   return -1;
79 }
80
81 bool BoUpSLP::isConsecutiveAccess(Value *A, Value *B) {
82   Value *PtrA = getPointerOperand(A);
83   Value *PtrB = getPointerOperand(B);
84   unsigned ASA = getAddressSpaceOperand(A);
85   unsigned ASB = getAddressSpaceOperand(B);
86
87   // Check that the address spaces match and that the pointers are valid.
88   if (!PtrA || !PtrB || (ASA != ASB))
89     return false;
90
91   // Check that A and B are of the same type.
92   if (PtrA->getType() != PtrB->getType())
93     return false;
94
95   // Calculate the distance.
96   const SCEV *PtrSCEVA = SE->getSCEV(PtrA);
97   const SCEV *PtrSCEVB = SE->getSCEV(PtrB);
98   const SCEV *OffsetSCEV = SE->getMinusSCEV(PtrSCEVA, PtrSCEVB);
99   const SCEVConstant *ConstOffSCEV = dyn_cast<SCEVConstant>(OffsetSCEV);
100
101   // Non constant distance.
102   if (!ConstOffSCEV)
103     return false;
104
105   int64_t Offset = ConstOffSCEV->getValue()->getSExtValue();
106   Type *Ty = cast<PointerType>(PtrA->getType())->getElementType();
107   // The Instructions are connsecutive if the size of the first load/store is
108   // the same as the offset.
109   int64_t Sz = DL->getTypeStoreSize(Ty);
110   return ((-Offset) == Sz);
111 }
112
113 bool BoUpSLP::vectorizeStoreChain(ArrayRef<Value *> Chain, int CostThreshold) {
114   unsigned ChainLen = Chain.size();
115   DEBUG(dbgs() << "SLP: Analyzing a store chain of length " << ChainLen
116                << "\n");
117   Type *StoreTy = cast<StoreInst>(Chain[0])->getValueOperand()->getType();
118   unsigned Sz = DL->getTypeSizeInBits(StoreTy);
119   unsigned VF = MinVecRegSize / Sz;
120
121   if (!isPowerOf2_32(Sz) || VF < 2)
122     return false;
123
124   bool Changed = false;
125   // Look for profitable vectorizable trees at all offsets, starting at zero.
126   for (unsigned i = 0, e = ChainLen; i < e; ++i) {
127     if (i + VF > e)
128       break;
129     DEBUG(dbgs() << "SLP: Analyzing " << VF << " stores at offset " << i
130                  << "\n");
131     ArrayRef<Value *> Operands = Chain.slice(i, VF);
132
133     int Cost = getTreeCost(Operands);
134     DEBUG(dbgs() << "SLP: Found cost=" << Cost << " for VF=" << VF << "\n");
135     if (Cost < CostThreshold) {
136       DEBUG(dbgs() << "SLP: Decided to vectorize cost=" << Cost << "\n");
137       Builder.SetInsertPoint(getInsertionPoint(getLastIndex(Operands, VF)));
138       vectorizeTree(Operands, VF);
139       i += VF - 1;
140       Changed = true;
141     }
142   }
143
144   if (Changed)
145     return true;
146
147   int Cost = getTreeCost(Chain);
148   if (Cost < CostThreshold) {
149     DEBUG(dbgs() << "SLP: Found store chain cost = " << Cost
150                  << " for size = " << ChainLen << "\n");
151     Builder.SetInsertPoint(getInsertionPoint(getLastIndex(Chain, ChainLen)));
152     vectorizeTree(Chain, ChainLen);
153     return true;
154   }
155
156   return false;
157 }
158
159 bool BoUpSLP::vectorizeStores(ArrayRef<StoreInst *> Stores, int costThreshold) {
160   SetVector<Value *> Heads, Tails;
161   SmallDenseMap<Value *, Value *> ConsecutiveChain;
162
163   // We may run into multiple chains that merge into a single chain. We mark the
164   // stores that we vectorized so that we don't visit the same store twice.
165   ValueSet VectorizedStores;
166   bool Changed = false;
167
168   // Do a quadratic search on all of the given stores and find
169   // all of the pairs of loads that follow each other.
170   for (unsigned i = 0, e = Stores.size(); i < e; ++i)
171     for (unsigned j = 0; j < e; ++j) {
172       if (i == j)
173         continue;
174       
175       if (isConsecutiveAccess(Stores[i], Stores[j])) {
176         Tails.insert(Stores[j]);
177         Heads.insert(Stores[i]);
178         ConsecutiveChain[Stores[i]] = Stores[j];
179       }
180     }
181
182   // For stores that start but don't end a link in the chain:
183   for (SetVector<Value *>::iterator it = Heads.begin(), e = Heads.end();
184        it != e; ++it) {
185     if (Tails.count(*it))
186       continue;
187
188     // We found a store instr that starts a chain. Now follow the chain and try
189     // to vectorize it.
190     ValueList Operands;
191     Value *I = *it;
192     // Collect the chain into a list.
193     while (Tails.count(I) || Heads.count(I)) {
194       if (VectorizedStores.count(I))
195         break;
196       Operands.push_back(I);
197       // Move to the next value in the chain.
198       I = ConsecutiveChain[I];
199     }
200
201     bool Vectorized = vectorizeStoreChain(Operands, costThreshold);
202
203     // Mark the vectorized stores so that we don't vectorize them again.
204     if (Vectorized)
205       VectorizedStores.insert(Operands.begin(), Operands.end());
206     Changed |= Vectorized;
207   }
208
209   return Changed;
210 }
211
212 int BoUpSLP::getScalarizationCost(ArrayRef<Value *> VL) {
213   // Find the type of the operands in VL.
214   Type *ScalarTy = VL[0]->getType();
215   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
216     ScalarTy = SI->getValueOperand()->getType();
217   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
218   // Find the cost of inserting/extracting values from the vector.
219   return getScalarizationCost(VecTy);
220 }
221
222 int BoUpSLP::getScalarizationCost(Type *Ty) {
223   int Cost = 0;
224   for (unsigned i = 0, e = cast<VectorType>(Ty)->getNumElements(); i < e; ++i)
225     Cost += TTI->getVectorInstrCost(Instruction::InsertElement, Ty, i);
226   return Cost;
227 }
228
229 AliasAnalysis::Location BoUpSLP::getLocation(Instruction *I) {
230   if (StoreInst *SI = dyn_cast<StoreInst>(I))
231     return AA->getLocation(SI);
232   if (LoadInst *LI = dyn_cast<LoadInst>(I))
233     return AA->getLocation(LI);
234   return AliasAnalysis::Location();
235 }
236
237 Value *BoUpSLP::isUnsafeToSink(Instruction *Src, Instruction *Dst) {
238   assert(Src->getParent() == Dst->getParent() && "Not the same BB");
239   BasicBlock::iterator I = Src, E = Dst;
240   /// Scan all of the instruction from SRC to DST and check if
241   /// the source may alias.
242   for (++I; I != E; ++I) {
243     // Ignore store instructions that are marked as 'ignore'.
244     if (MemBarrierIgnoreList.count(I))
245       continue;
246     if (Src->mayWriteToMemory()) /* Write */ {
247       if (!I->mayReadOrWriteMemory())
248         continue;
249     } else /* Read */ {
250       if (!I->mayWriteToMemory())
251         continue;
252     }
253     AliasAnalysis::Location A = getLocation(&*I);
254     AliasAnalysis::Location B = getLocation(Src);
255
256     if (!A.Ptr || !B.Ptr || AA->alias(A, B))
257       return I;
258   }
259   return 0;
260 }
261
262 Value *BoUpSLP::vectorizeArith(ArrayRef<Value *> Operands) {
263   int LastIdx = getLastIndex(Operands, Operands.size());
264   Instruction *Loc = getInsertionPoint(LastIdx);
265   Builder.SetInsertPoint(Loc);
266
267   assert(getFirstUserIndex(Operands, Operands.size()) > LastIdx &&
268          "Vectorizing with in-tree users");
269
270   Value *Vec = vectorizeTree(Operands, Operands.size());
271   // After vectorizing the operands we need to generate extractelement
272   // instructions and replace all of the uses of the scalar values with
273   // the values that we extracted from the vectorized tree.
274   for (unsigned i = 0, e = Operands.size(); i != e; ++i) {
275     Value *S = Builder.CreateExtractElement(Vec, Builder.getInt32(i));
276     Operands[i]->replaceAllUsesWith(S);
277   }
278
279   return Vec;
280 }
281
282 int BoUpSLP::getTreeCost(ArrayRef<Value *> VL) {
283   // Get rid of the list of stores that were removed, and from the
284   // lists of instructions with multiple users.
285   MemBarrierIgnoreList.clear();
286   LaneMap.clear();
287   MultiUserVals.clear();
288   MustScalarize.clear();
289   MustExtract.clear();
290
291   // Find the location of the last root.
292   int LastRootIndex = getLastIndex(VL, VL.size());
293   int FirstUserIndex = getFirstUserIndex(VL, VL.size());
294
295   // Don't vectorize if there are users of the tree roots inside the tree
296   // itself.
297   if (LastRootIndex > FirstUserIndex)
298     return max_cost;
299
300   // Scan the tree and find which value is used by which lane, and which values
301   // must be scalarized.
302   getTreeUses_rec(VL, 0);
303
304   // Check that instructions with multiple users can be vectorized. Mark unsafe
305   // instructions.
306   for (SetVector<Value *>::iterator it = MultiUserVals.begin(),
307                                     e = MultiUserVals.end();
308        it != e; ++it) {
309     // Check that all of the users of this instr are within the tree
310     // and that they are all from the same lane.
311     int Lane = -1;
312     for (Value::use_iterator I = (*it)->use_begin(), E = (*it)->use_end();
313          I != E; ++I) {
314       if (LaneMap.find(*I) == LaneMap.end()) {
315         DEBUG(dbgs() << "SLP: Instr " << **it << " has multiple users.\n");
316
317         // We don't have an ordering problem if the user is not in this basic
318         // block.
319         Instruction *Inst = cast<Instruction>(*I);
320         if (Inst->getParent() != BB) {
321           MustExtract.insert(*it);
322           continue;
323         }
324
325         // We don't have an ordering problem if the user is after the last root.
326         int Idx = InstrIdx[Inst];
327         if (Idx < LastRootIndex) {
328           MustScalarize.insert(*it);
329           DEBUG(dbgs() << "SLP: Adding to MustScalarize "
330                           "because of an unsafe out of tree usage.\n");
331           break;
332         }
333
334         DEBUG(dbgs() << "SLP: Adding to MustExtract "
335                         "because of a safe out of tree usage.\n");
336         MustExtract.insert(*it);
337         continue;
338       }
339       if (Lane == -1)
340         Lane = LaneMap[*I];
341       if (Lane != LaneMap[*I]) {
342         MustScalarize.insert(*it);
343         DEBUG(dbgs() << "SLP: Adding " << **it
344                      << " to MustScalarize because multiple lane use it: "
345                      << Lane << " and " << LaneMap[*I] << ".\n");
346         break;
347       }
348     }
349   }
350
351   // Now calculate the cost of vectorizing the tree.
352   return getTreeCost_rec(VL, 0);
353 }
354
355 static bool CanReuseExtract(ArrayRef<Value *> VL, unsigned VF,
356                             VectorType *VecTy) {
357   // Check if all of the extracts come from the same vector and from the
358   // correct offset.
359   Value *VL0 = VL[0];
360   ExtractElementInst *E0 = cast<ExtractElementInst>(VL0);
361   Value *Vec = E0->getOperand(0);
362
363   // We have to extract from the same vector type.
364   if (Vec->getType() != VecTy)
365     return false;
366
367   // Check that all of the indices extract from the correct offset.
368   ConstantInt *CI = dyn_cast<ConstantInt>(E0->getOperand(1));
369   if (!CI || CI->getZExtValue())
370     return false;
371
372   for (unsigned i = 1, e = VF; i < e; ++i) {
373     ExtractElementInst *E = cast<ExtractElementInst>(VL[i]);
374     ConstantInt *CI = dyn_cast<ConstantInt>(E->getOperand(1));
375
376     if (!CI || CI->getZExtValue() != i || E->getOperand(0) != Vec)
377       return false;
378   }
379
380   return true;
381 }
382
383 void BoUpSLP::getTreeUses_rec(ArrayRef<Value *> VL, unsigned Depth) {
384   if (Depth == RecursionMaxDepth)
385     return;
386
387   // Don't handle vectors.
388   if (VL[0]->getType()->isVectorTy())
389     return;
390   
391   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
392     if (SI->getValueOperand()->getType()->isVectorTy())
393       return;
394
395   // Check if all of the operands are constants.
396   bool AllConst = true;
397   bool AllSameScalar = true;
398   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
399     AllConst &= isa<Constant>(VL[i]);
400     AllSameScalar &= (VL[0] == VL[i]);
401     Instruction *I = dyn_cast<Instruction>(VL[i]);
402     // If one of the instructions is out of this BB, we need to scalarize all.
403     if (I && I->getParent() != BB)
404       return;
405   }
406
407   // If all of the operands are identical or constant we have a simple solution.
408   if (AllConst || AllSameScalar)
409     return;
410
411   // Scalarize unknown structures.
412   Instruction *VL0 = dyn_cast<Instruction>(VL[0]);
413   if (!VL0)
414     return;
415
416   unsigned Opcode = VL0->getOpcode();
417   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
418     Instruction *I = dyn_cast<Instruction>(VL[i]);
419     // If not all of the instructions are identical then we have to scalarize.
420     if (!I || Opcode != I->getOpcode())
421       return;
422   }
423
424   for (int i = 0, e = VL.size(); i < e; ++i) {
425     // Check that the instruction is only used within
426     // one lane.
427     if (LaneMap.count(VL[i]) && LaneMap[VL[i]] != i)
428       return;
429     // Make this instruction as 'seen' and remember the lane.
430     LaneMap[VL[i]] = i;
431   }
432
433   // Mark instructions with multiple users.
434   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
435     Instruction *I = dyn_cast<Instruction>(VL[i]);
436     // Remember to check if all of the users of this instr are vectorized
437     // within our tree. At depth zero we have no local users, only external
438     // users that we don't care about.
439     if (Depth && I && I->getNumUses() > 1) {
440       DEBUG(dbgs() << "SLP: Adding to MultiUserVals "
441                       "because it has multiple users:" << *I << " \n");
442       MultiUserVals.insert(I);
443     }
444   }
445
446   switch (Opcode) {
447   case Instruction::ExtractElement: {
448     VectorType *VecTy = VectorType::get(VL[0]->getType(), VL.size());
449     // No need to follow ExtractElements that are going to be optimized away.
450     if (CanReuseExtract(VL, VL.size(), VecTy))
451       return;
452     // Fall through.
453   }
454   case Instruction::ZExt:
455   case Instruction::SExt:
456   case Instruction::FPToUI:
457   case Instruction::FPToSI:
458   case Instruction::FPExt:
459   case Instruction::PtrToInt:
460   case Instruction::IntToPtr:
461   case Instruction::SIToFP:
462   case Instruction::UIToFP:
463   case Instruction::Trunc:
464   case Instruction::FPTrunc:
465   case Instruction::BitCast:
466   case Instruction::Select:
467   case Instruction::ICmp:
468   case Instruction::FCmp:
469   case Instruction::Add:
470   case Instruction::FAdd:
471   case Instruction::Sub:
472   case Instruction::FSub:
473   case Instruction::Mul:
474   case Instruction::FMul:
475   case Instruction::UDiv:
476   case Instruction::SDiv:
477   case Instruction::FDiv:
478   case Instruction::URem:
479   case Instruction::SRem:
480   case Instruction::FRem:
481   case Instruction::Shl:
482   case Instruction::LShr:
483   case Instruction::AShr:
484   case Instruction::And:
485   case Instruction::Or:
486   case Instruction::Xor: {
487     for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
488       ValueList Operands;
489       // Prepare the operand vector.
490       for (unsigned j = 0; j < VL.size(); ++j)
491         Operands.push_back(cast<Instruction>(VL[j])->getOperand(i));
492
493       getTreeUses_rec(Operands, Depth + 1);
494     }
495     return;
496   }
497   case Instruction::Store: {
498     ValueList Operands;
499     for (unsigned j = 0; j < VL.size(); ++j)
500       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
501     getTreeUses_rec(Operands, Depth + 1);
502     return;
503   }
504   default:
505     return;
506   }
507 }
508
509 int BoUpSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
510   Type *ScalarTy = VL[0]->getType();
511
512   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
513     ScalarTy = SI->getValueOperand()->getType();
514
515   /// Don't mess with vectors.
516   if (ScalarTy->isVectorTy())
517     return max_cost;
518   
519   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
520
521   if (Depth == RecursionMaxDepth)
522     return getScalarizationCost(VecTy);
523
524   // Check if all of the operands are constants.
525   bool AllConst = true;
526   bool AllSameScalar = true;
527   bool MustScalarizeFlag = false;
528   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
529     AllConst &= isa<Constant>(VL[i]);
530     AllSameScalar &= (VL[0] == VL[i]);
531     // Must have a single use.
532     Instruction *I = dyn_cast<Instruction>(VL[i]);
533     MustScalarizeFlag |= MustScalarize.count(VL[i]);
534     // This instruction is outside the basic block.
535     if (I && I->getParent() != BB)
536       return getScalarizationCost(VecTy);
537   }
538
539   // Is this a simple vector constant.
540   if (AllConst)
541     return 0;
542
543   // If all of the operands are identical we can broadcast them.
544   Instruction *VL0 = dyn_cast<Instruction>(VL[0]);
545   if (AllSameScalar) {
546     // If we are in a loop, and this is not an instruction (e.g. constant or
547     // argument) or the instruction is defined outside the loop then assume
548     // that the cost is zero.
549     if (L && (!VL0 || !L->contains(VL0)))
550       return 0;
551
552     // We need to broadcast the scalar.
553     return TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, 0);
554   }
555
556   // If this is not a constant, or a scalar from outside the loop then we
557   // need to scalarize it.
558   if (MustScalarizeFlag)
559     return getScalarizationCost(VecTy);
560
561   if (!VL0)
562     return getScalarizationCost(VecTy);
563   assert(VL0->getParent() == BB && "Wrong BB");
564
565   unsigned Opcode = VL0->getOpcode();
566   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
567     Instruction *I = dyn_cast<Instruction>(VL[i]);
568     // If not all of the instructions are identical then we have to scalarize.
569     if (!I || Opcode != I->getOpcode())
570       return getScalarizationCost(VecTy);
571   }
572
573   // Check if it is safe to sink the loads or the stores.
574   if (Opcode == Instruction::Load || Opcode == Instruction::Store) {
575     int MaxIdx = getLastIndex(VL, VL.size());
576     Instruction *Last = InstrVec[MaxIdx];
577
578     for (unsigned i = 0, e = VL.size(); i < e; ++i) {
579       if (VL[i] == Last)
580         continue;
581       Value *Barrier = isUnsafeToSink(cast<Instruction>(VL[i]), Last);
582       if (Barrier) {
583         DEBUG(dbgs() << "SLP: Can't sink " << *VL[i] << "\n down to " << *Last
584                      << "\n because of " << *Barrier << "\n");
585         return max_cost;
586       }
587     }
588   }
589
590   // Calculate the extract cost.
591   unsigned ExternalUserExtractCost = 0;
592   for (unsigned i = 0, e = VL.size(); i < e; ++i)
593     if (MustExtract.count(VL[i]))
594       ExternalUserExtractCost +=
595           TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, i);
596
597   switch (Opcode) {
598   case Instruction::ExtractElement: {
599     if (CanReuseExtract(VL, VL.size(), VecTy))
600       return 0;
601     return getScalarizationCost(VecTy);
602   }
603   case Instruction::ZExt:
604   case Instruction::SExt:
605   case Instruction::FPToUI:
606   case Instruction::FPToSI:
607   case Instruction::FPExt:
608   case Instruction::PtrToInt:
609   case Instruction::IntToPtr:
610   case Instruction::SIToFP:
611   case Instruction::UIToFP:
612   case Instruction::Trunc:
613   case Instruction::FPTrunc:
614   case Instruction::BitCast: {
615     int Cost = ExternalUserExtractCost;
616     ValueList Operands;
617     Type *SrcTy = VL0->getOperand(0)->getType();
618     // Prepare the operand vector.
619     for (unsigned j = 0; j < VL.size(); ++j) {
620       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
621       // Check that the casted type is the same for all users.
622       if (cast<Instruction>(VL[j])->getOperand(0)->getType() != SrcTy)
623         return getScalarizationCost(VecTy);
624     }
625
626     Cost += getTreeCost_rec(Operands, Depth + 1);
627     if (Cost >= max_cost)
628       return max_cost;
629
630     // Calculate the cost of this instruction.
631     int ScalarCost = VL.size() * TTI->getCastInstrCost(VL0->getOpcode(),
632                                                        VL0->getType(), SrcTy);
633
634     VectorType *SrcVecTy = VectorType::get(SrcTy, VL.size());
635     int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy);
636     Cost += (VecCost - ScalarCost);
637     return Cost;
638   }
639   case Instruction::FCmp:
640   case Instruction::ICmp: {
641     // Check that all of the compares have the same predicate.
642     CmpInst::Predicate P0 = dyn_cast<CmpInst>(VL0)->getPredicate();
643     for (unsigned i = 1, e = VL.size(); i < e; ++i) {
644       CmpInst *Cmp = cast<CmpInst>(VL[i]);
645       if (Cmp->getPredicate() != P0)
646         return getScalarizationCost(VecTy);
647     }
648     // Fall through.
649   }
650   case Instruction::Select:
651   case Instruction::Add:
652   case Instruction::FAdd:
653   case Instruction::Sub:
654   case Instruction::FSub:
655   case Instruction::Mul:
656   case Instruction::FMul:
657   case Instruction::UDiv:
658   case Instruction::SDiv:
659   case Instruction::FDiv:
660   case Instruction::URem:
661   case Instruction::SRem:
662   case Instruction::FRem:
663   case Instruction::Shl:
664   case Instruction::LShr:
665   case Instruction::AShr:
666   case Instruction::And:
667   case Instruction::Or:
668   case Instruction::Xor: {
669     int Cost = ExternalUserExtractCost;
670     // Calculate the cost of all of the operands.
671     for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
672       ValueList Operands;
673       // Prepare the operand vector.
674       for (unsigned j = 0; j < VL.size(); ++j)
675         Operands.push_back(cast<Instruction>(VL[j])->getOperand(i));
676
677       Cost += getTreeCost_rec(Operands, Depth + 1);
678       if (Cost >= max_cost)
679         return max_cost;
680     }
681
682     // Calculate the cost of this instruction.
683     int ScalarCost = 0;
684     int VecCost = 0;
685     if (Opcode == Instruction::FCmp || Opcode == Instruction::ICmp ||
686         Opcode == Instruction::Select) {
687       VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size());
688       ScalarCost =
689           VecTy->getNumElements() *
690           TTI->getCmpSelInstrCost(Opcode, ScalarTy, Builder.getInt1Ty());
691       VecCost = TTI->getCmpSelInstrCost(Opcode, VecTy, MaskTy);
692     } else {
693       ScalarCost = VecTy->getNumElements() *
694                    TTI->getArithmeticInstrCost(Opcode, ScalarTy);
695       VecCost = TTI->getArithmeticInstrCost(Opcode, VecTy);
696     }
697     Cost += (VecCost - ScalarCost);
698     return Cost;
699   }
700   case Instruction::Load: {
701     // If we are scalarize the loads, add the cost of forming the vector.
702     for (unsigned i = 0, e = VL.size() - 1; i < e; ++i)
703       if (!isConsecutiveAccess(VL[i], VL[i + 1]))
704         return getScalarizationCost(VecTy);
705
706     // Cost of wide load - cost of scalar loads.
707     int ScalarLdCost = VecTy->getNumElements() *
708                        TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0);
709     int VecLdCost = TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0);
710     return VecLdCost - ScalarLdCost + ExternalUserExtractCost;
711   }
712   case Instruction::Store: {
713     // We know that we can merge the stores. Calculate the cost.
714     int ScalarStCost = VecTy->getNumElements() *
715                        TTI->getMemoryOpCost(Instruction::Store, ScalarTy, 1, 0);
716     int VecStCost = TTI->getMemoryOpCost(Instruction::Store, ScalarTy, 1, 0);
717     int StoreCost = VecStCost - ScalarStCost;
718
719     ValueList Operands;
720     for (unsigned j = 0; j < VL.size(); ++j) {
721       Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
722       MemBarrierIgnoreList.insert(VL[j]);
723     }
724
725     int TotalCost = StoreCost + getTreeCost_rec(Operands, Depth + 1);
726     return TotalCost + ExternalUserExtractCost;
727   }
728   default:
729     // Unable to vectorize unknown instructions.
730     return getScalarizationCost(VecTy);
731   }
732 }
733
734 int BoUpSLP::getLastIndex(ArrayRef<Value *> VL, unsigned VF) {
735   int MaxIdx = InstrIdx[BB->getFirstNonPHI()];
736   for (unsigned i = 0; i < VF; ++i)
737     MaxIdx = std::max(MaxIdx, InstrIdx[VL[i]]);
738   return MaxIdx;
739 }
740
741 int BoUpSLP::getFirstUserIndex(ArrayRef<Value *> VL, unsigned VF) {
742   // Find the first user of the values.
743   int FirstUser = InstrVec.size();
744   for (unsigned i = 0; i < VF; ++i) {
745     for (Value::use_iterator U = VL[i]->use_begin(), UE = VL[i]->use_end();
746          U != UE; ++U) {
747       Instruction *Instr = dyn_cast<Instruction>(*U);
748       if (!Instr || Instr->getParent() != BB)
749         continue;
750
751       FirstUser = std::min(FirstUser, InstrIdx[Instr]);
752     }
753   }
754   return FirstUser;
755 }
756
757 int BoUpSLP::getLastIndex(Instruction *I, Instruction *J) {
758   assert(I->getParent() == BB && "Invalid parent for instruction I");
759   assert(J->getParent() == BB && "Invalid parent for instruction J");
760   return std::max(InstrIdx[I], InstrIdx[J]);
761 }
762
763 Instruction *BoUpSLP::getInsertionPoint(unsigned Index) {
764   return InstrVec[Index + 1];
765 }
766
767 Value *BoUpSLP::Scalarize(ArrayRef<Value *> VL, VectorType *Ty) {
768   Value *Vec = UndefValue::get(Ty);
769   for (unsigned i = 0; i < Ty->getNumElements(); ++i) {
770     // Generate the 'InsertElement' instruction.
771     Vec = Builder.CreateInsertElement(Vec, VL[i], Builder.getInt32(i));
772     // Remember that this instruction is used as part of a 'gather' sequence.
773     // The caller of the bottom-up slp vectorizer can try to hoist the sequence
774     // if the users are outside of the basic block.
775     if (InsertElementInst *IEI = dyn_cast<InsertElementInst>(Vec))
776       GatherInstructions.push_back(IEI);
777   }
778
779   // Mark the end of the gather sequence.
780   GatherInstructions.push_back(0);
781
782   for (unsigned i = 0; i < Ty->getNumElements(); ++i)
783     VectorizedValues[VL[i]] = Vec;
784
785   return Vec;
786 }
787
788 Value *BoUpSLP::vectorizeTree(ArrayRef<Value *> VL, int VF) {
789   Value *V = vectorizeTree_rec(VL, VF);
790
791   int LastInstrIdx = getLastIndex(VL, VL.size());
792   for (SetVector<Value *>::iterator it = MustExtract.begin(),
793                                     e = MustExtract.end();
794        it != e; ++it) {
795     Instruction *I = cast<Instruction>(*it);
796
797     // This is a scalarized value, so we can use the original value.
798     // No need to extract from the vector.
799     if (!LaneMap.count(I))
800       continue;
801
802     Value *Vec = VectorizedValues[I];
803     // We decided not to vectorize I because one of its users was not
804     // vectorizerd. This is okay.
805     if (!Vec)
806       continue;
807
808     Value *Idx = Builder.getInt32(LaneMap[I]);
809     Value *Extract = Builder.CreateExtractElement(Vec, Idx);
810     bool Replaced = false;
811     for (Value::use_iterator U = I->use_begin(), UE = I->use_end(); U != UE;
812          ++U) {
813       Instruction *UI = cast<Instruction>(*U);
814       if (UI->getParent() != I->getParent() || InstrIdx[UI] > LastInstrIdx)
815         UI->replaceUsesOfWith(I, Extract);
816       Replaced = true;
817     }
818     assert(Replaced && "Must replace at least one outside user");
819     (void)Replaced;
820   }
821
822   // We moved some instructions around. We have to number them again
823   // before we can do any analysis.
824   numberInstructions();
825   MustScalarize.clear();
826   MustExtract.clear();
827   VectorizedValues.clear();
828   return V;
829 }
830
831 Value *BoUpSLP::vectorizeTree_rec(ArrayRef<Value *> VL, int VF) {
832   Type *ScalarTy = VL[0]->getType();
833   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
834     ScalarTy = SI->getValueOperand()->getType();
835   VectorType *VecTy = VectorType::get(ScalarTy, VF);
836
837   // Check if all of the operands are constants or identical.
838   bool AllConst = true;
839   bool AllSameScalar = true;
840   for (unsigned i = 0, e = VF; i < e; ++i) {
841     AllConst &= isa<Constant>(VL[i]);
842     AllSameScalar &= (VL[0] == VL[i]);
843     // The instruction must be in the same BB, and it must be vectorizable.
844     Instruction *I = dyn_cast<Instruction>(VL[i]);
845     if (MustScalarize.count(VL[i]) || (I && I->getParent() != BB))
846       return Scalarize(VL, VecTy);
847   }
848
849   // Check that this is a simple vector constant.
850   if (AllConst || AllSameScalar)
851     return Scalarize(VL, VecTy);
852
853   // Scalarize unknown structures.
854   Instruction *VL0 = dyn_cast<Instruction>(VL[0]);
855   if (!VL0)
856     return Scalarize(VL, VecTy);
857
858   if (VectorizedValues.count(VL0)) {
859     Value *Vec = VectorizedValues[VL0];
860     for (int i = 0; i < VF; ++i)
861       VectorizedValues[VL[i]] = Vec;
862     return Vec;
863   }
864
865   unsigned Opcode = VL0->getOpcode();
866   for (unsigned i = 0, e = VF; i < e; ++i) {
867     Instruction *I = dyn_cast<Instruction>(VL[i]);
868     // If not all of the instructions are identical then we have to scalarize.
869     if (!I || Opcode != I->getOpcode())
870       return Scalarize(VL, VecTy);
871   }
872
873   switch (Opcode) {
874   case Instruction::ExtractElement: {
875     if (CanReuseExtract(VL, VL.size(), VecTy))
876       return VL0->getOperand(0);
877     return Scalarize(VL, VecTy);
878   }
879   case Instruction::ZExt:
880   case Instruction::SExt:
881   case Instruction::FPToUI:
882   case Instruction::FPToSI:
883   case Instruction::FPExt:
884   case Instruction::PtrToInt:
885   case Instruction::IntToPtr:
886   case Instruction::SIToFP:
887   case Instruction::UIToFP:
888   case Instruction::Trunc:
889   case Instruction::FPTrunc:
890   case Instruction::BitCast: {
891     ValueList INVL;
892     for (int i = 0; i < VF; ++i)
893       INVL.push_back(cast<Instruction>(VL[i])->getOperand(0));
894     Value *InVec = vectorizeTree_rec(INVL, VF);
895     CastInst *CI = dyn_cast<CastInst>(VL0);
896     Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy);
897
898     for (int i = 0; i < VF; ++i)
899       VectorizedValues[VL[i]] = V;
900
901     return V;
902   }
903   case Instruction::FCmp:
904   case Instruction::ICmp: {
905     // Check that all of the compares have the same predicate.
906     CmpInst::Predicate P0 = dyn_cast<CmpInst>(VL0)->getPredicate();
907     for (unsigned i = 1, e = VF; i < e; ++i) {
908       CmpInst *Cmp = cast<CmpInst>(VL[i]);
909       if (Cmp->getPredicate() != P0)
910         return Scalarize(VL, VecTy);
911     }
912
913     ValueList LHSV, RHSV;
914     for (int i = 0; i < VF; ++i) {
915       LHSV.push_back(cast<Instruction>(VL[i])->getOperand(0));
916       RHSV.push_back(cast<Instruction>(VL[i])->getOperand(1));
917     }
918
919     Value *L = vectorizeTree_rec(LHSV, VF);
920     Value *R = vectorizeTree_rec(RHSV, VF);
921     Value *V;
922     if (VL0->getOpcode() == Instruction::FCmp)
923       V = Builder.CreateFCmp(P0, L, R);
924     else
925       V = Builder.CreateICmp(P0, L, R);
926
927     for (int i = 0; i < VF; ++i)
928       VectorizedValues[VL[i]] = V;
929
930     return V;
931   }
932   case Instruction::Select: {
933     ValueList TrueVec, FalseVec, CondVec;
934     for (int i = 0; i < VF; ++i) {
935       CondVec.push_back(cast<Instruction>(VL[i])->getOperand(0));
936       TrueVec.push_back(cast<Instruction>(VL[i])->getOperand(1));
937       FalseVec.push_back(cast<Instruction>(VL[i])->getOperand(2));
938     }
939
940     Value *True = vectorizeTree_rec(TrueVec, VF);
941     Value *False = vectorizeTree_rec(FalseVec, VF);
942     Value *Cond = vectorizeTree_rec(CondVec, VF);
943     Value *V = Builder.CreateSelect(Cond, True, False);
944
945     for (int i = 0; i < VF; ++i)
946       VectorizedValues[VL[i]] = V;
947
948     return V;
949   }
950   case Instruction::Add:
951   case Instruction::FAdd:
952   case Instruction::Sub:
953   case Instruction::FSub:
954   case Instruction::Mul:
955   case Instruction::FMul:
956   case Instruction::UDiv:
957   case Instruction::SDiv:
958   case Instruction::FDiv:
959   case Instruction::URem:
960   case Instruction::SRem:
961   case Instruction::FRem:
962   case Instruction::Shl:
963   case Instruction::LShr:
964   case Instruction::AShr:
965   case Instruction::And:
966   case Instruction::Or:
967   case Instruction::Xor: {
968     ValueList LHSVL, RHSVL;
969     for (int i = 0; i < VF; ++i) {
970       LHSVL.push_back(cast<Instruction>(VL[i])->getOperand(0));
971       RHSVL.push_back(cast<Instruction>(VL[i])->getOperand(1));
972     }
973
974     Value *LHS = vectorizeTree_rec(LHSVL, VF);
975     Value *RHS = vectorizeTree_rec(RHSVL, VF);
976     BinaryOperator *BinOp = cast<BinaryOperator>(VL0);
977     Value *V = Builder.CreateBinOp(BinOp->getOpcode(), LHS, RHS);
978
979     for (int i = 0; i < VF; ++i)
980       VectorizedValues[VL[i]] = V;
981
982     return V;
983   }
984   case Instruction::Load: {
985     LoadInst *LI = cast<LoadInst>(VL0);
986     unsigned Alignment = LI->getAlignment();
987
988     // Check if all of the loads are consecutive.
989     for (unsigned i = 1, e = VF; i < e; ++i)
990       if (!isConsecutiveAccess(VL[i - 1], VL[i]))
991         return Scalarize(VL, VecTy);
992
993     // Loads are inserted at the head of the tree because we don't want to sink
994     // them all the way down past store instructions.
995     Instruction *Loc = getInsertionPoint(getLastIndex(VL, VL.size()));
996     IRBuilder<> LoadBuilder(Loc);
997     Value *VecPtr = LoadBuilder.CreateBitCast(LI->getPointerOperand(),
998                                               VecTy->getPointerTo());
999     LI = LoadBuilder.CreateLoad(VecPtr);
1000     LI->setAlignment(Alignment);
1001
1002     for (int i = 0; i < VF; ++i)
1003       VectorizedValues[VL[i]] = LI;
1004
1005     return LI;
1006   }
1007   case Instruction::Store: {
1008     StoreInst *SI = cast<StoreInst>(VL0);
1009     unsigned Alignment = SI->getAlignment();
1010
1011     ValueList ValueOp;
1012     for (int i = 0; i < VF; ++i)
1013       ValueOp.push_back(cast<StoreInst>(VL[i])->getValueOperand());
1014
1015     Value *VecValue = vectorizeTree_rec(ValueOp, VF);
1016     Value *VecPtr =
1017         Builder.CreateBitCast(SI->getPointerOperand(), VecTy->getPointerTo());
1018     Builder.CreateStore(VecValue, VecPtr)->setAlignment(Alignment);
1019
1020     for (int i = 0; i < VF; ++i)
1021       cast<Instruction>(VL[i])->eraseFromParent();
1022     return 0;
1023   }
1024   default:
1025     return Scalarize(VL, VecTy);
1026   }
1027 }
1028
1029 } // end of namespace