ed758e8d1069c9905c7f62f185dd84e73ccbd4ec
[oota-llvm.git] / lib / Target / NVPTX / NVPTXLowerAggrCopies.cpp
1 //===- NVPTXLowerAggrCopies.cpp - ------------------------------*- C++ -*--===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when
11 // the size is large or is not a compile-time constant.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "NVPTXLowerAggrCopies.h"
16 #include "llvm/CodeGen/MachineFunctionAnalysis.h"
17 #include "llvm/CodeGen/StackProtector.h"
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/DataLayout.h"
20 #include "llvm/IR/Function.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/InstIterator.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicInst.h"
25 #include "llvm/IR/Intrinsics.h"
26 #include "llvm/IR/LLVMContext.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
30
31 #define DEBUG_TYPE "nvptx"
32
33 using namespace llvm;
34
35 namespace {
36
37 // actual analysis class, which is a functionpass
38 struct NVPTXLowerAggrCopies : public FunctionPass {
39   static char ID;
40
41   NVPTXLowerAggrCopies() : FunctionPass(ID) {}
42
43   void getAnalysisUsage(AnalysisUsage &AU) const override {
44     AU.addPreserved<MachineFunctionAnalysis>();
45     AU.addPreserved<StackProtector>();
46   }
47
48   bool runOnFunction(Function &F) override;
49
50   static const unsigned MaxAggrCopySize = 128;
51
52   const char *getPassName() const override {
53     return "Lower aggregate copies/intrinsics into loops";
54   }
55 };
56
57 char NVPTXLowerAggrCopies::ID = 0;
58
59 // Lower memcpy to loop.
60 void convertMemCpyToLoop(Instruction *splitAt, Value *srcAddr, Value *dstAddr,
61                          Value *len, bool srcVolatile, bool dstVolatile,
62                          LLVMContext &Context, Function &F) {
63   Type *indType = len->getType();
64
65   BasicBlock *origBB = splitAt->getParent();
66   BasicBlock *newBB = splitAt->getParent()->splitBasicBlock(splitAt, "split");
67   BasicBlock *loopBB = BasicBlock::Create(Context, "loadstoreloop", &F, newBB);
68
69   origBB->getTerminator()->setSuccessor(0, loopBB);
70   IRBuilder<> builder(origBB, origBB->getTerminator());
71
72   // srcAddr and dstAddr are expected to be pointer types,
73   // so no check is made here.
74   unsigned srcAS = cast<PointerType>(srcAddr->getType())->getAddressSpace();
75   unsigned dstAS = cast<PointerType>(dstAddr->getType())->getAddressSpace();
76
77   // Cast pointers to (char *)
78   srcAddr = builder.CreateBitCast(srcAddr, Type::getInt8PtrTy(Context, srcAS));
79   dstAddr = builder.CreateBitCast(dstAddr, Type::getInt8PtrTy(Context, dstAS));
80
81   IRBuilder<> loop(loopBB);
82   // The loop index (ind) is a phi node.
83   PHINode *ind = loop.CreatePHI(indType, 0);
84   // Incoming value for ind is 0
85   ind->addIncoming(ConstantInt::get(indType, 0), origBB);
86
87   // load from srcAddr+ind
88   // TODO: we can leverage the align parameter of llvm.memcpy for more efficient
89   // word-sized loads and stores.
90   Value *val = loop.CreateLoad(loop.CreateGEP(loop.getInt8Ty(), srcAddr, ind),
91                                srcVolatile);
92   // store at dstAddr+ind
93   loop.CreateStore(val, loop.CreateGEP(loop.getInt8Ty(), dstAddr, ind),
94                    dstVolatile);
95
96   // The value for ind coming from backedge is (ind + 1)
97   Value *newind = loop.CreateAdd(ind, ConstantInt::get(indType, 1));
98   ind->addIncoming(newind, loopBB);
99
100   loop.CreateCondBr(loop.CreateICmpULT(newind, len), loopBB, newBB);
101 }
102
103 // Lower memmove to IR. memmove is required to correctly copy overlapping memory
104 // regions; therefore, it has to check the relative positions of the source and
105 // destination pointers and choose the copy direction accordingly.
106 //
107 // The code below is an IR rendition of this C function:
108 //
109 // void* memmove(void* dst, const void* src, size_t n) {
110 //   unsigned char* d = dst;
111 //   const unsigned char* s = src;
112 //   if (s < d) {
113 //     // copy backwards
114 //     while (n--) {
115 //       d[n] = s[n];
116 //     }
117 //   } else {
118 //     // copy forward
119 //     for (size_t i = 0; i < n; ++i) {
120 //       d[i] = s[i];
121 //     }
122 //   }
123 //   return dst;
124 // }
125 void convertMemMoveToLoop(Instruction *splitAt, Value *srcAddr, Value *dstAddr,
126                           Value *len, bool srcVolatile, bool dstVolatile,
127                           LLVMContext &Context, Function &F) {
128   Type *TypeOfLen = len->getType();
129   BasicBlock *OrigBB = splitAt->getParent();
130
131   // Create the a comparison of src and dst, based on which we jump to either
132   // the forward-copy part of the function (if src >= dst) or the backwards-copy
133   // part (if src < dst).
134   // SplitBlockAndInsertIfThenElse conveniently creates the basic if-then-else
135   // structure. Its block terminators (unconditional branches) are replaced by
136   // the appropriate conditional branches when the loop is built.
137   ICmpInst *PtrCompare = new ICmpInst(splitAt, ICmpInst::ICMP_ULT, srcAddr,
138                                       dstAddr, "compare_src_dst");
139   TerminatorInst *ThenTerm, *ElseTerm;
140   SplitBlockAndInsertIfThenElse(PtrCompare, splitAt, &ThenTerm, &ElseTerm);
141
142   // Each part of the function consists of two blocks:
143   //   copy_backwards:        used to skip the loop when n == 0
144   //   copy_backwards_loop:   the actual backwards loop BB
145   //   copy_forward:          used to skip the loop when n == 0
146   //   copy_forward_loop:     the actual forward loop BB
147   BasicBlock *CopyBackwardsBB = ThenTerm->getParent();
148   CopyBackwardsBB->setName("copy_backwards");
149   BasicBlock *CopyForwardBB = ElseTerm->getParent();
150   CopyForwardBB->setName("copy_forward");
151   BasicBlock *ExitBB = splitAt->getParent();
152   ExitBB->setName("memmove_done");
153
154   // Initial comparison of n == 0 that lets us skip the loops altogether. Shared
155   // between both backwards and forward copy clauses.
156   ICmpInst *CompareN =
157       new ICmpInst(OrigBB->getTerminator(), ICmpInst::ICMP_EQ, len,
158                    ConstantInt::get(TypeOfLen, 0), "compare_n_to_0");
159
160   // Copying backwards.
161   BasicBlock *LoopBB =
162       BasicBlock::Create(Context, "copy_backwards_loop", &F, CopyForwardBB);
163   IRBuilder<> LoopBuilder(LoopBB);
164   PHINode *LoopPhi = LoopBuilder.CreatePHI(TypeOfLen, 0);
165   Value *IndexPtr = LoopBuilder.CreateSub(
166       LoopPhi, ConstantInt::get(TypeOfLen, 1), "index_ptr");
167   Value *Element = LoopBuilder.CreateLoad(
168       LoopBuilder.CreateInBoundsGEP(srcAddr, IndexPtr), "element");
169   LoopBuilder.CreateStore(Element,
170                           LoopBuilder.CreateInBoundsGEP(dstAddr, IndexPtr));
171   LoopBuilder.CreateCondBr(
172       LoopBuilder.CreateICmpEQ(IndexPtr, ConstantInt::get(TypeOfLen, 0)),
173       ExitBB, LoopBB);
174   LoopPhi->addIncoming(IndexPtr, LoopBB);
175   LoopPhi->addIncoming(len, CopyBackwardsBB);
176   BranchInst::Create(ExitBB, LoopBB, CompareN, ThenTerm);
177   ThenTerm->removeFromParent();
178
179   // Copying forward.
180   BasicBlock *FwdLoopBB =
181       BasicBlock::Create(Context, "copy_forward_loop", &F, ExitBB);
182   IRBuilder<> FwdLoopBuilder(FwdLoopBB);
183   PHINode *FwdCopyPhi = FwdLoopBuilder.CreatePHI(TypeOfLen, 0, "index_ptr");
184   Value *FwdElement = FwdLoopBuilder.CreateLoad(
185       FwdLoopBuilder.CreateInBoundsGEP(srcAddr, FwdCopyPhi), "element");
186   FwdLoopBuilder.CreateStore(
187       FwdElement, FwdLoopBuilder.CreateInBoundsGEP(dstAddr, FwdCopyPhi));
188   Value *FwdIndexPtr = FwdLoopBuilder.CreateAdd(
189       FwdCopyPhi, ConstantInt::get(TypeOfLen, 1), "index_increment");
190   FwdLoopBuilder.CreateCondBr(FwdLoopBuilder.CreateICmpEQ(FwdIndexPtr, len),
191                               ExitBB, FwdLoopBB);
192   FwdCopyPhi->addIncoming(FwdIndexPtr, FwdLoopBB);
193   FwdCopyPhi->addIncoming(ConstantInt::get(TypeOfLen, 0), CopyForwardBB);
194
195   BranchInst::Create(ExitBB, FwdLoopBB, CompareN, ElseTerm);
196   ElseTerm->removeFromParent();
197 }
198
199 // Lower memset to loop.
200 void convertMemSetToLoop(Instruction *splitAt, Value *dstAddr, Value *len,
201                          Value *val, LLVMContext &Context, Function &F) {
202   BasicBlock *origBB = splitAt->getParent();
203   BasicBlock *newBB = splitAt->getParent()->splitBasicBlock(splitAt, "split");
204   BasicBlock *loopBB = BasicBlock::Create(Context, "loadstoreloop", &F, newBB);
205
206   origBB->getTerminator()->setSuccessor(0, loopBB);
207   IRBuilder<> builder(origBB, origBB->getTerminator());
208
209   unsigned dstAS = cast<PointerType>(dstAddr->getType())->getAddressSpace();
210
211   // Cast pointer to the type of value getting stored
212   dstAddr =
213       builder.CreateBitCast(dstAddr, PointerType::get(val->getType(), dstAS));
214
215   IRBuilder<> loop(loopBB);
216   PHINode *ind = loop.CreatePHI(len->getType(), 0);
217   ind->addIncoming(ConstantInt::get(len->getType(), 0), origBB);
218
219   loop.CreateStore(val, loop.CreateGEP(val->getType(), dstAddr, ind), false);
220
221   Value *newind = loop.CreateAdd(ind, ConstantInt::get(len->getType(), 1));
222   ind->addIncoming(newind, loopBB);
223
224   loop.CreateCondBr(loop.CreateICmpULT(newind, len), loopBB, newBB);
225 }
226
227 bool NVPTXLowerAggrCopies::runOnFunction(Function &F) {
228   SmallVector<LoadInst *, 4> aggrLoads;
229   SmallVector<MemIntrinsic *, 4> MemCalls;
230
231   const DataLayout &DL = F.getParent()->getDataLayout();
232   LLVMContext &Context = F.getParent()->getContext();
233
234   // Collect all aggregate loads and mem* calls.
235   for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) {
236     for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); II != IE;
237          ++II) {
238       if (LoadInst *load = dyn_cast<LoadInst>(II)) {
239         if (!load->hasOneUse())
240           continue;
241
242         if (DL.getTypeStoreSize(load->getType()) < MaxAggrCopySize)
243           continue;
244
245         User *use = load->user_back();
246         if (StoreInst *store = dyn_cast<StoreInst>(use)) {
247           if (store->getOperand(0) != load)
248             continue;
249           aggrLoads.push_back(load);
250         }
251       } else if (MemIntrinsic *IntrCall = dyn_cast<MemIntrinsic>(II)) {
252         // Convert intrinsic calls with variable size or with constant size
253         // larger than the MaxAggrCopySize threshold.
254         if (ConstantInt *LenCI = dyn_cast<ConstantInt>(IntrCall->getLength())) {
255           if (LenCI->getZExtValue() >= MaxAggrCopySize) {
256             MemCalls.push_back(IntrCall);
257           }
258         } else {
259           MemCalls.push_back(IntrCall);
260         }
261       }
262     }
263   }
264
265   if (aggrLoads.size() == 0 && MemCalls.size() == 0) {
266     return false;
267   }
268
269   //
270   // Do the transformation of an aggr load/copy/set to a loop
271   //
272   for (LoadInst *load : aggrLoads) {
273     StoreInst *store = dyn_cast<StoreInst>(*load->user_begin());
274     Value *srcAddr = load->getOperand(0);
275     Value *dstAddr = store->getOperand(1);
276     unsigned numLoads = DL.getTypeStoreSize(load->getType());
277     Value *len = ConstantInt::get(Type::getInt32Ty(Context), numLoads);
278
279     convertMemCpyToLoop(store, srcAddr, dstAddr, len, load->isVolatile(),
280                         store->isVolatile(), Context, F);
281
282     store->eraseFromParent();
283     load->eraseFromParent();
284   }
285
286   // Transform mem* intrinsic calls.
287   for (MemIntrinsic *MemCall : MemCalls) {
288     if (MemCpyInst *Memcpy = dyn_cast<MemCpyInst>(MemCall)) {
289       convertMemCpyToLoop(/* splitAt */ Memcpy,
290                           /* srcAddr */ Memcpy->getRawSource(),
291                           /* dstAddr */ Memcpy->getRawDest(),
292                           /* len */ Memcpy->getLength(),
293                           /* srcVolatile */ Memcpy->isVolatile(),
294                           /* dstVolatile */ Memcpy->isVolatile(),
295                           /* Context */ Context,
296                           /* Function F */ F);
297     } else if (MemMoveInst *Memmove = dyn_cast<MemMoveInst>(MemCall)) {
298       convertMemMoveToLoop(/* splitAt */ Memmove,
299                            /* srcAddr */ Memmove->getRawSource(),
300                            /* dstAddr */ Memmove->getRawDest(),
301                            /* len */ Memmove->getLength(),
302                            /* srcVolatile */ Memmove->isVolatile(),
303                            /* dstVolatile */ Memmove->isVolatile(),
304                            /* Context */ Context,
305                            /* Function F */ F);
306
307     } else if (MemSetInst *Memset = dyn_cast<MemSetInst>(MemCall)) {
308       convertMemSetToLoop(/* splitAt */ Memset,
309                           /* dstAddr */ Memset->getRawDest(),
310                           /* len */ Memset->getLength(),
311                           /* val */ Memset->getValue(),
312                           /* Context */ Context,
313                           /* F */ F);
314     }
315     MemCall->eraseFromParent();
316   }
317
318   return true;
319 }
320
321 } // namespace
322
323 namespace llvm {
324 void initializeNVPTXLowerAggrCopiesPass(PassRegistry &);
325 }
326
327 INITIALIZE_PASS(NVPTXLowerAggrCopies, "nvptx-lower-aggr-copies",
328                 "Lower aggregate copies, and llvm.mem* intrinsics into loops",
329                 false, false)
330
331 FunctionPass *llvm::createLowerAggrCopies() {
332   return new NVPTXLowerAggrCopies();
333 }