[EarlyCSE] Fix handling of target memory intrinsics for CSE'ing loads.
[oota-llvm.git] / lib / Transforms / Scalar / LoadCombine.cpp
1 //===- LoadCombine.cpp - Combine Adjacent Loads ---------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 /// \file
10 /// This transformation combines adjacent loads.
11 ///
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/Transforms/Scalar.h"
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/AliasAnalysis.h"
18 #include "llvm/Analysis/AliasSetTracker.h"
19 #include "llvm/Analysis/GlobalsModRef.h"
20 #include "llvm/Analysis/TargetFolder.h"
21 #include "llvm/IR/DataLayout.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/Pass.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/MathExtras.h"
29 #include "llvm/Support/raw_ostream.h"
30
31 using namespace llvm;
32
33 #define DEBUG_TYPE "load-combine"
34
35 STATISTIC(NumLoadsAnalyzed, "Number of loads analyzed for combining");
36 STATISTIC(NumLoadsCombined, "Number of loads combined");
37
38 namespace {
39 struct PointerOffsetPair {
40   Value *Pointer;
41   uint64_t Offset;
42 };
43
44 struct LoadPOPPair {
45   LoadPOPPair() = default;
46   LoadPOPPair(LoadInst *L, PointerOffsetPair P, unsigned O)
47       : Load(L), POP(P), InsertOrder(O) {}
48   LoadInst *Load;
49   PointerOffsetPair POP;
50   /// \brief The new load needs to be created before the first load in IR order.
51   unsigned InsertOrder;
52 };
53
54 class LoadCombine : public BasicBlockPass {
55   LLVMContext *C;
56   AliasAnalysis *AA;
57
58 public:
59   LoadCombine() : BasicBlockPass(ID), C(nullptr), AA(nullptr) {
60     initializeLoadCombinePass(*PassRegistry::getPassRegistry());
61   }
62   
63   using llvm::Pass::doInitialization;
64   bool doInitialization(Function &) override;
65   bool runOnBasicBlock(BasicBlock &BB) override;
66   void getAnalysisUsage(AnalysisUsage &AU) const override;
67
68   const char *getPassName() const override { return "LoadCombine"; }
69   static char ID;
70
71   typedef IRBuilder<true, TargetFolder> BuilderTy;
72
73 private:
74   BuilderTy *Builder;
75
76   PointerOffsetPair getPointerOffsetPair(LoadInst &);
77   bool combineLoads(DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &);
78   bool aggregateLoads(SmallVectorImpl<LoadPOPPair> &);
79   bool combineLoads(SmallVectorImpl<LoadPOPPair> &);
80 };
81 }
82
83 bool LoadCombine::doInitialization(Function &F) {
84   DEBUG(dbgs() << "LoadCombine function: " << F.getName() << "\n");
85   C = &F.getContext();
86   return true;
87 }
88
89 PointerOffsetPair LoadCombine::getPointerOffsetPair(LoadInst &LI) {
90   PointerOffsetPair POP;
91   POP.Pointer = LI.getPointerOperand();
92   POP.Offset = 0;
93   while (isa<BitCastInst>(POP.Pointer) || isa<GetElementPtrInst>(POP.Pointer)) {
94     if (auto *GEP = dyn_cast<GetElementPtrInst>(POP.Pointer)) {
95       auto &DL = LI.getModule()->getDataLayout();
96       unsigned BitWidth = DL.getPointerTypeSizeInBits(GEP->getType());
97       APInt Offset(BitWidth, 0);
98       if (GEP->accumulateConstantOffset(DL, Offset))
99         POP.Offset += Offset.getZExtValue();
100       else
101         // Can't handle GEPs with variable indices.
102         return POP;
103       POP.Pointer = GEP->getPointerOperand();
104     } else if (auto *BC = dyn_cast<BitCastInst>(POP.Pointer))
105       POP.Pointer = BC->getOperand(0);
106   }
107   return POP;
108 }
109
110 bool LoadCombine::combineLoads(
111     DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> &LoadMap) {
112   bool Combined = false;
113   for (auto &Loads : LoadMap) {
114     if (Loads.second.size() < 2)
115       continue;
116     std::sort(Loads.second.begin(), Loads.second.end(),
117               [](const LoadPOPPair &A, const LoadPOPPair &B) {
118       return A.POP.Offset < B.POP.Offset;
119     });
120     if (aggregateLoads(Loads.second))
121       Combined = true;
122   }
123   return Combined;
124 }
125
126 /// \brief Try to aggregate loads from a sorted list of loads to be combined.
127 ///
128 /// It is guaranteed that no writes occur between any of the loads. All loads
129 /// have the same base pointer. There are at least two loads.
130 bool LoadCombine::aggregateLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
131   assert(Loads.size() >= 2 && "Insufficient loads!");
132   LoadInst *BaseLoad = nullptr;
133   SmallVector<LoadPOPPair, 8> AggregateLoads;
134   bool Combined = false;
135   uint64_t PrevOffset = -1ull;
136   uint64_t PrevSize = 0;
137   for (auto &L : Loads) {
138     if (PrevOffset == -1ull) {
139       BaseLoad = L.Load;
140       PrevOffset = L.POP.Offset;
141       PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize(
142           L.Load->getType());
143       AggregateLoads.push_back(L);
144       continue;
145     }
146     if (L.Load->getAlignment() > BaseLoad->getAlignment())
147       continue;
148     if (L.POP.Offset > PrevOffset + PrevSize) {
149       // No other load will be combinable
150       if (combineLoads(AggregateLoads))
151         Combined = true;
152       AggregateLoads.clear();
153       PrevOffset = -1;
154       continue;
155     }
156     if (L.POP.Offset != PrevOffset + PrevSize)
157       // This load is offset less than the size of the last load.
158       // FIXME: We may want to handle this case.
159       continue;
160     PrevOffset = L.POP.Offset;
161     PrevSize = L.Load->getModule()->getDataLayout().getTypeStoreSize(
162         L.Load->getType());
163     AggregateLoads.push_back(L);
164   }
165   if (combineLoads(AggregateLoads))
166     Combined = true;
167   return Combined;
168 }
169
170 /// \brief Given a list of combinable load. Combine the maximum number of them.
171 bool LoadCombine::combineLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
172   // Remove loads from the end while the size is not a power of 2.
173   unsigned TotalSize = 0;
174   for (const auto &L : Loads)
175     TotalSize += L.Load->getType()->getPrimitiveSizeInBits();
176   while (TotalSize != 0 && !isPowerOf2_32(TotalSize))
177     TotalSize -= Loads.pop_back_val().Load->getType()->getPrimitiveSizeInBits();
178   if (Loads.size() < 2)
179     return false;
180
181   DEBUG({
182     dbgs() << "***** Combining Loads ******\n";
183     for (const auto &L : Loads) {
184       dbgs() << L.POP.Offset << ": " << *L.Load << "\n";
185     }
186   });
187
188   // Find first load. This is where we put the new load.
189   LoadPOPPair FirstLP;
190   FirstLP.InsertOrder = -1u;
191   for (const auto &L : Loads)
192     if (L.InsertOrder < FirstLP.InsertOrder)
193       FirstLP = L;
194
195   unsigned AddressSpace =
196       FirstLP.POP.Pointer->getType()->getPointerAddressSpace();
197
198   Builder->SetInsertPoint(FirstLP.Load);
199   Value *Ptr = Builder->CreateConstGEP1_64(
200       Builder->CreatePointerCast(Loads[0].POP.Pointer,
201                                  Builder->getInt8PtrTy(AddressSpace)),
202       Loads[0].POP.Offset);
203   LoadInst *NewLoad = new LoadInst(
204       Builder->CreatePointerCast(
205           Ptr, PointerType::get(IntegerType::get(Ptr->getContext(), TotalSize),
206                                 Ptr->getType()->getPointerAddressSpace())),
207       Twine(Loads[0].Load->getName()) + ".combined", false,
208       Loads[0].Load->getAlignment(), FirstLP.Load);
209
210   for (const auto &L : Loads) {
211     Builder->SetInsertPoint(L.Load);
212     Value *V = Builder->CreateExtractInteger(
213         L.Load->getModule()->getDataLayout(), NewLoad,
214         cast<IntegerType>(L.Load->getType()),
215         L.POP.Offset - Loads[0].POP.Offset, "combine.extract");
216     L.Load->replaceAllUsesWith(V);
217   }
218
219   NumLoadsCombined = NumLoadsCombined + Loads.size();
220   return true;
221 }
222
223 bool LoadCombine::runOnBasicBlock(BasicBlock &BB) {
224   if (skipOptnoneFunction(BB))
225     return false;
226
227   AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
228
229   IRBuilder<true, TargetFolder> TheBuilder(
230       BB.getContext(), TargetFolder(BB.getModule()->getDataLayout()));
231   Builder = &TheBuilder;
232
233   DenseMap<const Value *, SmallVector<LoadPOPPair, 8>> LoadMap;
234   AliasSetTracker AST(*AA);
235
236   bool Combined = false;
237   unsigned Index = 0;
238   for (auto &I : BB) {
239     if (I.mayThrow() || (I.mayWriteToMemory() && AST.containsUnknown(&I))) {
240       if (combineLoads(LoadMap))
241         Combined = true;
242       LoadMap.clear();
243       AST.clear();
244       continue;
245     }
246     LoadInst *LI = dyn_cast<LoadInst>(&I);
247     if (!LI)
248       continue;
249     ++NumLoadsAnalyzed;
250     if (!LI->isSimple() || !LI->getType()->isIntegerTy())
251       continue;
252     auto POP = getPointerOffsetPair(*LI);
253     if (!POP.Pointer)
254       continue;
255     LoadMap[POP.Pointer].push_back(LoadPOPPair(LI, POP, Index++));
256     AST.add(LI);
257   }
258   if (combineLoads(LoadMap))
259     Combined = true;
260   return Combined;
261 }
262
263 void LoadCombine::getAnalysisUsage(AnalysisUsage &AU) const {
264   AU.setPreservesCFG();
265
266   AU.addRequired<AAResultsWrapperPass>();
267   AU.addPreserved<GlobalsAAWrapperPass>();
268 }
269
270 char LoadCombine::ID = 0;
271
272 BasicBlockPass *llvm::createLoadCombinePass() {
273   return new LoadCombine();
274 }
275
276 INITIALIZE_PASS_BEGIN(LoadCombine, "load-combine", "Combine Adjacent Loads",
277                       false, false)
278 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
279 INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
280 INITIALIZE_PASS_END(LoadCombine, "load-combine", "Combine Adjacent Loads",
281                     false, false)
282