LoopIdiom: Recognize memmove loops.
[oota-llvm.git] / lib / Transforms / Scalar / LoopStrengthReduce.cpp
index c69abcb6409d3dd57c06e020f8a12c093cbd37a8..958348d9faad141bcaa26469e5420426bd36bfcc 100644 (file)
@@ -54,7 +54,7 @@
 //===----------------------------------------------------------------------===//
 
 #define DEBUG_TYPE "loop-reduce"
-#include "llvm/Transforms/Scalar.h"
+#include "llvm/AddressingMode.h"
 #include "llvm/Constants.h"
 #include "llvm/Instructions.h"
 #include "llvm/IntrinsicInst.h"
@@ -64,6 +64,7 @@
 #include "llvm/Analysis/LoopPass.h"
 #include "llvm/Analysis/ScalarEvolutionExpander.h"
 #include "llvm/Assembly/Writer.h"
+#include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/ADT/SmallBitVector.h"
@@ -121,9 +122,11 @@ void RegSortData::print(raw_ostream &OS) const {
   OS << "[NumUses=" << UsedByIndices.count() << ']';
 }
 
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void RegSortData::dump() const {
   print(errs()); errs() << '\n';
 }
+#endif
 
 namespace {
 
@@ -223,7 +226,7 @@ namespace {
 struct Formula {
   /// AM - This is used to represent complex addressing, as well as other kinds
   /// of interesting uses.
-  TargetLowering::AddrMode AM;
+  AddrMode AM;
 
   /// BaseRegs - The list of "base" registers for this use. When this is
   /// non-empty, AM.HasBaseReg should be set to true.
@@ -414,9 +417,11 @@ void Formula::print(raw_ostream &OS) const {
   }
 }
 
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void Formula::dump() const {
   print(errs()); errs() << '\n';
 }
+#endif
 
 /// isAddRecSExtable - Return true if the given addrec can be sign-extended
 /// without changing its value.
@@ -738,7 +743,8 @@ DeleteTriviallyDeadInstructions(SmallVectorImpl<WeakVH> &DeadInsts) {
   bool Changed = false;
 
   while (!DeadInsts.empty()) {
-    Instruction *I = dyn_cast_or_null<Instruction>(&*DeadInsts.pop_back_val());
+    Value *V = DeadInsts.pop_back_val();
+    Instruction *I = dyn_cast_or_null<Instruction>(V);
 
     if (I == 0 || !isInstructionTriviallyDead(I))
       continue;
@@ -973,9 +979,11 @@ void Cost::print(raw_ostream &OS) const {
     OS << ", plus " << SetupCost << " setup cost";
 }
 
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void Cost::dump() const {
   print(errs()); errs() << '\n';
 }
+#endif
 
 namespace {
 
@@ -1059,9 +1067,11 @@ void LSRFixup::print(raw_ostream &OS) const {
     OS << ", Offset=" << Offset;
 }
 
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void LSRFixup::dump() const {
   print(errs()); errs() << '\n';
 }
+#endif
 
 namespace {
 
@@ -1251,14 +1261,16 @@ void LSRUse::print(raw_ostream &OS) const {
     OS << ", widest fixup type: " << *WidestFixupType;
 }
 
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void LSRUse::dump() const {
   print(errs()); errs() << '\n';
 }
+#endif
 
 /// isLegalUse - Test whether the use described by AM is "legal", meaning it can
 /// be completely folded into the user instruction at isel time. This includes
 /// address-mode folding and special icmp tricks.
-static bool isLegalUse(const TargetLowering::AddrMode &AM,
+static bool isLegalUse(const AddrMode &AM,
                        LSRUse::KindType Kind, Type *AccessTy,
                        const TargetLowering *TLI) {
   switch (Kind) {
@@ -1308,14 +1320,14 @@ static bool isLegalUse(const TargetLowering::AddrMode &AM,
     return !AM.BaseGV && AM.Scale == 0 && AM.BaseOffs == 0;
 
   case LSRUse::Special:
-    // Only handle -1 scales, or no scale.
-    return AM.Scale == 0 || AM.Scale == -1;
+    // Special case Basic to handle -1 scales.
+    return !AM.BaseGV && (AM.Scale == 0 || AM.Scale == -1) && AM.BaseOffs == 0;
   }
 
   llvm_unreachable("Invalid LSRUse Kind!");
 }
 
-static bool isLegalUse(TargetLowering::AddrMode AM,
+static bool isLegalUse(AddrMode AM,
                        int64_t MinOffset, int64_t MaxOffset,
                        LSRUse::KindType Kind, Type *AccessTy,
                        const TargetLowering *TLI) {
@@ -1346,7 +1358,7 @@ static bool isAlwaysFoldable(int64_t BaseOffs,
 
   // Conservatively, create an address with an immediate and a
   // base and a scale.
-  TargetLowering::AddrMode AM;
+  AddrMode AM;
   AM.BaseOffs = BaseOffs;
   AM.BaseGV = BaseGV;
   AM.HasBaseReg = HasBaseReg;
@@ -1384,7 +1396,7 @@ static bool isAlwaysFoldable(const SCEV *S,
 
   // Conservatively, create an address with an immediate and a
   // base and a scale.
-  TargetLowering::AddrMode AM;
+  AddrMode AM;
   AM.BaseOffs = BaseOffs;
   AM.BaseGV = BaseGV;
   AM.HasBaseReg = HasBaseReg;
@@ -2009,7 +2021,7 @@ LSRInstance::OptimizeLoopTermCond() {
               goto decline_post_inc;
             // Check for possible scaled-address reuse.
             Type *AccessTy = getAccessType(UI->getUser());
-            TargetLowering::AddrMode AM;
+            AddrMode AM;
             AM.Scale = C->getSExtValue();
             if (TLI->isLegalAddressingMode(AM, AccessTy))
               goto decline_post_inc;
@@ -2194,7 +2206,7 @@ LSRInstance::FindUseWithSimilarFormula(const Formula &OrigF,
             return &LU;
           // This is the formula where all the registers and symbols matched;
           // there aren't going to be any others. Since we declined it, we
-          // can skip the rest of the formulae and procede to the next LSRUse.
+          // can skip the rest of the formulae and proceed to the next LSRUse.
           break;
         }
       }
@@ -2836,7 +2848,7 @@ void LSRInstance::CollectFixupsAndInitialFormulae() {
 
         // x == y  -->  x - y == 0
         const SCEV *N = SE.getSCEV(NV);
-        if (SE.isLoopInvariant(N, L)) {
+        if (SE.isLoopInvariant(N, L) && isSafeToExpand(N)) {
           // S is normalized, so normalize N before folding it into S
           // to keep the result normalized.
           N = TransformForPostIncUse(Normalize, N, CI, 0,
@@ -3006,42 +3018,64 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() {
 
 /// CollectSubexprs - Split S into subexpressions which can be pulled out into
 /// separate registers. If C is non-null, multiply each subexpression by C.
-static void CollectSubexprs(const SCEV *S, const SCEVConstant *C,
-                            SmallVectorImpl<const SCEV *> &Ops,
-                            const Loop *L,
-                            ScalarEvolution &SE) {
+///
+/// Return remainder expression after factoring the subexpressions captured by
+/// Ops. If Ops is complete, return NULL.
+static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C,
+                                   SmallVectorImpl<const SCEV *> &Ops,
+                                   const Loop *L,
+                                   ScalarEvolution &SE,
+                                   unsigned Depth = 0) {
+  // Arbitrarily cap recursion to protect compile time.
+  if (Depth >= 3)
+    return S;
+
   if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
     // Break out add operands.
     for (SCEVAddExpr::op_iterator I = Add->op_begin(), E = Add->op_end();
-         I != E; ++I)
-      CollectSubexprs(*I, C, Ops, L, SE);
-    return;
+         I != E; ++I) {
+      const SCEV *Remainder = CollectSubexprs(*I, C, Ops, L, SE, Depth+1);
+      if (Remainder)
+        Ops.push_back(C ? SE.getMulExpr(C, Remainder) : Remainder);
+    }
+    return NULL;
   } else if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
     // Split a non-zero base out of an addrec.
-    if (!AR->getStart()->isZero()) {
-      CollectSubexprs(SE.getAddRecExpr(SE.getConstant(AR->getType(), 0),
-                                       AR->getStepRecurrence(SE),
-                                       AR->getLoop(),
-                                       //FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
-                                       SCEV::FlagAnyWrap),
-                      C, Ops, L, SE);
-      CollectSubexprs(AR->getStart(), C, Ops, L, SE);
-      return;
+    if (AR->getStart()->isZero())
+      return S;
+
+    const SCEV *Remainder = CollectSubexprs(AR->getStart(),
+                                            C, Ops, L, SE, Depth+1);
+    // Split the non-zero AddRec unless it is part of a nested recurrence that
+    // does not pertain to this loop.
+    if (Remainder && (AR->getLoop() == L || !isa<SCEVAddRecExpr>(Remainder))) {
+      Ops.push_back(C ? SE.getMulExpr(C, Remainder) : Remainder);
+      Remainder = NULL;
+    }
+    if (Remainder != AR->getStart()) {
+      if (!Remainder)
+        Remainder = SE.getConstant(AR->getType(), 0);
+      return SE.getAddRecExpr(Remainder,
+                              AR->getStepRecurrence(SE),
+                              AR->getLoop(),
+                              //FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
+                              SCEV::FlagAnyWrap);
     }
   } else if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
     // Break (C * (a + b + c)) into C*a + C*b + C*c.
-    if (Mul->getNumOperands() == 2)
-      if (const SCEVConstant *Op0 =
-            dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
-        CollectSubexprs(Mul->getOperand(1),
-                        C ? cast<SCEVConstant>(SE.getMulExpr(C, Op0)) : Op0,
-                        Ops, L, SE);
-        return;
-      }
+    if (Mul->getNumOperands() != 2)
+      return S;
+    if (const SCEVConstant *Op0 =
+        dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
+      C = C ? cast<SCEVConstant>(SE.getMulExpr(C, Op0)) : Op0;
+      const SCEV *Remainder =
+        CollectSubexprs(Mul->getOperand(1), C, Ops, L, SE, Depth+1);
+      if (Remainder)
+        Ops.push_back(SE.getMulExpr(C, Remainder));
+      return NULL;
+    }
   }
-
-  // Otherwise use the value itself, optionally with a scale applied.
-  Ops.push_back(C ? SE.getMulExpr(C, S) : S);
+  return S;
 }
 
 /// GenerateReassociations - Split out subexpressions from adds and the bases of
@@ -3056,7 +3090,9 @@ void LSRInstance::GenerateReassociations(LSRUse &LU, unsigned LUIdx,
     const SCEV *BaseReg = Base.BaseRegs[i];
 
     SmallVector<const SCEV *, 8> AddOps;
-    CollectSubexprs(BaseReg, 0, AddOps, L, SE);
+    const SCEV *Remainder = CollectSubexprs(BaseReg, 0, AddOps, L, SE);
+    if (Remainder)
+      AddOps.push_back(Remainder);
 
     if (AddOps.size() == 1) continue;
 
@@ -3411,9 +3447,11 @@ void WorkItem::print(raw_ostream &OS) const {
      << " , add offset " << Imm;
 }
 
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void WorkItem::dump() const {
   print(errs()); errs() << '\n';
 }
+#endif
 
 /// GenerateCrossUseConstantOffsets - Look for registers which are a constant
 /// distance apart and try to form reuse opportunities between them.
@@ -4268,13 +4306,6 @@ Value *LSRInstance::Expand(const LSRFixup &LF,
     Ops.push_back(SE.getUnknown(Rewriter.expandCodeFor(Reg, 0, IP)));
   }
 
-  // Flush the operand list to suppress SCEVExpander hoisting.
-  if (!Ops.empty()) {
-    Value *FullV = Rewriter.expandCodeFor(SE.getAddExpr(Ops), Ty, IP);
-    Ops.clear();
-    Ops.push_back(SE.getUnknown(FullV));
-  }
-
   // Expand the ScaledReg portion.
   Value *ICmpScaledV = 0;
   if (F.AM.Scale != 0) {
@@ -4296,23 +4327,34 @@ Value *LSRInstance::Expand(const LSRFixup &LF,
     } else {
       // Otherwise just expand the scaled register and an explicit scale,
       // which is expected to be matched as part of the address.
+
+      // Flush the operand list to suppress SCEVExpander hoisting address modes.
+      if (!Ops.empty() && LU.Kind == LSRUse::Address) {
+        Value *FullV = Rewriter.expandCodeFor(SE.getAddExpr(Ops), Ty, IP);
+        Ops.clear();
+        Ops.push_back(SE.getUnknown(FullV));
+      }
       ScaledS = SE.getUnknown(Rewriter.expandCodeFor(ScaledS, 0, IP));
       ScaledS = SE.getMulExpr(ScaledS,
                               SE.getConstant(ScaledS->getType(), F.AM.Scale));
       Ops.push_back(ScaledS);
-
-      // Flush the operand list to suppress SCEVExpander hoisting.
-      Value *FullV = Rewriter.expandCodeFor(SE.getAddExpr(Ops), Ty, IP);
-      Ops.clear();
-      Ops.push_back(SE.getUnknown(FullV));
     }
   }
 
   // Expand the GV portion.
   if (F.AM.BaseGV) {
+    // Flush the operand list to suppress SCEVExpander hoisting.
+    if (!Ops.empty()) {
+      Value *FullV = Rewriter.expandCodeFor(SE.getAddExpr(Ops), Ty, IP);
+      Ops.clear();
+      Ops.push_back(SE.getUnknown(FullV));
+    }
     Ops.push_back(SE.getUnknown(F.AM.BaseGV));
+  }
 
-    // Flush the operand list to suppress SCEVExpander hoisting.
+  // Flush the operand list to suppress SCEVExpander hoisting of both folded and
+  // unfolded offsets. LSR assumes they both live next to their uses.
+  if (!Ops.empty()) {
     Value *FullV = Rewriter.expandCodeFor(SE.getAddExpr(Ops), Ty, IP);
     Ops.clear();
     Ops.push_back(SE.getUnknown(FullV));
@@ -4423,17 +4465,21 @@ void LSRInstance::RewriteForPHI(PHINode *PN,
             SplitLandingPadPredecessors(Parent, BB, "", "", P, NewBBs);
             NewBB = NewBBs[0];
           }
-
-          // If PN is outside of the loop and BB is in the loop, we want to
-          // move the block to be immediately before the PHI block, not
-          // immediately after BB.
-          if (L->contains(BB) && !L->contains(PN))
-            NewBB->moveBefore(PN->getParent());
-
-          // Splitting the edge can reduce the number of PHI entries we have.
-          e = PN->getNumIncomingValues();
-          BB = NewBB;
-          i = PN->getBasicBlockIndex(BB);
+          // If NewBB==NULL, then SplitCriticalEdge refused to split because all
+          // phi predecessors are identical. The simple thing to do is skip
+          // splitting in this case rather than complicate the API.
+          if (NewBB) {
+            // If PN is outside of the loop and BB is in the loop, we want to
+            // move the block to be immediately before the PHI block, not
+            // immediately after BB.
+            if (L->contains(BB) && !L->contains(PN))
+              NewBB->moveBefore(PN->getParent());
+
+            // Splitting the edge can reduce the number of PHI entries we have.
+            e = PN->getNumIncomingValues();
+            BB = NewBB;
+            i = PN->getBasicBlockIndex(BB);
+          }
         }
       }
 
@@ -4702,9 +4748,11 @@ void LSRInstance::print(raw_ostream &OS) const {
   print_uses(OS);
 }
 
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 void LSRInstance::dump() const {
   print(errs()); errs() << '\n';
 }
+#endif
 
 namespace {