[C++11] Add range based accessors for the Use-Def chain of a Value.
[oota-llvm.git] / lib / Transforms / IPO / IPConstantPropagation.cpp
index 99ad8445b15bf246068b82ee0ad0cf2180f8d06e..8684796b4e7835b729c89301650ca2af680c7824 100644 (file)
@@ -1,10 +1,10 @@
 //===-- IPConstantPropagation.cpp - Propagate constants through calls -----===//
-// 
+//
 //                     The LLVM Compiler Infrastructure
 //
-// This file was developed by the LLVM research group and is distributed under
-// the University of Illinois Open Source License. See LICENSE.TXT for details.
-// 
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
 //===----------------------------------------------------------------------===//
 //
 // This pass implements an _extremely_ simple interprocedural constant
 //
 //===----------------------------------------------------------------------===//
 
+#define DEBUG_TYPE "ipconstprop"
 #include "llvm/Transforms/IPO.h"
-#include "llvm/Constants.h"
-#include "llvm/Instructions.h"
-#include "llvm/Module.h"
-#include "llvm/Pass.h"
-#include "llvm/Support/CallSite.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Pass.h"
 using namespace llvm;
 
-namespace {
-  Statistic<> NumArgumentsProped("ipconstprop",
-                                 "Number of args turned into constants");
-  Statistic<> NumReturnValProped("ipconstprop",
-                                 "Number of return values turned into constants");
+STATISTIC(NumArgumentsProped, "Number of args turned into constants");
+STATISTIC(NumReturnValProped, "Number of return values turned into constants");
 
+namespace {
   /// IPCP - The interprocedural constant propagation pass
   ///
   struct IPCP : public ModulePass {
-    bool runOnModule(Module &M);
+    static char ID; // Pass identification, replacement for typeid
+    IPCP() : ModulePass(ID) {
+      initializeIPCPPass(*PassRegistry::getPassRegistry());
+    }
+
+    bool runOnModule(Module &M) override;
   private:
     bool PropagateConstantsIntoArguments(Function &F);
     bool PropagateConstantReturn(Function &F);
   };
-  RegisterOpt<IPCP> X("ipconstprop", "Interprocedural constant propagation");
 }
 
+char IPCP::ID = 0;
+INITIALIZE_PASS(IPCP, "ipconstprop",
+                "Interprocedural constant propagation", false, false)
+
 ModulePass *llvm::createIPConstantPropagationPass() { return new IPCP(); }
 
 bool IPCP::runOnModule(Module &M) {
@@ -52,10 +61,10 @@ bool IPCP::runOnModule(Module &M) {
   while (LocalChange) {
     LocalChange = false;
     for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I)
-      if (!I->isExternal()) {
+      if (!I->isDeclaration()) {
         // Delete any klingons.
         I->removeDeadConstantUsers();
-        if (I->hasInternalLinkage())
+        if (I->hasLocalLinkage())
           LocalChange |= PropagateConstantsIntoArguments(*I);
         Changed |= PropagateConstantReturn(*I);
       }
@@ -69,120 +78,199 @@ bool IPCP::runOnModule(Module &M) {
 /// constant in for an argument, propagate that constant in as the argument.
 ///
 bool IPCP::PropagateConstantsIntoArguments(Function &F) {
-  if (F.aempty() || F.use_empty()) return false;  // No arguments?  Early exit.
+  if (F.arg_empty() || F.use_empty()) return false; // No arguments? Early exit.
 
-  std::vector<std::pair<Constant*, bool> > ArgumentConstants;
-  ArgumentConstants.resize(F.asize());
+  // For each argument, keep track of its constant value and whether it is a
+  // constant or not.  The bool is driven to true when found to be non-constant.
+  SmallVector<std::pair<Constant*, bool>, 16> ArgumentConstants;
+  ArgumentConstants.resize(F.arg_size());
 
   unsigned NumNonconstant = 0;
+  for (Use &U : F.uses()) {
+    User *UR = U.getUser();
+    // Ignore blockaddress uses.
+    if (isa<BlockAddress>(UR)) continue;
+    
+    // Used by a non-instruction, or not the callee of a function, do not
+    // transform.
+    if (!isa<CallInst>(UR) && !isa<InvokeInst>(UR))
+      return false;
+    
+    CallSite CS(cast<Instruction>(UR));
+    if (!CS.isCallee(&U))
+      return false;
 
-  for (Value::use_iterator I = F.use_begin(), E = F.use_end(); I != E; ++I)
-    if (!isa<Instruction>(*I))
-      return false;  // Used by a non-instruction, do not transform
-    else {
-      CallSite CS = CallSite::get(cast<Instruction>(*I));
-      if (CS.getInstruction() == 0 || 
-          CS.getCalledFunction() != &F)
-        return false;  // Not a direct call site?
+    // Check out all of the potentially constant arguments.  Note that we don't
+    // inspect varargs here.
+    CallSite::arg_iterator AI = CS.arg_begin();
+    Function::arg_iterator Arg = F.arg_begin();
+    for (unsigned i = 0, e = ArgumentConstants.size(); i != e;
+         ++i, ++AI, ++Arg) {
       
-      // Check out all of the potentially constant arguments
-      CallSite::arg_iterator AI = CS.arg_begin();
-      Function::aiterator Arg = F.abegin();
-      for (unsigned i = 0, e = ArgumentConstants.size(); i != e;
-           ++i, ++AI, ++Arg) {
-        if (*AI == &F) return false;  // Passes the function into itself
-
-        if (!ArgumentConstants[i].second) {
-          if (Constant *C = dyn_cast<Constant>(*AI)) {
-            if (!ArgumentConstants[i].first)
-              ArgumentConstants[i].first = C;
-            else if (ArgumentConstants[i].first != C) {
-              // Became non-constant
-              ArgumentConstants[i].second = true;
-              ++NumNonconstant;
-              if (NumNonconstant == ArgumentConstants.size()) return false;
-            }
-          } else if (*AI != &*Arg) {    // Ignore recursive calls with same arg
-            // This is not a constant argument.  Mark the argument as
-            // non-constant.
-            ArgumentConstants[i].second = true;
-            ++NumNonconstant;
-            if (NumNonconstant == ArgumentConstants.size()) return false;
-          }
-        }
+      // If this argument is known non-constant, ignore it.
+      if (ArgumentConstants[i].second)
+        continue;
+      
+      Constant *C = dyn_cast<Constant>(*AI);
+      if (C && ArgumentConstants[i].first == 0) {
+        ArgumentConstants[i].first = C;   // First constant seen.
+      } else if (C && ArgumentConstants[i].first == C) {
+        // Still the constant value we think it is.
+      } else if (*AI == &*Arg) {
+        // Ignore recursive calls passing argument down.
+      } else {
+        // Argument became non-constant.  If all arguments are non-constant now,
+        // give up on this function.
+        if (++NumNonconstant == ArgumentConstants.size())
+          return false;
+        ArgumentConstants[i].second = true;
       }
     }
+  }
 
   // If we got to this point, there is a constant argument!
   assert(NumNonconstant != ArgumentConstants.size());
-  Function::aiterator AI = F.abegin();
   bool MadeChange = false;
-  for (unsigned i = 0, e = ArgumentConstants.size(); i != e; ++i, ++AI)
-    // Do we have a constant argument!?
-    if (!ArgumentConstants[i].second && !AI->use_empty()) {
-      Value *V = ArgumentConstants[i].first;
-      if (V == 0) V = UndefValue::get(AI->getType());
-      AI->replaceAllUsesWith(V);
-      ++NumArgumentsProped;
-      MadeChange = true;
-    }
+  Function::arg_iterator AI = F.arg_begin();
+  for (unsigned i = 0, e = ArgumentConstants.size(); i != e; ++i, ++AI) {
+    // Do we have a constant argument?
+    if (ArgumentConstants[i].second || AI->use_empty() ||
+        AI->hasInAllocaAttr() || (AI->hasByValAttr() && !F.onlyReadsMemory()))
+      continue;
+  
+    Value *V = ArgumentConstants[i].first;
+    if (V == 0) V = UndefValue::get(AI->getType());
+    AI->replaceAllUsesWith(V);
+    ++NumArgumentsProped;
+    MadeChange = true;
+  }
   return MadeChange;
 }
 
 
-// Check to see if this function returns a constant.  If so, replace all callers
-// that user the return value with the returned valued.  If we can replace ALL
-// callers,
+// Check to see if this function returns one or more constants. If so, replace
+// all callers that use those return values with the constant value. This will
+// leave in the actual return values and instructions, but deadargelim will
+// clean that up.
+//
+// Additionally if a function always returns one of its arguments directly,
+// callers will be updated to use the value they pass in directly instead of
+// using the return value.
 bool IPCP::PropagateConstantReturn(Function &F) {
-  if (F.getReturnType() == Type::VoidTy)
+  if (F.getReturnType()->isVoidTy())
     return false; // No return value.
 
+  // If this function could be overridden later in the link stage, we can't
+  // propagate information about its results into callers.
+  if (F.mayBeOverridden())
+    return false;
+    
   // Check to see if this function returns a constant.
-  Value *RetVal = 0;
+  SmallVector<Value *,4> RetVals;
+  StructType *STy = dyn_cast<StructType>(F.getReturnType());
+  if (STy)
+    for (unsigned i = 0, e = STy->getNumElements(); i < e; ++i) 
+      RetVals.push_back(UndefValue::get(STy->getElementType(i)));
+  else
+    RetVals.push_back(UndefValue::get(F.getReturnType()));
+
+  unsigned NumNonConstant = 0;
   for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB)
-    if (ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator()))
-      if (isa<UndefValue>(RI->getOperand(0))) {
-        // Ignore.
-      } else if (Constant *C = dyn_cast<Constant>(RI->getOperand(0))) {
-        if (RetVal == 0)
-          RetVal = C;
-        else if (RetVal != C)
-          return false;  // Does not return the same constant.
-      } else {
-        return false;  // Does not return a constant.
-      }
+    if (ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator())) {
+      for (unsigned i = 0, e = RetVals.size(); i != e; ++i) {
+        // Already found conflicting return values?
+        Value *RV = RetVals[i];
+        if (!RV)
+          continue;
 
-  if (RetVal == 0) RetVal = UndefValue::get(F.getReturnType());
+        // Find the returned value
+        Value *V;
+        if (!STy)
+          V = RI->getOperand(0);
+        else
+          V = FindInsertedValue(RI->getOperand(0), i);
 
-  // If we got here, the function returns a constant value.  Loop over all
-  // users, replacing any uses of the return value with the returned constant.
-  bool ReplacedAllUsers = true;
-  bool MadeChange = false;
-  for (Value::use_iterator I = F.use_begin(), E = F.use_end(); I != E; ++I)
-    if (!isa<Instruction>(*I))
-      ReplacedAllUsers = false;
-    else {
-      CallSite CS = CallSite::get(cast<Instruction>(*I));
-      if (CS.getInstruction() == 0 || 
-          CS.getCalledFunction() != &F) {
-        ReplacedAllUsers = false;
-      } else {
-        if (!CS.getInstruction()->use_empty()) {
-          CS.getInstruction()->replaceAllUsesWith(RetVal);
-          MadeChange = true;
+        if (V) {
+          // Ignore undefs, we can change them into anything
+          if (isa<UndefValue>(V))
+            continue;
+          
+          // Try to see if all the rets return the same constant or argument.
+          if (isa<Constant>(V) || isa<Argument>(V)) {
+            if (isa<UndefValue>(RV)) {
+              // No value found yet? Try the current one.
+              RetVals[i] = V;
+              continue;
+            }
+            // Returning the same value? Good.
+            if (RV == V)
+              continue;
+          }
         }
+        // Different or no known return value? Don't propagate this return
+        // value.
+        RetVals[i] = 0;
+        // All values non-constant? Stop looking.
+        if (++NumNonConstant == RetVals.size())
+          return false;
       }
     }
 
-  // If we replace all users with the returned constant, and there can be no
-  // other callers of the function, replace the constant being returned in the
-  // function with an undef value.
-  if (ReplacedAllUsers && F.hasInternalLinkage() && !isa<UndefValue>(RetVal)) {
-    Value *RV = UndefValue::get(RetVal->getType());
-    for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB)
-      if (ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator()))
-        RI->setOperand(0, RV);
+  // If we got here, the function returns at least one constant value.  Loop
+  // over all users, replacing any uses of the return value with the returned
+  // constant.
+  bool MadeChange = false;
+  for (Use &U : F.uses()) {
+    CallSite CS(U.getUser());
+    Instruction* Call = CS.getInstruction();
+
+    // Not a call instruction or a call instruction that's not calling F
+    // directly?
+    if (!Call || !CS.isCallee(&U))
+      continue;
+    
+    // Call result not used?
+    if (Call->use_empty())
+      continue;
+
     MadeChange = true;
+
+    if (STy == 0) {
+      Value* New = RetVals[0];
+      if (Argument *A = dyn_cast<Argument>(New))
+        // Was an argument returned? Then find the corresponding argument in
+        // the call instruction and use that.
+        New = CS.getArgument(A->getArgNo());
+      Call->replaceAllUsesWith(New);
+      continue;
+    }
+
+    for (auto I = Call->user_begin(), E = Call->user_end(); I != E;) {
+      Instruction *Ins = cast<Instruction>(*I);
+
+      // Increment now, so we can remove the use
+      ++I;
+
+      // Find the index of the retval to replace with
+      int index = -1;
+      if (ExtractValueInst *EV = dyn_cast<ExtractValueInst>(Ins))
+        if (EV->hasIndices())
+          index = *EV->idx_begin();
+
+      // If this use uses a specific return value, and we have a replacement,
+      // replace it.
+      if (index != -1) {
+        Value *New = RetVals[index];
+        if (New) {
+          if (Argument *A = dyn_cast<Argument>(New))
+            // Was an argument returned? Then find the corresponding argument in
+            // the call instruction and use that.
+            New = CS.getArgument(A->getArgNo());
+          Ins->replaceAllUsesWith(New);
+          Ins->eraseFromParent();
+        }
+      }
+    }
   }
 
   if (MadeChange) ++NumReturnValProped;