SLPVectorization: Add a basic support for cross-basic block slp vectorization.
[oota-llvm.git] / lib / Transforms / Vectorize / SLPVectorizer.cpp
1 //===- SLPVectorizer.cpp - A bottom up SLP Vectorizer ---------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This pass implements the Bottom Up SLP vectorizer. It detects consecutive
10 // stores that can be put together into vector-stores. Next, it attempts to
11 // construct vectorizable tree using the use-def chains. If a profitable tree
12 // was found, the SLP vectorizer performs vectorization on the tree.
13 //
14 // The pass is inspired by the work described in the paper:
15 //  "Loop-Aware SLP in GCC" by Ira Rosen, Dorit Nuzman, Ayal Zaks.
16 //
17 //===----------------------------------------------------------------------===//
18 #define SV_NAME "slp-vectorizer"
19 #define DEBUG_TYPE "SLP"
20
21 #include "VecUtils.h"
22 #include "llvm/Transforms/Vectorize.h"
23 #include "llvm/ADT/MapVector.h"
24 #include "llvm/Analysis/AliasAnalysis.h"
25 #include "llvm/Analysis/ScalarEvolution.h"
26 #include "llvm/Analysis/TargetTransformInfo.h"
27 #include "llvm/Analysis/Verifier.h"
28 #include "llvm/Analysis/LoopInfo.h"
29 #include "llvm/IR/DataLayout.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/Type.h"
34 #include "llvm/IR/Value.h"
35 #include "llvm/Pass.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/Debug.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include <map>
40
41 using namespace llvm;
42
43 static cl::opt<int>
44 SLPCostThreshold("slp-threshold", cl::init(0), cl::Hidden,
45                  cl::desc("Only vectorize trees if the gain is above this "
46                           "number. (gain = -cost of vectorization)"));
47 namespace {
48
49 /// The SLPVectorizer Pass.
50 struct SLPVectorizer : public FunctionPass {
51   typedef MapVector<Value*, BoUpSLP::StoreList> StoreListMap;
52
53   /// Pass identification, replacement for typeid
54   static char ID;
55
56   explicit SLPVectorizer() : FunctionPass(ID) {
57     initializeSLPVectorizerPass(*PassRegistry::getPassRegistry());
58   }
59
60   ScalarEvolution *SE;
61   DataLayout *DL;
62   TargetTransformInfo *TTI;
63   AliasAnalysis *AA;
64   LoopInfo *LI;
65
66   virtual bool runOnFunction(Function &F) {
67     SE = &getAnalysis<ScalarEvolution>();
68     DL = getAnalysisIfAvailable<DataLayout>();
69     TTI = &getAnalysis<TargetTransformInfo>();
70     AA = &getAnalysis<AliasAnalysis>();
71     LI = &getAnalysis<LoopInfo>();
72
73     StoreRefs.clear();
74     bool Changed = false;
75
76     // Must have DataLayout. We can't require it because some tests run w/o
77     // triple.
78     if (!DL)
79       return false;
80
81     DEBUG(dbgs()<<"SLP: Analyzing blocks in " << F.getName() << ".\n");
82
83     for (Function::iterator it = F.begin(), e = F.end(); it != e; ++it) {
84       BasicBlock *BB = it;
85       bool BBChanged = false;
86
87       // Use the bollom up slp vectorizer to construct chains that start with
88       // he store instructions.
89       BoUpSLP R(BB, SE, DL, TTI, AA, LI->getLoopFor(BB));
90
91       // Vectorize trees that end at reductions.
92       BBChanged |= vectorizeChainsInBlock(BB, R);
93
94       // Vectorize trees that end at stores.
95       if (unsigned count = collectStores(BB, R)) {
96         (void)count;
97         DEBUG(dbgs()<<"SLP: Found " << count << " stores to vectorize.\n");
98         BBChanged |= vectorizeStoreChains(R);
99       }
100
101       // Try to hoist some of the scalarization code to the preheader.
102       if (BBChanged) {
103         hoistGatherSequence(LI, BB, R);
104         Changed |= vectorizeUsingGatherHints(R.getGatherSeqInstructions());
105       }
106
107       Changed |= BBChanged;
108     }
109
110     if (Changed) {
111       DEBUG(dbgs()<<"SLP: vectorized \""<<F.getName()<<"\"\n");
112       DEBUG(verifyFunction(F));
113     }
114     return Changed;
115   }
116
117   virtual void getAnalysisUsage(AnalysisUsage &AU) const {
118     FunctionPass::getAnalysisUsage(AU);
119     AU.addRequired<ScalarEvolution>();
120     AU.addRequired<AliasAnalysis>();
121     AU.addRequired<TargetTransformInfo>();
122     AU.addRequired<LoopInfo>();
123   }
124
125 private:
126
127   /// \brief Collect memory references and sort them according to their base
128   /// object. We sort the stores to their base objects to reduce the cost of the
129   /// quadratic search on the stores. TODO: We can further reduce this cost
130   /// if we flush the chain creation every time we run into a memory barrier.
131   unsigned collectStores(BasicBlock *BB, BoUpSLP &R);
132
133   /// \brief Try to vectorize a chain that starts at two arithmetic instrs.
134   bool tryToVectorizePair(Value *A, Value *B,  BoUpSLP &R);
135
136   /// \brief Try to vectorize a list of operands. If \p NeedExtracts is true
137   /// then we calculate the cost of extracting the scalars from the vector.
138   /// \returns true if a value was vectorized.
139   bool tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, bool NeedExtracts);
140
141   /// \brief Try to vectorize a chain that may start at the operands of \V;
142   bool tryToVectorize(BinaryOperator *V,  BoUpSLP &R);
143
144   /// \brief Vectorize the stores that were collected in StoreRefs.
145   bool vectorizeStoreChains(BoUpSLP &R);
146
147   /// \brief Try to hoist gather sequences outside of the loop in cases where
148   /// all of the sources are loop invariant.
149   void hoistGatherSequence(LoopInfo *LI, BasicBlock *BB, BoUpSLP &R);
150
151   /// \brief Try to vectorize additional sequences in different basic blocks
152   /// based on values that we gathered in previous blocks. The list \p Gathers
153   /// holds the gather InsertElement instructions that were generated during
154   /// vectorization.
155   /// \returns True if some code was vectorized.
156   bool vectorizeUsingGatherHints(BoUpSLP::InstrList &Gathers);
157
158   /// \brief Scan the basic block and look for patterns that are likely to start
159   /// a vectorization chain.
160   bool vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R);
161
162 private:
163   StoreListMap StoreRefs;
164 };
165
166 unsigned SLPVectorizer::collectStores(BasicBlock *BB, BoUpSLP &R) {
167   unsigned count = 0;
168   StoreRefs.clear();
169   for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
170     StoreInst *SI = dyn_cast<StoreInst>(it);
171     if (!SI)
172       continue;
173
174     // Check that the pointer points to scalars.
175     Type *Ty = SI->getValueOperand()->getType();
176     if (Ty->isAggregateType() || Ty->isVectorTy())
177       return 0;
178
179     // Find the base of the GEP.
180     Value *Ptr = SI->getPointerOperand();
181     if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr))
182       Ptr = GEP->getPointerOperand();
183
184     // Save the store locations.
185     StoreRefs[Ptr].push_back(SI);
186     count++;
187   }
188   return count;
189 }
190
191 bool SLPVectorizer::tryToVectorizePair(Value *A, Value *B,  BoUpSLP &R) {
192   if (!A || !B) return false;
193   Value *VL[] = { A, B };
194   return tryToVectorizeList(VL, R, true);
195 }
196
197 bool SLPVectorizer::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R,
198                                        bool NeedExtracts) {
199   if (VL.size() < 2)
200     return false;
201
202   DEBUG(dbgs()<<"SLP: Vectorizing a list of length = " << VL.size() << ".\n");
203
204   // Check that all of the parts are scalar instructions of the same type.
205   Instruction *I0 = dyn_cast<Instruction>(VL[0]);
206   if (!I0) return 0;
207
208   unsigned Opcode0 = I0->getOpcode();
209
210   for (int i = 0, e = VL.size(); i < e; ++i) {
211     Type *Ty = VL[i]->getType();
212     if (Ty->isAggregateType() || Ty->isVectorTy())
213       return 0;
214     Instruction *Inst = dyn_cast<Instruction>(VL[i]);
215     if (!Inst || Inst->getOpcode() != Opcode0)
216       return 0;
217   }
218
219   int Cost = R.getTreeCost(VL);
220   int ExtrCost =  NeedExtracts ? R.getScalarizationCost(VL) : 0;
221   DEBUG(dbgs()<<"SLP: Cost of pair:" << Cost <<
222         " Cost of extract:" << ExtrCost << ".\n");
223   if ((Cost+ExtrCost) >= -SLPCostThreshold) return false;
224   DEBUG(dbgs()<<"SLP: Vectorizing pair.\n");
225   R.vectorizeArith(VL);
226   return true;
227 }
228
229 bool SLPVectorizer::tryToVectorize(BinaryOperator *V,  BoUpSLP &R) {
230   if (!V) return false;
231   // Try to vectorize V.
232   if (tryToVectorizePair(V->getOperand(0), V->getOperand(1), R))
233     return true;
234
235   BinaryOperator *A = dyn_cast<BinaryOperator>(V->getOperand(0));
236   BinaryOperator *B = dyn_cast<BinaryOperator>(V->getOperand(1));
237   // Try to skip B.
238   if (B && B->hasOneUse()) {
239     BinaryOperator *B0 = dyn_cast<BinaryOperator>(B->getOperand(0));
240     BinaryOperator *B1 = dyn_cast<BinaryOperator>(B->getOperand(1));
241     if (tryToVectorizePair(A, B0, R)) {
242       B->moveBefore(V);
243       return true;
244     }
245     if (tryToVectorizePair(A, B1, R)) {
246       B->moveBefore(V);
247       return true;
248     }
249   }
250
251   // Try to skip A.
252   if (A && A->hasOneUse()) {
253     BinaryOperator *A0 = dyn_cast<BinaryOperator>(A->getOperand(0));
254     BinaryOperator *A1 = dyn_cast<BinaryOperator>(A->getOperand(1));
255     if (tryToVectorizePair(A0, B, R)) {
256       A->moveBefore(V);
257       return true;
258     }
259     if (tryToVectorizePair(A1, B, R)) {
260       A->moveBefore(V);
261       return true;
262     }
263   }
264   return 0;
265 }
266
267 bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
268   bool Changed = false;
269   for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
270     if (isa<DbgInfoIntrinsic>(it)) continue;
271
272     // Try to vectorize reductions that use PHINodes.
273     if (PHINode *P = dyn_cast<PHINode>(it)) {
274       // Check that the PHI is a reduction PHI.
275       if (P->getNumIncomingValues() != 2) return Changed;
276       Value *Rdx = (P->getIncomingBlock(0) == BB ? P->getIncomingValue(0) :
277                     (P->getIncomingBlock(1) == BB ? P->getIncomingValue(1) :
278                      0));
279       // Check if this is a Binary Operator.
280       BinaryOperator *BI = dyn_cast_or_null<BinaryOperator>(Rdx);
281       if (!BI)
282         continue;
283
284       Value *Inst = BI->getOperand(0);
285       if (Inst == P) Inst = BI->getOperand(1);
286       Changed |= tryToVectorize(dyn_cast<BinaryOperator>(Inst), R);
287       continue;
288     }
289
290     // Try to vectorize trees that start at compare instructions.
291     if (CmpInst *CI = dyn_cast<CmpInst>(it)) {
292       if (tryToVectorizePair(CI->getOperand(0), CI->getOperand(1), R)) {
293         Changed |= true;
294         continue;
295       }
296       for (int i = 0; i < 2; ++i)
297         if (BinaryOperator *BI = dyn_cast<BinaryOperator>(CI->getOperand(i)))
298           Changed |= tryToVectorizePair(BI->getOperand(0), BI->getOperand(1), R);
299       continue;
300     }
301   }
302
303   // Scan the PHINodes in our successors in search for pairing hints.
304   for (succ_iterator it = succ_begin(BB), e = succ_end(BB); it != e; ++it) {
305     BasicBlock *Succ = *it;
306     SmallVector<Value*, 4> Incoming;
307
308     // Collect the incoming values from the PHIs.
309     for (BasicBlock::iterator instr = Succ->begin(), ie = Succ->end();
310          instr != ie; ++instr) {
311       PHINode *P = dyn_cast<PHINode>(instr);
312
313       if (!P)
314         break;
315
316       Value *V = P->getIncomingValueForBlock(BB);
317       if (Instruction *I = dyn_cast<Instruction>(V))
318         if (I->getParent() == BB)
319           Incoming.push_back(I);
320     }
321
322     if (Incoming.size() > 1)
323       Changed |= tryToVectorizeList(Incoming, R, true);
324   }
325   
326   return Changed;
327 }
328
329 bool SLPVectorizer::vectorizeStoreChains(BoUpSLP &R) {
330   bool Changed = false;
331   // Attempt to sort and vectorize each of the store-groups.
332   for (StoreListMap::iterator it = StoreRefs.begin(), e = StoreRefs.end();
333        it != e; ++it) {
334     if (it->second.size() < 2)
335       continue;
336
337     DEBUG(dbgs()<<"SLP: Analyzing a store chain of length " <<
338           it->second.size() << ".\n");
339
340     Changed |= R.vectorizeStores(it->second, -SLPCostThreshold);
341   }
342   return Changed;
343 }
344
345 bool SLPVectorizer::vectorizeUsingGatherHints(BoUpSLP::InstrList &Gathers) {
346   SmallVector<Value*, 4> Seq;
347   bool Changed = false;
348   for (int i = 0, e = Gathers.size(); i < e; ++i) {
349     InsertElementInst *IEI = dyn_cast_or_null<InsertElementInst>(Gathers[i]);
350
351     if (IEI) {
352       if (Instruction *I = dyn_cast<Instruction>(IEI->getOperand(1)))
353         Seq.push_back(I);
354     } else {
355
356       if (!Seq.size())
357         continue;
358
359       Instruction *I = cast<Instruction>(Seq[0]);
360       BasicBlock *BB = I->getParent();
361
362       DEBUG(dbgs()<<"SLP: Inspecting a gather list of size " << Seq.size() <<
363             " in " << BB->getName() << ".\n");
364
365       // Check if the gathered values have multiple uses. If they only have one
366       // user then we know that the insert/extract pair will go away.
367       bool HasMultipleUsers = false;
368       for (int i=0; e = Seq.size(), i < e; ++i) {
369         if (!Seq[i]->hasOneUse()) {
370           HasMultipleUsers = true;
371           break;
372         }
373       }
374
375       BoUpSLP BO(BB, SE, DL, TTI, AA, LI->getLoopFor(BB));
376
377       if (tryToVectorizeList(Seq, BO, HasMultipleUsers)) {
378         DEBUG(dbgs()<<"SLP: Vectorized a gather list of len " << Seq.size() <<
379               " in " << BB->getName() << ".\n");
380         Changed = true;
381       }
382
383       Seq.clear();
384     }
385   }
386
387   return Changed;
388 }
389
390 void SLPVectorizer::hoistGatherSequence(LoopInfo *LI, BasicBlock *BB,
391                                         BoUpSLP &R) {
392   // Check if this block is inside a loop.
393   Loop *L = LI->getLoopFor(BB);
394   if (!L)
395     return;
396
397   // Check if it has a preheader.
398   BasicBlock *PreHeader = L->getLoopPreheader();
399   if (!PreHeader)
400     return;
401
402   // Mark the insertion point for the block.
403   Instruction *Location = PreHeader->getTerminator();
404
405   BoUpSLP::InstrList &Gathers = R.getGatherSeqInstructions();
406   for (BoUpSLP::InstrList::iterator it = Gathers.begin(), e = Gathers.end();
407        it != e; ++it) {
408     InsertElementInst *Insert = dyn_cast_or_null<InsertElementInst>(*it);
409
410     // The InsertElement sequence can be simplified into a constant.
411     // Also Ignore NULL pointers because they are only here to separate
412     // sequences.
413     if (!Insert)
414       continue;
415
416     // If the vector or the element that we insert into it are
417     // instructions that are defined in this basic block then we can't
418     // hoist this instruction.
419     Instruction *CurrVec = dyn_cast<Instruction>(Insert->getOperand(0));
420     Instruction *NewElem = dyn_cast<Instruction>(Insert->getOperand(1));
421     if (CurrVec && L->contains(CurrVec)) continue;
422     if (NewElem && L->contains(NewElem)) continue;
423
424     // We can hoist this instruction. Move it to the pre-header.
425     Insert->moveBefore(Location);
426   }
427 }
428
429 } // end anonymous namespace
430
431 char SLPVectorizer::ID = 0;
432 static const char lv_name[] = "SLP Vectorizer";
433 INITIALIZE_PASS_BEGIN(SLPVectorizer, SV_NAME, lv_name, false, false)
434 INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
435 INITIALIZE_AG_DEPENDENCY(TargetTransformInfo)
436 INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
437 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
438 INITIALIZE_PASS_END(SLPVectorizer, SV_NAME, lv_name, false, false)
439
440 namespace llvm {
441   Pass *createSLPVectorizerPass() {
442     return new SLPVectorizer();
443   }
444 }
445