Remove unnecessary copying or replace it with moves in a bunch of places.
[oota-llvm.git] / lib / Target / R600 / AMDGPUPromoteAlloca.cpp
1 //===-- AMDGPUPromoteAlloca.cpp - Promote Allocas -------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass eliminates allocas by either converting them into vectors or
11 // by migrating them to local address space.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "AMDGPU.h"
16 #include "AMDGPUSubtarget.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/InstVisitor.h"
20 #include "llvm/Support/Debug.h"
21
22 #define DEBUG_TYPE "amdgpu-promote-alloca"
23
24 using namespace llvm;
25
26 namespace {
27
28 class AMDGPUPromoteAlloca : public FunctionPass,
29                        public InstVisitor<AMDGPUPromoteAlloca> {
30
31   static char ID;
32   Module *Mod;
33   const AMDGPUSubtarget &ST;
34   int LocalMemAvailable;
35
36 public:
37   AMDGPUPromoteAlloca(const AMDGPUSubtarget &st) : FunctionPass(ID), ST(st),
38                                                    LocalMemAvailable(0) { }
39   bool doInitialization(Module &M) override;
40   bool runOnFunction(Function &F) override;
41   const char *getPassName() const override { return "AMDGPU Promote Alloca"; }
42   void visitAlloca(AllocaInst &I);
43 };
44
45 } // End anonymous namespace
46
47 char AMDGPUPromoteAlloca::ID = 0;
48
49 bool AMDGPUPromoteAlloca::doInitialization(Module &M) {
50   Mod = &M;
51   return false;
52 }
53
54 bool AMDGPUPromoteAlloca::runOnFunction(Function &F) {
55
56   const FunctionType *FTy = F.getFunctionType();
57
58   LocalMemAvailable = ST.getLocalMemorySize();
59
60
61   // If the function has any arguments in the local address space, then it's
62   // possible these arguments require the entire local memory space, so
63   // we cannot use local memory in the pass.
64   for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) {
65     const Type *ParamTy = FTy->getParamType(i);
66     if (ParamTy->isPointerTy() &&
67         ParamTy->getPointerAddressSpace() == AMDGPUAS::LOCAL_ADDRESS) {
68       LocalMemAvailable = 0;
69       DEBUG(dbgs() << "Function has local memory argument.  Promoting to "
70                       "local memory disabled.\n");
71       break;
72     }
73   }
74
75   if (LocalMemAvailable > 0) {
76     // Check how much local memory is being used by global objects
77     for (Module::global_iterator I = Mod->global_begin(),
78                                  E = Mod->global_end(); I != E; ++I) {
79       GlobalVariable *GV = I;
80       PointerType *GVTy = GV->getType();
81       if (GVTy->getAddressSpace() != AMDGPUAS::LOCAL_ADDRESS)
82         continue;
83       for (Value::use_iterator U = GV->use_begin(),
84                                UE = GV->use_end(); U != UE; ++U) {
85         Instruction *Use = dyn_cast<Instruction>(*U);
86         if (!Use)
87           continue;
88         if (Use->getParent()->getParent() == &F)
89           LocalMemAvailable -=
90               Mod->getDataLayout()->getTypeAllocSize(GVTy->getElementType());
91       }
92     }
93   }
94
95   LocalMemAvailable = std::max(0, LocalMemAvailable);
96   DEBUG(dbgs() << LocalMemAvailable << "bytes free in local memory.\n");
97
98   visit(F);
99
100   return false;
101 }
102
103 static VectorType *arrayTypeToVecType(const Type *ArrayTy) {
104   return VectorType::get(ArrayTy->getArrayElementType(),
105                          ArrayTy->getArrayNumElements());
106 }
107
108 static Value *
109 calculateVectorIndex(Value *Ptr,
110                      const std::map<GetElementPtrInst *, Value *> &GEPIdx) {
111   if (isa<AllocaInst>(Ptr))
112     return Constant::getNullValue(Type::getInt32Ty(Ptr->getContext()));
113
114   GetElementPtrInst *GEP = cast<GetElementPtrInst>(Ptr);
115
116   auto I = GEPIdx.find(GEP);
117   return I == GEPIdx.end() ? nullptr : I->second;
118 }
119
120 static Value* GEPToVectorIndex(GetElementPtrInst *GEP) {
121   // FIXME we only support simple cases
122   if (GEP->getNumOperands() != 3)
123     return NULL;
124
125   ConstantInt *I0 = dyn_cast<ConstantInt>(GEP->getOperand(1));
126   if (!I0 || !I0->isZero())
127     return NULL;
128
129   return GEP->getOperand(2);
130 }
131
132 // Not an instruction handled below to turn into a vector.
133 //
134 // TODO: Check isTriviallyVectorizable for calls and handle other
135 // instructions.
136 static bool canVectorizeInst(Instruction *Inst) {
137   switch (Inst->getOpcode()) {
138   case Instruction::Load:
139   case Instruction::Store:
140   case Instruction::BitCast:
141   case Instruction::AddrSpaceCast:
142     return true;
143   default:
144     return false;
145   }
146 }
147
148 static bool tryPromoteAllocaToVector(AllocaInst *Alloca) {
149   Type *AllocaTy = Alloca->getAllocatedType();
150
151   DEBUG(dbgs() << "Alloca Candidate for vectorization \n");
152
153   // FIXME: There is no reason why we can't support larger arrays, we
154   // are just being conservative for now.
155   if (!AllocaTy->isArrayTy() ||
156       AllocaTy->getArrayElementType()->isVectorTy() ||
157       AllocaTy->getArrayNumElements() > 4) {
158
159     DEBUG(dbgs() << "  Cannot convert type to vector");
160     return false;
161   }
162
163   std::map<GetElementPtrInst*, Value*> GEPVectorIdx;
164   std::vector<Value*> WorkList;
165   for (User *AllocaUser : Alloca->users()) {
166     GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(AllocaUser);
167     if (!GEP) {
168       if (!canVectorizeInst(cast<Instruction>(AllocaUser)))
169         return false;
170
171       WorkList.push_back(AllocaUser);
172       continue;
173     }
174
175     Value *Index = GEPToVectorIndex(GEP);
176
177     // If we can't compute a vector index from this GEP, then we can't
178     // promote this alloca to vector.
179     if (!Index) {
180       DEBUG(dbgs() << "  Cannot compute vector index for GEP " << *GEP << '\n');
181       return false;
182     }
183
184     GEPVectorIdx[GEP] = Index;
185     for (User *GEPUser : AllocaUser->users()) {
186       if (!canVectorizeInst(cast<Instruction>(GEPUser)))
187         return false;
188
189       WorkList.push_back(GEPUser);
190     }
191   }
192
193   VectorType *VectorTy = arrayTypeToVecType(AllocaTy);
194
195   DEBUG(dbgs() << "  Converting alloca to vector "
196         << *AllocaTy << " -> " << *VectorTy << '\n');
197
198   for (std::vector<Value*>::iterator I = WorkList.begin(),
199                                      E = WorkList.end(); I != E; ++I) {
200     Instruction *Inst = cast<Instruction>(*I);
201     IRBuilder<> Builder(Inst);
202     switch (Inst->getOpcode()) {
203     case Instruction::Load: {
204       Value *Ptr = Inst->getOperand(0);
205       Value *Index = calculateVectorIndex(Ptr, GEPVectorIdx);
206       Value *BitCast = Builder.CreateBitCast(Alloca, VectorTy->getPointerTo(0));
207       Value *VecValue = Builder.CreateLoad(BitCast);
208       Value *ExtractElement = Builder.CreateExtractElement(VecValue, Index);
209       Inst->replaceAllUsesWith(ExtractElement);
210       Inst->eraseFromParent();
211       break;
212     }
213     case Instruction::Store: {
214       Value *Ptr = Inst->getOperand(1);
215       Value *Index = calculateVectorIndex(Ptr, GEPVectorIdx);
216       Value *BitCast = Builder.CreateBitCast(Alloca, VectorTy->getPointerTo(0));
217       Value *VecValue = Builder.CreateLoad(BitCast);
218       Value *NewVecValue = Builder.CreateInsertElement(VecValue,
219                                                        Inst->getOperand(0),
220                                                        Index);
221       Builder.CreateStore(NewVecValue, BitCast);
222       Inst->eraseFromParent();
223       break;
224     }
225     case Instruction::BitCast:
226     case Instruction::AddrSpaceCast:
227       break;
228
229     default:
230       Inst->dump();
231       llvm_unreachable("Inconsistency in instructions promotable to vector");
232     }
233   }
234   return true;
235 }
236
237 static void collectUsesWithPtrTypes(Value *Val, std::vector<Value*> &WorkList) {
238   for (User *User : Val->users()) {
239     if(std::find(WorkList.begin(), WorkList.end(), User) != WorkList.end())
240       continue;
241     if (isa<CallInst>(User)) {
242       WorkList.push_back(User);
243       continue;
244     }
245     if (!User->getType()->isPointerTy())
246       continue;
247     WorkList.push_back(User);
248     collectUsesWithPtrTypes(User, WorkList);
249   }
250 }
251
252 void AMDGPUPromoteAlloca::visitAlloca(AllocaInst &I) {
253   IRBuilder<> Builder(&I);
254
255   // First try to replace the alloca with a vector
256   Type *AllocaTy = I.getAllocatedType();
257
258   DEBUG(dbgs() << "Trying to promote " << I << '\n');
259
260   if (tryPromoteAllocaToVector(&I))
261     return;
262
263   DEBUG(dbgs() << " alloca is not a candidate for vectorization.\n");
264
265   // FIXME: This is the maximum work group size.  We should try to get
266   // value from the reqd_work_group_size function attribute if it is
267   // available.
268   unsigned WorkGroupSize = 256;
269   int AllocaSize = WorkGroupSize *
270       Mod->getDataLayout()->getTypeAllocSize(AllocaTy);
271
272   if (AllocaSize > LocalMemAvailable) {
273     DEBUG(dbgs() << " Not enough local memory to promote alloca.\n");
274     return;
275   }
276
277   DEBUG(dbgs() << "Promoting alloca to local memory\n");
278   LocalMemAvailable -= AllocaSize;
279
280   GlobalVariable *GV = new GlobalVariable(
281       *Mod, ArrayType::get(I.getAllocatedType(), 256), false,
282       GlobalValue::ExternalLinkage, 0, I.getName(), 0,
283       GlobalVariable::NotThreadLocal, AMDGPUAS::LOCAL_ADDRESS);
284
285   FunctionType *FTy = FunctionType::get(
286       Type::getInt32Ty(Mod->getContext()), false);
287   AttributeSet AttrSet;
288   AttrSet.addAttribute(Mod->getContext(), 0, Attribute::ReadNone);
289
290   Value *ReadLocalSizeY = Mod->getOrInsertFunction(
291       "llvm.r600.read.local.size.y", FTy, AttrSet);
292   Value *ReadLocalSizeZ = Mod->getOrInsertFunction(
293       "llvm.r600.read.local.size.z", FTy, AttrSet);
294   Value *ReadTIDIGX = Mod->getOrInsertFunction(
295       "llvm.r600.read.tidig.x", FTy, AttrSet);
296   Value *ReadTIDIGY = Mod->getOrInsertFunction(
297       "llvm.r600.read.tidig.y", FTy, AttrSet);
298   Value *ReadTIDIGZ = Mod->getOrInsertFunction(
299       "llvm.r600.read.tidig.z", FTy, AttrSet);
300
301
302   Value *TCntY = Builder.CreateCall(ReadLocalSizeY);
303   Value *TCntZ = Builder.CreateCall(ReadLocalSizeZ);
304   Value *TIdX  = Builder.CreateCall(ReadTIDIGX);
305   Value *TIdY  = Builder.CreateCall(ReadTIDIGY);
306   Value *TIdZ  = Builder.CreateCall(ReadTIDIGZ);
307
308   Value *Tmp0 = Builder.CreateMul(TCntY, TCntZ);
309   Tmp0 = Builder.CreateMul(Tmp0, TIdX);
310   Value *Tmp1 = Builder.CreateMul(TIdY, TCntZ);
311   Value *TID = Builder.CreateAdd(Tmp0, Tmp1);
312   TID = Builder.CreateAdd(TID, TIdZ);
313
314   std::vector<Value*> Indices;
315   Indices.push_back(Constant::getNullValue(Type::getInt32Ty(Mod->getContext())));
316   Indices.push_back(TID);
317
318   Value *Offset = Builder.CreateGEP(GV, Indices);
319   I.mutateType(Offset->getType());
320   I.replaceAllUsesWith(Offset);
321   I.eraseFromParent();
322
323   std::vector<Value*> WorkList;
324
325   collectUsesWithPtrTypes(Offset, WorkList);
326
327   for (std::vector<Value*>::iterator i = WorkList.begin(),
328                                      e = WorkList.end(); i != e; ++i) {
329     Value *V = *i;
330     CallInst *Call = dyn_cast<CallInst>(V);
331     if (!Call) {
332       Type *EltTy = V->getType()->getPointerElementType();
333       PointerType *NewTy = PointerType::get(EltTy, AMDGPUAS::LOCAL_ADDRESS);
334
335       // The operand's value should be corrected on its own.
336       if (isa<AddrSpaceCastInst>(V))
337         continue;
338
339       // FIXME: It doesn't really make sense to try to do this for all
340       // instructions.
341       V->mutateType(NewTy);
342       continue;
343     }
344
345     IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Call);
346     if (!Intr) {
347       std::vector<Type*> ArgTypes;
348       for (unsigned ArgIdx = 0, ArgEnd = Call->getNumArgOperands();
349                                 ArgIdx != ArgEnd; ++ArgIdx) {
350         ArgTypes.push_back(Call->getArgOperand(ArgIdx)->getType());
351       }
352       Function *F = Call->getCalledFunction();
353       FunctionType *NewType = FunctionType::get(Call->getType(), ArgTypes,
354                                                 F->isVarArg());
355       Constant *C = Mod->getOrInsertFunction(StringRef(F->getName().str() + ".local"), NewType,
356                                              F->getAttributes());
357       Function *NewF = cast<Function>(C);
358       Call->setCalledFunction(NewF);
359       continue;
360     }
361
362     Builder.SetInsertPoint(Intr);
363     switch (Intr->getIntrinsicID()) {
364     case Intrinsic::lifetime_start:
365     case Intrinsic::lifetime_end:
366       // These intrinsics are for address space 0 only
367       Intr->eraseFromParent();
368       continue;
369     case Intrinsic::memcpy: {
370       MemCpyInst *MemCpy = cast<MemCpyInst>(Intr);
371       Builder.CreateMemCpy(MemCpy->getRawDest(), MemCpy->getRawSource(),
372                            MemCpy->getLength(), MemCpy->getAlignment(),
373                            MemCpy->isVolatile());
374       Intr->eraseFromParent();
375       continue;
376     }
377     case Intrinsic::memset: {
378       MemSetInst *MemSet = cast<MemSetInst>(Intr);
379       Builder.CreateMemSet(MemSet->getRawDest(), MemSet->getValue(),
380                            MemSet->getLength(), MemSet->getAlignment(),
381                            MemSet->isVolatile());
382       Intr->eraseFromParent();
383       continue;
384     }
385     default:
386       Intr->dump();
387       llvm_unreachable("Don't know how to promote alloca intrinsic use.");
388     }
389   }
390 }
391
392 FunctionPass *llvm::createAMDGPUPromoteAlloca(const AMDGPUSubtarget &ST) {
393   return new AMDGPUPromoteAlloca(ST);
394 }