[C++11] Add predecessors(BasicBlock *) / successors(BasicBlock *) iterator ranges.
[oota-llvm.git] / lib / Transforms / Scalar / TailRecursionElimination.cpp
index 1b8ed4127c4b7537f67c5ead640ea4699c428873..d9280ac6c9c11084785a59316401606d6fa83a3c 100644 (file)
@@ -16,9 +16,9 @@
 //     transformation from taking place, though currently the analysis cannot
 //     support moving any really useful instructions (only dead ones).
 //  2. This pass transforms functions that are prevented from being tail
-//     recursive by an associative expression to use an accumulator variable,
-//     thus compiling the typical naive factorial or 'fib' implementation into
-//     efficient code.
+//     recursive by an associative and commutative expression to use an
+//     accumulator variable, thus compiling the typical naive factorial or
+//     'fib' implementation into efficient code.
 //  3. TRE is performed if the function returns void, if the return
 //     returns the result returned by the call, or if the function returns a
 //     run-time constant on all exits from the function.  It is possible, though
@@ -36,7 +36,7 @@
 //     evaluated each time through the tail recursion.  Safely keeping allocas
 //     in the entry block requires analysis to proves that the tail-called
 //     function does not read or write the stack object.
-//  2. Tail recursion is only performed if the call immediately preceeds the
+//  2. Tail recursion is only performed if the call immediately precedes the
 //     return instruction.  It's possible that there could be a jump between
 //     the call and the return.
 //  3. There can be intervening operations between the call and the return that
 //
 //===----------------------------------------------------------------------===//
 
-#define DEBUG_TYPE "tailcallelim"
 #include "llvm/Transforms/Scalar.h"
-#include "llvm/Transforms/Utils/Local.h"
-#include "llvm/Constants.h"
-#include "llvm/DerivedTypes.h"
-#include "llvm/Function.h"
-#include "llvm/Instructions.h"
-#include "llvm/Pass.h"
-#include "llvm/Analysis/CaptureTracking.h"
-#include "llvm/Support/CFG.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/CaptureTracking.h"
+#include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/InlineCost.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/Loads.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/ValueHandle.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Local.h"
 using namespace llvm;
 
+#define DEBUG_TYPE "tailcallelim"
+
 STATISTIC(NumEliminated, "Number of tail calls removed");
