Add override to overriden virtual methods, remove virtual keywords.
[oota-llvm.git] / lib / Target / R600 / AMDGPUPromoteAlloca.cpp
1 //===-- AMDGPUPromoteAlloca.cpp - Promote Allocas -------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass eliminates allocas by either converting them into vectors or
11 // by migrating them to local address space.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "AMDGPU.h"
16 #include "AMDGPUSubtarget.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/InstVisitor.h"
20 #include "llvm/Support/Debug.h"
21
22 #define DEBUG_TYPE "amdgpu-promote-alloca"
23
24 using namespace llvm;
25
26 namespace {
27
28 class AMDGPUPromoteAlloca : public FunctionPass,
29                        public InstVisitor<AMDGPUPromoteAlloca> {
30
31   static char ID;
32   Module *Mod;
33   const AMDGPUSubtarget &ST;
34   int LocalMemAvailable;
35
36 public:
37   AMDGPUPromoteAlloca(const AMDGPUSubtarget &st) : FunctionPass(ID), ST(st),
38                                                    LocalMemAvailable(0) { }
39   bool doInitialization(Module &M) override;
40   bool runOnFunction(Function &F) override;
41   const char *getPassName() const override { return "AMDGPU Promote Alloca"; }
42   void visitAlloca(AllocaInst &I);
43 };
44
45 } // End anonymous namespace
46
47 char AMDGPUPromoteAlloca::ID = 0;
48
49 bool AMDGPUPromoteAlloca::doInitialization(Module &M) {
50   Mod = &M;
51   return false;
52 }
53
54 bool AMDGPUPromoteAlloca::runOnFunction(Function &F) {
55
56   const FunctionType *FTy = F.getFunctionType();
57
58   LocalMemAvailable = ST.getLocalMemorySize();
59
60
61   // If the function has any arguments in the local address space, then it's
62   // possible these arguments require the entire local memory space, so
63   // we cannot use local memory in the pass.
64   for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) {
65     const Type *ParamTy = FTy->getParamType(i);
66     if (ParamTy->isPointerTy() &&
67         ParamTy->getPointerAddressSpace() == AMDGPUAS::LOCAL_ADDRESS) {
68       LocalMemAvailable = 0;
69       DEBUG(dbgs() << "Function has local memory argument.  Promoting to "
70                       "local memory disabled.\n");
71       break;
72     }
73   }
74
75   if (LocalMemAvailable > 0) {
76     // Check how much local memory is being used by global objects
77     for (Module::global_iterator I = Mod->global_begin(),
78                                  E = Mod->global_end(); I != E; ++I) {
79       GlobalVariable *GV = I;
80       PointerType *GVTy = GV->getType();
81       if (GVTy->getAddressSpace() != AMDGPUAS::LOCAL_ADDRESS)
82         continue;
83       for (Value::use_iterator U = GV->use_begin(),
84                                UE = GV->use_end(); U != UE; ++U) {
85         Instruction *Use = dyn_cast<Instruction>(*U);
86         if (!Use)
87           continue;
88         if (Use->getParent()->getParent() == &F)
89           LocalMemAvailable -=
90               Mod->getDataLayout()->getTypeAllocSize(GVTy->getElementType());
91       }
92     }
93   }
94
95   LocalMemAvailable = std::max(0, LocalMemAvailable);
96   DEBUG(dbgs() << LocalMemAvailable << "bytes free in local memory.\n");
97
98   visit(F);
99
100   return false;
101 }
102
103 static VectorType *arrayTypeToVecType(const Type *ArrayTy) {
104   return VectorType::get(ArrayTy->getArrayElementType(),
105                          ArrayTy->getArrayNumElements());
106 }
107
108 static Value* calculateVectorIndex(Value *Ptr,
109                                   std::map<GetElementPtrInst*, Value*> GEPIdx) {
110   if (isa<AllocaInst>(Ptr))
111     return Constant::getNullValue(Type::getInt32Ty(Ptr->getContext()));
112
113   GetElementPtrInst *GEP = cast<GetElementPtrInst>(Ptr);
114
115   return GEPIdx[GEP];
116 }
117
118 static Value* GEPToVectorIndex(GetElementPtrInst *GEP) {
119   // FIXME we only support simple cases
120   if (GEP->getNumOperands() != 3)
121     return NULL;
122
123   ConstantInt *I0 = dyn_cast<ConstantInt>(GEP->getOperand(1));
124   if (!I0 || !I0->isZero())
125     return NULL;
126
127   return GEP->getOperand(2);
128 }
129
130 // Not an instruction handled below to turn into a vector.
131 //
132 // TODO: Check isTriviallyVectorizable for calls and handle other
133 // instructions.
134 static bool canVectorizeInst(Instruction *Inst) {
135   switch (Inst->getOpcode()) {
136   case Instruction::Load:
137   case Instruction::Store:
138   case Instruction::BitCast:
139   case Instruction::AddrSpaceCast:
140     return true;
141   default:
142     return false;
143   }
144 }
145
146 static bool tryPromoteAllocaToVector(AllocaInst *Alloca) {
147   Type *AllocaTy = Alloca->getAllocatedType();
148
149   DEBUG(dbgs() << "Alloca Candidate for vectorization \n");
150
151   // FIXME: There is no reason why we can't support larger arrays, we
152   // are just being conservative for now.
153   if (!AllocaTy->isArrayTy() ||
154       AllocaTy->getArrayElementType()->isVectorTy() ||
155       AllocaTy->getArrayNumElements() > 4) {
156
157     DEBUG(dbgs() << "  Cannot convert type to vector");
158     return false;
159   }
160
161   std::map<GetElementPtrInst*, Value*> GEPVectorIdx;
162   std::vector<Value*> WorkList;
163   for (User *AllocaUser : Alloca->users()) {
164     GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(AllocaUser);
165     if (!GEP) {
166       if (!canVectorizeInst(cast<Instruction>(AllocaUser)))
167         return false;
168
169       WorkList.push_back(AllocaUser);
170       continue;
171     }
172
173     Value *Index = GEPToVectorIndex(GEP);
174
175     // If we can't compute a vector index from this GEP, then we can't
176     // promote this alloca to vector.
177     if (!Index) {
178       DEBUG(dbgs() << "  Cannot compute vector index for GEP " << *GEP << '\n');
179       return false;
180     }
181
182     GEPVectorIdx[GEP] = Index;
183     for (User *GEPUser : AllocaUser->users()) {
184       if (!canVectorizeInst(cast<Instruction>(GEPUser)))
185         return false;
186
187       WorkList.push_back(GEPUser);
188     }
189   }
190
191   VectorType *VectorTy = arrayTypeToVecType(AllocaTy);
192
193   DEBUG(dbgs() << "  Converting alloca to vector "
194         << *AllocaTy << " -> " << *VectorTy << '\n');
195
196   for (std::vector<Value*>::iterator I = WorkList.begin(),
197                                      E = WorkList.end(); I != E; ++I) {
198     Instruction *Inst = cast<Instruction>(*I);
199     IRBuilder<> Builder(Inst);
200     switch (Inst->getOpcode()) {
201     case Instruction::Load: {
202       Value *Ptr = Inst->getOperand(0);
203       Value *Index = calculateVectorIndex(Ptr, GEPVectorIdx);
204       Value *BitCast = Builder.CreateBitCast(Alloca, VectorTy->getPointerTo(0));
205       Value *VecValue = Builder.CreateLoad(BitCast);
206       Value *ExtractElement = Builder.CreateExtractElement(VecValue, Index);
207       Inst->replaceAllUsesWith(ExtractElement);
208       Inst->eraseFromParent();
209       break;
210     }
211     case Instruction::Store: {
212       Value *Ptr = Inst->getOperand(1);
213       Value *Index = calculateVectorIndex(Ptr, GEPVectorIdx);
214       Value *BitCast = Builder.CreateBitCast(Alloca, VectorTy->getPointerTo(0));
215       Value *VecValue = Builder.CreateLoad(BitCast);
216       Value *NewVecValue = Builder.CreateInsertElement(VecValue,
217                                                        Inst->getOperand(0),
218                                                        Index);
219       Builder.CreateStore(NewVecValue, BitCast);
220       Inst->eraseFromParent();
221       break;
222     }
223     case Instruction::BitCast:
224     case Instruction::AddrSpaceCast:
225       break;
226
227     default:
228       Inst->dump();
229       llvm_unreachable("Inconsistency in instructions promotable to vector");
230     }
231   }
232   return true;
233 }
234
235 static void collectUsesWithPtrTypes(Value *Val, std::vector<Value*> &WorkList) {
236   for (User *User : Val->users()) {
237     if(std::find(WorkList.begin(), WorkList.end(), User) != WorkList.end())
238       continue;
239     if (isa<CallInst>(User)) {
240       WorkList.push_back(User);
241       continue;
242     }
243     if (!User->getType()->isPointerTy())
244       continue;
245     WorkList.push_back(User);
246     collectUsesWithPtrTypes(User, WorkList);
247   }
248 }
249
250 void AMDGPUPromoteAlloca::visitAlloca(AllocaInst &I) {
251   IRBuilder<> Builder(&I);
252
253   // First try to replace the alloca with a vector
254   Type *AllocaTy = I.getAllocatedType();
255
256   DEBUG(dbgs() << "Trying to promote " << I << '\n');
257
258   if (tryPromoteAllocaToVector(&I))
259     return;
260
261   DEBUG(dbgs() << " alloca is not a candidate for vectorization.\n");
262
263   // FIXME: This is the maximum work group size.  We should try to get
264   // value from the reqd_work_group_size function attribute if it is
265   // available.
266   unsigned WorkGroupSize = 256;
267   int AllocaSize = WorkGroupSize *
268       Mod->getDataLayout()->getTypeAllocSize(AllocaTy);
269
270   if (AllocaSize > LocalMemAvailable) {
271     DEBUG(dbgs() << " Not enough local memory to promote alloca.\n");
272     return;
273   }
274
275   DEBUG(dbgs() << "Promoting alloca to local memory\n");
276   LocalMemAvailable -= AllocaSize;
277
278   GlobalVariable *GV = new GlobalVariable(
279       *Mod, ArrayType::get(I.getAllocatedType(), 256), false,
280       GlobalValue::ExternalLinkage, 0, I.getName(), 0,
281       GlobalVariable::NotThreadLocal, AMDGPUAS::LOCAL_ADDRESS);
282
283   FunctionType *FTy = FunctionType::get(
284       Type::getInt32Ty(Mod->getContext()), false);
285   AttributeSet AttrSet;
286   AttrSet.addAttribute(Mod->getContext(), 0, Attribute::ReadNone);
287
288   Value *ReadLocalSizeY = Mod->getOrInsertFunction(
289       "llvm.r600.read.local.size.y", FTy, AttrSet);
290   Value *ReadLocalSizeZ = Mod->getOrInsertFunction(
291       "llvm.r600.read.local.size.z", FTy, AttrSet);
292   Value *ReadTIDIGX = Mod->getOrInsertFunction(
293       "llvm.r600.read.tidig.x", FTy, AttrSet);
294   Value *ReadTIDIGY = Mod->getOrInsertFunction(
295       "llvm.r600.read.tidig.y", FTy, AttrSet);
296   Value *ReadTIDIGZ = Mod->getOrInsertFunction(
297       "llvm.r600.read.tidig.z", FTy, AttrSet);
298
299
300   Value *TCntY = Builder.CreateCall(ReadLocalSizeY);
301   Value *TCntZ = Builder.CreateCall(ReadLocalSizeZ);
302   Value *TIdX  = Builder.CreateCall(ReadTIDIGX);
303   Value *TIdY  = Builder.CreateCall(ReadTIDIGY);
304   Value *TIdZ  = Builder.CreateCall(ReadTIDIGZ);
305
306   Value *Tmp0 = Builder.CreateMul(TCntY, TCntZ);
307   Tmp0 = Builder.CreateMul(Tmp0, TIdX);
308   Value *Tmp1 = Builder.CreateMul(TIdY, TCntZ);
309   Value *TID = Builder.CreateAdd(Tmp0, Tmp1);
310   TID = Builder.CreateAdd(TID, TIdZ);
311
312   std::vector<Value*> Indices;
313   Indices.push_back(Constant::getNullValue(Type::getInt32Ty(Mod->getContext())));
314   Indices.push_back(TID);
315
316   Value *Offset = Builder.CreateGEP(GV, Indices);
317   I.mutateType(Offset->getType());
318   I.replaceAllUsesWith(Offset);
319   I.eraseFromParent();
320
321   std::vector<Value*> WorkList;
322
323   collectUsesWithPtrTypes(Offset, WorkList);
324
325   for (std::vector<Value*>::iterator i = WorkList.begin(),
326                                      e = WorkList.end(); i != e; ++i) {
327     Value *V = *i;
328     CallInst *Call = dyn_cast<CallInst>(V);
329     if (!Call) {
330       Type *EltTy = V->getType()->getPointerElementType();
331       PointerType *NewTy = PointerType::get(EltTy, AMDGPUAS::LOCAL_ADDRESS);
332       V->mutateType(NewTy);
333       continue;
334     }
335
336     IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Call);
337     if (!Intr) {
338       std::vector<Type*> ArgTypes;
339       for (unsigned ArgIdx = 0, ArgEnd = Call->getNumArgOperands();
340                                 ArgIdx != ArgEnd; ++ArgIdx) {
341         ArgTypes.push_back(Call->getArgOperand(ArgIdx)->getType());
342       }
343       Function *F = Call->getCalledFunction();
344       FunctionType *NewType = FunctionType::get(Call->getType(), ArgTypes,
345                                                 F->isVarArg());
346       Constant *C = Mod->getOrInsertFunction(StringRef(F->getName().str() + ".local"), NewType,
347                                              F->getAttributes());
348       Function *NewF = cast<Function>(C);
349       Call->setCalledFunction(NewF);
350       continue;
351     }
352
353     Builder.SetInsertPoint(Intr);
354     switch (Intr->getIntrinsicID()) {
355     case Intrinsic::lifetime_start:
356     case Intrinsic::lifetime_end:
357       // These intrinsics are for address space 0 only
358       Intr->eraseFromParent();
359       continue;
360     case Intrinsic::memcpy: {
361       MemCpyInst *MemCpy = cast<MemCpyInst>(Intr);
362       Builder.CreateMemCpy(MemCpy->getRawDest(), MemCpy->getRawSource(),
363                            MemCpy->getLength(), MemCpy->getAlignment(),
364                            MemCpy->isVolatile());
365       Intr->eraseFromParent();
366       continue;
367     }
368     case Intrinsic::memset: {
369       MemSetInst *MemSet = cast<MemSetInst>(Intr);
370       Builder.CreateMemSet(MemSet->getRawDest(), MemSet->getValue(),
371                            MemSet->getLength(), MemSet->getAlignment(),
372                            MemSet->isVolatile());
373       Intr->eraseFromParent();
374       continue;
375     }
376     default:
377       Intr->dump();
378       llvm_unreachable("Don't know how to promote alloca intrinsic use.");
379     }
380   }
381 }
382
383 FunctionPass *llvm::createAMDGPUPromoteAlloca(const AMDGPUSubtarget &ST) {
384   return new AMDGPUPromoteAlloca(ST);
385 }