R600: Use LDS and vectors for private memory
[oota-llvm.git] / lib / Target / R600 / AMDGPUPromoteAlloca.cpp
1 //===-- AMDGPUPromoteAlloca.cpp - Promote Allocas -------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass eliminates allocas by either converting them into vectors or
11 // by migrating them to local address space.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "AMDGPU.h"
16 #include "AMDGPUSubtarget.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/InstVisitor.h"
20 #include "llvm/Support/Debug.h"
21
22 #define DEBUG_TYPE "amdgpu-promote-alloca"
23
24 using namespace llvm;
25
26 namespace {
27
28 class AMDGPUPromoteAlloca : public FunctionPass,
29                        public InstVisitor<AMDGPUPromoteAlloca> {
30
31   static char ID;
32   Module *Mod;
33   const AMDGPUSubtarget &ST;
34   int LocalMemAvailable;
35
36 public:
37   AMDGPUPromoteAlloca(const AMDGPUSubtarget &st) : FunctionPass(ID), ST(st),
38                                                    LocalMemAvailable(0) { }
39   virtual bool doInitialization(Module &M);
40   virtual bool runOnFunction(Function &F);
41   virtual const char *getPassName() const {
42     return "AMDGPU Promote Alloca";
43   }
44   void visitAlloca(AllocaInst &I);
45 };
46
47 } // End anonymous namespace
48
49 char AMDGPUPromoteAlloca::ID = 0;
50
51 bool AMDGPUPromoteAlloca::doInitialization(Module &M) {
52   Mod = &M;
53   return false;
54 }
55
56 bool AMDGPUPromoteAlloca::runOnFunction(Function &F) {
57
58   const FunctionType *FTy = F.getFunctionType();
59
60   LocalMemAvailable = ST.getLocalMemorySize();
61
62
63   // If the function has any arguments in the local address space, then it's
64   // possible these arguments require the entire local memory space, so
65   // we cannot use local memory in the pass.
66   for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) {
67     const Type *ParamTy = FTy->getParamType(i);
68     if (ParamTy->isPointerTy() &&
69         ParamTy->getPointerAddressSpace() == AMDGPUAS::LOCAL_ADDRESS) {
70       LocalMemAvailable = 0;
71       DEBUG(dbgs() << "Function has local memory argument.  Promoting to "
72                       "local memory disabled.\n");
73       break;
74     }
75   }
76
77   if (LocalMemAvailable > 0) {
78     // Check how much local memory is being used by global objects
79     for (Module::global_iterator I = Mod->global_begin(),
80                                  E = Mod->global_end(); I != E; ++I) {
81       GlobalVariable *GV = I;
82       PointerType *GVTy = GV->getType();
83       if (GVTy->getAddressSpace() != AMDGPUAS::LOCAL_ADDRESS)
84         continue;
85       for (Value::use_iterator U = GV->use_begin(),
86                                UE = GV->use_end(); U != UE; ++U) {
87         Instruction *Use = dyn_cast<Instruction>(*U);
88         if (!Use)
89           continue;
90         if (Use->getParent()->getParent() == &F)
91           LocalMemAvailable -=
92               Mod->getDataLayout()->getTypeAllocSize(GVTy->getElementType());
93       }
94     }
95   }
96
97   LocalMemAvailable = std::max(0, LocalMemAvailable);
98   DEBUG(dbgs() << LocalMemAvailable << "bytes free in local memory.\n");
99
100   visit(F);
101
102   return false;
103 }
104
105 static VectorType *arrayTypeToVecType(const Type *ArrayTy) {
106   return VectorType::get(ArrayTy->getArrayElementType(),
107                          ArrayTy->getArrayNumElements());
108 }
109
110 static Value* calculateVectorIndex(Value *Ptr,
111                                   std::map<GetElementPtrInst*, Value*> GEPIdx) {
112   if (isa<AllocaInst>(Ptr))
113     return Constant::getNullValue(Type::getInt32Ty(Ptr->getContext()));
114
115   GetElementPtrInst *GEP = cast<GetElementPtrInst>(Ptr);
116
117   return GEPIdx[GEP];
118 }
119
120 static Value* GEPToVectorIndex(GetElementPtrInst *GEP) {
121   // FIXME we only support simple cases
122   if (GEP->getNumOperands() != 3)
123     return NULL;
124
125   ConstantInt *I0 = dyn_cast<ConstantInt>(GEP->getOperand(1));
126   if (!I0 || !I0->isZero())
127     return NULL;
128
129   return GEP->getOperand(2);
130 }
131
132 static bool tryPromoteAllocaToVector(AllocaInst *Alloca) {
133   Type *AllocaTy = Alloca->getAllocatedType();
134
135   DEBUG(dbgs() << "Alloca Candidate for vectorization \n");
136
137   // FIXME: There is no reason why we can't support larger arrays, we
138   // are just being conservative for now.
139   if (!AllocaTy->isArrayTy() ||
140       AllocaTy->getArrayElementType()->isVectorTy() ||
141       AllocaTy->getArrayNumElements() > 4) {
142
143     DEBUG(dbgs() << "  Cannot convert type to vector");
144     return false;
145   }
146
147   std::map<GetElementPtrInst*, Value*> GEPVectorIdx;
148   std::vector<Value*> WorkList;
149   for (User *AllocaUser : Alloca->users()) {
150     GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(AllocaUser);
151     if (!GEP) {
152       WorkList.push_back(AllocaUser);
153       continue;
154     }
155
156     Value *Index = GEPToVectorIndex(GEP);
157
158     // If we can't compute a vector index from this GEP, then we can't
159     // promote this alloca to vector.
160     if (!Index) {
161       DEBUG(dbgs() << "  Cannot compute vector index for GEP " << *GEP << "\n");
162       return false;
163     }
164
165     GEPVectorIdx[GEP] = Index;
166     for (User *GEPUser : AllocaUser->users()) {
167       WorkList.push_back(GEPUser);
168     }
169   }
170
171   VectorType *VectorTy = arrayTypeToVecType(AllocaTy);
172
173   DEBUG(dbgs() << "  Converting alloca to vector "; AllocaTy->dump();
174         dbgs() << " -> "; VectorTy->dump(); dbgs() << "\n");
175
176   for (std::vector<Value*>::iterator I = WorkList.begin(),
177                                      E = WorkList.end(); I != E; ++I) {
178     Instruction *Inst = cast<Instruction>(*I);
179     IRBuilder<> Builder(Inst);
180     switch (Inst->getOpcode()) {
181     case Instruction::Load: {
182       Value *Ptr = Inst->getOperand(0);
183       Value *Index = calculateVectorIndex(Ptr, GEPVectorIdx);
184       Value *BitCast = Builder.CreateBitCast(Alloca, VectorTy->getPointerTo(0));
185       Value *VecValue = Builder.CreateLoad(BitCast);
186       Value *ExtractElement = Builder.CreateExtractElement(VecValue, Index);
187       Inst->replaceAllUsesWith(ExtractElement);
188       Inst->eraseFromParent();
189       break;
190     }
191     case Instruction::Store: {
192       Value *Ptr = Inst->getOperand(1);
193       Value *Index = calculateVectorIndex(Ptr, GEPVectorIdx);
194       Value *BitCast = Builder.CreateBitCast(Alloca, VectorTy->getPointerTo(0));
195       Value *VecValue = Builder.CreateLoad(BitCast);
196       Value *NewVecValue = Builder.CreateInsertElement(VecValue,
197                                                        Inst->getOperand(0),
198                                                        Index);
199       Builder.CreateStore(NewVecValue, BitCast);
200       Inst->eraseFromParent();
201       break;
202     }
203     case Instruction::BitCast:
204       break;
205
206     default:
207       Inst->dump();
208       llvm_unreachable("Do not know how to replace this instruction "
209                               "with vector op");
210     }
211   }
212   return true;
213 }
214
215 static void collectUsesWithPtrTypes(Value *Val, std::vector<Value*> &WorkList) {
216   for (User *User : Val->users()) {
217     if(std::find(WorkList.begin(), WorkList.end(), User) != WorkList.end())
218       continue;
219     if (isa<CallInst>(User)) {
220       WorkList.push_back(User);
221       continue;
222     }
223     if (!User->getType()->isPointerTy())
224       continue;
225     WorkList.push_back(User);
226     collectUsesWithPtrTypes(User, WorkList);
227   }
228 }
229
230 void AMDGPUPromoteAlloca::visitAlloca(AllocaInst &I) {
231   IRBuilder<> Builder(&I);
232
233   // First try to replace the alloca with a vector
234   Type *AllocaTy = I.getAllocatedType();
235
236   DEBUG(dbgs() << "Trying to promote " << I);
237
238   if (tryPromoteAllocaToVector(&I))
239     return;
240
241   DEBUG(dbgs() << " alloca is not a candidate for vectorization.\n");
242
243   // FIXME: This is the maximum work group size.  We should try to get
244   // value from the reqd_work_group_size function attribute if it is
245   // available.
246   unsigned WorkGroupSize = 256;
247   int AllocaSize = WorkGroupSize *
248       Mod->getDataLayout()->getTypeAllocSize(AllocaTy);
249
250   if (AllocaSize > LocalMemAvailable) {
251     DEBUG(dbgs() << " Not enough local memory to promote alloca.\n");
252     return;
253   }
254
255   DEBUG(dbgs() << "Promoting alloca to local memory\n");
256   LocalMemAvailable -= AllocaSize;
257
258   GlobalVariable *GV = new GlobalVariable(
259       *Mod, ArrayType::get(I.getAllocatedType(), 256), false,
260       GlobalValue::ExternalLinkage, 0, I.getName(), 0,
261       GlobalVariable::NotThreadLocal, AMDGPUAS::LOCAL_ADDRESS);
262
263   FunctionType *FTy = FunctionType::get(
264       Type::getInt32Ty(Mod->getContext()), false);
265   AttributeSet AttrSet;
266   AttrSet.addAttribute(Mod->getContext(), 0, Attribute::ReadNone);
267
268   Value *ReadLocalSizeY = Mod->getOrInsertFunction(
269       "llvm.r600.read.local.size.y", FTy, AttrSet);
270   Value *ReadLocalSizeZ = Mod->getOrInsertFunction(
271       "llvm.r600.read.local.size.z", FTy, AttrSet);
272   Value *ReadTIDIGX = Mod->getOrInsertFunction(
273       "llvm.r600.read.tidig.x", FTy, AttrSet);
274   Value *ReadTIDIGY = Mod->getOrInsertFunction(
275       "llvm.r600.read.tidig.y", FTy, AttrSet);
276   Value *ReadTIDIGZ = Mod->getOrInsertFunction(
277       "llvm.r600.read.tidig.z", FTy, AttrSet);
278
279
280   Value *TCntY = Builder.CreateCall(ReadLocalSizeY);
281   Value *TCntZ = Builder.CreateCall(ReadLocalSizeZ);
282   Value *TIdX  = Builder.CreateCall(ReadTIDIGX);
283   Value *TIdY  = Builder.CreateCall(ReadTIDIGY);
284   Value *TIdZ  = Builder.CreateCall(ReadTIDIGZ);
285
286   Value *Tmp0 = Builder.CreateMul(TCntY, TCntZ);
287   Tmp0 = Builder.CreateMul(Tmp0, TIdX);
288   Value *Tmp1 = Builder.CreateMul(TIdY, TCntZ);
289   Value *TID = Builder.CreateAdd(Tmp0, Tmp1);
290   TID = Builder.CreateAdd(TID, TIdZ);
291
292   std::vector<Value*> Indices;
293   Indices.push_back(Constant::getNullValue(Type::getInt32Ty(Mod->getContext())));
294   Indices.push_back(TID);
295
296   Value *Offset = Builder.CreateGEP(GV, Indices);
297   I.mutateType(Offset->getType());
298   I.replaceAllUsesWith(Offset);
299   I.eraseFromParent();
300
301   std::vector<Value*> WorkList;
302
303   collectUsesWithPtrTypes(Offset, WorkList);
304
305   for (std::vector<Value*>::iterator i = WorkList.begin(),
306                                      e = WorkList.end(); i != e; ++i) {
307     Value *V = *i;
308     CallInst *Call = dyn_cast<CallInst>(V);
309     if (!Call) {
310       Type *EltTy = V->getType()->getPointerElementType();
311       PointerType *NewTy = PointerType::get(EltTy, AMDGPUAS::LOCAL_ADDRESS);
312       V->mutateType(NewTy);
313       continue;
314     }
315
316     IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Call);
317     if (!Intr) {
318       std::vector<Type*> ArgTypes;
319       for (unsigned ArgIdx = 0, ArgEnd = Call->getNumArgOperands();
320                                 ArgIdx != ArgEnd; ++ArgIdx) {
321         ArgTypes.push_back(Call->getArgOperand(ArgIdx)->getType());
322       }
323       Function *F = Call->getCalledFunction();
324       FunctionType *NewType = FunctionType::get(Call->getType(), ArgTypes,
325                                                 F->isVarArg());
326       Constant *C = Mod->getOrInsertFunction(StringRef(F->getName().str() + ".local"), NewType,
327                                              F->getAttributes());
328       Function *NewF = cast<Function>(C);
329       Call->setCalledFunction(NewF);
330       continue;
331     }
332
333     Builder.SetInsertPoint(Intr);
334     switch (Intr->getIntrinsicID()) {
335     case Intrinsic::lifetime_start:
336     case Intrinsic::lifetime_end:
337       // These intrinsics are for address space 0 only
338       Intr->eraseFromParent();
339       continue;
340     case Intrinsic::memcpy: {
341       MemCpyInst *MemCpy = cast<MemCpyInst>(Intr);
342       Builder.CreateMemCpy(MemCpy->getRawDest(), MemCpy->getRawSource(),
343                            MemCpy->getLength(), MemCpy->getAlignment(),
344                            MemCpy->isVolatile());
345       Intr->eraseFromParent();
346       continue;
347     }
348     case Intrinsic::memset: {
349       MemSetInst *MemSet = cast<MemSetInst>(Intr);
350       Builder.CreateMemSet(MemSet->getRawDest(), MemSet->getValue(),
351                            MemSet->getLength(), MemSet->getAlignment(),
352                            MemSet->isVolatile());
353       Intr->eraseFromParent();
354       continue;
355     }
356     default:
357       Intr->dump();
358       llvm_unreachable("Don't know how to promote alloca intrinsic use.");
359     }
360   }
361 }
362
363 FunctionPass *llvm::createAMDGPUPromoteAlloca(const AMDGPUSubtarget &ST) {
364   return new AMDGPUPromoteAlloca(ST);
365 }