Unbreak autouprade of llvm.sqrt, simplify some code.
[oota-llvm.git] / lib / VMCore / AutoUpgrade.cpp
1 //===-- AutoUpgrade.cpp - Implement auto-upgrade helper functions ---------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Reid Spencer and is distributed under the 
6 // University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the auto-upgrade helper functions 
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/Assembly/AutoUpgrade.h"
15 #include "llvm/DerivedTypes.h"
16 #include "llvm/Function.h"
17 #include "llvm/Module.h"
18 #include "llvm/Instructions.h"
19 #include "llvm/Intrinsics.h"
20 #include "llvm/SymbolTable.h"
21 #include <iostream>
22 using namespace llvm;
23
24 static Function *getUpgradedUnaryFn(Function *F) {
25   const std::string &Name = F->getName();
26   Module *M = F->getParent();
27   switch (F->getReturnType()->getTypeID()) {
28   default: return 0;
29   case Type::UByteTyID:
30   case Type::SByteTyID:
31     return M->getOrInsertFunction(Name+".i8", 
32                                   Type::UByteTy, Type::UByteTy, NULL);
33   case Type::UShortTyID:
34   case Type::ShortTyID:
35     return M->getOrInsertFunction(Name+".i16", 
36                                   Type::UShortTy, Type::UShortTy, NULL);
37   case Type::UIntTyID:
38   case Type::IntTyID:
39     return M->getOrInsertFunction(Name+".i32", 
40                                   Type::UIntTy, Type::UIntTy, NULL);
41   case Type::ULongTyID:
42   case Type::LongTyID:
43     return M->getOrInsertFunction(Name+".i64",
44                                   Type::ULongTy, Type::ULongTy, NULL);
45   case Type::FloatTyID:
46     return M->getOrInsertFunction(Name+".f32",
47                                   Type::FloatTy, Type::FloatTy, NULL);
48   case Type::DoubleTyID:
49     return M->getOrInsertFunction(Name+".f64",
50                                   Type::DoubleTy, Type::DoubleTy, NULL);
51   }
52 }
53
54 static Function *getUpgradedIntrinsic(Function *F) {
55   // If there's no function, we can't get the argument type.
56   if (!F) return 0;
57
58   // Get the Function's name.
59   const std::string& Name = F->getName();
60
61   // Quickly eliminate it, if it's not a candidate.
62   if (Name.length() <= 8 || Name[0] != 'l' || Name[1] != 'l' || 
63       Name[2] != 'v' || Name[3] != 'm' || Name[4] != '.')
64     return 0;
65
66   Module *M = F->getParent();
67   switch (Name[5]) {
68   default: break;
69   case 'b':
70     if (Name == "llvm.bswap") return getUpgradedUnaryFn(F);
71     break;
72   case 'c':
73     if (Name == "llvm.ctpop" || Name == "llvm.ctlz" || Name == "llvm.cttz")
74       return getUpgradedUnaryFn(F);
75     break;
76   case 'i':
77     if (Name == "llvm.isunordered" && F->arg_begin() != F->arg_end()) {
78       if (F->arg_begin()->getType() == Type::FloatTy)
79         return M->getOrInsertFunction(Name+".f32", F->getFunctionType());
80       if (F->arg_begin()->getType() == Type::DoubleTy)
81         return M->getOrInsertFunction(Name+".f64", F->getFunctionType());
82     }
83     break;
84   case 'm':
85     if (Name == "llvm.memcpy" || Name == "llvm.memset" || 
86         Name == "llvm.memmove") {
87       if (F->getFunctionType()->getParamType(2) == Type::UIntTy)
88         return M->getOrInsertFunction(Name+".i32", F->getFunctionType());
89       if (F->getFunctionType()->getParamType(2) == Type::ULongTy)
90         return M->getOrInsertFunction(Name+".i64", F->getFunctionType());
91     }
92     break;
93   case 's':
94     if (Name == "llvm.sqrt")
95       return getUpgradedUnaryFn(F);
96     break;
97   }
98   return 0;
99 }
100
101 // UpgradeIntrinsicFunction - Convert overloaded intrinsic function names to
102 // their non-overloaded variants by appending the appropriate suffix based on
103 // the argument types.
104 Function *llvm::UpgradeIntrinsicFunction(Function* F) {
105   // See if its one of the name's we're interested in.
106   if (Function *R = getUpgradedIntrinsic(F)) {
107     std::cerr << "WARNING: change " << F->getName() << " to "
108               << R->getName() << "\n";
109     return R;
110   }
111   return 0;
112 }
113
114
115 Instruction* llvm::MakeUpgradedCall(Function *F, 
116                                     const std::vector<Value*> &Params,
117                                     BasicBlock *BB, bool isTailCall,
118                                     unsigned CallingConv) {
119   assert(F && "Need a Function to make a CallInst");
120   assert(BB && "Need a BasicBlock to make a CallInst");
121
122   // Convert the params
123   bool signedArg = false;
124   std::vector<Value*> Oprnds;
125   for (std::vector<Value*>::const_iterator PI = Params.begin(), 
126        PE = Params.end(); PI != PE; ++PI) {
127     const Type* opTy = (*PI)->getType();
128     if (opTy->isSigned()) {
129       signedArg = true;
130       CastInst* cast = 
131         new CastInst(*PI,opTy->getUnsignedVersion(), "autoupgrade_cast");
132       BB->getInstList().push_back(cast);
133       Oprnds.push_back(cast);
134     }
135     else
136       Oprnds.push_back(*PI);
137   }
138
139   Instruction *result = new CallInst(F, Oprnds);
140   if (result->getType() != Type::VoidTy) result->setName("autoupgrade_call");
141   if (isTailCall) cast<CallInst>(result)->setTailCall();
142   if (CallingConv) cast<CallInst>(result)->setCallingConv(CallingConv);
143   if (signedArg) {
144     const Type* newTy = F->getReturnType()->getUnsignedVersion();
145     CastInst* final = new CastInst(result, newTy, "autoupgrade_uncast");
146     BB->getInstList().push_back(result);
147     result = final;
148   }
149   return result;
150 }
151
152 // UpgradeIntrinsicCall - In the BC reader, change a call to some intrinsic to
153 // be a called to the specified intrinsic.  We expect the callees to have the
154 // same number of arguments, but their types may be different.
155 void llvm::UpgradeIntrinsicCall(CallInst *CI, Function *NewFn) {
156   Function *F = CI->getCalledFunction();
157
158   const FunctionType *NewFnTy = NewFn->getFunctionType();
159   std::vector<Value*> Oprnds;
160   for (unsigned i = 1, e = CI->getNumOperands(); i != e; ++i) {
161     Value *V = CI->getOperand(i);
162     if (V->getType() != NewFnTy->getParamType(i-1))
163       V = new CastInst(V, NewFnTy->getParamType(i-1), V->getName(), CI);
164     Oprnds.push_back(V);
165   }
166   CallInst *NewCI = new CallInst(NewFn, Oprnds, CI->getName(), CI);
167   NewCI->setTailCall(CI->isTailCall());
168   NewCI->setCallingConv(CI->getCallingConv());
169   
170   if (!CI->use_empty()) {
171     Instruction *RetVal = NewCI;
172     if (F->getReturnType() != NewFn->getReturnType()) {
173       RetVal = new CastInst(NewCI, NewFn->getReturnType(), 
174                             NewCI->getName(), CI);
175       NewCI->moveBefore(RetVal);
176     }
177     CI->replaceAllUsesWith(RetVal);
178   }
179   CI->eraseFromParent();
180 }
181
182 bool llvm::UpgradeCallsToIntrinsic(Function* F) {
183   if (Function* newF = UpgradeIntrinsicFunction(F)) {
184     for (Value::use_iterator UI = F->use_begin(), UE = F->use_end();
185          UI != UE; ) {
186       if (CallInst* CI = dyn_cast<CallInst>(*UI++)) {
187         std::vector<Value*> Oprnds;
188         User::op_iterator OI = CI->op_begin();
189         ++OI;
190         for (User::op_iterator OE = CI->op_end(); OI != OE; ++OI) {
191           const Type* opTy = OI->get()->getType();
192           if (opTy->isSigned()) {
193             Oprnds.push_back(
194               new CastInst(OI->get(),opTy->getUnsignedVersion(), 
195                   "autoupgrade_cast",CI));
196           } else {
197             Oprnds.push_back(*OI);
198           }
199         }
200         CallInst* newCI = new CallInst(newF, Oprnds,
201                                        CI->hasName() ? "autoupcall" : "", CI);
202         newCI->setTailCall(CI->isTailCall());
203         newCI->setCallingConv(CI->getCallingConv());
204         if (CI->use_empty()) {
205           // noop
206         } else if (CI->getType() != newCI->getType()) {
207           CastInst *final = new CastInst(newCI, CI->getType(),
208                                          "autoupgrade_uncast", newCI);
209           newCI->moveBefore(final);
210           CI->replaceAllUsesWith(final);
211         } else {
212           CI->replaceAllUsesWith(newCI);
213         }
214         CI->eraseFromParent();
215       }
216     }
217     if (newF != F)
218       F->eraseFromParent();
219     return true;
220   }
221   return false;
222 }