1 //===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
11 /// This pass removes performs the following type substitution on all
12 /// non-compute shaders:
15 /// - v16i8 is used for constant memory resource descriptors. This type is
16 /// legal for some compute APIs, and we don't want to declare it as legal
17 /// in the backend, because we want the legalizer to expand all v16i8
19 //===----------------------------------------------------------------------===//
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/InstVisitor.h"
30 class SITypeRewriter : public FunctionPass,
31 public InstVisitor<SITypeRewriter> {
39 SITypeRewriter() : FunctionPass(ID) { }
40 virtual bool doInitialization(Module &M);
41 virtual bool runOnFunction(Function &F);
42 virtual const char *getPassName() const {
43 return "SI Type Rewriter";
45 void visitLoadInst(LoadInst &I);
46 void visitCallInst(CallInst &I);
47 void visitBitCast(BitCastInst &I);
50 } // End anonymous namespace
52 char SITypeRewriter::ID = 0;
54 bool SITypeRewriter::doInitialization(Module &M) {
56 v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
57 i128 = Type::getIntNTy(M.getContext(), 128);
61 bool SITypeRewriter::runOnFunction(Function &F) {
62 AttributeSet Set = F.getAttributes();
63 Attribute A = Set.getAttribute(AttributeSet::FunctionIndex, "ShaderType");
65 unsigned ShaderType = ShaderType::COMPUTE;
66 if (A.isStringAttribute()) {
67 StringRef Str = A.getValueAsString();
68 Str.getAsInteger(0, ShaderType);
70 if (ShaderType != ShaderType::COMPUTE) {
79 void SITypeRewriter::visitLoadInst(LoadInst &I) {
80 Value *Ptr = I.getPointerOperand();
81 Type *PtrTy = Ptr->getType();
82 Type *ElemTy = PtrTy->getPointerElementType();
83 IRBuilder<> Builder(&I);
84 if (ElemTy == v16i8) {
85 Value *BitCast = Builder.CreateBitCast(Ptr, Type::getIntNPtrTy(I.getContext(), 128, 2));
86 LoadInst *Load = Builder.CreateLoad(BitCast);
87 SmallVector <std::pair<unsigned, MDNode*>, 8> MD;
88 I.getAllMetadataOtherThanDebugLoc(MD);
89 for (unsigned i = 0, e = MD.size(); i != e; ++i) {
90 Load->setMetadata(MD[i].first, MD[i].second);
92 Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
93 I.replaceAllUsesWith(BitCastLoad);
98 void SITypeRewriter::visitCallInst(CallInst &I) {
99 IRBuilder<> Builder(&I);
100 SmallVector <Value*, 8> Args;
101 SmallVector <Type*, 8> Types;
102 bool NeedToReplace = false;
103 Function *F = I.getCalledFunction();
104 std::string Name = F->getName().str();
105 for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
106 Value *Arg = I.getArgOperand(i);
107 if (Arg->getType() == v16i8) {
108 Args.push_back(Builder.CreateBitCast(Arg, i128));
109 Types.push_back(i128);
110 NeedToReplace = true;
111 Name = Name + ".i128";
114 Types.push_back(Arg->getType());
118 if (!NeedToReplace) {
121 Function *NewF = Mod->getFunction(Name);
123 NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
124 NewF->setAttributes(F->getAttributes());
126 I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
130 void SITypeRewriter::visitBitCast(BitCastInst &I) {
131 IRBuilder<> Builder(&I);
132 if (I.getDestTy() != i128) {
136 if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
137 if (Op->getSrcTy() == i128) {
138 I.replaceAllUsesWith(Op->getOperand(0));
144 FunctionPass *llvm::createSITypeRewriter() {
145 return new SITypeRewriter();