Update SetVector to rely on the underlying set's insert to return a pair<iterator...
[oota-llvm.git] / lib / Transforms / Scalar / TailRecursionElimination.cpp
index f8eafc845db5f8a8ead61071cc7405ff1583dc41..65b1f1428213f6a48b04778777a5edb4c584dc5c 100644 (file)
@@ -63,6 +63,7 @@
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/CallSite.h"
 #include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/Function.h"
@@ -86,6 +87,7 @@ STATISTIC(NumAccumAdded, "Number of accumulators introduced");
 namespace {
   struct TailCallElim : public FunctionPass {
     const TargetTransformInfo *TTI;
+    const DataLayout *DL;
 
     static char ID; // Pass identification, replacement for typeid
     TailCallElim() : FunctionPass(ID) {
@@ -157,6 +159,8 @@ bool TailCallElim::runOnFunction(Function &F) {
   if (skipOptnoneFunction(F))
     return false;
 
+  DL = F.getParent()->getDataLayout();
+
   bool AllCallsAreTailCalls = false;
   bool Modified = markTails(F, AllCallsAreTailCalls);
   if (AllCallsAreTailCalls)
@@ -175,7 +179,7 @@ struct AllocaDerivedValueTracker {
 
     auto AddUsesToWorklist = [&](Value *V) {
       for (auto &U : V->uses()) {
-        if (!Visited.insert(&U))
+        if (!Visited.insert(&U).second)
           continue;
         Worklist.push_back(&U);
       }
@@ -249,7 +253,12 @@ bool TailCallElim::markTails(Function &F, bool &AllCallsAreTailCalls) {
     return false;
   AllCallsAreTailCalls = true;
 
+  // The local stack holds all alloca instructions and all byval arguments.
   AllocaDerivedValueTracker Tracker;
+  for (Argument &Arg : F.args()) {
+    if (Arg.hasByValAttr())
+      Tracker.walk(&Arg);
+  }
   for (auto &BB : F) {
     for (auto &I : BB)
       if (AllocaInst *AI = dyn_cast<AllocaInst>(&I))
@@ -305,8 +314,9 @@ bool TailCallElim::markTails(Function &F, bool &AllCallsAreTailCalls) {
         for (auto &Arg : CI->arg_operands()) {
           if (isa<Constant>(Arg.getUser()))
             continue;
-          if (isa<Argument>(Arg.getUser()))
-            continue;
+          if (Argument *A = dyn_cast<Argument>(Arg.getUser()))
+            if (!A->hasByValAttr())
+              continue;
           SafeToTail = false;
           break;
         }
@@ -444,7 +454,7 @@ bool TailCallElim::CanMoveAboveCall(Instruction *I, CallInst *CI) {
       // being loaded from.
       if (CI->mayWriteToMemory() ||
           !isSafeToLoadUnconditionally(L->getPointerOperand(), L,
-                                       L->getAlignment()))
+                                       L->getAlignment(), DL))
         return false;
     }
   }