f194d8b56dc6b433db0f73a2f7578a82e9a2d661
[oota-llvm.git] / lib / Target / R600 / SITypeRewriter.cpp
1 //===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 /// \file
11 /// This pass removes performs the following type substitution on all
12 /// non-compute shaders:
13 ///
14 /// v16i8 => i128
15 ///   - v16i8 is used for constant memory resource descriptors.  This type is
16 ///      legal for some compute APIs, and we don't want to declare it as legal
17 ///      in the backend, because we want the legalizer to expand all v16i8
18 ///      operations.
19 /// v1* => *
20 ///   - Having v1* types complicates the legalizer and we can easily replace
21 ///   - them with the element type.
22 //===----------------------------------------------------------------------===//
23
24 #include "AMDGPU.h"
25
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/InstVisitor.h"
28
29 using namespace llvm;
30
31 namespace {
32
33 class SITypeRewriter : public FunctionPass,
34                        public InstVisitor<SITypeRewriter> {
35
36   static char ID;
37   Module *Mod;
38   Type *v16i8;
39   Type *i128;
40
41 public:
42   SITypeRewriter() : FunctionPass(ID) { }
43   virtual bool doInitialization(Module &M);
44   virtual bool runOnFunction(Function &F);
45   virtual const char *getPassName() const {
46     return "SI Type Rewriter";
47   }
48   void visitLoadInst(LoadInst &I);
49   void visitCallInst(CallInst &I);
50   void visitBitCast(BitCastInst &I);
51 };
52
53 } // End anonymous namespace
54
55 char SITypeRewriter::ID = 0;
56
57 bool SITypeRewriter::doInitialization(Module &M) {
58   Mod = &M;
59   v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
60   i128 = Type::getIntNTy(M.getContext(), 128);
61   return false;
62 }
63
64 bool SITypeRewriter::runOnFunction(Function &F) {
65   AttributeSet Set = F.getAttributes();
66   Attribute A = Set.getAttribute(AttributeSet::FunctionIndex, "ShaderType");
67
68   unsigned ShaderType = ShaderType::COMPUTE;
69   if (A.isStringAttribute()) {
70     StringRef Str = A.getValueAsString();
71     Str.getAsInteger(0, ShaderType);
72   }
73   if (ShaderType != ShaderType::COMPUTE) {
74     visit(F);
75   }
76
77   visit(F);
78
79   return false;
80 }
81
82 void SITypeRewriter::visitLoadInst(LoadInst &I) {
83   Value *Ptr = I.getPointerOperand();
84   Type *PtrTy = Ptr->getType();
85   Type *ElemTy = PtrTy->getPointerElementType();
86   IRBuilder<> Builder(&I);
87   if (ElemTy == v16i8)  {
88     Value *BitCast = Builder.CreateBitCast(Ptr, Type::getIntNPtrTy(I.getContext(), 128, 2));
89     LoadInst *Load = Builder.CreateLoad(BitCast);
90     SmallVector <std::pair<unsigned, MDNode*>, 8> MD;
91     I.getAllMetadataOtherThanDebugLoc(MD);
92     for (unsigned i = 0, e = MD.size(); i != e; ++i) {
93       Load->setMetadata(MD[i].first, MD[i].second);
94     }
95     Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
96     I.replaceAllUsesWith(BitCastLoad);
97     I.eraseFromParent();
98   }
99 }
100
101 void SITypeRewriter::visitCallInst(CallInst &I) {
102   IRBuilder<> Builder(&I);
103   SmallVector <Value*, 8> Args;
104   SmallVector <Type*, 8> Types;
105   bool NeedToReplace = false;
106   Function *F = I.getCalledFunction();
107   std::string Name = F->getName().str();
108   for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
109     Value *Arg = I.getArgOperand(i);
110     if (Arg->getType() == v16i8) {
111       Args.push_back(Builder.CreateBitCast(Arg, i128));
112       Types.push_back(i128);
113       NeedToReplace = true;
114       Name = Name + ".i128";
115     } else if (Arg->getType()->isVectorTy() &&
116                Arg->getType()->getVectorNumElements() == 1 &&
117                Arg->getType()->getVectorElementType() ==
118                                               Type::getInt32Ty(I.getContext())){
119       Type *ElementTy = Arg->getType()->getVectorElementType();
120       std::string TypeName = "i32";
121       InsertElementInst *Def = dyn_cast<InsertElementInst>(Arg);
122       assert(Def);
123       Args.push_back(Def->getOperand(1));
124       Types.push_back(ElementTy);
125       std::string VecTypeName = "v1" + TypeName;
126       Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName);
127       NeedToReplace = true;
128     } else {
129       Args.push_back(Arg);
130       Types.push_back(Arg->getType());
131     }
132   }
133
134   if (!NeedToReplace) {
135     return;
136   }
137   Function *NewF = Mod->getFunction(Name);
138   if (!NewF) {
139     NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
140     NewF->setAttributes(F->getAttributes());
141   }
142   I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
143   I.eraseFromParent();
144 }
145
146 void SITypeRewriter::visitBitCast(BitCastInst &I) {
147   IRBuilder<> Builder(&I);
148   if (I.getDestTy() != i128) {
149     return;
150   }
151
152   if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
153     if (Op->getSrcTy() == i128) {
154       I.replaceAllUsesWith(Op->getOperand(0));
155       I.eraseFromParent();
156     }
157   }
158 }
159
160 FunctionPass *llvm::createSITypeRewriter() {
161   return new SITypeRewriter();
162 }