[C++] Use 'nullptr'. Target edition.
[oota-llvm.git] / lib / Target / ARM64 / ARM64AddressTypePromotion.cpp
1
2 //===-- ARM64AddressTypePromotion.cpp --- Promote type for addr accesses -===//
3 //
4 //                     The LLVM Compiler Infrastructure
5 //
6 // This file is distributed under the University of Illinois Open Source
7 // License. See LICENSE.TXT for details.
8 //
9 //===----------------------------------------------------------------------===//
10 //
11 // This pass tries to promote the computations use to obtained a sign extended
12 // value used into memory accesses.
13 // E.g.
14 // a = add nsw i32 b, 3
15 // d = sext i32 a to i64
16 // e = getelementptr ..., i64 d
17 //
18 // =>
19 // f = sext i32 b to i64
20 // a = add nsw i64 f, 3
21 // e = getelementptr ..., i64 a
22 //
23 // This is legal to do so if the computations are markers with either nsw or nuw
24 // markers.
25 // Moreover, the current heuristic is simple: it does not create new sext
26 // operations, i.e., it gives up when a sext would have forked (e.g., if
27 // a = add i32 b, c, two sexts are required to promote the computation).
28 //
29 // FIXME: This pass may be useful for other targets too.
30 // ===---------------------------------------------------------------------===//
31
32 #include "ARM64.h"
33 #include "llvm/ADT/DenseMap.h"
34 #include "llvm/ADT/SmallPtrSet.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/IR/Constants.h"
37 #include "llvm/IR/Dominators.h"
38 #include "llvm/IR/Function.h"
39 #include "llvm/IR/Instructions.h"
40 #include "llvm/IR/Module.h"
41 #include "llvm/IR/Operator.h"
42 #include "llvm/Pass.h"
43 #include "llvm/Support/CommandLine.h"
44 #include "llvm/Support/Debug.h"
45
46 using namespace llvm;
47
48 #define DEBUG_TYPE "arm64-type-promotion"
49
50 static cl::opt<bool>
51 EnableAddressTypePromotion("arm64-type-promotion", cl::Hidden,
52                            cl::desc("Enable the type promotion pass"),
53                            cl::init(true));
54 static cl::opt<bool>
55 EnableMerge("arm64-type-promotion-merge", cl::Hidden,
56             cl::desc("Enable merging of redundant sexts when one is dominating"
57                      " the other."),
58             cl::init(true));
59
60 //===----------------------------------------------------------------------===//
61 //                       ARM64AddressTypePromotion
62 //===----------------------------------------------------------------------===//
63
64 namespace llvm {
65 void initializeARM64AddressTypePromotionPass(PassRegistry &);
66 }
67
68 namespace {
69 class ARM64AddressTypePromotion : public FunctionPass {
70
71 public:
72   static char ID;
73   ARM64AddressTypePromotion()
74       : FunctionPass(ID), Func(nullptr), ConsideredSExtType(nullptr) {
75     initializeARM64AddressTypePromotionPass(*PassRegistry::getPassRegistry());
76   }
77
78   virtual const char *getPassName() const {
79     return "ARM64 Address Type Promotion";
80   }
81
82   /// Iterate over the functions and promote the computation of interesting
83   // sext instructions.
84   bool runOnFunction(Function &F);
85
86 private:
87   /// The current function.
88   Function *Func;
89   /// Filter out all sexts that does not have this type.
90   /// Currently initialized with Int64Ty.
91   Type *ConsideredSExtType;
92
93   // This transformation requires dominator info.
94   virtual void getAnalysisUsage(AnalysisUsage &AU) const {
95     AU.setPreservesCFG();
96     AU.addRequired<DominatorTreeWrapperPass>();
97     AU.addPreserved<DominatorTreeWrapperPass>();
98     FunctionPass::getAnalysisUsage(AU);
99   }
100
101   typedef SmallPtrSet<Instruction *, 32> SetOfInstructions;
102   typedef SmallVector<Instruction *, 16> Instructions;
103   typedef DenseMap<Value *, Instructions> ValueToInsts;
104
105   /// Check if it is profitable to move a sext through this instruction.
106   /// Currently, we consider it is profitable if:
107   /// - Inst is used only once (no need to insert truncate).
108   /// - Inst has only one operand that will require a sext operation (we do
109   ///   do not create new sext operation).
110   bool shouldGetThrough(const Instruction *Inst);
111
112   /// Check if it is possible and legal to move a sext through this
113   /// instruction.
114   /// Current heuristic considers that we can get through:
115   /// - Arithmetic operation marked with the nsw or nuw flag.
116   /// - Other sext operation.
117   /// - Truncate operation if it was just dropping sign extended bits.
118   bool canGetThrough(const Instruction *Inst);
119
120   /// Move sext operations through safe to sext instructions.
121   bool propagateSignExtension(Instructions &SExtInsts);
122
123   /// Is this sext should be considered for code motion.
124   /// We look for sext with ConsideredSExtType and uses in at least one
125   // GetElementPtrInst.
126   bool shouldConsiderSExt(const Instruction *SExt) const;
127
128   /// Collect all interesting sext operations, i.e., the ones with the right
129   /// type and used in memory accesses.
130   /// More precisely, a sext instruction is considered as interesting if it
131   /// is used in a "complex" getelementptr or it exits at least another
132   /// sext instruction that sign extended the same initial value.
133   /// A getelementptr is considered as "complex" if it has more than 2
134   // operands.
135   void analyzeSExtension(Instructions &SExtInsts);
136
137   /// Merge redundant sign extension operations in common dominator.
138   void mergeSExts(ValueToInsts &ValToSExtendedUses,
139                   SetOfInstructions &ToRemove);
140 };
141 } // end anonymous namespace.
142
143 char ARM64AddressTypePromotion::ID = 0;
144
145 INITIALIZE_PASS_BEGIN(ARM64AddressTypePromotion, "arm64-type-promotion",
146                       "ARM64 Type Promotion Pass", false, false)
147 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
148 INITIALIZE_PASS_END(ARM64AddressTypePromotion, "arm64-type-promotion",
149                     "ARM64 Type Promotion Pass", false, false)
150
151 FunctionPass *llvm::createARM64AddressTypePromotionPass() {
152   return new ARM64AddressTypePromotion();
153 }
154
155 bool ARM64AddressTypePromotion::canGetThrough(const Instruction *Inst) {
156   if (isa<SExtInst>(Inst))
157     return true;
158
159   const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst);
160   if (BinOp && isa<OverflowingBinaryOperator>(BinOp) &&
161       (BinOp->hasNoUnsignedWrap() || BinOp->hasNoSignedWrap()))
162     return true;
163
164   // sext(trunc(sext)) --> sext
165   if (isa<TruncInst>(Inst) && isa<SExtInst>(Inst->getOperand(0))) {
166     const Instruction *Opnd = cast<Instruction>(Inst->getOperand(0));
167     // Check that the truncate just drop sign extended bits.
168     if (Inst->getType()->getIntegerBitWidth() >=
169             Opnd->getOperand(0)->getType()->getIntegerBitWidth() &&
170         Inst->getOperand(0)->getType()->getIntegerBitWidth() <=
171             ConsideredSExtType->getIntegerBitWidth())
172       return true;
173   }
174
175   return false;
176 }
177
178 bool ARM64AddressTypePromotion::shouldGetThrough(const Instruction *Inst) {
179   // If the type of the sext is the same as the considered one, this sext
180   // will become useless.
181   // Otherwise, we will have to do something to preserve the original value,
182   // unless it is used once.
183   if (isa<SExtInst>(Inst) &&
184       (Inst->getType() == ConsideredSExtType || Inst->hasOneUse()))
185     return true;
186
187   // If the Inst is used more that once, we may need to insert truncate
188   // operations and we don't do that at the moment.
189   if (!Inst->hasOneUse())
190     return false;
191
192   // This truncate is used only once, thus if we can get thourgh, it will become
193   // useless.
194   if (isa<TruncInst>(Inst))
195     return true;
196
197   // If both operands are not constant, a new sext will be created here.
198   // Current heuristic is: each step should be profitable.
199   // Therefore we don't allow to increase the number of sext even if it may
200   // be profitable later on.
201   if (isa<BinaryOperator>(Inst) && isa<ConstantInt>(Inst->getOperand(1)))
202     return true;
203
204   return false;
205 }
206
207 static bool shouldSExtOperand(const Instruction *Inst, int OpIdx) {
208   if (isa<SelectInst>(Inst) && OpIdx == 0)
209     return false;
210   return true;
211 }
212
213 bool
214 ARM64AddressTypePromotion::shouldConsiderSExt(const Instruction *SExt) const {
215   if (SExt->getType() != ConsideredSExtType)
216     return false;
217
218   for (const Use &U : SExt->uses()) {
219     if (isa<GetElementPtrInst>(*U))
220       return true;
221   }
222
223   return false;
224 }
225
226 // Input:
227 // - SExtInsts contains all the sext instructions that are use direclty in
228 //   GetElementPtrInst, i.e., access to memory.
229 // Algorithm:
230 // - For each sext operation in SExtInsts:
231 //   Let var be the operand of sext.
232 //   while it is profitable (see shouldGetThrough), legal, and safe
233 //   (see canGetThrough) to move sext through var's definition:
234 //   * promote the type of var's definition.
235 //   * fold var into sext uses.
236 //   * move sext above var's definition.
237 //   * update sext operand to use the operand of var that should be sign
238 //     extended (by construction there is only one).
239 //
240 //   E.g.,
241 //   a = ... i32 c, 3
242 //   b = sext i32 a to i64 <- is it legal/safe/profitable to get through 'a'
243 //   ...
244 //   = b
245 // => Yes, update the code
246 //   b = sext i32 c to i64
247 //   a = ... i64 b, 3
248 //   ...
249 //   = a
250 // Iterate on 'c'.
251 bool
252 ARM64AddressTypePromotion::propagateSignExtension(Instructions &SExtInsts) {
253   DEBUG(dbgs() << "*** Propagate Sign Extension ***\n");
254
255   bool LocalChange = false;
256   SetOfInstructions ToRemove;
257   ValueToInsts ValToSExtendedUses;
258   while (!SExtInsts.empty()) {
259     // Get through simple chain.
260     Instruction *SExt = SExtInsts.pop_back_val();
261
262     DEBUG(dbgs() << "Consider:\n" << *SExt << '\n');
263
264     // If this SExt has already been merged continue.
265     if (SExt->use_empty() && ToRemove.count(SExt)) {
266       DEBUG(dbgs() << "No uses => marked as delete\n");
267       continue;
268     }
269
270     // Now try to get through the chain of definitions.
271     while (isa<Instruction>(SExt->getOperand(0))) {
272       Instruction *Inst = dyn_cast<Instruction>(SExt->getOperand(0));
273       DEBUG(dbgs() << "Try to get through:\n" << *Inst << '\n');
274       if (!canGetThrough(Inst) || !shouldGetThrough(Inst)) {
275         // We cannot get through something that is not an Instruction
276         // or not safe to SExt.
277         DEBUG(dbgs() << "Cannot get through\n");
278         break;
279       }
280
281       LocalChange = true;
282       // If this is a sign extend, it becomes useless.
283       if (isa<SExtInst>(Inst) || isa<TruncInst>(Inst)) {
284         DEBUG(dbgs() << "SExt or trunc, mark it as to remove\n");
285         // We cannot use replaceAllUsesWith here because we may trigger some
286         // assertion on the type as all involved sext operation may have not
287         // been moved yet.
288         while (!Inst->use_empty()) {
289           Value::use_iterator UseIt = Inst->use_begin();
290           Instruction *UseInst = dyn_cast<Instruction>(*UseIt);
291           assert(UseInst && "Use of sext is not an Instruction!");
292           UseInst->setOperand(UseIt->getOperandNo(), SExt);
293         }
294         ToRemove.insert(Inst);
295         SExt->setOperand(0, Inst->getOperand(0));
296         SExt->moveBefore(Inst);
297         continue;
298       }
299
300       // Get through the Instruction:
301       // 1. Update its type.
302       // 2. Replace the uses of SExt by Inst.
303       // 3. Sign extend each operand that needs to be sign extended.
304
305       // Step #1.
306       Inst->mutateType(SExt->getType());
307       // Step #2.
308       SExt->replaceAllUsesWith(Inst);
309       // Step #3.
310       Instruction *SExtForOpnd = SExt;
311
312       DEBUG(dbgs() << "Propagate SExt to operands\n");
313       for (int OpIdx = 0, EndOpIdx = Inst->getNumOperands(); OpIdx != EndOpIdx;
314            ++OpIdx) {
315         DEBUG(dbgs() << "Operand:\n" << *(Inst->getOperand(OpIdx)) << '\n');
316         if (Inst->getOperand(OpIdx)->getType() == SExt->getType() ||
317             !shouldSExtOperand(Inst, OpIdx)) {
318           DEBUG(dbgs() << "No need to propagate\n");
319           continue;
320         }
321         // Check if we can statically sign extend the operand.
322         Value *Opnd = Inst->getOperand(OpIdx);
323         if (const ConstantInt *Cst = dyn_cast<ConstantInt>(Opnd)) {
324           DEBUG(dbgs() << "Statically sign extend\n");
325           Inst->setOperand(OpIdx, ConstantInt::getSigned(SExt->getType(),
326                                                          Cst->getSExtValue()));
327           continue;
328         }
329         // UndefValue are typed, so we have to statically sign extend them.
330         if (isa<UndefValue>(Opnd)) {
331           DEBUG(dbgs() << "Statically sign extend\n");
332           Inst->setOperand(OpIdx, UndefValue::get(SExt->getType()));
333           continue;
334         }
335
336         // Otherwise we have to explicity sign extend it.
337         assert(SExtForOpnd &&
338                "Only one operand should have been sign extended");
339
340         SExtForOpnd->setOperand(0, Opnd);
341
342         DEBUG(dbgs() << "Move before:\n" << *Inst << "\nSign extend\n");
343         // Move the sign extension before the insertion point.
344         SExtForOpnd->moveBefore(Inst);
345         Inst->setOperand(OpIdx, SExtForOpnd);
346         // If more sext are required, new instructions will have to be created.
347         SExtForOpnd = nullptr;
348       }
349       if (SExtForOpnd == SExt) {
350         DEBUG(dbgs() << "Sign extension is useless now\n");
351         ToRemove.insert(SExt);
352         break;
353       }
354     }
355
356     // If the use is already of the right type, connect its uses to its argument
357     // and delete it.
358     // This can happen for an Instruction which all uses are sign extended.
359     if (!ToRemove.count(SExt) &&
360         SExt->getType() == SExt->getOperand(0)->getType()) {
361       DEBUG(dbgs() << "Sign extension is useless, attach its use to "
362                       "its argument\n");
363       SExt->replaceAllUsesWith(SExt->getOperand(0));
364       ToRemove.insert(SExt);
365     } else
366       ValToSExtendedUses[SExt->getOperand(0)].push_back(SExt);
367   }
368
369   if (EnableMerge)
370     mergeSExts(ValToSExtendedUses, ToRemove);
371
372   // Remove all instructions marked as ToRemove.
373   for (Instruction *I: ToRemove)
374     I->eraseFromParent();
375   return LocalChange;
376 }
377
378 void ARM64AddressTypePromotion::mergeSExts(ValueToInsts &ValToSExtendedUses,
379                                            SetOfInstructions &ToRemove) {
380   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
381
382   for (auto &Entry : ValToSExtendedUses) {
383     Instructions &Insts = Entry.second;
384     Instructions CurPts;
385     for (Instruction *Inst : Insts) {
386       if (ToRemove.count(Inst))
387         continue;
388       bool inserted = false;
389       for (auto Pt : CurPts) {
390         if (DT.dominates(Inst, Pt)) {
391           DEBUG(dbgs() << "Replace all uses of:\n" << *Pt << "\nwith:\n"
392                        << *Inst << '\n');
393           (Pt)->replaceAllUsesWith(Inst);
394           ToRemove.insert(Pt);
395           Pt = Inst;
396           inserted = true;
397           break;
398         }
399         if (!DT.dominates(Pt, Inst))
400           // Give up if we need to merge in a common dominator as the
401           // expermients show it is not profitable.
402           continue;
403
404         DEBUG(dbgs() << "Replace all uses of:\n" << *Inst << "\nwith:\n"
405                      << *Pt << '\n');
406         Inst->replaceAllUsesWith(Pt);
407         ToRemove.insert(Inst);
408         inserted = true;
409         break;
410       }
411       if (!inserted)
412         CurPts.push_back(Inst);
413     }
414   }
415 }
416
417 void ARM64AddressTypePromotion::analyzeSExtension(Instructions &SExtInsts) {
418   DEBUG(dbgs() << "*** Analyze Sign Extensions ***\n");
419
420   DenseMap<Value *, Instruction *> SeenChains;
421
422   for (auto &BB : *Func) {
423     for (auto &II : BB) {
424       Instruction *SExt = &II;
425
426       // Collect all sext operation per type.
427       if (!isa<SExtInst>(SExt) || !shouldConsiderSExt(SExt))
428         continue;
429
430       DEBUG(dbgs() << "Found:\n" << (*SExt) << '\n');
431
432       // Cases where we actually perform the optimization:
433       // 1. SExt is used in a getelementptr with more than 2 operand =>
434       //    likely we can merge some computation if they are done on 64 bits.
435       // 2. The beginning of the SExt chain is SExt several time. =>
436       //    code sharing is possible.
437
438       bool insert = false;
439       // #1.
440       for (const Use &U : SExt->uses()) {
441         const Instruction *Inst = dyn_cast<GetElementPtrInst>(U);
442         if (Inst && Inst->getNumOperands() > 2) {
443           DEBUG(dbgs() << "Interesting use in GetElementPtrInst\n" << *Inst
444                        << '\n');
445           insert = true;
446           break;
447         }
448       }
449
450       // #2.
451       // Check the head of the chain.
452       Instruction *Inst = SExt;
453       Value *Last;
454       do {
455         int OpdIdx = 0;
456         const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst);
457         if (BinOp && isa<ConstantInt>(BinOp->getOperand(0)))
458           OpdIdx = 1;
459         Last = Inst->getOperand(OpdIdx);
460         Inst = dyn_cast<Instruction>(Last);
461       } while (Inst && canGetThrough(Inst) && shouldGetThrough(Inst));
462
463       DEBUG(dbgs() << "Head of the chain:\n" << *Last << '\n');
464       DenseMap<Value *, Instruction *>::iterator AlreadySeen =
465           SeenChains.find(Last);
466       if (insert || AlreadySeen != SeenChains.end()) {
467         DEBUG(dbgs() << "Insert\n");
468         SExtInsts.push_back(SExt);
469         if (AlreadySeen != SeenChains.end() && AlreadySeen->second != nullptr) {
470           DEBUG(dbgs() << "Insert chain member\n");
471           SExtInsts.push_back(AlreadySeen->second);
472           SeenChains[Last] = nullptr;
473         }
474       } else {
475         DEBUG(dbgs() << "Record its chain membership\n");
476         SeenChains[Last] = SExt;
477       }
478     }
479   }
480 }
481
482 bool ARM64AddressTypePromotion::runOnFunction(Function &F) {
483   if (!EnableAddressTypePromotion || F.isDeclaration())
484     return false;
485   Func = &F;
486   ConsideredSExtType = Type::getInt64Ty(Func->getContext());
487
488   DEBUG(dbgs() << "*** " << getPassName() << ": " << Func->getName() << '\n');
489
490   Instructions SExtInsts;
491   analyzeSExtension(SExtInsts);
492   return propagateSignExtension(SExtInsts);
493 }