[SCEV] Add and use SCEVConstant::getAPInt; NFCI
[oota-llvm.git] / lib / Transforms / Scalar / AlignmentFromAssumptions.cpp
1 //===----------------------- AlignmentFromAssumptions.cpp -----------------===//
2 //                  Set Load/Store Alignments From Assumptions
3 //
4 //                     The LLVM Compiler Infrastructure
5 //
6 // This file is distributed under the University of Illinois Open Source
7 // License. See LICENSE.TXT for details.
8 //
9 //===----------------------------------------------------------------------===//
10 //
11 // This file implements a ScalarEvolution-based transformation to set
12 // the alignments of load, stores and memory intrinsics based on the truth
13 // expressions of assume intrinsics. The primary motivation is to handle
14 // complex alignment assumptions that apply to vector loads and stores that
15 // appear after vectorization and unrolling.
16 //
17 //===----------------------------------------------------------------------===//
18
19 #define AA_NAME "alignment-from-assumptions"
20 #define DEBUG_TYPE AA_NAME
21 #include "llvm/Transforms/Scalar.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/Statistic.h"
24 #include "llvm/Analysis/AliasAnalysis.h"
25 #include "llvm/Analysis/GlobalsModRef.h"
26 #include "llvm/Analysis/AssumptionCache.h"
27 #include "llvm/Analysis/LoopInfo.h"
28 #include "llvm/Analysis/ScalarEvolution.h"
29 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
30 #include "llvm/Analysis/ValueTracking.h"
31 #include "llvm/IR/Constant.h"
32 #include "llvm/IR/Dominators.h"
33 #include "llvm/IR/Instruction.h"
34 #include "llvm/IR/IntrinsicInst.h"
35 #include "llvm/IR/Intrinsics.h"
36 #include "llvm/IR/Module.h"
37 #include "llvm/Support/Debug.h"
38 #include "llvm/Support/raw_ostream.h"
39 using namespace llvm;
40
41 STATISTIC(NumLoadAlignChanged,
42   "Number of loads changed by alignment assumptions");
43 STATISTIC(NumStoreAlignChanged,
44   "Number of stores changed by alignment assumptions");
45 STATISTIC(NumMemIntAlignChanged,
46   "Number of memory intrinsics changed by alignment assumptions");
47
48 namespace {
49 struct AlignmentFromAssumptions : public FunctionPass {
50   static char ID; // Pass identification, replacement for typeid
51   AlignmentFromAssumptions() : FunctionPass(ID) {
52     initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry());
53   }
54
55   bool runOnFunction(Function &F) override;
56
57   void getAnalysisUsage(AnalysisUsage &AU) const override {
58     AU.addRequired<AssumptionCacheTracker>();
59     AU.addRequired<ScalarEvolutionWrapperPass>();
60     AU.addRequired<DominatorTreeWrapperPass>();
61
62     AU.setPreservesCFG();
63     AU.addPreserved<AAResultsWrapperPass>();
64     AU.addPreserved<GlobalsAAWrapperPass>();
65     AU.addPreserved<LoopInfoWrapperPass>();
66     AU.addPreserved<DominatorTreeWrapperPass>();
67     AU.addPreserved<ScalarEvolutionWrapperPass>();
68   }
69
70   // For memory transfers, we need a common alignment for both the source and
71   // destination. If we have a new alignment for only one operand of a transfer
72   // instruction, save it in these maps.  If we reach the other operand through
73   // another assumption later, then we may change the alignment at that point.
74   DenseMap<MemTransferInst *, unsigned> NewDestAlignments, NewSrcAlignments;
75
76   ScalarEvolution *SE;
77   DominatorTree *DT;
78
79   bool extractAlignmentInfo(CallInst *I, Value *&AAPtr, const SCEV *&AlignSCEV,
80                             const SCEV *&OffSCEV);
81   bool processAssumption(CallInst *I);
82 };
83 }
84
85 char AlignmentFromAssumptions::ID = 0;
86 static const char aip_name[] = "Alignment from assumptions";
87 INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions, AA_NAME,
88                       aip_name, false, false)
89 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
90 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
91 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
92 INITIALIZE_PASS_END(AlignmentFromAssumptions, AA_NAME,
93                     aip_name, false, false)
94
95 FunctionPass *llvm::createAlignmentFromAssumptionsPass() {
96   return new AlignmentFromAssumptions();
97 }
98
99 // Given an expression for the (constant) alignment, AlignSCEV, and an
100 // expression for the displacement between a pointer and the aligned address,
101 // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
102 // to a constant. Using SCEV to compute alignment handles the case where
103 // DiffSCEV is a recurrence with constant start such that the aligned offset
104 // is constant. e.g. {16,+,32} % 32 -> 16.
105 static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV,
106                                     const SCEV *AlignSCEV,
107                                     ScalarEvolution *SE) {
108   // DiffUnits = Diff % int64_t(Alignment)
109   const SCEV *DiffAlignDiv = SE->getUDivExpr(DiffSCEV, AlignSCEV);
110   const SCEV *DiffAlign = SE->getMulExpr(DiffAlignDiv, AlignSCEV);
111   const SCEV *DiffUnitsSCEV = SE->getMinusSCEV(DiffAlign, DiffSCEV);
112
113   DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is " <<
114                   *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n");
115
116   if (const SCEVConstant *ConstDUSCEV =
117       dyn_cast<SCEVConstant>(DiffUnitsSCEV)) {
118     int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue();
119
120     // If the displacement is an exact multiple of the alignment, then the
121     // displaced pointer has the same alignment as the aligned pointer, so
122     // return the alignment value.
123     if (!DiffUnits)
124       return (unsigned)
125         cast<SCEVConstant>(AlignSCEV)->getValue()->getSExtValue();
126
127     // If the displacement is not an exact multiple, but the remainder is a
128     // constant, then return this remainder (but only if it is a power of 2).
129     uint64_t DiffUnitsAbs = std::abs(DiffUnits);
130     if (isPowerOf2_64(DiffUnitsAbs))
131       return (unsigned) DiffUnitsAbs;
132   }
133
134   return 0;
135 }
136
137 // There is an address given by an offset OffSCEV from AASCEV which has an
138 // alignment AlignSCEV. Use that information, if possible, to compute a new
139 // alignment for Ptr.
140 static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
141                                 const SCEV *OffSCEV, Value *Ptr,
142                                 ScalarEvolution *SE) {
143   const SCEV *PtrSCEV = SE->getSCEV(Ptr);
144   const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV);
145
146   // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
147   // sign-extended OffSCEV to i64, so make sure they agree again.
148   DiffSCEV = SE->getNoopOrSignExtend(DiffSCEV, OffSCEV->getType());
149
150   // What we really want to know is the overall offset to the aligned
151   // address. This address is displaced by the provided offset.
152   DiffSCEV = SE->getMinusSCEV(DiffSCEV, OffSCEV);
153
154   DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to " <<
155                   *AlignSCEV << " and offset " << *OffSCEV <<
156                   " using diff " << *DiffSCEV << "\n");
157
158   unsigned NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE);
159   DEBUG(dbgs() << "\tnew alignment: " << NewAlignment << "\n");
160
161   if (NewAlignment) {
162     return NewAlignment;
163   } else if (const SCEVAddRecExpr *DiffARSCEV =
164              dyn_cast<SCEVAddRecExpr>(DiffSCEV)) {
165     // The relative offset to the alignment assumption did not yield a constant,
166     // but we should try harder: if we assume that a is 32-byte aligned, then in
167     // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
168     // 32-byte aligned, but instead alternate between 32 and 16-byte alignment.
169     // As a result, the new alignment will not be a constant, but can still
170     // be improved over the default (of 4) to 16.
171
172     const SCEV *DiffStartSCEV = DiffARSCEV->getStart();
173     const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE);
174
175     DEBUG(dbgs() << "\ttrying start/inc alignment using start " <<
176                     *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n");
177
178     // Now compute the new alignment using the displacement to the value in the
179     // first iteration, and also the alignment using the per-iteration delta.
180     // If these are the same, then use that answer. Otherwise, use the smaller
181     // one, but only if it divides the larger one.
182     NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE);
183     unsigned NewIncAlignment = getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE);
184
185     DEBUG(dbgs() << "\tnew start alignment: " << NewAlignment << "\n");
186     DEBUG(dbgs() << "\tnew inc alignment: " << NewIncAlignment << "\n");
187
188     if (!NewAlignment || !NewIncAlignment) {
189       return 0;
190     } else if (NewAlignment > NewIncAlignment) {
191       if (NewAlignment % NewIncAlignment == 0) {
192         DEBUG(dbgs() << "\tnew start/inc alignment: " <<
193                         NewIncAlignment << "\n");
194         return NewIncAlignment;
195       }
196     } else if (NewIncAlignment > NewAlignment) {
197       if (NewIncAlignment % NewAlignment == 0) {
198         DEBUG(dbgs() << "\tnew start/inc alignment: " <<
199                         NewAlignment << "\n");
200         return NewAlignment;
201       }
202     } else if (NewIncAlignment == NewAlignment) {
203       DEBUG(dbgs() << "\tnew start/inc alignment: " <<
204                       NewAlignment << "\n");
205       return NewAlignment;
206     }
207   }
208
209   return 0;
210 }
211
212 bool AlignmentFromAssumptions::extractAlignmentInfo(CallInst *I,
213                                  Value *&AAPtr, const SCEV *&AlignSCEV,
214                                  const SCEV *&OffSCEV) {
215   // An alignment assume must be a statement about the least-significant
216   // bits of the pointer being zero, possibly with some offset.
217   ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0));
218   if (!ICI)
219     return false;
220
221   // This must be an expression of the form: x & m == 0.
222   if (ICI->getPredicate() != ICmpInst::ICMP_EQ)
223     return false;
224
225   // Swap things around so that the RHS is 0.
226   Value *CmpLHS = ICI->getOperand(0);
227   Value *CmpRHS = ICI->getOperand(1);
228   const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS);
229   const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS);
230   if (CmpLHSSCEV->isZero())
231     std::swap(CmpLHS, CmpRHS);
232   else if (!CmpRHSSCEV->isZero())
233     return false;
234
235   BinaryOperator *CmpBO = dyn_cast<BinaryOperator>(CmpLHS);
236   if (!CmpBO || CmpBO->getOpcode() != Instruction::And)
237     return false;
238
239   // Swap things around so that the right operand of the and is a constant
240   // (the mask); we cannot deal with variable masks.
241   Value *AndLHS = CmpBO->getOperand(0);
242   Value *AndRHS = CmpBO->getOperand(1);
243   const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS);
244   const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS);
245   if (isa<SCEVConstant>(AndLHSSCEV)) {
246     std::swap(AndLHS, AndRHS);
247     std::swap(AndLHSSCEV, AndRHSSCEV);
248   }
249
250   const SCEVConstant *MaskSCEV = dyn_cast<SCEVConstant>(AndRHSSCEV);
251   if (!MaskSCEV)
252     return false;
253
254   // The mask must have some trailing ones (otherwise the condition is
255   // trivial and tells us nothing about the alignment of the left operand).
256   unsigned TrailingOnes = MaskSCEV->getAPInt().countTrailingOnes();
257   if (!TrailingOnes)
258     return false;
259
260   // Cap the alignment at the maximum with which LLVM can deal (and make sure
261   // we don't overflow the shift).
262   uint64_t Alignment;
263   TrailingOnes = std::min(TrailingOnes,
264     unsigned(sizeof(unsigned) * CHAR_BIT - 1));
265   Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment);
266
267   Type *Int64Ty = Type::getInt64Ty(I->getParent()->getParent()->getContext());
268   AlignSCEV = SE->getConstant(Int64Ty, Alignment);
269
270   // The LHS might be a ptrtoint instruction, or it might be the pointer
271   // with an offset.
272   AAPtr = nullptr;
273   OffSCEV = nullptr;
274   if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(AndLHS)) {
275     AAPtr = PToI->getPointerOperand();
276     OffSCEV = SE->getZero(Int64Ty);
277   } else if (const SCEVAddExpr* AndLHSAddSCEV =
278              dyn_cast<SCEVAddExpr>(AndLHSSCEV)) {
279     // Try to find the ptrtoint; subtract it and the rest is the offset.
280     for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(),
281          JE = AndLHSAddSCEV->op_end(); J != JE; ++J)
282       if (const SCEVUnknown *OpUnk = dyn_cast<SCEVUnknown>(*J))
283         if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(OpUnk->getValue())) {
284           AAPtr = PToI->getPointerOperand();
285           OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J);
286           break;
287         }
288   }
289
290   if (!AAPtr)
291     return false;
292
293   // Sign extend the offset to 64 bits (so that it is like all of the other
294   // expressions). 
295   unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits();
296   if (OffSCEVBits < 64)
297     OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty);
298   else if (OffSCEVBits > 64)
299     return false;
300
301   AAPtr = AAPtr->stripPointerCasts();
302   return true;
303 }
304
305 bool AlignmentFromAssumptions::processAssumption(CallInst *ACall) {
306   Value *AAPtr;
307   const SCEV *AlignSCEV, *OffSCEV;
308   if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV))
309     return false;
310
311   const SCEV *AASCEV = SE->getSCEV(AAPtr);
312
313   // Apply the assumption to all other users of the specified pointer.
314   SmallPtrSet<Instruction *, 32> Visited;
315   SmallVector<Instruction*, 16> WorkList;
316   for (User *J : AAPtr->users()) {
317     if (J == ACall)
318       continue;
319
320     if (Instruction *K = dyn_cast<Instruction>(J))
321       if (isValidAssumeForContext(ACall, K, DT))
322         WorkList.push_back(K);
323   }
324
325   while (!WorkList.empty()) {
326     Instruction *J = WorkList.pop_back_val();
327
328     if (LoadInst *LI = dyn_cast<LoadInst>(J)) {
329       unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
330         LI->getPointerOperand(), SE);
331
332       if (NewAlignment > LI->getAlignment()) {
333         LI->setAlignment(NewAlignment);
334         ++NumLoadAlignChanged;
335       }
336     } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {
337       unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
338         SI->getPointerOperand(), SE);
339
340       if (NewAlignment > SI->getAlignment()) {
341         SI->setAlignment(NewAlignment);
342         ++NumStoreAlignChanged;
343       }
344     } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {
345       unsigned NewDestAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
346         MI->getDest(), SE);
347
348       // For memory transfers, we need a common alignment for both the
349       // source and destination. If we have a new alignment for this
350       // instruction, but only for one operand, save it. If we reach the
351       // other operand through another assumption later, then we may
352       // change the alignment at that point.
353       if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
354         unsigned NewSrcAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
355           MTI->getSource(), SE);
356
357         DenseMap<MemTransferInst *, unsigned>::iterator DI =
358           NewDestAlignments.find(MTI);
359         unsigned AltDestAlignment = (DI == NewDestAlignments.end()) ?
360                                     0 : DI->second;
361
362         DenseMap<MemTransferInst *, unsigned>::iterator SI =
363           NewSrcAlignments.find(MTI);
364         unsigned AltSrcAlignment = (SI == NewSrcAlignments.end()) ?
365                                    0 : SI->second;
366
367         DEBUG(dbgs() << "\tmem trans: " << NewDestAlignment << " " <<
368                         AltDestAlignment << " " << NewSrcAlignment <<
369                         " " << AltSrcAlignment << "\n");
370
371         // Of these four alignments, pick the largest possible...
372         unsigned NewAlignment = 0;
373         if (NewDestAlignment <= std::max(NewSrcAlignment, AltSrcAlignment))
374           NewAlignment = std::max(NewAlignment, NewDestAlignment);
375         if (AltDestAlignment <= std::max(NewSrcAlignment, AltSrcAlignment))
376           NewAlignment = std::max(NewAlignment, AltDestAlignment);
377         if (NewSrcAlignment <= std::max(NewDestAlignment, AltDestAlignment))
378           NewAlignment = std::max(NewAlignment, NewSrcAlignment);
379         if (AltSrcAlignment <= std::max(NewDestAlignment, AltDestAlignment))
380           NewAlignment = std::max(NewAlignment, AltSrcAlignment);
381
382         if (NewAlignment > MI->getAlignment()) {
383           MI->setAlignment(ConstantInt::get(Type::getInt32Ty(
384             MI->getParent()->getContext()), NewAlignment));
385           ++NumMemIntAlignChanged;
386         }
387
388         NewDestAlignments.insert(std::make_pair(MTI, NewDestAlignment));
389         NewSrcAlignments.insert(std::make_pair(MTI, NewSrcAlignment));
390       } else if (NewDestAlignment > MI->getAlignment()) {
391         assert((!isa<MemIntrinsic>(MI) || isa<MemSetInst>(MI)) &&
392                "Unknown memory intrinsic");
393
394         MI->setAlignment(ConstantInt::get(Type::getInt32Ty(
395           MI->getParent()->getContext()), NewDestAlignment));
396         ++NumMemIntAlignChanged;
397       }
398     }
399
400     // Now that we've updated that use of the pointer, look for other uses of
401     // the pointer to update.
402     Visited.insert(J);
403     for (User *UJ : J->users()) {
404       Instruction *K = cast<Instruction>(UJ);
405       if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DT))
406         WorkList.push_back(K);
407     }
408   }
409
410   return true;
411 }
412
413 bool AlignmentFromAssumptions::runOnFunction(Function &F) {
414   bool Changed = false;
415   auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
416   SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
417   DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
418
419   NewDestAlignments.clear();
420   NewSrcAlignments.clear();
421
422   for (auto &AssumeVH : AC.assumptions())
423     if (AssumeVH)
424       Changed |= processAssumption(cast<CallInst>(AssumeVH));
425
426   return Changed;
427 }
428