[Constant Hoisting] Replace the MapVector with a separate Map and Vector to keep...
[oota-llvm.git] / lib / Transforms / Scalar / ConstantHoisting.cpp
1 //===- ConstantHoisting.cpp - Prepare code for expensive constants --------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass identifies expensive constants to hoist and coalesces them to
11 // better prepare it for SelectionDAG-based code generation. This works around
12 // the limitations of the basic-block-at-a-time approach.
13 //
14 // First it scans all instructions for integer constants and calculates its
15 // cost. If the constant can be folded into the instruction (the cost is
16 // TCC_Free) or the cost is just a simple operation (TCC_BASIC), then we don't
17 // consider it expensive and leave it alone. This is the default behavior and
18 // the default implementation of getIntImmCost will always return TCC_Free.
19 //
20 // If the cost is more than TCC_BASIC, then the integer constant can't be folded
21 // into the instruction and it might be beneficial to hoist the constant.
22 // Similar constants are coalesced to reduce register pressure and
23 // materialization code.
24 //
25 // When a constant is hoisted, it is also hidden behind a bitcast to force it to
26 // be live-out of the basic block. Otherwise the constant would be just
27 // duplicated and each basic block would have its own copy in the SelectionDAG.
28 // The SelectionDAG recognizes such constants as opaque and doesn't perform
29 // certain transformations on them, which would create a new expensive constant.
30 //
31 // This optimization is only applied to integer constants in instructions and
32 // simple (this means not nested) constant cast experessions. For example:
33 // %0 = load i64* inttoptr (i64 big_constant to i64*)
34 //===----------------------------------------------------------------------===//
35
36 #define DEBUG_TYPE "consthoist"
37 #include "llvm/Transforms/Scalar.h"
38 #include "llvm/ADT/SmallSet.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/Statistic.h"
41 #include "llvm/Analysis/TargetTransformInfo.h"
42 #include "llvm/IR/Constants.h"
43 #include "llvm/IR/Dominators.h"
44 #include "llvm/IR/IntrinsicInst.h"
45 #include "llvm/Pass.h"
46 #include "llvm/Support/Debug.h"
47
48 using namespace llvm;
49
50 STATISTIC(NumConstantsHoisted, "Number of constants hoisted");
51 STATISTIC(NumConstantsRebased, "Number of constants rebased");
52
53 namespace {
54 typedef SmallVector<User *, 4> ConstantUseListType;
55 struct ConstantCandidate {
56   ConstantUseListType Uses;
57   ConstantInt *ConstInt;
58   unsigned CumulativeCost;
59
60   ConstantCandidate(ConstantInt *ConstInt)
61     : ConstInt(ConstInt), CumulativeCost(0) { }
62 };
63
64 struct ConstantInfo {
65   ConstantInt *BaseConstant;
66   struct RebasedConstantInfo {
67     ConstantInt *OriginalConstant;
68     Constant *Offset;
69     ConstantUseListType Uses;
70   };
71   typedef SmallVector<RebasedConstantInfo, 4> RebasedConstantListType;
72   RebasedConstantListType RebasedConstants;
73 };
74
75 class ConstantHoisting : public FunctionPass {
76   typedef DenseMap<ConstantInt *, unsigned> ConstCandMapType;
77   typedef std::vector<ConstantCandidate> ConstCandVecType;
78
79   const TargetTransformInfo *TTI;
80   DominatorTree *DT;
81
82   /// Keeps track of constant candidates found in the function.
83   ConstCandMapType ConstCandMap;
84   ConstCandVecType ConstCandVec;
85
86   /// These are the final constants we decided to hoist.
87   SmallVector<ConstantInfo, 4> Constants;
88 public:
89   static char ID; // Pass identification, replacement for typeid
90   ConstantHoisting() : FunctionPass(ID), TTI(0) {
91     initializeConstantHoistingPass(*PassRegistry::getPassRegistry());
92   }
93
94   bool runOnFunction(Function &F) override;
95
96   const char *getPassName() const override { return "Constant Hoisting"; }
97
98   void getAnalysisUsage(AnalysisUsage &AU) const override {
99     AU.setPreservesCFG();
100     AU.addRequired<DominatorTreeWrapperPass>();
101     AU.addRequired<TargetTransformInfo>();
102   }
103
104 private:
105   void CollectConstant(User *U, unsigned Opcode, Intrinsic::ID IID,
106                         ConstantInt *C);
107   void CollectConstants(Instruction *I);
108   void CollectConstants(Function &F);
109   void FindAndMakeBaseConstant(ConstCandVecType::iterator S,
110                                ConstCandVecType::iterator E);
111   void FindBaseConstants();
112   Instruction *FindConstantInsertionPoint(Function &F,
113                                           const ConstantInfo &CI) const;
114   void EmitBaseConstants(Function &F, User *U, Instruction *Base,
115                          Constant *Offset, ConstantInt *OriginalConstant);
116   bool EmitBaseConstants(Function &F);
117   bool OptimizeConstants(Function &F);
118 };
119 }
120
121 char ConstantHoisting::ID = 0;
122 INITIALIZE_PASS_BEGIN(ConstantHoisting, "consthoist", "Constant Hoisting",
123                       false, false)
124 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
125 INITIALIZE_AG_DEPENDENCY(TargetTransformInfo)
126 INITIALIZE_PASS_END(ConstantHoisting, "consthoist", "Constant Hoisting",
127                     false, false)
128
129 FunctionPass *llvm::createConstantHoistingPass() {
130   return new ConstantHoisting();
131 }
132
133 /// \brief Perform the constant hoisting optimization for the given function.
134 bool ConstantHoisting::runOnFunction(Function &F) {
135   DEBUG(dbgs() << "********** Constant Hoisting **********\n");
136   DEBUG(dbgs() << "********** Function: " << F.getName() << '\n');
137
138   DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
139   TTI = &getAnalysis<TargetTransformInfo>();
140
141   return OptimizeConstants(F);
142 }
143
144 void ConstantHoisting::CollectConstant(User * U, unsigned Opcode,
145                                        Intrinsic::ID IID, ConstantInt *C) {
146   unsigned Cost;
147   if (Opcode)
148     Cost = TTI->getIntImmCost(Opcode, C->getValue(), C->getType());
149   else
150     Cost = TTI->getIntImmCost(IID, C->getValue(), C->getType());
151
152   // Ignore cheap integer constants.
153   if (Cost > TargetTransformInfo::TCC_Basic) {
154     ConstCandMapType::iterator Itr;
155     bool Inserted;
156     std::tie(Itr, Inserted) = ConstCandMap.insert(std::make_pair(C, 0));
157     if (Inserted) {
158       ConstCandVec.push_back(ConstantCandidate(C));
159       Itr->second = ConstCandVec.size() - 1;
160     }
161     ConstantCandidate &CC = ConstCandVec[Itr->second];
162     CC.CumulativeCost += Cost;
163     CC.Uses.push_back(U);
164     DEBUG(dbgs() << "Collect constant " << *C << " with cost " << Cost
165                  << " from " << *U << '\n');
166   }
167 }
168
169 /// \brief Scan the instruction or constant expression for expensive integer
170 /// constants and record them in the constant map.
171 void ConstantHoisting::CollectConstants(Instruction *I) {
172   unsigned Opcode = 0;
173   Intrinsic::ID IID = Intrinsic::not_intrinsic;
174   if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I))
175     IID = II->getIntrinsicID();
176   else
177     Opcode = I->getOpcode();
178
179   // Scan all operands.
180   for (User::op_iterator O = I->op_begin(), E = I->op_end(); O != E; ++O) {
181     if (ConstantInt *C = dyn_cast<ConstantInt>(O)) {
182       CollectConstant(I, Opcode, IID, C);
183       continue;
184     }
185     if (ConstantExpr *CE = dyn_cast<ConstantExpr>(O)) {
186       // We only handle constant cast expressions.
187       if (!CE->isCast())
188         continue;
189
190       if (ConstantInt *C = dyn_cast<ConstantInt>(CE->getOperand(0))) {
191         // Ignore the cast expression and use the opcode of the instruction.
192         CollectConstant(CE, Opcode, IID, C);
193         continue;
194       }
195     }
196   }
197 }
198
199 /// \brief Collect all integer constants in the function that cannot be folded
200 /// into an instruction itself.
201 void ConstantHoisting::CollectConstants(Function &F) {
202   for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB)
203     for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I)
204       CollectConstants(I);
205 }
206
207 /// \brief Find the base constant within the given range and rebase all other
208 /// constants with respect to the base constant.
209 void ConstantHoisting::FindAndMakeBaseConstant(ConstCandVecType::iterator S,
210                                                ConstCandVecType::iterator E) {
211   ConstCandVecType::iterator MaxCostItr = S;
212   unsigned NumUses = 0;
213   // Use the constant that has the maximum cost as base constant.
214   for (ConstCandVecType::iterator I = S; I != E; ++I) {
215     NumUses += I->Uses.size();
216     if (I->CumulativeCost > MaxCostItr->CumulativeCost)
217       MaxCostItr = I;
218   }
219
220   // Don't hoist constants that have only one use.
221   if (NumUses <= 1)
222     return;
223
224   ConstantInfo CI;
225   CI.BaseConstant = MaxCostItr->ConstInt;
226   Type *Ty = CI.BaseConstant->getType();
227   // Rebase the constants with respect to the base constant.
228   for (ConstCandVecType::iterator I = S; I != E; ++I) {
229     APInt Diff = I->ConstInt->getValue() - CI.BaseConstant->getValue();
230     ConstantInfo::RebasedConstantInfo RCI;
231     RCI.OriginalConstant = I->ConstInt;
232     RCI.Offset = ConstantInt::get(Ty, Diff);
233     RCI.Uses = std::move(I->Uses);
234     CI.RebasedConstants.push_back(RCI);
235   }
236   Constants.push_back(CI);
237 }
238
239 /// \brief Finds and combines constants that can be easily rematerialized with
240 /// an add from a common base constant.
241 void ConstantHoisting::FindBaseConstants() {
242   // Sort the constants by value and type. This invalidates the mapping.
243   std::sort(ConstCandVec.begin(), ConstCandVec.end(),
244             [](const ConstantCandidate &LHS, const ConstantCandidate &RHS) {
245     if (LHS.ConstInt->getType() != RHS.ConstInt->getType())
246       return LHS.ConstInt->getType()->getBitWidth() <
247              RHS.ConstInt->getType()->getBitWidth();
248     return LHS.ConstInt->getValue().ult(RHS.ConstInt->getValue());
249   });
250
251   // Simple linear scan through the sorted constant map for viable merge
252   // candidates.
253   ConstCandVecType::iterator MinValItr = ConstCandVec.begin();
254   for (ConstCandVecType::iterator I = std::next(ConstCandVec.begin()),
255        E = ConstCandVec.end(); I != E; ++I) {
256     if (MinValItr->ConstInt->getType() == I->ConstInt->getType()) {
257       // Check if the constant is in range of an add with immediate.
258       APInt Diff = I->ConstInt->getValue() - MinValItr->ConstInt->getValue();
259       if ((Diff.getBitWidth() <= 64) &&
260           TTI->isLegalAddImmediate(Diff.getSExtValue()))
261         continue;
262     }
263     // We either have now a different constant type or the constant is not in
264     // range of an add with immediate anymore.
265     FindAndMakeBaseConstant(MinValItr, I);
266     // Start a new base constant search.
267     MinValItr = I;
268   }
269   // Finalize the last base constant search.
270   FindAndMakeBaseConstant(MinValItr, ConstCandVec.end());
271 }
272
273 /// \brief Records the basic block of the instruction or all basic blocks of the
274 /// users of the constant expression.
275 static void CollectBasicBlocks(SmallPtrSet<BasicBlock *, 4> &BBs, Function &F,
276                                User *U) {
277   if (Instruction *I = dyn_cast<Instruction>(U))
278     BBs.insert(I->getParent());
279   else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U))
280     // Find all users of this constant expression.
281     for (User *UU : CE->users())
282       // Only record users that are instructions. We don't want to go down a
283       // nested constant expression chain. Also check if the instruction is even
284       // in the current function.
285       if (Instruction *I = dyn_cast<Instruction>(UU))
286         if(I->getParent()->getParent() == &F)
287           BBs.insert(I->getParent());
288 }
289
290 /// \brief Find the instruction we should insert the constant materialization
291 /// before.
292 static Instruction *getMatInsertPt(Instruction *I, const DominatorTree *DT) {
293   if (!isa<PHINode>(I) && !isa<LandingPadInst>(I)) // Simple case.
294     return I;
295
296   // We can't insert directly before a phi node or landing pad. Insert before
297   // the terminator of the dominating block.
298   assert(&I->getParent()->getParent()->getEntryBlock() != I->getParent() &&
299          "PHI or landing pad in entry block!");
300   BasicBlock *IDom = DT->getNode(I->getParent())->getIDom()->getBlock();
301   return IDom->getTerminator();
302 }
303
304 /// \brief Find an insertion point that dominates all uses.
305 Instruction *ConstantHoisting::
306 FindConstantInsertionPoint(Function &F, const ConstantInfo &CI) const {
307   BasicBlock *Entry = &F.getEntryBlock();
308
309   // Collect all basic blocks.
310   SmallPtrSet<BasicBlock *, 4> BBs;
311   ConstantInfo::RebasedConstantListType::const_iterator RCI, RCE;
312   for (RCI = CI.RebasedConstants.begin(), RCE = CI.RebasedConstants.end();
313        RCI != RCE; ++RCI)
314     for (SmallVectorImpl<User *>::const_iterator U = RCI->Uses.begin(),
315          E = RCI->Uses.end(); U != E; ++U)
316       CollectBasicBlocks(BBs, F, *U);
317
318   if (BBs.count(Entry))
319     return getMatInsertPt(&Entry->front(), DT);
320
321   while (BBs.size() >= 2) {
322     BasicBlock *BB, *BB1, *BB2;
323     BB1 = *BBs.begin();
324     BB2 = *std::next(BBs.begin());
325     BB = DT->findNearestCommonDominator(BB1, BB2);
326     if (BB == Entry)
327       return getMatInsertPt(&Entry->front(), DT);
328     BBs.erase(BB1);
329     BBs.erase(BB2);
330     BBs.insert(BB);
331   }
332   assert((BBs.size() == 1) && "Expected only one element.");
333   Instruction &FirstInst = (*BBs.begin())->front();
334   return getMatInsertPt(&FirstInst, DT);
335 }
336
337 /// \brief Emit materialization code for all rebased constants and update their
338 /// users.
339 void ConstantHoisting::EmitBaseConstants(Function &F, User *U,
340                                          Instruction *Base, Constant *Offset,
341                                          ConstantInt *OriginalConstant) {
342   if (Instruction *I = dyn_cast<Instruction>(U)) {
343     Instruction *Mat = Base;
344     if (!Offset->isNullValue()) {
345       Mat = BinaryOperator::Create(Instruction::Add, Base, Offset,
346                                    "const_mat", getMatInsertPt(I, DT));
347
348       // Use the same debug location as the instruction we are about to update.
349       Mat->setDebugLoc(I->getDebugLoc());
350
351       DEBUG(dbgs() << "Materialize constant (" << *Base->getOperand(0)
352                    << " + " << *Offset << ") in BB "
353                    << I->getParent()->getName() << '\n' << *Mat << '\n');
354     }
355     DEBUG(dbgs() << "Update: " << *I << '\n');
356     I->replaceUsesOfWith(OriginalConstant, Mat);
357     DEBUG(dbgs() << "To: " << *I << '\n');
358     return;
359   }
360   assert(isa<ConstantExpr>(U) && "Expected a ConstantExpr.");
361   ConstantExpr *CE = cast<ConstantExpr>(U);
362   SmallVector<std::pair<Instruction *, Instruction *>, 8> WorkList;
363   DEBUG(dbgs() << "Visit ConstantExpr " << *CE << '\n');
364   for (User *UU : CE->users()) {
365     DEBUG(dbgs() << "Check user "; UU->print(dbgs()); dbgs() << '\n');
366     // We only handel instructions here and won't walk down a ConstantExpr chain
367     // to replace all ConstExpr with instructions.
368     if (Instruction *I = dyn_cast<Instruction>(UU)) {
369       // Only update constant expressions in the current function.
370       if (I->getParent()->getParent() != &F) {
371         DEBUG(dbgs() << "Not in the same function - skip.\n");
372         continue;
373       }
374
375       Instruction *Mat = Base;
376       Instruction *InsertBefore = getMatInsertPt(I, DT);
377       if (!Offset->isNullValue()) {
378         Mat = BinaryOperator::Create(Instruction::Add, Base, Offset,
379                                      "const_mat", InsertBefore);
380
381         // Use the same debug location as the instruction we are about to
382         // update.
383         Mat->setDebugLoc(I->getDebugLoc());
384
385         DEBUG(dbgs() << "Materialize constant (" << *Base->getOperand(0)
386                      << " + " << *Offset << ") in BB "
387                      << I->getParent()->getName() << '\n' << *Mat << '\n');
388       }
389       Instruction *ICE = CE->getAsInstruction();
390       ICE->replaceUsesOfWith(OriginalConstant, Mat);
391       ICE->insertBefore(InsertBefore);
392
393       // Use the same debug location as the instruction we are about to update.
394       ICE->setDebugLoc(I->getDebugLoc());
395
396       WorkList.push_back(std::make_pair(I, ICE));
397     } else {
398       DEBUG(dbgs() << "Not an instruction - skip.\n");
399     }
400   }
401   SmallVectorImpl<std::pair<Instruction *, Instruction *> >::iterator I, E;
402   for (I = WorkList.begin(), E = WorkList.end(); I != E; ++I) {
403     DEBUG(dbgs() << "Create instruction: " << *I->second << '\n');
404     DEBUG(dbgs() << "Update: " << *I->first << '\n');
405     I->first->replaceUsesOfWith(CE, I->second);
406     DEBUG(dbgs() << "To: " << *I->first << '\n');
407   }
408 }
409
410 /// \brief Hoist and hide the base constant behind a bitcast and emit
411 /// materialization code for derived constants.
412 bool ConstantHoisting::EmitBaseConstants(Function &F) {
413   bool MadeChange = false;
414   SmallVectorImpl<ConstantInfo>::iterator CI, CE;
415   for (CI = Constants.begin(), CE = Constants.end(); CI != CE; ++CI) {
416     // Hoist and hide the base constant behind a bitcast.
417     Instruction *IP = FindConstantInsertionPoint(F, *CI);
418     IntegerType *Ty = CI->BaseConstant->getType();
419     Instruction *Base = new BitCastInst(CI->BaseConstant, Ty, "const", IP);
420     DEBUG(dbgs() << "Hoist constant (" << *CI->BaseConstant << ") to BB "
421                  << IP->getParent()->getName() << '\n');
422     NumConstantsHoisted++;
423
424     // Emit materialization code for all rebased constants.
425     ConstantInfo::RebasedConstantListType::iterator RCI, RCE;
426     for (RCI = CI->RebasedConstants.begin(), RCE = CI->RebasedConstants.end();
427          RCI != RCE; ++RCI) {
428       NumConstantsRebased++;
429       for (SmallVectorImpl<User *>::iterator U = RCI->Uses.begin(),
430            E = RCI->Uses.end(); U != E; ++U)
431         EmitBaseConstants(F, *U, Base, RCI->Offset, RCI->OriginalConstant);
432     }
433
434     // Use the same debug location as the last user of the constant.
435     assert(!Base->use_empty() && "The use list is empty!?");
436     assert(isa<Instruction>(Base->user_back()) &&
437            "All uses should be instructions.");
438     Base->setDebugLoc(cast<Instruction>(Base->user_back())->getDebugLoc());
439
440     // Correct for base constant, which we counted above too.
441     NumConstantsRebased--;
442     MadeChange = true;
443   }
444   return MadeChange;
445 }
446
447 /// \brief Optimize expensive integer constants in the given function.
448 bool ConstantHoisting::OptimizeConstants(Function &F) {
449   bool MadeChange = false;
450
451   // Collect all constant candidates.
452   CollectConstants(F);
453
454   // There are no constant candidates to worry about.
455   if (ConstCandVec.empty())
456     return false;
457
458   // Combine constants that can be easily materialized with an add from a common
459   // base constant.
460   FindBaseConstants();
461
462   // Finally hoist the base constant and emit materializating code for dependent
463   // constants.
464   MadeChange |= EmitBaseConstants(F);
465
466   ConstCandMap.clear();
467   ConstCandVec.clear();
468   Constants.clear();
469
470   return MadeChange;
471 }