+STATISTIC(NumRetDuped,   "Number of return duplicated");
 STATISTIC(NumAccumAdded, "Number of accumulators introduced");
 
 namespace {
   struct TailCallElim : public FunctionPass {
+    const TargetTransformInfo *TTI;
+
     static char ID; // Pass identification, replacement for typeid
-    TailCallElim() : FunctionPass(&ID) {}
+    TailCallElim() : FunctionPass(ID) {
+      initializeTailCallElimPass(*PassRegistry::getPassRegistry());
+    }
+
+    void getAnalysisUsage(AnalysisUsage &AU) const override;
 
-    virtual bool runOnFunction(Function &F);
+    bool runOnFunction(Function &F) override;
 
   private:
+    bool runTRE(Function &F);
+    bool markTails(Function &F, bool &AllCallsAreTailCalls);
+
+    CallInst *FindTRECandidate(Instruction *I,
+                               bool CannotTailCallElimCallsMarkedTail);
+    bool EliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret,
+                                    BasicBlock *&OldEntry,
+                                    bool &TailCallsAreMarkedTail,
+                                    SmallVectorImpl<PHINode *> &ArgumentPHIs,
+                                    bool CannotTailCallElimCallsMarkedTail);
+    bool FoldReturnAndProcessPred(BasicBlock *BB,
+                                  ReturnInst *Ret, BasicBlock *&OldEntry,
+                                  bool &TailCallsAreMarkedTail,
+                                  SmallVectorImpl<PHINode *> &ArgumentPHIs,
+                                  bool CannotTailCallElimCallsMarkedTail);
     bool ProcessReturningBlock(ReturnInst *RI, BasicBlock *&OldEntry,
                                bool &TailCallsAreMarkedTail,
-                               SmallVector<PHINode*, 8> &ArgumentPHIs,
+                               SmallVectorImpl<PHINode *> &ArgumentPHIs,
                                bool CannotTailCallElimCallsMarkedTail);
     bool CanMoveAboveCall(Instruction *I, CallInst *CI);
     Value *CanTransformAccumulatorRecursion(Instruction *I, CallInst *CI);
@@ -84,103 +122,313 @@ namespace {
 }
 
 char TailCallElim::ID = 0;
-static RegisterPass<TailCallElim> X("tailcallelim", "Tail Call Elimination");
+INITIALIZE_PASS_BEGIN(TailCallElim, "tailcallelim",
+                      "Tail Call Elimination", false, false)
+INITIALIZE_AG_DEPENDENCY(TargetTransformInfo)
+INITIALIZE_PASS_END(TailCallElim, "tailcallelim",
+                    "Tail Call Elimination", false, false)
 
 // Public interface to the TailCallElimination pass
 FunctionPass *llvm::createTailCallEliminationPass() {
   return new TailCallElim();
 }
 
-/// CheckForEscapingAllocas - Scan the specified basic block for alloca
-/// instructions.  If it contains any that might be accessed by calls, return
-/// true.
-static bool CheckForEscapingAllocas(BasicBlock *BB,
-                                    bool &CannotTCETailMarkedCall) {
-  bool RetVal = false;
-  for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I)
-    if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) {
-      RetVal |= PointerMayBeCaptured(AI, true);
-
-      // If this alloca is in the body of the function, or if it is a variable
-      // sized allocation, we cannot tail call eliminate calls marked 'tail'
-      // with this mechanism.
-      if (BB != &BB->getParent()->getEntryBlock() ||
-          !isa<ConstantInt>(AI->getArraySize()))
-        CannotTCETailMarkedCall = true;
+void TailCallElim::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.addRequired<TargetTransformInfo>();
+}
+
+/// \brief Scan the specified function for alloca instructions.
+/// If it contains any dynamic allocas, returns false.
+static bool CanTRE(Function &F) {
+  // Because of PR962, we don't TRE dynamic allocas.
+  for (auto &BB : F) {
+    for (auto &I : BB) {
+      if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
+        if (!AI->isStaticAlloca())
+          return false;
+      }
     }
-  return RetVal;
+  }
+
+  return true;
 }
 
 bool TailCallElim::runOnFunction(Function &F) {
+  if (skipOptnoneFunction(F))
+    return false;
+
+  bool AllCallsAreTailCalls = false;
+  bool Modified = markTails(F, AllCallsAreTailCalls);
+  if (AllCallsAreTailCalls)
+    Modified |= runTRE(F);
+  return Modified;
+}
+
+namespace {
+struct AllocaDerivedValueTracker {
+  // Start at a root value and walk its use-def chain to mark calls that use the
+  // value or a derived value in AllocaUsers, and places where it may escape in
+  // EscapePoints.
+  void walk(Value *Root) {
+    SmallVector<Use *, 32> Worklist;
+    SmallPtrSet<Use *, 32> Visited;
+
+    auto AddUsesToWorklist = [&](Value *V) {
+      for (auto &U : V->uses()) {
+        if (!Visited.insert(&U))
+          continue;
+        Worklist.push_back(&U);
+      }
+    };
+
+    AddUsesToWorklist(Root);
+
+    while (!Worklist.empty()) {
+      Use *U = Worklist.pop_back_val();
+      Instruction *I = cast<Instruction>(U->getUser());
+
+      switch (I->getOpcode()) {
+      case Instruction::Call:
+      case Instruction::Invoke: {
+        CallSite CS(I);
+        bool IsNocapture = !CS.isCallee(U) &&
+                           CS.doesNotCapture(CS.getArgumentNo(U));
+        callUsesLocalStack(CS, IsNocapture);
+        if (IsNocapture) {
+          // If the alloca-derived argument is passed in as nocapture, then it
+          // can't propagate to the call's return. That would be capturing.
+          continue;
+        }
+        break;
+      }
+      case Instruction::Load: {
+        // The result of a load is not alloca-derived (unless an alloca has
+        // otherwise escaped, but this is a local analysis).
+        continue;
+      }
+      case Instruction::Store: {
+        if (U->getOperandNo() == 0)
+          EscapePoints.insert(I);
+        continue;  // Stores have no users to analyze.
+      }
+      case Instruction::BitCast:
+      case Instruction::GetElementPtr:
+      case Instruction::PHI:
+      case Instruction::Select:
+      case Instruction::AddrSpaceCast:
+        break;
+      default:
+        EscapePoints.insert(I);
+        break;
+      }
+
+      AddUsesToWorklist(I);
+    }
+  }
+
+  void callUsesLocalStack(CallSite CS, bool IsNocapture) {
+    // Add it to the list of alloca users. If it's already there, skip further
+    // processing.
+    if (!AllocaUsers.insert(CS.getInstruction()))
+      return;
+
+    // If it's nocapture then it can't capture the alloca.
+    if (IsNocapture)
+      return;
+
+    // If it can write to memory, it can leak the alloca value.
+    if (!CS.onlyReadsMemory())
+      EscapePoints.insert(CS.getInstruction());
+  }
+
+  SmallPtrSet<Instruction *, 32> AllocaUsers;
+  SmallPtrSet<Instruction *, 32> EscapePoints;
+};
+}
+
+bool TailCallElim::markTails(Function &F, bool &AllCallsAreTailCalls) {
+  if (F.callsFunctionThatReturnsTwice())
+    return false;
+  AllCallsAreTailCalls = true;
+
+  // The local stack holds all alloca instructions and all byval arguments.
+  AllocaDerivedValueTracker Tracker;
+  for (Argument &Arg : F.args()) {
+    if (Arg.hasByValAttr())
+      Tracker.walk(&Arg);
+  }
+  for (auto &BB : F) {
+    for (auto &I : BB)
+      if (AllocaInst *AI = dyn_cast<AllocaInst>(&I))
+        Tracker.walk(AI);
+  }
+
+  bool Modified = false;
+
+  // Track whether a block is reachable after an alloca has escaped. Blocks that
+  // contain the escaping instruction will be marked as being visited without an
+  // escaped alloca, since that is how the block began.
+  enum VisitType {
+    UNVISITED,
+    UNESCAPED,
+    ESCAPED
+  };
+  DenseMap<BasicBlock *, VisitType> Visited;
+
+  // We propagate the fact that an alloca has escaped from block to successor.
+  // Visit the blocks that are propagating the escapedness first. To do this, we
+  // maintain two worklists.
+  SmallVector<BasicBlock *, 32> WorklistUnescaped, WorklistEscaped;
+
+  // We may enter a block and visit it thinking that no alloca has escaped yet,
+  // then see an escape point and go back around a loop edge and come back to
+  // the same block twice. Because of this, we defer setting tail on calls when
+  // we first encounter them in a block. Every entry in this list does not
+  // statically use an alloca via use-def chain analysis, but may find an alloca
+  // through other means if the block turns out to be reachable after an escape
+  // point.
+  SmallVector<CallInst *, 32> DeferredTails;
+
+  BasicBlock *BB = &F.getEntryBlock();
+  VisitType Escaped = UNESCAPED;
+  do {
+    for (auto &I : *BB) {
+      if (Tracker.EscapePoints.count(&I))
+        Escaped = ESCAPED;
+
+      CallInst *CI = dyn_cast<CallInst>(&I);
+      if (!CI || CI->isTailCall())
+        continue;
+
+      if (CI->doesNotAccessMemory()) {
+        // A call to a readnone function whose arguments are all things computed
+        // outside this function can be marked tail. Even if you stored the
+        // alloca address into a global, a readnone function can't load the
+        // global anyhow.
+        //
+        // Note that this runs whether we know an alloca has escaped or not. If
+        // it has, then we can't trust Tracker.AllocaUsers to be accurate.
+        bool SafeToTail = true;
+        for (auto &Arg : CI->arg_operands()) {
+          if (isa<Constant>(Arg.getUser()))
+            continue;
+          if (Argument *A = dyn_cast<Argument>(Arg.getUser()))
+            if (!A->hasByValAttr())
+              continue;
+          SafeToTail = false;
+          break;
+        }
+        if (SafeToTail) {
+          emitOptimizationRemark(
+              F.getContext(), "tailcallelim", F, CI->getDebugLoc(),
+              "marked this readnone call a tail call candidate");
+          CI->setTailCall();
+          Modified = true;
+          continue;
+        }
+      }
+
+      if (Escaped == UNESCAPED && !Tracker.AllocaUsers.count(CI)) {
+        DeferredTails.push_back(CI);
+      } else {
+        AllCallsAreTailCalls = false;
+      }
+    }
+
+    for (auto *SuccBB : successors(BB)) {
+      auto &State = Visited[SuccBB];
+      if (State < Escaped) {
+        State = Escaped;
+        if (State == ESCAPED)
+          WorklistEscaped.push_back(SuccBB);
+        else
+          WorklistUnescaped.push_back(SuccBB);
+      }
+    }
+
+    if (!WorklistEscaped.empty()) {
+      BB = WorklistEscaped.pop_back_val();
+      Escaped = ESCAPED;
+    } else {
+      BB = nullptr;
+      while (!WorklistUnescaped.empty()) {
+        auto *NextBB = WorklistUnescaped.pop_back_val();
+        if (Visited[NextBB] == UNESCAPED) {
+          BB = NextBB;
+          Escaped = UNESCAPED;
+          break;
+        }
+      }
+    }
+  } while (BB);
+
+  for (CallInst *CI : DeferredTails) {
+    if (Visited[CI->getParent()] != ESCAPED) {
+      // If the escape point was part way through the block, calls after the
+      // escape point wouldn't have been put into DeferredTails.
+      emitOptimizationRemark(F.getContext(), "tailcallelim", F,
+                             CI->getDebugLoc(),
+                             "marked this call a tail call candidate");
+      CI->setTailCall();
+      Modified = true;
+    } else {
+      AllCallsAreTailCalls = false;
+    }
+  }
+
+  return Modified;
+}
+
+bool TailCallElim::runTRE(Function &F) {
   // If this function is a varargs function, we won't be able to PHI the args
   // right, so don't even try to convert it...
   if (F.getFunctionType()->isVarArg()) return false;
 
-  BasicBlock *OldEntry = 0;
+  TTI = &getAnalysis<TargetTransformInfo>();
+  BasicBlock *OldEntry = nullptr;
   bool TailCallsAreMarkedTail = false;
   SmallVector<PHINode*, 8> ArgumentPHIs;
   bool MadeChange = false;
 
-  bool FunctionContainsEscapingAllocas = false;
-
-  // CannotTCETailMarkedCall - If true, we cannot perform TCE on tail calls
+  // CanTRETailMarkedCall - If false, we cannot perform TRE on tail calls
   // marked with the 'tail' attribute, because doing so would cause the stack
-  // size to increase (real TCE would deallocate variable sized allocas, TCE
+  // size to increase (real TRE would deallocate variable sized allocas, TRE
   // doesn't).
-  bool CannotTCETailMarkedCall = false;
+  bool CanTRETailMarkedCall = CanTRE(F);
 
-  // Loop over the function, looking for any returning blocks, and keeping track
-  // of whether this function has any non-trivially used allocas.
+  // Change any tail recursive calls to loops.
+  //
+  // FIXME: The code generator produces really bad code when an 'escaping
+  // alloca' is changed from being a static alloca to being a dynamic alloca.
+  // Until this is resolved, disable this transformation if that would ever
+  // happen.  This bug is PR962.
   for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) {
-    if (FunctionContainsEscapingAllocas && CannotTCETailMarkedCall)
-      break;
-
-    FunctionContainsEscapingAllocas |=
-      CheckForEscapingAllocas(BB, CannotTCETailMarkedCall);
+    if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) {
+      bool Change = ProcessReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail,
+                                          ArgumentPHIs, !CanTRETailMarkedCall);
+      if (!Change && BB->getFirstNonPHIOrDbg() == Ret)
+        Change = FoldReturnAndProcessPred(BB, Ret, OldEntry,
+                                          TailCallsAreMarkedTail, ArgumentPHIs,
+                                          !CanTRETailMarkedCall);
+      MadeChange |= Change;
+    }
   }
