AMDGPU/SI: Attempt to fix Windows bots broken by r244372
[oota-llvm.git] / lib / Target / AMDGPU / AMDGPUOpenCLImageTypeLoweringPass.cpp
1 //===-- AMDGPUOpenCLImageTypeLoweringPass.cpp -----------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass resolves calls to OpenCL image attribute, image resource ID and
12 /// sampler resource ID getter functions.
13 ///
14 /// Image attributes (size and format) are expected to be passed to the kernel
15 /// as kernel arguments immediately following the image argument itself,
16 /// therefore this pass adds image size and format arguments to the kernel
17 /// functions in the module. The kernel functions with image arguments are
18 /// re-created using the new signature. The new arguments are added to the
19 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
20 /// Note: this pass may invalidate pointers to functions.
21 ///
22 /// Resource IDs of read-only images, write-only images and samplers are
23 /// defined to be their index among the kernel arguments of the same
24 /// type and access qualifier.
25 //===----------------------------------------------------------------------===//
26
27 #include "AMDGPU.h"
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/Analysis/Passes.h"
32 #include "llvm/IR/Constants.h"
33 #include "llvm/IR/Function.h"
34 #include "llvm/IR/Instructions.h"
35 #include "llvm/IR/Module.h"
36 #include "llvm/Transforms/Utils/Cloning.h"
37
38 using namespace llvm;
39
40 namespace {
41
42 StringRef GetImageSizeFunc =         "llvm.OpenCL.image.get.size";
43 StringRef GetImageFormatFunc =       "llvm.OpenCL.image.get.format";
44 StringRef GetImageResourceIDFunc =   "llvm.OpenCL.image.get.resource.id";
45 StringRef GetSamplerResourceIDFunc = "llvm.OpenCL.sampler.get.resource.id";
46
47 StringRef ImageSizeArgMDType =   "__llvm_image_size";
48 StringRef ImageFormatArgMDType = "__llvm_image_format";
49
50 StringRef KernelsMDNodeName = "opencl.kernels";
51 StringRef KernelArgMDNodeNames[] = {
52   "kernel_arg_addr_space",
53   "kernel_arg_access_qual",
54   "kernel_arg_type",
55   "kernel_arg_base_type",
56   "kernel_arg_type_qual"};
57 const unsigned NumKernelArgMDNodes = array_lengthof(KernelArgMDNodeNames);
58
59 typedef SmallVector<Metadata *, 8> MDVector;
60 struct KernelArgMD {
61   MDVector ArgVector[NumKernelArgMDNodes];
62 };
63
64 } // end anonymous namespace
65
66 static inline bool
67 IsImageType(StringRef TypeString) {
68   return TypeString == "image2d_t" || TypeString == "image3d_t";
69 }
70
71 static inline bool
72 IsSamplerType(StringRef TypeString) {
73   return TypeString == "sampler_t";
74 }
75
76 static Function *
77 GetFunctionFromMDNode(MDNode *Node) {
78   if (!Node)
79     return nullptr;
80
81   size_t NumOps = Node->getNumOperands();
82   if (NumOps != NumKernelArgMDNodes + 1)
83     return nullptr;
84
85   auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
86   if (!F)
87     return nullptr;
88
89   // Sanity checks.
90   size_t ExpectNumArgNodeOps = F->arg_size() + 1;
91   for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
92     MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
93     if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
94       return nullptr;
95     if (!ArgNode->getOperand(0))
96       return nullptr;
97     assert(cast<MDString>(ArgNode->getOperand(0))->getString() ==
98            KernelArgMDNodeNames[i] && "Wrong kernel arg metadata name");
99   }
100
101   return F;
102 }
103
104 static StringRef
105 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
106   MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
107   return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
108 }
109
110 static StringRef
111 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
112   MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
113   return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
114 }
115
116 static MDVector
117 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
118   MDVector Res;
119   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
120     MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
121     Res.push_back(Node->getOperand(OpIdx));
122   }
123   return Res;
124 }
125
126 static void
127 PushArgMD(KernelArgMD &MD, const MDVector &V) {
128   assert(V.size() == NumKernelArgMDNodes);
129   for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
130     MD.ArgVector[i].push_back(V[i]);
131   }
132 }
133
134 namespace {
135
136 class AMDGPUOpenCLImageTypeLoweringPass : public ModulePass {
137   static char ID;
138
139   LLVMContext *Context;
140   Type *Int32Type;
141   Type *ImageSizeType;
142   Type *ImageFormatType;
143   SmallVector<Instruction *, 4> InstsToErase;
144
145   bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
146                         Argument &ImageSizeArg,
147                         Argument &ImageFormatArg) {
148     bool Modified = false;
149
150     for (auto &Use : ImageArg.uses()) {
151       auto Inst = dyn_cast<CallInst>(Use.getUser());
152       if (!Inst) {
153         continue;
154       }
155
156       Function *F = Inst->getCalledFunction();
157       if (!F)
158         continue;
159
160       Value *Replacement = nullptr;
161       StringRef Name = F->getName();
162       if (Name.startswith(GetImageResourceIDFunc)) {
163         Replacement = ConstantInt::get(Int32Type, ResourceID);
164       } else if (Name.startswith(GetImageSizeFunc)) {
165         Replacement = &ImageSizeArg;
166       } else if (Name.startswith(GetImageFormatFunc)) {
167         Replacement = &ImageFormatArg;
168       } else {
169         continue;
170       }
171
172       Inst->replaceAllUsesWith(Replacement);
173       InstsToErase.push_back(Inst);
174       Modified = true;
175     }
176
177     return Modified;
178   }
179
180   bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
181     bool Modified = false;
182
183     for (const auto &Use : SamplerArg.uses()) {
184       auto Inst = dyn_cast<CallInst>(Use.getUser());
185       if (!Inst) {
186         continue;
187       }
188
189       Function *F = Inst->getCalledFunction();
190       if (!F)
191         continue;
192
193       Value *Replacement = nullptr;
194       StringRef Name = F->getName();
195       if (Name == GetSamplerResourceIDFunc) {
196         Replacement = ConstantInt::get(Int32Type, ResourceID);
197       } else {
198         continue;
199       }
200
201       Inst->replaceAllUsesWith(Replacement);
202       InstsToErase.push_back(Inst);
203       Modified = true;
204     }
205
206     return Modified;
207   }
208
209   bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
210     uint32_t NumReadOnlyImageArgs = 0;
211     uint32_t NumWriteOnlyImageArgs = 0;
212     uint32_t NumSamplerArgs = 0;
213
214     bool Modified = false;
215     InstsToErase.clear();
216     for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
217       Argument &Arg = *ArgI;
218       StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
219
220       // Handle image types.
221       if (IsImageType(Type)) {
222         StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
223         uint32_t ResourceID;
224         if (AccessQual == "read_only") {
225           ResourceID = NumReadOnlyImageArgs++;
226         } else if (AccessQual == "write_only") {
227           ResourceID = NumWriteOnlyImageArgs++;
228         } else {
229           llvm_unreachable("Wrong image access qualifier.");
230         }
231
232         Argument &SizeArg = *(++ArgI);
233         Argument &FormatArg = *(++ArgI);
234         Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
235
236       // Handle sampler type.
237       } else if (IsSamplerType(Type)) {
238         uint32_t ResourceID = NumSamplerArgs++;
239         Modified |= replaceSamplerUses(Arg, ResourceID);
240       }
241     }
242     for (unsigned i = 0; i < InstsToErase.size(); ++i) {
243       InstsToErase[i]->eraseFromParent();
244     }
245
246     return Modified;
247   }
248
249   std::tuple<Function *, MDNode *>
250   addImplicitArgs(Function *F, MDNode *KernelMDNode) {
251     bool Modified = false;
252
253     FunctionType *FT = F->getFunctionType();
254     SmallVector<Type *, 8> ArgTypes;
255
256     // Metadata operands for new MDNode.
257     KernelArgMD NewArgMDs;
258     PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
259
260     // Add implicit arguments to the signature.
261     for (unsigned i = 0; i < FT->getNumParams(); ++i) {
262       ArgTypes.push_back(FT->getParamType(i));
263       MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
264       PushArgMD(NewArgMDs, ArgMD);
265
266       if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
267         continue;
268
269       // Add size implicit argument.
270       ArgTypes.push_back(ImageSizeType);
271       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
272       PushArgMD(NewArgMDs, ArgMD);
273
274       // Add format implicit argument.
275       ArgTypes.push_back(ImageFormatType);
276       ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
277       PushArgMD(NewArgMDs, ArgMD);
278
279       Modified = true;
280     }
281     if (!Modified) {
282       return std::make_tuple(nullptr, nullptr);
283     }
284
285     // Create function with new signature and clone the old body into it.
286     auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
287     auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
288     ValueToValueMapTy VMap;
289     auto NewFArgIt = NewF->arg_begin();
290     for (auto &Arg: F->args()) {
291       auto ArgName = Arg.getName();
292       NewFArgIt->setName(ArgName);
293       VMap[&Arg] = &(*NewFArgIt++);
294       if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
295         (NewFArgIt++)->setName(Twine("__size_") + ArgName);
296         (NewFArgIt++)->setName(Twine("__format_") + ArgName);
297       }
298     }
299     SmallVector<ReturnInst*, 8> Returns;
300     CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
301
302     // Build new MDNode.
303     SmallVector<llvm::Metadata *, 6> KernelMDArgs;
304     KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
305     for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
306       KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
307     MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
308
309     return std::make_tuple(NewF, NewMDNode);
310   }
311
312   bool transformKernels(Module &M) {
313     NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
314     if (!KernelsMDNode)
315       return false;
316
317     bool Modified = false;
318     for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
319       MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
320       Function *F = GetFunctionFromMDNode(KernelMDNode);
321       if (!F)
322         continue;
323
324       Function *NewF;
325       MDNode *NewMDNode;
326       std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
327       if (NewF) {
328         // Replace old function and metadata with new ones.
329         F->eraseFromParent();
330         M.getFunctionList().push_back(NewF);
331         M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
332                               NewF->getAttributes());
333         KernelsMDNode->setOperand(i, NewMDNode);
334
335         F = NewF;
336         KernelMDNode = NewMDNode;
337         Modified = true;
338       }
339
340       Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
341     }
342
343     return Modified;
344   }
345
346  public:
347   AMDGPUOpenCLImageTypeLoweringPass() : ModulePass(ID) {}
348
349   bool runOnModule(Module &M) override {
350     Context = &M.getContext();
351     Int32Type = Type::getInt32Ty(M.getContext());
352     ImageSizeType = ArrayType::get(Int32Type, 3);
353     ImageFormatType = ArrayType::get(Int32Type, 2);
354
355     return transformKernels(M);
356   }
357
358   const char *getPassName() const override {
359     return "AMDGPU OpenCL Image Type Pass";
360   }
361 };
362
363 char AMDGPUOpenCLImageTypeLoweringPass::ID = 0;
364
365 } // end anonymous namespace
366
367 ModulePass *llvm::createAMDGPUOpenCLImageTypeLoweringPass() {
368   return new AMDGPUOpenCLImageTypeLoweringPass();
369 }