Delete trailing whitespace; NFC
[oota-llvm.git] / lib / Target / AArch64 / AArch64AddressTypePromotion.cpp
1 //===-- AArch64AddressTypePromotion.cpp --- Promote type for addr accesses -==//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass tries to promote the computations use to obtained a sign extended
11 // value used into memory accesses.
12 // E.g.
13 // a = add nsw i32 b, 3
14 // d = sext i32 a to i64
15 // e = getelementptr ..., i64 d
16 //
17 // =>
18 // f = sext i32 b to i64
19 // a = add nsw i64 f, 3
20 // e = getelementptr ..., i64 a
21 //
22 // This is legal to do if the computations are marked with either nsw or nuw
23 // markers.
24 // Moreover, the current heuristic is simple: it does not create new sext
25 // operations, i.e., it gives up when a sext would have forked (e.g., if
26 // a = add i32 b, c, two sexts are required to promote the computation).
27 //
28 // FIXME: This pass may be useful for other targets too.
29 // ===---------------------------------------------------------------------===//
30
31 #include "AArch64.h"
32 #include "llvm/ADT/DenseMap.h"
33 #include "llvm/ADT/SmallPtrSet.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/IR/Constants.h"
36 #include "llvm/IR/Dominators.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/Instructions.h"
39 #include "llvm/IR/Module.h"
40 #include "llvm/IR/Operator.h"
41 #include "llvm/Pass.h"
42 #include "llvm/Support/CommandLine.h"
43 #include "llvm/Support/Debug.h"
44 #include "llvm/Support/raw_ostream.h"
45
46 using namespace llvm;
47
48 #define DEBUG_TYPE "aarch64-type-promotion"
49
50 static cl::opt<bool>
51 EnableAddressTypePromotion("aarch64-type-promotion", cl::Hidden,
52                            cl::desc("Enable the type promotion pass"),
53                            cl::init(true));
54 static cl::opt<bool>
55 EnableMerge("aarch64-type-promotion-merge", cl::Hidden,
56             cl::desc("Enable merging of redundant sexts when one is dominating"
57                      " the other."),
58             cl::init(true));
59
60 #define AARCH64_TYPE_PROMO_NAME "AArch64 Address Type Promotion"
61
62 //===----------------------------------------------------------------------===//
63 //                       AArch64AddressTypePromotion
64 //===----------------------------------------------------------------------===//
65
66 namespace llvm {
67 void initializeAArch64AddressTypePromotionPass(PassRegistry &);
68 }
69
70 namespace {
71 class AArch64AddressTypePromotion : public FunctionPass {
72
73 public:
74   static char ID;
75   AArch64AddressTypePromotion()
76       : FunctionPass(ID), Func(nullptr), ConsideredSExtType(nullptr) {
77     initializeAArch64AddressTypePromotionPass(*PassRegistry::getPassRegistry());
78   }
79
80   const char *getPassName() const override {
81     return AARCH64_TYPE_PROMO_NAME;
82   }
83
84   /// Iterate over the functions and promote the computation of interesting
85   // sext instructions.
86   bool runOnFunction(Function &F) override;
87
88 private:
89   /// The current function.
90   Function *Func;
91   /// Filter out all sexts that does not have this type.
92   /// Currently initialized with Int64Ty.
93   Type *ConsideredSExtType;
94
95   // This transformation requires dominator info.
96   void getAnalysisUsage(AnalysisUsage &AU) const override {
97     AU.setPreservesCFG();
98     AU.addRequired<DominatorTreeWrapperPass>();
99     AU.addPreserved<DominatorTreeWrapperPass>();
100     FunctionPass::getAnalysisUsage(AU);
101   }
102
103   typedef SmallPtrSet<Instruction *, 32> SetOfInstructions;
104   typedef SmallVector<Instruction *, 16> Instructions;
105   typedef DenseMap<Value *, Instructions> ValueToInsts;
106
107   /// Check if it is profitable to move a sext through this instruction.
108   /// Currently, we consider it is profitable if:
109   /// - Inst is used only once (no need to insert truncate).
110   /// - Inst has only one operand that will require a sext operation (we do
111   ///   do not create new sext operation).
112   bool shouldGetThrough(const Instruction *Inst);
113
114   /// Check if it is possible and legal to move a sext through this
115   /// instruction.
116   /// Current heuristic considers that we can get through:
117   /// - Arithmetic operation marked with the nsw or nuw flag.
118   /// - Other sext operation.
119   /// - Truncate operation if it was just dropping sign extended bits.
120   bool canGetThrough(const Instruction *Inst);
121
122   /// Move sext operations through safe to sext instructions.
123   bool propagateSignExtension(Instructions &SExtInsts);
124
125   /// Is this sext should be considered for code motion.
126   /// We look for sext with ConsideredSExtType and uses in at least one
127   // GetElementPtrInst.
128   bool shouldConsiderSExt(const Instruction *SExt) const;
129
130   /// Collect all interesting sext operations, i.e., the ones with the right
131   /// type and used in memory accesses.
132   /// More precisely, a sext instruction is considered as interesting if it
133   /// is used in a "complex" getelementptr or it exits at least another
134   /// sext instruction that sign extended the same initial value.
135   /// A getelementptr is considered as "complex" if it has more than 2
136   // operands.
137   void analyzeSExtension(Instructions &SExtInsts);
138
139   /// Merge redundant sign extension operations in common dominator.
140   void mergeSExts(ValueToInsts &ValToSExtendedUses,
141                   SetOfInstructions &ToRemove);
142 };
143 } // end anonymous namespace.
144
145 char AArch64AddressTypePromotion::ID = 0;
146
147 INITIALIZE_PASS_BEGIN(AArch64AddressTypePromotion, "aarch64-type-promotion",
148                       AARCH64_TYPE_PROMO_NAME, false, false)
149 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
150 INITIALIZE_PASS_END(AArch64AddressTypePromotion, "aarch64-type-promotion",
151                     AARCH64_TYPE_PROMO_NAME, false, false)
152
153 FunctionPass *llvm::createAArch64AddressTypePromotionPass() {
154   return new AArch64AddressTypePromotion();
155 }
156
157 bool AArch64AddressTypePromotion::canGetThrough(const Instruction *Inst) {
158   if (isa<SExtInst>(Inst))
159     return true;
160
161   const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst);
162   if (BinOp && isa<OverflowingBinaryOperator>(BinOp) &&
163       (BinOp->hasNoUnsignedWrap() || BinOp->hasNoSignedWrap()))
164     return true;
165
166   // sext(trunc(sext)) --> sext
167   if (isa<TruncInst>(Inst) && isa<SExtInst>(Inst->getOperand(0))) {
168     const Instruction *Opnd = cast<Instruction>(Inst->getOperand(0));
169     // Check that the truncate just drop sign extended bits.
170     if (Inst->getType()->getIntegerBitWidth() >=
171             Opnd->getOperand(0)->getType()->getIntegerBitWidth() &&
172         Inst->getOperand(0)->getType()->getIntegerBitWidth() <=
173             ConsideredSExtType->getIntegerBitWidth())
174       return true;
175   }
176
177   return false;
178 }
179
180 bool AArch64AddressTypePromotion::shouldGetThrough(const Instruction *Inst) {
181   // If the type of the sext is the same as the considered one, this sext
182   // will become useless.
183   // Otherwise, we will have to do something to preserve the original value,
184   // unless it is used once.
185   if (isa<SExtInst>(Inst) &&
186       (Inst->getType() == ConsideredSExtType || Inst->hasOneUse()))
187     return true;
188
189   // If the Inst is used more that once, we may need to insert truncate
190   // operations and we don't do that at the moment.
191   if (!Inst->hasOneUse())
192     return false;
193
194   // This truncate is used only once, thus if we can get thourgh, it will become
195   // useless.
196   if (isa<TruncInst>(Inst))
197     return true;
198
199   // If both operands are not constant, a new sext will be created here.
200   // Current heuristic is: each step should be profitable.
201   // Therefore we don't allow to increase the number of sext even if it may
202   // be profitable later on.
203   if (isa<BinaryOperator>(Inst) && isa<ConstantInt>(Inst->getOperand(1)))
204     return true;
205
206   return false;
207 }
208
209 static bool shouldSExtOperand(const Instruction *Inst, int OpIdx) {
210   if (isa<SelectInst>(Inst) && OpIdx == 0)
211     return false;
212   return true;
213 }
214
215 bool
216 AArch64AddressTypePromotion::shouldConsiderSExt(const Instruction *SExt) const {
217   if (SExt->getType() != ConsideredSExtType)
218     return false;
219
220   for (const User *U : SExt->users()) {
221     if (isa<GetElementPtrInst>(U))
222       return true;
223   }
224
225   return false;
226 }
227
228 // Input:
229 // - SExtInsts contains all the sext instructions that are used directly in
230 //   GetElementPtrInst, i.e., access to memory.
231 // Algorithm:
232 // - For each sext operation in SExtInsts:
233 //   Let var be the operand of sext.
234 //   while it is profitable (see shouldGetThrough), legal, and safe
235 //   (see canGetThrough) to move sext through var's definition:
236 //   * promote the type of var's definition.
237 //   * fold var into sext uses.
238 //   * move sext above var's definition.
239 //   * update sext operand to use the operand of var that should be sign
240 //     extended (by construction there is only one).
241 //
242 //   E.g.,
243 //   a = ... i32 c, 3
244 //   b = sext i32 a to i64 <- is it legal/safe/profitable to get through 'a'
245 //   ...
246 //   = b
247 // => Yes, update the code
248 //   b = sext i32 c to i64
249 //   a = ... i64 b, 3
250 //   ...
251 //   = a
252 // Iterate on 'c'.
253 bool
254 AArch64AddressTypePromotion::propagateSignExtension(Instructions &SExtInsts) {
255   DEBUG(dbgs() << "*** Propagate Sign Extension ***\n");
256
257   bool LocalChange = false;
258   SetOfInstructions ToRemove;
259   ValueToInsts ValToSExtendedUses;
260   while (!SExtInsts.empty()) {
261     // Get through simple chain.
262     Instruction *SExt = SExtInsts.pop_back_val();
263
264     DEBUG(dbgs() << "Consider:\n" << *SExt << '\n');
265
266     // If this SExt has already been merged continue.
267     if (SExt->use_empty() && ToRemove.count(SExt)) {
268       DEBUG(dbgs() << "No uses => marked as delete\n");
269       continue;
270     }
271
272     // Now try to get through the chain of definitions.
273     while (auto *Inst = dyn_cast<Instruction>(SExt->getOperand(0))) {
274       DEBUG(dbgs() << "Try to get through:\n" << *Inst << '\n');
275       if (!canGetThrough(Inst) || !shouldGetThrough(Inst)) {
276         // We cannot get through something that is not an Instruction
277         // or not safe to SExt.
278         DEBUG(dbgs() << "Cannot get through\n");
279         break;
280       }
281
282       LocalChange = true;
283       // If this is a sign extend, it becomes useless.
284       if (isa<SExtInst>(Inst) || isa<TruncInst>(Inst)) {
285         DEBUG(dbgs() << "SExt or trunc, mark it as to remove\n");
286         // We cannot use replaceAllUsesWith here because we may trigger some
287         // assertion on the type as all involved sext operation may have not
288         // been moved yet.
289         while (!Inst->use_empty()) {
290           Use &U = *Inst->use_begin();
291           Instruction *User = dyn_cast<Instruction>(U.getUser());
292           assert(User && "User of sext is not an Instruction!");
293           User->setOperand(U.getOperandNo(), SExt);
294         }
295         ToRemove.insert(Inst);
296         SExt->setOperand(0, Inst->getOperand(0));
297         SExt->moveBefore(Inst);
298         continue;
299       }
300
301       // Get through the Instruction:
302       // 1. Update its type.
303       // 2. Replace the uses of SExt by Inst.
304       // 3. Sign extend each operand that needs to be sign extended.
305
306       // Step #1.
307       Inst->mutateType(SExt->getType());
308       // Step #2.
309       SExt->replaceAllUsesWith(Inst);
310       // Step #3.
311       Instruction *SExtForOpnd = SExt;
312
313       DEBUG(dbgs() << "Propagate SExt to operands\n");
314       for (int OpIdx = 0, EndOpIdx = Inst->getNumOperands(); OpIdx != EndOpIdx;
315            ++OpIdx) {
316         DEBUG(dbgs() << "Operand:\n" << *(Inst->getOperand(OpIdx)) << '\n');
317         if (Inst->getOperand(OpIdx)->getType() == SExt->getType() ||
318             !shouldSExtOperand(Inst, OpIdx)) {
319           DEBUG(dbgs() << "No need to propagate\n");
320           continue;
321         }
322         // Check if we can statically sign extend the operand.
323         Value *Opnd = Inst->getOperand(OpIdx);
324         if (const ConstantInt *Cst = dyn_cast<ConstantInt>(Opnd)) {
325           DEBUG(dbgs() << "Statically sign extend\n");
326           Inst->setOperand(OpIdx, ConstantInt::getSigned(SExt->getType(),
327                                                          Cst->getSExtValue()));
328           continue;
329         }
330         // UndefValue are typed, so we have to statically sign extend them.
331         if (isa<UndefValue>(Opnd)) {
332           DEBUG(dbgs() << "Statically sign extend\n");
333           Inst->setOperand(OpIdx, UndefValue::get(SExt->getType()));
334           continue;
335         }
336
337         // Otherwise we have to explicity sign extend it.
338         assert(SExtForOpnd &&
339                "Only one operand should have been sign extended");
340
341         SExtForOpnd->setOperand(0, Opnd);
342
343         DEBUG(dbgs() << "Move before:\n" << *Inst << "\nSign extend\n");
344         // Move the sign extension before the insertion point.
345         SExtForOpnd->moveBefore(Inst);
346         Inst->setOperand(OpIdx, SExtForOpnd);
347         // If more sext are required, new instructions will have to be created.
348         SExtForOpnd = nullptr;
349       }
350       if (SExtForOpnd == SExt) {
351         DEBUG(dbgs() << "Sign extension is useless now\n");
352         ToRemove.insert(SExt);
353         break;
354       }
355     }
356
357     // If the use is already of the right type, connect its uses to its argument
358     // and delete it.
359     // This can happen for an Instruction all uses of which are sign extended.
360     if (!ToRemove.count(SExt) &&
361         SExt->getType() == SExt->getOperand(0)->getType()) {
362       DEBUG(dbgs() << "Sign extension is useless, attach its use to "
363                       "its argument\n");
364       SExt->replaceAllUsesWith(SExt->getOperand(0));
365       ToRemove.insert(SExt);
366     } else
367       ValToSExtendedUses[SExt->getOperand(0)].push_back(SExt);
368   }
369
370   if (EnableMerge)
371     mergeSExts(ValToSExtendedUses, ToRemove);
372
373   // Remove all instructions marked as ToRemove.
374   for (Instruction *I: ToRemove)
375     I->eraseFromParent();
376   return LocalChange;
377 }
378
379 void AArch64AddressTypePromotion::mergeSExts(ValueToInsts &ValToSExtendedUses,
380                                              SetOfInstructions &ToRemove) {
381   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
382
383   for (auto &Entry : ValToSExtendedUses) {
384     Instructions &Insts = Entry.second;
385     Instructions CurPts;
386     for (Instruction *Inst : Insts) {
387       if (ToRemove.count(Inst))
388         continue;
389       bool inserted = false;
390       for (auto &Pt : CurPts) {
391         if (DT.dominates(Inst, Pt)) {
392           DEBUG(dbgs() << "Replace all uses of:\n" << *Pt << "\nwith:\n"
393                        << *Inst << '\n');
394           Pt->replaceAllUsesWith(Inst);
395           ToRemove.insert(Pt);
396           Pt = Inst;
397           inserted = true;
398           break;
399         }
400         if (!DT.dominates(Pt, Inst))
401           // Give up if we need to merge in a common dominator as the
402           // expermients show it is not profitable.
403           continue;
404
405         DEBUG(dbgs() << "Replace all uses of:\n" << *Inst << "\nwith:\n"
406                      << *Pt << '\n');
407         Inst->replaceAllUsesWith(Pt);
408         ToRemove.insert(Inst);
409         inserted = true;
410         break;
411       }
412       if (!inserted)
413         CurPts.push_back(Inst);
414     }
415   }
416 }
417
418 void AArch64AddressTypePromotion::analyzeSExtension(Instructions &SExtInsts) {
419   DEBUG(dbgs() << "*** Analyze Sign Extensions ***\n");
420
421   DenseMap<Value *, Instruction *> SeenChains;
422
423   for (auto &BB : *Func) {
424     for (auto &II : BB) {
425       Instruction *SExt = &II;
426
427       // Collect all sext operation per type.
428       if (!isa<SExtInst>(SExt) || !shouldConsiderSExt(SExt))
429         continue;
430
431       DEBUG(dbgs() << "Found:\n" << (*SExt) << '\n');
432
433       // Cases where we actually perform the optimization:
434       // 1. SExt is used in a getelementptr with more than 2 operand =>
435       //    likely we can merge some computation if they are done on 64 bits.
436       // 2. The beginning of the SExt chain is SExt several time. =>
437       //    code sharing is possible.
438
439       bool insert = false;
440       // #1.
441       for (const User *U : SExt->users()) {
442         const Instruction *Inst = dyn_cast<GetElementPtrInst>(U);
443         if (Inst && Inst->getNumOperands() > 2) {
444           DEBUG(dbgs() << "Interesting use in GetElementPtrInst\n" << *Inst
445                        << '\n');
446           insert = true;
447           break;
448         }
449       }
450
451       // #2.
452       // Check the head of the chain.
453       Instruction *Inst = SExt;
454       Value *Last;
455       do {
456         int OpdIdx = 0;
457         const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst);
458         if (BinOp && isa<ConstantInt>(BinOp->getOperand(0)))
459           OpdIdx = 1;
460         Last = Inst->getOperand(OpdIdx);
461         Inst = dyn_cast<Instruction>(Last);
462       } while (Inst && canGetThrough(Inst) && shouldGetThrough(Inst));
463
464       DEBUG(dbgs() << "Head of the chain:\n" << *Last << '\n');
465       DenseMap<Value *, Instruction *>::iterator AlreadySeen =
466           SeenChains.find(Last);
467       if (insert || AlreadySeen != SeenChains.end()) {
468         DEBUG(dbgs() << "Insert\n");
469         SExtInsts.push_back(SExt);
470         if (AlreadySeen != SeenChains.end() && AlreadySeen->second != nullptr) {
471           DEBUG(dbgs() << "Insert chain member\n");
472           SExtInsts.push_back(AlreadySeen->second);
473           SeenChains[Last] = nullptr;
474         }
475       } else {
476         DEBUG(dbgs() << "Record its chain membership\n");
477         SeenChains[Last] = SExt;
478       }
479     }
480   }
481 }
482
483 bool AArch64AddressTypePromotion::runOnFunction(Function &F) {
484   if (!EnableAddressTypePromotion || F.isDeclaration())
485     return false;
486   Func = &F;
487   ConsideredSExtType = Type::getInt64Ty(Func->getContext());
488
489   DEBUG(dbgs() << "*** " << getPassName() << ": " << Func->getName() << '\n');
490
491   Instructions SExtInsts;
492   analyzeSExtension(SExtInsts);
493   return propagateSignExtension(SExtInsts);
494 }