-  
-  /// FIXME: The code generator produces really bad code when an 'escaping
-  /// alloca' is changed from being a static alloca to being a dynamic alloca.
-  /// Until this is resolved, disable this transformation if that would ever
-  /// happen.  This bug is PR962.
-  if (FunctionContainsEscapingAllocas)
-    return false;
-  
-
-  // Second pass, change any tail calls to loops.
-  for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB)
-    if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator()))
-      MadeChange |= ProcessReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail,
-                                          ArgumentPHIs,CannotTCETailMarkedCall);
 
   // If we eliminated any tail recursions, it's possible that we inserted some
   // silly PHI nodes which just merge an initial value (the incoming operand)
   // with themselves.  Check to see if we did and clean up our mess if so.  This
   // occurs when a function passes an argument straight through to its tail
   // call.
-  if (!ArgumentPHIs.empty()) {
-    for (unsigned i = 0, e = ArgumentPHIs.size(); i != e; ++i) {
-      PHINode *PN = ArgumentPHIs[i];
-
-      // If the PHI Node is a dynamic constant, replace it with the value it is.
-      if (Value *PNV = PN->hasConstantValue()) {
-        PN->replaceAllUsesWith(PNV);
-        PN->eraseFromParent();
-      }
+  for (unsigned i = 0, e = ArgumentPHIs.size(); i != e; ++i) {
+    PHINode *PN = ArgumentPHIs[i];
+
+    // If the PHI Node is a dynamic constant, replace it with the value it is.
+    if (Value *PNV = SimplifyInstruction(PN)) {
+      PN->replaceAllUsesWith(PNV);
+      PN->eraseFromParent();
     }
   }
 
-  // Finally, if this function contains no non-escaping allocas, mark all calls
-  // in the function as eligible for tail calls (there is no stack memory for
-  // them to access).
-  if (!FunctionContainsEscapingAllocas)
-    for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB)
-      for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I)
-        if (CallInst *CI = dyn_cast<CallInst>(I)) {
-          CI->setTailCall();
-          MadeChange = true;
-        }
-
   return MadeChange;
 }
 
