1 //===-- AMDGPUOpenCLImageTypeLoweringPass.cpp -----------------------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
11 /// This pass resolves calls to OpenCL image attribute, image resource ID and
12 /// sampler resource ID getter functions.
14 /// Image attributes (size and format) are expected to be passed to the kernel
15 /// as kernel arguments immediately following the image argument itself,
16 /// therefore this pass adds image size and format arguments to the kernel
17 /// functions in the module. The kernel functions with image arguments are
18 /// re-created using the new signature. The new arguments are added to the
19 /// kernel metadata with kernel_arg_type set to "image_size" or "image_format".
20 /// Note: this pass may invalidate pointers to functions.
22 /// Resource IDs of read-only images, write-only images and samplers are
23 /// defined to be their index among the kernel arguments of the same
24 /// type and access qualifier.
25 //===----------------------------------------------------------------------===//
28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/Analysis/Passes.h"
32 #include "llvm/IR/Constants.h"
33 #include "llvm/IR/Function.h"
34 #include "llvm/IR/Instructions.h"
35 #include "llvm/IR/Module.h"
36 #include "llvm/Transforms/Utils/Cloning.h"
42 StringRef GetImageSizeFunc = "llvm.OpenCL.image.get.size";
43 StringRef GetImageFormatFunc = "llvm.OpenCL.image.get.format";
44 StringRef GetImageResourceIDFunc = "llvm.OpenCL.image.get.resource.id";
45 StringRef GetSamplerResourceIDFunc = "llvm.OpenCL.sampler.get.resource.id";
47 StringRef ImageSizeArgMDType = "__llvm_image_size";
48 StringRef ImageFormatArgMDType = "__llvm_image_format";
50 StringRef KernelsMDNodeName = "opencl.kernels";
51 StringRef KernelArgMDNodeNames[] = {
52 "kernel_arg_addr_space",
53 "kernel_arg_access_qual",
55 "kernel_arg_base_type",
56 "kernel_arg_type_qual"};
57 constexpr unsigned NumKernelArgMDNodes = array_lengthof(KernelArgMDNodeNames);
59 typedef SmallVector<Metadata *, 8> MDVector;
61 MDVector ArgVector[NumKernelArgMDNodes];
64 } // end anonymous namespace
67 IsImageType(StringRef TypeString) {
68 return TypeString == "image2d_t" || TypeString == "image3d_t";
72 IsSamplerType(StringRef TypeString) {
73 return TypeString == "sampler_t";
77 GetFunctionFromMDNode(MDNode *Node) {
81 size_t NumOps = Node->getNumOperands();
82 if (NumOps != NumKernelArgMDNodes + 1)
85 auto F = mdconst::dyn_extract<Function>(Node->getOperand(0));
90 size_t ExpectNumArgNodeOps = F->arg_size() + 1;
91 for (size_t i = 0; i < NumKernelArgMDNodes; ++i) {
92 MDNode *ArgNode = dyn_cast_or_null<MDNode>(Node->getOperand(i + 1));
93 if (ArgNode->getNumOperands() != ExpectNumArgNodeOps)
95 if (!ArgNode->getOperand(0))
97 assert(cast<MDString>(ArgNode->getOperand(0))->getString() ==
98 KernelArgMDNodeNames[i] && "Wrong kernel arg metadata name");
105 AccessQualFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
106 MDNode *ArgAQNode = cast<MDNode>(KernelMDNode->getOperand(2));
107 return cast<MDString>(ArgAQNode->getOperand(ArgIdx + 1))->getString();
111 ArgTypeFromMD(MDNode *KernelMDNode, unsigned ArgIdx) {
112 MDNode *ArgTypeNode = cast<MDNode>(KernelMDNode->getOperand(3));
113 return cast<MDString>(ArgTypeNode->getOperand(ArgIdx + 1))->getString();
117 GetArgMD(MDNode *KernelMDNode, unsigned OpIdx) {
119 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
120 MDNode *Node = cast<MDNode>(KernelMDNode->getOperand(i + 1));
121 Res.push_back(Node->getOperand(OpIdx));
127 PushArgMD(KernelArgMD &MD, const MDVector &V) {
128 assert(V.size() == NumKernelArgMDNodes);
129 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i) {
130 MD.ArgVector[i].push_back(V[i]);
136 class AMDGPUOpenCLImageTypeLoweringPass : public ModulePass {
139 LLVMContext *Context;
142 Type *ImageFormatType;
143 SmallVector<Instruction *, 4> InstsToErase;
145 bool replaceImageUses(Argument &ImageArg, uint32_t ResourceID,
146 Argument &ImageSizeArg,
147 Argument &ImageFormatArg) {
148 bool Modified = false;
150 for (auto &Use : ImageArg.uses()) {
151 auto Inst = dyn_cast<CallInst>(Use.getUser());
156 Function *F = Inst->getCalledFunction();
160 Value *Replacement = nullptr;
161 StringRef Name = F->getName();
162 if (Name.startswith(GetImageResourceIDFunc)) {
163 Replacement = ConstantInt::get(Int32Type, ResourceID);
164 } else if (Name.startswith(GetImageSizeFunc)) {
165 Replacement = &ImageSizeArg;
166 } else if (Name.startswith(GetImageFormatFunc)) {
167 Replacement = &ImageFormatArg;
172 Inst->replaceAllUsesWith(Replacement);
173 InstsToErase.push_back(Inst);
180 bool replaceSamplerUses(Argument &SamplerArg, uint32_t ResourceID) {
181 bool Modified = false;
183 for (const auto &Use : SamplerArg.uses()) {
184 auto Inst = dyn_cast<CallInst>(Use.getUser());
189 Function *F = Inst->getCalledFunction();
193 Value *Replacement = nullptr;
194 StringRef Name = F->getName();
195 if (Name == GetSamplerResourceIDFunc) {
196 Replacement = ConstantInt::get(Int32Type, ResourceID);
201 Inst->replaceAllUsesWith(Replacement);
202 InstsToErase.push_back(Inst);
209 bool replaceImageAndSamplerUses(Function *F, MDNode *KernelMDNode) {
210 uint32_t NumReadOnlyImageArgs = 0;
211 uint32_t NumWriteOnlyImageArgs = 0;
212 uint32_t NumSamplerArgs = 0;
214 bool Modified = false;
215 InstsToErase.clear();
216 for (auto ArgI = F->arg_begin(); ArgI != F->arg_end(); ++ArgI) {
217 Argument &Arg = *ArgI;
218 StringRef Type = ArgTypeFromMD(KernelMDNode, Arg.getArgNo());
220 // Handle image types.
221 if (IsImageType(Type)) {
222 StringRef AccessQual = AccessQualFromMD(KernelMDNode, Arg.getArgNo());
224 if (AccessQual == "read_only") {
225 ResourceID = NumReadOnlyImageArgs++;
226 } else if (AccessQual == "write_only") {
227 ResourceID = NumWriteOnlyImageArgs++;
229 llvm_unreachable("Wrong image access qualifier.");
232 Argument &SizeArg = *(++ArgI);
233 Argument &FormatArg = *(++ArgI);
234 Modified |= replaceImageUses(Arg, ResourceID, SizeArg, FormatArg);
236 // Handle sampler type.
237 } else if (IsSamplerType(Type)) {
238 uint32_t ResourceID = NumSamplerArgs++;
239 Modified |= replaceSamplerUses(Arg, ResourceID);
242 for (unsigned i = 0; i < InstsToErase.size(); ++i) {
243 InstsToErase[i]->eraseFromParent();
249 std::tuple<Function *, MDNode *>
250 addImplicitArgs(Function *F, MDNode *KernelMDNode) {
251 bool Modified = false;
253 FunctionType *FT = F->getFunctionType();
254 SmallVector<Type *, 8> ArgTypes;
256 // Metadata operands for new MDNode.
257 KernelArgMD NewArgMDs;
258 PushArgMD(NewArgMDs, GetArgMD(KernelMDNode, 0));
260 // Add implicit arguments to the signature.
261 for (unsigned i = 0; i < FT->getNumParams(); ++i) {
262 ArgTypes.push_back(FT->getParamType(i));
263 MDVector ArgMD = GetArgMD(KernelMDNode, i + 1);
264 PushArgMD(NewArgMDs, ArgMD);
266 if (!IsImageType(ArgTypeFromMD(KernelMDNode, i)))
269 // Add size implicit argument.
270 ArgTypes.push_back(ImageSizeType);
271 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageSizeArgMDType);
272 PushArgMD(NewArgMDs, ArgMD);
274 // Add format implicit argument.
275 ArgTypes.push_back(ImageFormatType);
276 ArgMD[2] = ArgMD[3] = MDString::get(*Context, ImageFormatArgMDType);
277 PushArgMD(NewArgMDs, ArgMD);
282 return std::make_tuple(nullptr, nullptr);
285 // Create function with new signature and clone the old body into it.
286 auto NewFT = FunctionType::get(FT->getReturnType(), ArgTypes, false);
287 auto NewF = Function::Create(NewFT, F->getLinkage(), F->getName());
288 ValueToValueMapTy VMap;
289 auto NewFArgIt = NewF->arg_begin();
290 for (auto &Arg: F->args()) {
291 auto ArgName = Arg.getName();
292 NewFArgIt->setName(ArgName);
293 VMap[&Arg] = &(*NewFArgIt++);
294 if (IsImageType(ArgTypeFromMD(KernelMDNode, Arg.getArgNo()))) {
295 (NewFArgIt++)->setName(Twine("__size_") + ArgName);
296 (NewFArgIt++)->setName(Twine("__format_") + ArgName);
299 SmallVector<ReturnInst*, 8> Returns;
300 CloneFunctionInto(NewF, F, VMap, /*ModuleLevelChanges=*/false, Returns);
303 SmallVector<llvm::Metadata *, 6> KernelMDArgs;
304 KernelMDArgs.push_back(ConstantAsMetadata::get(NewF));
305 for (unsigned i = 0; i < NumKernelArgMDNodes; ++i)
306 KernelMDArgs.push_back(MDNode::get(*Context, NewArgMDs.ArgVector[i]));
307 MDNode *NewMDNode = MDNode::get(*Context, KernelMDArgs);
309 return std::make_tuple(NewF, NewMDNode);
312 bool transformKernels(Module &M) {
313 NamedMDNode *KernelsMDNode = M.getNamedMetadata(KernelsMDNodeName);
317 bool Modified = false;
318 for (unsigned i = 0; i < KernelsMDNode->getNumOperands(); ++i) {
319 MDNode *KernelMDNode = KernelsMDNode->getOperand(i);
320 Function *F = GetFunctionFromMDNode(KernelMDNode);
326 std::tie(NewF, NewMDNode) = addImplicitArgs(F, KernelMDNode);
328 // Replace old function and metadata with new ones.
329 F->eraseFromParent();
330 M.getFunctionList().push_back(NewF);
331 M.getOrInsertFunction(NewF->getName(), NewF->getFunctionType(),
332 NewF->getAttributes());
333 KernelsMDNode->setOperand(i, NewMDNode);
336 KernelMDNode = NewMDNode;
340 Modified |= replaceImageAndSamplerUses(F, KernelMDNode);
347 AMDGPUOpenCLImageTypeLoweringPass() : ModulePass(ID) {}
349 bool runOnModule(Module &M) override {
350 Context = &M.getContext();
351 Int32Type = Type::getInt32Ty(M.getContext());
352 ImageSizeType = ArrayType::get(Int32Type, 3);
353 ImageFormatType = ArrayType::get(Int32Type, 2);
355 return transformKernels(M);
358 const char *getPassName() const override {
359 return "AMDGPU OpenCL Image Type Pass";
363 char AMDGPUOpenCLImageTypeLoweringPass::ID = 0;
365 } // end anonymous namespace
367 ModulePass *llvm::createAMDGPUOpenCLImageTypeLoweringPass() {
368 return new AMDGPUOpenCLImageTypeLoweringPass();