Fix a typo I spotted when hacking on SROA. Somewhat alarming that
[oota-llvm.git] / lib / Transforms / Scalar / LoadCombine.cpp
1 //===- LoadCombine.cpp - Combine Adjacent Loads ---------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 /// \file
10 /// This transformation combines adjacent loads.
11 ///
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/Transforms/Scalar.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/AliasAnalysis.h"
18 #include "llvm/Analysis/AliasSetTracker.h"
19 #include "llvm/Analysis/TargetFolder.h"
20 #include "llvm/IR/DataLayout.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/MathExtras.h"
28 #include "llvm/Support/raw_ostream.h"
29
30 using namespace llvm;
31
32 #define DEBUG_TYPE "load-combine"
33
34 STATISTIC(NumLoadsAnalyzed, "Number of loads analyzed for combining");
35 STATISTIC(NumLoadsCombined, "Number of loads combined");
36
37 namespace {
38 struct PointerOffsetPair {
39   Value *Pointer;
40   uint64_t Offset;
41 };
42
43 struct LoadPOPPair {
44   LoadPOPPair() = default;
45   LoadPOPPair(LoadInst *L, PointerOffsetPair P, unsigned O)
46       : Load(L), POP(P), InsertOrder(O) {}
47   LoadInst *Load;
48   PointerOffsetPair POP;
49   /// \brief The new load needs to be created before the first load in IR order.
50   unsigned InsertOrder;
51 };
52
53 class LoadCombine : public BasicBlockPass {
54   LLVMContext *C;
55   AliasAnalysis *AA;
56
57 public:
58   LoadCombine() : BasicBlockPass(ID), C(nullptr), AA(nullptr) {
59     initializeLoadCombinePass(*PassRegistry::getPassRegistry());
60   }
61   
62   using llvm::Pass::doInitialization;
63   bool doInitialization(Function &) override;
64   bool runOnBasicBlock(BasicBlock &BB) override;
65   void getAnalysisUsage(AnalysisUsage &AU) const override;
66
67   const char *getPassName() const override { return "LoadCombine"; }
68   static char ID;
69
70   typedef IRBuilder<true, TargetFolder> BuilderTy;
71
72 private:
73   BuilderTy *Builder;
74
75   PointerOffsetPair getPointerOffsetPair(LoadInst &);
76   bool combineLoads(DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &);
77   bool aggregateLoads(SmallVectorImpl<LoadPOPPair> &);
78   bool combineLoads(SmallVectorImpl<LoadPOPPair> &);
79 };
80 }
81
82 bool LoadCombine::doInitialization(Function &F) {
83   DEBUG(dbgs() << "LoadCombine function: " << F.getName() << "\n");
84   C = &F.getContext();
85   return true;
86 }
87
88 PointerOffsetPair LoadCombine::getPointerOffsetPair(LoadInst &LI) {
89   PointerOffsetPair POP;
90   POP.Pointer = LI.getPointerOperand();
91   POP.Offset = 0;
92   while (isa<BitCastInst>(POP.Pointer) || isa<GetElementPtrInst>(POP.Pointer)) {
93     if (auto *GEP = dyn_cast<GetElementPtrInst>(POP.Pointer)) {
94       auto &DL = LI.getModule()->getDataLayout();
95       unsigned BitWidth = DL.getPointerTypeSizeInBits(GEP->getType());
96       APInt Offset(BitWidth, 0);
97       if (GEP->accumulateConstantOffset(DL, Offset))
98         POP.Offset += Offset.getZExtValue();
99       else
100         // Can't handle GEPs with variable indices.
101         return POP;
102       POP.Pointer = GEP->getPointerOperand();
103     } else if (auto *BC = dyn_cast<BitCastInst>(POP.Pointer))
104       POP.Pointer = BC->getOperand(0);
105   }
106   return POP;
107 }
108
109 bool LoadCombine::combineLoads(
110     DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &LoadMap) {
111   bool Combined = false;
112   for (auto &Loads : LoadMap) {
113     if (Loads.second.size() < 2)
114       continue;
115     std::sort(Loads.second.begin(), Loads.second.end(),
116               [](const LoadPOPPair &A, const LoadPOPPair &B) {
117       return A.POP.Offset < B.POP.Offset;
118     });
119     if (aggregateLoads(Loads.second))
120       Combined = true;
121   }
122   return Combined;
123 }
124
125 /// \brief Try to aggregate loads from a sorted list of loads to be combined.
126 ///
127 /// It is guaranteed that no writes occur between any of the loads. All loads
128 /// have the same base pointer. There are at least two loads.
129 bool LoadCombine::aggregateLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
130   assert(Loads.size() >= 2 && "Insufficient loads!");
131   LoadInst *BaseLoad = nullptr;
132   SmallVector<LoadPOPPair, 8> AggregateLoads;
133   bool Combined = false;
134   uint64_t PrevOffset = -1ull;
135   uint64_t PrevSize = 0;
136   for (auto &L : Loads) {
137     if (PrevOffset == -1ull) {
138       BaseLoad = L.Load;
139       PrevOffset = L.POP.Offset;
140       PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize(
141           L.Load->getType());
142       AggregateLoads.push_back(L);
143       continue;
144     }
145     if (L.Load->getAlignment() > BaseLoad->getAlignment())
146       continue;
147     if (L.POP.Offset > PrevOffset + PrevSize) {
148       // No other load will be combinable
149       if (combineLoads(AggregateLoads))
150         Combined = true;
151       AggregateLoads.clear();
152       PrevOffset = -1;
153       continue;
154     }
155     if (L.POP.Offset != PrevOffset + PrevSize)
156       // This load is offset less than the size of the last load.
157       // FIXME: We may want to handle this case.
158       continue;
159     PrevOffset = L.POP.Offset;
160     PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize(
161         L.Load->getType());
162     AggregateLoads.push_back(L);
163   }
164   if (combineLoads(AggregateLoads))
165     Combined = true;
166   return Combined;
167 }
168
169 /// \brief Given a list of combinable load. Combine the maximum number of them.
170 bool LoadCombine::combineLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
171   // Remove loads from the end while the size is not a power of 2.
172   unsigned TotalSize = 0;
173   for (const auto &L : Loads)
174     TotalSize += L.Load->getType()->getPrimitiveSizeInBits();
175   while (TotalSize != 0 && !isPowerOf2_32(TotalSize))
176     TotalSize -= Loads.pop_back_val().Load->getType()->getPrimitiveSizeInBits();
177   if (Loads.size() < 2)
178     return false;
179
180   DEBUG({
181     dbgs() << "***** Combining Loads ******\n";
182     for (const auto &L : Loads) {
183       dbgs() << L.POP.Offset << ": " << *L.Load << "\n";
184     }
185   });
186
187   // Find first load. This is where we put the new load.
188   LoadPOPPair FirstLP;
189   FirstLP.InsertOrder = -1u;
190   for (const auto &L : Loads)
191     if (L.InsertOrder < FirstLP.InsertOrder)
192       FirstLP = L;
193
194   unsigned AddressSpace =
195       FirstLP.POP.Pointer->getType()->getPointerAddressSpace();
196
197   Builder->SetInsertPoint(FirstLP.Load);
198   Value *Ptr = Builder->CreateConstGEP1_64(
199       Builder->CreatePointerCast(Loads[0].POP.Pointer,
200                                  Builder->getInt8PtrTy(AddressSpace)),
201       Loads[0].POP.Offset);
202   LoadInst *NewLoad = new LoadInst(
203       Builder->CreatePointerCast(
204           Ptr, PointerType::get(IntegerType::get(Ptr->getContext(), TotalSize),
205                                 Ptr->getType()->getPointerAddressSpace())),
206       Twine(Loads[0].Load->getName()) + ".combined", false,
207       Loads[0].Load->getAlignment(), FirstLP.Load);
208
209   for (const auto &L : Loads) {
210     Builder->SetInsertPoint(L.Load);
211     Value *V = Builder->CreateExtractInteger(
212         L.Load->getModule()->getDataLayout(), NewLoad,
213         cast<IntegerType>(L.Load->getType()),
214         L.POP.Offset - Loads[0].POP.Offset, "combine.extract");
215     L.Load->replaceAllUsesWith(V);
216   }
217
218   NumLoadsCombined = NumLoadsCombined + Loads.size();
219   return true;
220 }
221
222 bool LoadCombine::runOnBasicBlock(BasicBlock &BB) {
223   if (skipOptnoneFunction(BB))
224     return false;
225
226   AA = &getAnalysis<AliasAnalysis>();
227
228   IRBuilder<true, TargetFolder> TheBuilder(
229       BB.getContext(), TargetFolder(BB.getModule()->getDataLayout()));
230   Builder = &TheBuilder;
231
232   DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> LoadMap;
233   AliasSetTracker AST(*AA);
234
235   bool Combined = false;
236   unsigned Index = 0;
237   for (auto &I : BB) {
238     if (I.mayThrow() || (I.mayWriteToMemory() && AST.containsUnknown(&I))) {
239       if (combineLoads(LoadMap))
240         Combined = true;
241       LoadMap.clear();
242       AST.clear();
243       continue;
244     }
245     LoadInst *LI = dyn_cast<LoadInst>(&I);
246     if (!LI)
247       continue;
248     ++NumLoadsAnalyzed;
249     if (!LI->isSimple() || !LI->getType()->isIntegerTy())
250       continue;
251     auto POP = getPointerOffsetPair(*LI);
252     if (!POP.Pointer)
253       continue;
254     LoadMap[POP.Pointer].push_back(LoadPOPPair(LI, POP, Index++));
255     AST.add(LI);
256   }
257   if (combineLoads(LoadMap))
258     Combined = true;
259   return Combined;
260 }
261
262 void LoadCombine::getAnalysisUsage(AnalysisUsage &AU) const {
263   AU.setPreservesCFG();
264
265   AU.addRequired<AliasAnalysis>();
266   AU.addPreserved<AliasAnalysis>();
267 }
268
269 char LoadCombine::ID = 0;
270
271 BasicBlockPass *llvm::createLoadCombinePass() {
272   return new LoadCombine();
273 }
274
275 INITIALIZE_PASS_BEGIN(LoadCombine, "load-combine", "Combine Adjacent Loads",
276                       false, false)
277 INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
278 INITIALIZE_PASS_END(LoadCombine, "load-combine", "Combine Adjacent Loads",
279                     false, false)
280