@@ -194,7 +442,7 @@ bool TailCallElim::CanMoveAboveCall(Instruction *I, CallInst *CI) {
   // call does not mod/ref the memory location being processed.
   if (I->mayHaveSideEffects())  // This also handles volatile loads.
     return false;
-  
+
   if (LoadInst *L = dyn_cast<LoadInst>(I)) {
     // Loads may always be moved above calls without side effects.
     if (CI->mayHaveSideEffects()) {
@@ -203,7 +451,8 @@ bool TailCallElim::CanMoveAboveCall(Instruction *I, CallInst *CI) {
       // FIXME: Writes to memory only matter if they may alias the pointer
       // being loaded from.
       if (CI->mayWriteToMemory() ||
-          !isSafeToLoadUnconditionally(L->getPointerOperand(), L))
+          !isSafeToLoadUnconditionally(L->getPointerOperand(), L,
+                                       L->getAlignment()))
         return false;
     }
   }
@@ -226,7 +475,7 @@ bool TailCallElim::CanMoveAboveCall(Instruction *I, CallInst *CI) {
 // We currently handle static constants and arguments that are not modified as
 // part of the recursion.
 //
-static bool isDynamicConstant(Value *V, CallInst *CI) {
+static bool isDynamicConstant(Value *V, CallInst *CI, ReturnInst *RI) {
   if (isa<Constant>(V)) return true; // Static constants are always dyn consts
 
   // Check to see if this is an immutable argument, if so, the value
@@ -241,37 +490,46 @@ static bool isDynamicConstant(Value *V, CallInst *CI) {
     // If we are passing this argument into call as the corresponding
     // argument operand, then the argument is dynamically constant.
     // Otherwise, we cannot transform this function safely.
-    if (CI->getOperand(ArgNo+1) == Arg)
+    if (CI->getArgOperand(ArgNo) == Arg)
       return true;
   }
+
+  // Switch cases are always constant integers. If the value is being switched
+  // on and the return is only reachable from one of its cases, it's
+  // effectively constant.
+  if (BasicBlock *UniquePred = RI->getParent()->getUniquePredecessor())
+    if (SwitchInst *SI = dyn_cast<SwitchInst>(UniquePred->getTerminator()))
+      if (SI->getCondition() == V)
+        return SI->getDefaultDest() != RI->getParent();
+
   // Not a constant or immutable argument, we can't safely transform.
   return false;
 }
 
 // getCommonReturnValue - Check to see if the function containing the specified
-// return instruction and tail call consistently returns the same
-// runtime-constant value at all exit points.  If so, return the returned value.
+// tail call consistently returns the same runtime-constant value at all exit
+// points except for IgnoreRI.  If so, return the returned value.
 //
-static Value *getCommonReturnValue(ReturnInst *TheRI, CallInst *CI) {
-  Function *F = TheRI->getParent()->getParent();
-  Value *ReturnedValue = 0;
-
-  for (Function::iterator BBI = F->begin(), E = F->end(); BBI != E; ++BBI)
-    if (ReturnInst *RI = dyn_cast<ReturnInst>(BBI->getTerminator()))
-      if (RI != TheRI) {
-        Value *RetOp = RI->getOperand(0);
-
-        // We can only perform this transformation if the value returned is
-        // evaluatable at the start of the initial invocation of the function,
-        // instead of at the end of the evaluation.
-        //
-        if (!isDynamicConstant(RetOp, CI))
-          return 0;
-
-        if (ReturnedValue && RetOp != ReturnedValue)
-          return 0;     // Cannot transform if differing values are returned.
-        ReturnedValue = RetOp;
-      }
+static Value *getCommonReturnValue(ReturnInst *IgnoreRI, CallInst *CI) {
+  Function *F = CI->getParent()->getParent();
+  Value *ReturnedValue = nullptr;
+
+  for (Function::iterator BBI = F->begin(), E = F->end(); BBI != E; ++BBI) {
+    ReturnInst *RI = dyn_cast<ReturnInst>(BBI->getTerminator());
+    if (RI == nullptr || RI == IgnoreRI) continue;
+
+    // We can only perform this transformation if the value returned is
+    // evaluatable at the start of the initial invocation of the function,
+    // instead of at the end of the evaluation.
+    //
+    Value *RetOp = RI->getOperand(0);
+    if (!isDynamicConstant(RetOp, CI, RI))
+      return nullptr;
+
+    if (ReturnedValue && RetOp != ReturnedValue)
+      return nullptr;     // Cannot transform if differing values are returned.
+    ReturnedValue = RetOp;
+  }
   return ReturnedValue;
 }
 
@@ -281,90 +539,122 @@ static Value *getCommonReturnValue(ReturnInst *TheRI, CallInst *CI) {
 ///
 Value *TailCallElim::CanTransformAccumulatorRecursion(Instruction *I,
                                                       CallInst *CI) {
-  if (!I->isAssociative()) return 0;
+  if (!I->isAssociative() || !I->isCommutative()) return nullptr;
   assert(I->getNumOperands() == 2 &&
-         "Associative operations should have 2 args!");
+         "Associative/commutative operations should have 2 args!");
 
-  // Exactly one operand should be the result of the call instruction...
+  // Exactly one operand should be the result of the call instruction.
   if ((I->getOperand(0) == CI && I->getOperand(1) == CI) ||
       (I->getOperand(0) != CI && I->getOperand(1) != CI))
-    return 0;
+    return nullptr;
 
   // The only user of this instruction we allow is a single return instruction.
-  if (!I->hasOneUse() || !isa<ReturnInst>(I->use_back()))
-    return 0;
+  if (!I->hasOneUse() || !isa<ReturnInst>(I->user_back()))
+    return nullptr;
 
   // Ok, now we have to check all of the other return instructions in this
   // function.  If they return non-constants or differing values, then we cannot
   // transform the function safely.
-  return getCommonReturnValue(cast<ReturnInst>(I->use_back()), CI);
+  return getCommonReturnValue(cast<ReturnInst>(I->user_back()), CI);
 }
 
-bool TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry,
-                                         bool &TailCallsAreMarkedTail,
-                                         SmallVector<PHINode*, 8> &ArgumentPHIs,
-                                       bool CannotTailCallElimCallsMarkedTail) {
-  BasicBlock *BB = Ret->getParent();
+static Instruction *FirstNonDbg(BasicBlock::iterator I) {
+  while (isa<DbgInfoIntrinsic>(I))
+    ++I;
+  return &*I;
+}
+
+CallInst*
+TailCallElim::FindTRECandidate(Instruction *TI,
+                               bool CannotTailCallElimCallsMarkedTail) {
+  BasicBlock *BB = TI->getParent();
   Function *F = BB->getParent();
 
-  if (&BB->front() == Ret) // Make sure there is something before the ret...
-    return false;
-  
-  // If the return is in the entry block, then making this transformation would
-  // turn infinite recursion into an infinite loop.  This transformation is ok
-  // in theory, but breaks some code like:
-  //   double fabs(double f) { return __builtin_fabs(f); } // a 'fabs' call
-  // disable this xform in this case, because the code generator will lower the
-  // call to fabs into inline code.
-  if (BB == &F->getEntryBlock())
-    return false;
+  if (&BB->front() == TI) // Make sure there is something before the terminator.
+    return nullptr;
 
   // Scan backwards from the return, checking to see if there is a tail call in
   // this block.  If so, set CI to it.
-  CallInst *CI;
-  BasicBlock::iterator BBI = Ret;
-  while (1) {
+  CallInst *CI = nullptr;
+  BasicBlock::iterator BBI = TI;
+  while (true) {
     CI = dyn_cast<CallInst>(BBI);
     if (CI && CI->getCalledFunction() == F)
       break;
 
     if (BBI == BB->begin())
-      return false;          // Didn't find a potential tail call.
+      return nullptr;          // Didn't find a potential tail call.
     --BBI;
   }
 
   // If this call is marked as a tail call, and if there are dynamic allocas in
   // the function, we cannot perform this optimization.
   if (CI->isTailCall() && CannotTailCallElimCallsMarkedTail)
-    return false;
+    return nullptr;
 
-  // If we are introducing accumulator recursion to eliminate associative
-  // operations after the call instruction, this variable contains the initial
-  // value for the accumulator.  If this value is set, we actually perform
-  // accumulator recursion elimination instead of simple tail recursion
-  // elimination.
-  Value *AccumulatorRecursionEliminationInitVal = 0;
-  Instruction *AccumulatorRecursionInstr = 0;
+  // As a special case, detect code like this:
+  //   double fabs(double f) { return __builtin_fabs(f); } // a 'fabs' call
+  // and disable this xform in this case, because the code generator will
+  // lower the call to fabs into inline code.
+  if (BB == &F->getEntryBlock() &&
+      FirstNonDbg(BB->front()) == CI &&
+      FirstNonDbg(std::next(BB->begin())) == TI &&
+      CI->getCalledFunction() &&
+      !TTI->isLoweredToCall(CI->getCalledFunction())) {
+    // A single-block function with just a call and a return. Check that
+    // the arguments match.
+    CallSite::arg_iterator I = CallSite(CI).arg_begin(),
+                           E = CallSite(CI).arg_end();
+    Function::arg_iterator FI = F->arg_begin(),
+                           FE = F->arg_end();
+    for (; I != E && FI != FE; ++I, ++FI)
+      if (*I != &*FI) break;
+    if (I == E && FI == FE)
+      return nullptr;
+  }
+
+  return CI;
+}
+
+bool TailCallElim::EliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret,
+                                       BasicBlock *&OldEntry,
+                                       bool &TailCallsAreMarkedTail,
+                                       SmallVectorImpl<PHINode *> &ArgumentPHIs,
+                                       bool CannotTailCallElimCallsMarkedTail) {
+  // If we are introducing accumulator recursion to eliminate operations after
+  // the call instruction that are both associative and commutative, the initial
+  // value for the accumulator is placed in this variable.  If this value is set
+  // then we actually perform accumulator recursion elimination instead of
+  // simple tail recursion elimination.  If the operation is an LLVM instruction
+  // (eg: "add") then it is recorded in AccumulatorRecursionInstr.  If not, then
+  // we are handling the case when the return instruction returns a constant C
+  // which is different to the constant returned by other return instructions
+  // (which is recorded in AccumulatorRecursionEliminationInitVal).  This is a
+  // special case of accumulator recursion, the operation being "return C".
+  Value *AccumulatorRecursionEliminationInitVal = nullptr;
+  Instruction *AccumulatorRecursionInstr = nullptr;
 
   // Ok, we found a potential tail call.  We can currently only transform the
   // tail call if all of the instructions between the call and the return are
   // movable to above the call itself, leaving the call next to the return.
   // Check that this is the case now.
-  for (BBI = CI, ++BBI; &*BBI != Ret; ++BBI)
-    if (!CanMoveAboveCall(BBI, CI)) {
-      // If we can't move the instruction above the call, it might be because it
-      // is an associative operation that could be tranformed using accumulator
-      // recursion elimination.  Check to see if this is the case, and if so,
-      // remember the initial accumulator value for later.
-      if ((AccumulatorRecursionEliminationInitVal =
-                             CanTransformAccumulatorRecursion(BBI, CI))) {
-        // Yes, this is accumulator recursion.  Remember which instruction
-        // accumulates.
-        AccumulatorRecursionInstr = BBI;
-      } else {
-        return false;   // Otherwise, we cannot eliminate the tail recursion!
-      }
+  BasicBlock::iterator BBI = CI;
+  for (++BBI; &*BBI != Ret; ++BBI) {
+    if (CanMoveAboveCall(BBI, CI)) continue;
+
+    // If we can't move the instruction above the call, it might be because it
+    // is an associative and commutative operation that could be transformed
+    // using accumulator recursion elimination.  Check to see if this is the
+    // case, and if so, remember the initial accumulator value for later.
+    if ((AccumulatorRecursionEliminationInitVal =
+                           CanTransformAccumulatorRecursion(BBI, CI))) {
+      // Yes, this is accumulator recursion.  Remember which instruction
+      // accumulates.
+      AccumulatorRecursionInstr = BBI;
+    } else {
+      return false;   // Otherwise, we cannot eliminate the tail recursion!
     }
+  }
 
   // We can only transform call/return pairs that either ignore the return value
   // of the call and return void, ignore the value of the call and return a
@@ -372,13 +662,29 @@ bool TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry,
   // accumulator recursion variable eliminated.
   if (Ret->getNumOperands() == 1 && Ret->getReturnValue() != CI &&
       !isa<UndefValue>(Ret->getReturnValue()) &&
-      AccumulatorRecursionEliminationInitVal == 0 &&
-      !getCommonReturnValue(Ret, CI))
-    return false;
+      AccumulatorRecursionEliminationInitVal == nullptr &&
+      !getCommonReturnValue(nullptr, CI)) {
+    // One case remains that we are able to handle: the current return
+    // instruction returns a constant, and all other return instructions
+    // return a different constant.
+    if (!isDynamicConstant(Ret->getReturnValue(), CI, Ret))
+      return false; // Current return instruction does not return a constant.
+    // Check that all other return instructions return a common constant.  If
+    // so, record it in AccumulatorRecursionEliminationInitVal.
+    AccumulatorRecursionEliminationInitVal = getCommonReturnValue(Ret, CI);
+    if (!AccumulatorRecursionEliminationInitVal)
+      return false;
+  }
+
+  BasicBlock *BB = Ret->getParent();
+  Function *F = BB->getParent();
+
+  emitOptimizationRemark(F->getContext(), "tailcallelim", *F, CI->getDebugLoc(),
+                         "transforming tail recursion to loop");
 
   // OK! We can transform this tail call.  If this is the first one found,
   // create the new entry block, allowing us to branch back to the old entry.
-  if (OldEntry == 0) {
+  if (!OldEntry) {
     OldEntry = &F->getEntryBlock();
     BasicBlock *NewEntry = BasicBlock::Create(F->getContext(), "", F, OldEntry);
     NewEntry->takeName(OldEntry);
@@ -403,7 +709,7 @@ bool TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry,
     Instruction *InsertPos = OldEntry->begin();
     for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end();
          I != E; ++I) {
-      PHINode *PN = PHINode::Create(I->getType(),
+      PHINode *PN = PHINode::Create(I->getType(), 2,
                                     I->getName() + ".tr", InsertPos);
       I->replaceAllUsesWith(PN); // Everyone use the PHI node now!
       PN->addIncoming(I, NewEntry);
@@ -423,8 +729,8 @@ bool TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry,
   // Ok, now that we know we have a pseudo-entry block WITH all of the
   // required PHI nodes, add entries into the PHI node for the actual
   // parameters passed into the tail-recursive call.
-  for (unsigned i = 0, e = CI->getNumOperands()-1; i != e; ++i)
-    ArgumentPHIs[i]->addIncoming(CI->getOperand(i+1), BB);
+  for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i)
+    ArgumentPHIs[i]->addIncoming(CI->getArgOperand(i), BB);
 
   // If we are introducing an accumulator variable to eliminate the recursion,
   // do so now.  Note that we _know_ that no subsequent tail recursion
@@ -434,8 +740,11 @@ bool TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry,
   if (AccumulatorRecursionEliminationInitVal) {
     Instruction *AccRecInstr = AccumulatorRecursionInstr;
     // Start by inserting a new PHI node for the accumulator.
-    PHINode *AccPN = PHINode::Create(AccRecInstr->getType(), "accumulator.tr",
-                                     OldEntry->begin());
+    pred_iterator PB = pred_begin(OldEntry), PE = pred_end(OldEntry);
+    PHINode *AccPN =
+      PHINode::Create(AccumulatorRecursionEliminationInitVal->getType(),
+                      std::distance(PB, PE) + 1,
+                      "accumulator.tr", OldEntry->begin());
 
     // Loop over all of the predecessors of the tail recursion block.  For the
     // real entry into the function we seed the PHI with the initial value,
@@ -443,22 +752,28 @@ bool TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry,
     // other tail recursions eliminated) the accumulator is not modified.
     // Because we haven't added the branch in the current block to OldEntry yet,
     // it will not show up as a predecessor.
-    for (pred_iterator PI = pred_begin(OldEntry), PE = pred_end(OldEntry);
-         PI != PE; ++PI) {
-      if (*PI == &F->getEntryBlock())
-        AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, *PI);
+    for (pred_iterator PI = PB; PI != PE; ++PI) {
+      BasicBlock *P = *PI;
+      if (P == &F->getEntryBlock())
+        AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, P);
       else
-        AccPN->addIncoming(AccPN, *PI);
+        AccPN->addIncoming(AccPN, P);
     }
 
-    // Add an incoming argument for the current block, which is computed by our
-    // associative accumulator instruction.
-    AccPN->addIncoming(AccRecInstr, BB);
-
-    // Next, rewrite the accumulator recursion instruction so that it does not
-    // use the result of the call anymore, instead, use the PHI node we just
-    // inserted.
-    AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN);
+    if (AccRecInstr) {
+      // Add an incoming argument for the current block, which is computed by
+      // our associative and commutative accumulator instruction.
+      AccPN->addIncoming(AccRecInstr, BB);
+
+      // Next, rewrite the accumulator recursion instruction so that it does not
+      // use the result of the call anymore, instead, use the PHI node we just
+      // inserted.
+      AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN);
+    } else {
+      // Add an incoming argument for the current block, which is just the
+      // constant returned by the current return instruction.
+      AccPN->addIncoming(Ret->getReturnValue(), BB);
+    }
 
     // Finally, rewrite any return instructions in the program to return the PHI
     // node instead of the "initval" that they do currently.  This loop will
@@ -471,9 +786,61 @@ bool TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry,
 
   // Now that all of the PHI nodes are in place, remove the call and
   // ret instructions, replacing them with an unconditional branch.
-  BranchInst::Create(OldEntry, Ret);
+  BranchInst *NewBI = BranchInst::Create(OldEntry, Ret);
+  NewBI->setDebugLoc(CI->getDebugLoc());
+
   BB->getInstList().erase(Ret);  // Remove return.
   BB->getInstList().erase(CI);   // Remove call.
   ++NumEliminated;
   return true;
 }
+
+bool TailCallElim::FoldReturnAndProcessPred(BasicBlock *BB,
+                                       ReturnInst *Ret, BasicBlock *&OldEntry,
+                                       bool &TailCallsAreMarkedTail,
+                                       SmallVectorImpl<PHINode *> &ArgumentPHIs,
+                                       bool CannotTailCallElimCallsMarkedTail) {
+  bool Change = false;
+
+  // If the return block contains nothing but the return and PHI's,
+  // there might be an opportunity to duplicate the return in its
+  // predecessors and perform TRC there. Look for predecessors that end
+  // in unconditional branch and recursive call(s).
+  SmallVector<BranchInst*, 8> UncondBranchPreds;
+  for (BasicBlock *Pred : predecessors(BB)) {
+    TerminatorInst *PTI = Pred->getTerminator();
+    if (BranchInst *BI = dyn_cast<BranchInst>(PTI))
+      if (BI->isUnconditional())
+        UncondBranchPreds.push_back(BI);
+  }
+
+  while (!UncondBranchPreds.empty()) {
+    BranchInst *BI = UncondBranchPreds.pop_back_val();
+    BasicBlock *Pred = BI->getParent();
+    if (CallInst *CI = FindTRECandidate(BI, CannotTailCallElimCallsMarkedTail)){
+      DEBUG(dbgs() << "FOLDING: " << *BB
+            << "INTO UNCOND BRANCH PRED: " << *Pred);
+      EliminateRecursiveTailCall(CI, FoldReturnIntoUncondBranch(Ret, BB, Pred),
+                                 OldEntry, TailCallsAreMarkedTail, ArgumentPHIs,
+                                 CannotTailCallElimCallsMarkedTail);
+      ++NumRetDuped;
+      Change = true;
+    }
+  }
+
+  return Change;
+}
+
+bool
+TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry,
+                                    bool &TailCallsAreMarkedTail,
+                                    SmallVectorImpl<PHINode *> &ArgumentPHIs,
+                                    bool CannotTailCallElimCallsMarkedTail) {
+  CallInst *CI = FindTRECandidate(Ret, CannotTailCallElimCallsMarkedTail);
+  if (!CI)
+    return false;
+
+  return EliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail,
+                                    ArgumentPHIs,
+                                    CannotTailCallElimCallsMarkedTail);
+}