[NVPTX] kernel pointer arguments point to the global address space
[oota-llvm.git] / lib / Target / NVPTX / NVPTXLowerKernelArgs.cpp
1 //===-- NVPTXLowerKernelArgs.cpp - Lower kernel arguments -----------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Pointer arguments to kernel functions need to be lowered specially.
11 //
12 // 1. Copy byval struct args to local memory. This is a preparation for handling
13 //    cases like
14 //
15 //    kernel void foo(struct A arg, ...)
16 //    {
17 //      struct A *p = &arg;
18 //      ...
19 //      ... = p->filed1 ...  (this is no generic address for .param)
20 //      p->filed2 = ...      (this is no write access to .param)
21 //    }
22 //
23 // 2. Convert non-byval pointer arguments of CUDA kernels to pointers in the
24 //    global address space. This allows later optimizations to emit
25 //    ld.global.*/st.global.* for accessing these pointer arguments. For
26 //    example,
27 //
28 //    define void @foo(float* %input) {
29 //      %v = load float, float* %input, align 4
30 //      ...
31 //    }
32 //
33 //    becomes
34 //
35 //    define void @foo(float* %input) {
36 //      %input2 = addrspacecast float* %input to float addrspace(1)*
37 //      %input3 = addrspacecast float addrspace(1)* %input2 to float*
38 //      %v = load float, float* %input3, align 4
39 //      ...
40 //    }
41 //
42 //    Later, NVPTXFavorNonGenericAddrSpaces will optimize it to
43 //
44 //    define void @foo(float* %input) {
45 //      %input2 = addrspacecast float* %input to float addrspace(1)*
46 //      %v = load float, float addrspace(1)* %input2, align 4
47 //      ...
48 //    }
49 //
50 // TODO: merge this pass with NVPTXFavorNonGenericAddrSpace so that other passes
51 // don't cancel the addrspacecast pair this pass emits.
52 //===----------------------------------------------------------------------===//
53
54 #include "NVPTX.h"
55 #include "NVPTXUtilities.h"
56 #include "NVPTXTargetMachine.h"
57 #include "llvm/IR/Function.h"
58 #include "llvm/IR/Instructions.h"
59 #include "llvm/IR/Module.h"
60 #include "llvm/IR/Type.h"
61 #include "llvm/Pass.h"
62
63 using namespace llvm;
64
65 namespace llvm {
66 void initializeNVPTXLowerKernelArgsPass(PassRegistry &);
67 }
68
69 namespace {
70 class NVPTXLowerKernelArgs : public FunctionPass {
71   bool runOnFunction(Function &F) override;
72
73   // handle byval parameters
74   void handleByValParam(Argument *);
75   // handle non-byval pointer parameters
76   void handlePointerParam(Argument *);
77
78 public:
79   static char ID; // Pass identification, replacement for typeid
80   NVPTXLowerKernelArgs(const NVPTXTargetMachine *TM = nullptr)
81       : FunctionPass(ID), TM(TM) {}
82   const char *getPassName() const override {
83     return "Lower pointer arguments of CUDA kernels";
84   }
85
86 private:
87   const NVPTXTargetMachine *TM;
88 };
89 } // namespace
90
91 char NVPTXLowerKernelArgs::ID = 1;
92
93 INITIALIZE_PASS(NVPTXLowerKernelArgs, "nvptx-lower-kernel-args",
94                 "Lower kernel arguments (NVPTX)", false, false)
95
96 // =============================================================================
97 // If the function had a byval struct ptr arg, say foo(%struct.x *byval %d),
98 // then add the following instructions to the first basic block:
99 //
100 // %temp = alloca %struct.x, align 8
101 // %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)*
102 // %tv = load %struct.x addrspace(101)* %tempd
103 // store %struct.x %tv, %struct.x* %temp, align 8
104 //
105 // The above code allocates some space in the stack and copies the incoming
106 // struct from param space to local space.
107 // Then replace all occurences of %d by %temp.
108 // =============================================================================
109 void NVPTXLowerKernelArgs::handleByValParam(Argument *Arg) {
110   Function *Func = Arg->getParent();
111   Instruction *FirstInst = &(Func->getEntryBlock().front());
112   PointerType *PType = dyn_cast<PointerType>(Arg->getType());
113
114   assert(PType && "Expecting pointer type in handleByValParam");
115
116   Type *StructType = PType->getElementType();
117   AllocaInst *AllocA = new AllocaInst(StructType, Arg->getName(), FirstInst);
118   // Set the alignment to alignment of the byval parameter. This is because,
119   // later load/stores assume that alignment, and we are going to replace
120   // the use of the byval parameter with this alloca instruction.
121   AllocA->setAlignment(Func->getParamAlignment(Arg->getArgNo() + 1));
122   Arg->replaceAllUsesWith(AllocA);
123
124   Value *ArgInParam = new AddrSpaceCastInst(
125       Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
126       FirstInst);
127   LoadInst *LI = new LoadInst(ArgInParam, Arg->getName(), FirstInst);
128   new StoreInst(LI, AllocA, FirstInst);
129 }
130
131 void NVPTXLowerKernelArgs::handlePointerParam(Argument *Arg) {
132   assert(!Arg->hasByValAttr() &&
133          "byval params should be handled by handleByValParam");
134
135   Instruction *FirstInst = Arg->getParent()->getEntryBlock().begin();
136   Instruction *ArgInGlobal = new AddrSpaceCastInst(
137       Arg, PointerType::get(Arg->getType()->getPointerElementType(),
138                             ADDRESS_SPACE_GLOBAL),
139       Arg->getName(), FirstInst);
140   Value *ArgInGeneric = new AddrSpaceCastInst(ArgInGlobal, Arg->getType(),
141                                               Arg->getName(), FirstInst);
142   // Replace with ArgInGeneric all uses of Args except ArgInGlobal.
143   Arg->replaceAllUsesWith(ArgInGeneric);
144   ArgInGlobal->setOperand(0, Arg);
145 }
146
147
148 // =============================================================================
149 // Main function for this pass.
150 // =============================================================================
151 bool NVPTXLowerKernelArgs::runOnFunction(Function &F) {
152   // Skip non-kernels. See the comments at the top of this file.
153   if (!isKernelFunction(F))
154     return false;
155
156   for (Argument &Arg : F.args()) {
157     if (Arg.getType()->isPointerTy()) {
158       if (Arg.hasByValAttr())
159         handleByValParam(&Arg);
160       else if (TM && TM->getDrvInterface() == NVPTX::CUDA)
161         handlePointerParam(&Arg);
162     }
163   }
164   return true;
165 }
166
167 FunctionPass *
168 llvm::createNVPTXLowerKernelArgsPass(const NVPTXTargetMachine *TM) {
169   return new NVPTXLowerKernelArgs(TM);
170 }