If a function always returns a constant, replace all calls sites with that
[oota-llvm.git] / lib / Transforms / IPO / IPConstantPropagation.cpp
1 //===-- IPConstantPropagation.cpp - Propagate constants through calls -----===//
2 // 
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by the LLVM research group and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 // 
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass implements an _extremely_ simple interprocedural constant
11 // propagation pass.  It could certainly be improved in many different ways,
12 // like using a worklist.  This pass makes arguments dead, but does not remove
13 // them.  The existing dead argument elimination pass should be run after this
14 // to clean up the mess.
15 //
16 //===----------------------------------------------------------------------===//
17
18 #include "llvm/Transforms/IPO.h"
19 #include "llvm/Constants.h"
20 #include "llvm/Instructions.h"
21 #include "llvm/Module.h"
22 #include "llvm/Pass.h"
23 #include "llvm/Support/CallSite.h"
24 #include "llvm/ADT/Statistic.h"
25 using namespace llvm;
26
27 namespace {
28   Statistic<> NumArgumentsProped("ipconstprop",
29                                  "Number of args turned into constants");
30   Statistic<> NumReturnValProped("ipconstprop",
31                                  "Number of return values turned into constants");
32
33   /// IPCP - The interprocedural constant propagation pass
34   ///
35   struct IPCP : public ModulePass {
36     bool runOnModule(Module &M);
37   private:
38     bool PropagateConstantsIntoArguments(Function &F);
39     bool PropagateConstantReturn(Function &F);
40   };
41   RegisterOpt<IPCP> X("ipconstprop", "Interprocedural constant propagation");
42 }
43
44 ModulePass *llvm::createIPConstantPropagationPass() { return new IPCP(); }
45
46 bool IPCP::runOnModule(Module &M) {
47   bool Changed = false;
48   bool LocalChange = true;
49
50   // FIXME: instead of using smart algorithms, we just iterate until we stop
51   // making changes.
52   while (LocalChange) {
53     LocalChange = false;
54     for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
55       if (!I->isExternal()) {
56         // Delete any klingons.
57         I->removeDeadConstantUsers();
58         if (I->hasInternalLinkage())
59           LocalChange |= PropagateConstantsIntoArguments(*I);
60         Changed |= PropagateConstantReturn(*I);
61       }
62     Changed |= LocalChange;
63   }
64   return Changed;
65 }
66
67 /// PropagateConstantsIntoArguments - Look at all uses of the specified
68 /// function.  If all uses are direct call sites, and all pass a particular
69 /// constant in for an argument, propagate that constant in as the argument.
70 ///
71 bool IPCP::PropagateConstantsIntoArguments(Function &F) {
72   if (F.aempty() || F.use_empty()) return false;  // No arguments?  Early exit.
73
74   std::vector<std::pair<Constant*, bool> > ArgumentConstants;
75   ArgumentConstants.resize(F.asize());
76
77   unsigned NumNonconstant = 0;
78
79   for (Value::use_iterator I = F.use_begin(), E = F.use_end(); I != E; ++I)
80     if (!isa<Instruction>(*I))
81       return false;  // Used by a non-instruction, do not transform
82     else {
83       CallSite CS = CallSite::get(cast<Instruction>(*I));
84       if (CS.getInstruction() == 0 || 
85           CS.getCalledFunction() != &F)
86         return false;  // Not a direct call site?
87       
88       // Check out all of the potentially constant arguments
89       CallSite::arg_iterator AI = CS.arg_begin();
90       Function::aiterator Arg = F.abegin();
91       for (unsigned i = 0, e = ArgumentConstants.size(); i != e;
92            ++i, ++AI, ++Arg) {
93         if (*AI == &F) return false;  // Passes the function into itself
94
95         if (!ArgumentConstants[i].second) {
96           if (Constant *C = dyn_cast<Constant>(*AI)) {
97             if (!ArgumentConstants[i].first)
98               ArgumentConstants[i].first = C;
99             else if (ArgumentConstants[i].first != C) {
100               // Became non-constant
101               ArgumentConstants[i].second = true;
102               ++NumNonconstant;
103               if (NumNonconstant == ArgumentConstants.size()) return false;
104             }
105           } else if (*AI != &*Arg) {    // Ignore recursive calls with same arg
106             // This is not a constant argument.  Mark the argument as
107             // non-constant.
108             ArgumentConstants[i].second = true;
109             ++NumNonconstant;
110             if (NumNonconstant == ArgumentConstants.size()) return false;
111           }
112         }
113       }
114     }
115
116   // If we got to this point, there is a constant argument!
117   assert(NumNonconstant != ArgumentConstants.size());
118   Function::aiterator AI = F.abegin();
119   bool MadeChange = false;
120   for (unsigned i = 0, e = ArgumentConstants.size(); i != e; ++i, ++AI)
121     // Do we have a constant argument!?
122     if (!ArgumentConstants[i].second && !AI->use_empty()) {
123       Value *V = ArgumentConstants[i].first;
124       if (V == 0) V = UndefValue::get(AI->getType());
125       AI->replaceAllUsesWith(V);
126       ++NumArgumentsProped;
127       MadeChange = true;
128     }
129   return MadeChange;
130 }
131
132
133 // Check to see if this function returns a constant.  If so, replace all callers
134 // that user the return value with the returned valued.  If we can replace ALL
135 // callers,
136 bool IPCP::PropagateConstantReturn(Function &F) {
137   if (F.getReturnType() == Type::VoidTy)
138     return false; // No return value.
139
140   // Check to see if this function returns a constant.
141   Value *RetVal = 0;
142   for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB)
143     if (ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator()))
144       if (isa<UndefValue>(RI->getOperand(0))) {
145         // Ignore.
146       } else if (Constant *C = dyn_cast<Constant>(RI->getOperand(0))) {
147         if (RetVal == 0)
148           RetVal = C;
149         else if (RetVal != C)
150           return false;  // Does not return the same constant.
151       } else {
152         return false;  // Does not return a constant.
153       }
154
155   if (RetVal == 0) RetVal = UndefValue::get(F.getReturnType());
156
157   // If we got here, the function returns a constant value.  Loop over all
158   // users, replacing any uses of the return value with the returned constant.
159   bool ReplacedAllUsers = true;
160   bool MadeChange = false;
161   for (Value::use_iterator I = F.use_begin(), E = F.use_end(); I != E; ++I)
162     if (!isa<Instruction>(*I))
163       ReplacedAllUsers = false;
164     else {
165       CallSite CS = CallSite::get(cast<Instruction>(*I));
166       if (CS.getInstruction() == 0 || 
167           CS.getCalledFunction() != &F) {
168         ReplacedAllUsers = false;
169       } else {
170         if (!CS.getInstruction()->use_empty()) {
171           CS.getInstruction()->replaceAllUsesWith(RetVal);
172           MadeChange = true;
173         }
174       }
175     }
176
177   // If we replace all users with the returned constant, and there can be no
178   // other callers of the function, replace the constant being returned in the
179   // function with an undef value.
180   if (ReplacedAllUsers && F.hasInternalLinkage() && !isa<UndefValue>(RetVal)) {
181     Value *RV = UndefValue::get(RetVal->getType());
182     for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB)
183       if (ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator()))
184         RI->setOperand(0, RV);
185     MadeChange = true;
186   }
187
188   if (MadeChange) ++NumReturnValProped;
189
190   // FIXME: DAE should remove dead return values if the result is an undef
191   // value... or if it is never used.
192
193   return MadeChange;
194 }