R600/SI: Convert v16i8 resource descriptors to i128
[oota-llvm.git] / lib / Target / R600 / SITypeRewriter.cpp
1 //===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass removes performs the following type substitution on all
12 /// non-compute shaders:
13 ///
14 /// v16i8 => i128
15 ///   - v16i8 is used for constant memory resource descriptors.  This type is
16 ///      legal for some compute APIs, and we don't want to declare it as legal
17 ///      in the backend, because we want the legalizer to expand all v16i8
18 ///      operations.
19 //===----------------------------------------------------------------------===//
20
21 #include "AMDGPU.h"
22
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/InstVisitor.h"
25
26 using namespace llvm;
27
28 namespace {
29
30 class SITypeRewriter : public FunctionPass,
31                        public InstVisitor<SITypeRewriter> {
32
33   static char ID;
34   Module *Mod;
35   Type *v16i8;
36   Type *i128;
37
38 public:
39   SITypeRewriter() : FunctionPass(ID) { }
40   virtual bool doInitialization(Module &M);
41   virtual bool runOnFunction(Function &F);
42   virtual const char *getPassName() const {
43     return "SI Type Rewriter";
44   }
45   void visitLoadInst(LoadInst &I);
46   void visitCallInst(CallInst &I);
47   void visitBitCast(BitCastInst &I);
48 };
49
50 } // End anonymous namespace
51
52 char SITypeRewriter::ID = 0;
53
54 bool SITypeRewriter::doInitialization(Module &M) {
55   Mod = &M;
56   v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
57   i128 = Type::getIntNTy(M.getContext(), 128);
58   return false;
59 }
60
61 bool SITypeRewriter::runOnFunction(Function &F) {
62   AttributeSet Set = F.getAttributes();
63   Attribute A = Set.getAttribute(AttributeSet::FunctionIndex, "ShaderType");
64
65   unsigned ShaderType = ShaderType::COMPUTE;
66   if (A.isStringAttribute()) {
67     StringRef Str = A.getValueAsString();
68     Str.getAsInteger(0, ShaderType);
69   }
70   if (ShaderType != ShaderType::COMPUTE) {
71     visit(F);
72   }
73
74   visit(F);
75
76   return false;
77 }
78
79 void SITypeRewriter::visitLoadInst(LoadInst &I) {
80   Value *Ptr = I.getPointerOperand();
81   Type *PtrTy = Ptr->getType();
82   Type *ElemTy = PtrTy->getPointerElementType();
83   IRBuilder<> Builder(&I);
84   if (ElemTy == v16i8)  {
85     Value *BitCast = Builder.CreateBitCast(Ptr, Type::getIntNPtrTy(I.getContext(), 128, 2));
86     LoadInst *Load = Builder.CreateLoad(BitCast);
87     SmallVector <std::pair<unsigned, MDNode*>, 8> MD;
88     I.getAllMetadataOtherThanDebugLoc(MD);
89     for (unsigned i = 0, e = MD.size(); i != e; ++i) {
90       Load->setMetadata(MD[i].first, MD[i].second);
91     }
92     Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
93     I.replaceAllUsesWith(BitCastLoad);
94     I.eraseFromParent();
95   }
96 }
97
98 void SITypeRewriter::visitCallInst(CallInst &I) {
99   IRBuilder<> Builder(&I);
100   SmallVector <Value*, 8> Args;
101   SmallVector <Type*, 8> Types;
102   bool NeedToReplace = false;
103   Function *F = I.getCalledFunction();
104   std::string Name = F->getName().str();
105   for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
106     Value *Arg = I.getArgOperand(i);
107     if (Arg->getType() == v16i8) {
108       Args.push_back(Builder.CreateBitCast(Arg, i128));
109       Types.push_back(i128);
110       NeedToReplace = true;
111       Name = Name + ".i128";
112     } else {
113       Args.push_back(Arg);
114       Types.push_back(Arg->getType());
115     }
116   }
117
118   if (!NeedToReplace) {
119     return;
120   }
121   Function *NewF = Mod->getFunction(Name);
122   if (!NewF) {
123     NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
124     NewF->setAttributes(F->getAttributes());
125   }
126   I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
127   I.eraseFromParent();
128 }
129
130 void SITypeRewriter::visitBitCast(BitCastInst &I) {
131   IRBuilder<> Builder(&I);
132   if (I.getDestTy() != i128) {
133     return;
134   }
135
136   if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
137     if (Op->getSrcTy() == i128) {
138       I.replaceAllUsesWith(Op->getOperand(0));
139       I.eraseFromParent();
140     }
141   }
142 }
143
144 FunctionPass *llvm::createSITypeRewriter() {
145   return new SITypeRewriter();
146 }