EarlyCSE: Replace custom hash mixing with Hashing.h
[oota-llvm.git] / lib / Transforms / Scalar / LoopUnrollPass.cpp
index 22dd65c7cf69094b21d05c3797d5d0e14ac97f71..1f9cf45152502d954f6c84d2ac9365ab773edb10 100644 (file)
@@ -13,7 +13,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Scalar.h"
+#include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/CodeMetrics.h"
+#include "llvm/Analysis/FunctionTargetTransformInfo.h"
 #include "llvm/Analysis/LoopPass.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
@@ -102,15 +104,17 @@ namespace {
     /// loop preheaders be inserted into the CFG...
     ///
     void getAnalysisUsage(AnalysisUsage &AU) const override {
-      AU.addRequired<LoopInfo>();
-      AU.addPreserved<LoopInfo>();
+      AU.addRequired<AssumptionCacheTracker>();
+      AU.addRequired<LoopInfoWrapperPass>();
+      AU.addPreserved<LoopInfoWrapperPass>();
       AU.addRequiredID(LoopSimplifyID);
       AU.addPreservedID(LoopSimplifyID);
       AU.addRequiredID(LCSSAID);
       AU.addPreservedID(LCSSAID);
       AU.addRequired<ScalarEvolution>();
       AU.addPreserved<ScalarEvolution>();
-      AU.addRequired<TargetTransformInfo>();
+      AU.addRequired<TargetTransformInfoWrapperPass>();
+      AU.addRequired<FunctionTargetTransformInfo>();
       // FIXME: Loop unroll requires LCSSA. And LCSSA requires dom info.
       // If loop unroll does not preserve dom info then LCSSA pass on next
       // loop will receive invalid dom info.
@@ -120,7 +124,7 @@ namespace {
 
     // Fill in the UnrollingPreferences parameter with values from the
     // TargetTransformationInfo.
-    void getUnrollingPreferences(Loop *L, const TargetTransformInfo &TTI,
+    void getUnrollingPreferences(Loop *L, const FunctionTargetTransformInfo &FTTI,
                                  TargetTransformInfo::UnrollingPreferences &UP) {
       UP.Threshold = CurrentThreshold;
       UP.OptSizeThreshold = OptSizeUnrollThreshold;
@@ -130,7 +134,7 @@ namespace {
       UP.MaxCount = UINT_MAX;
       UP.Partial = CurrentAllowPartial;
       UP.Runtime = CurrentRuntime;
-      TTI.getUnrollingPreferences(L, UP);
+      FTTI.getUnrollingPreferences(L, UP);
     }
 
     // Select and return an unroll count based on parameters from
@@ -181,8 +185,10 @@ namespace {
 
 char LoopUnroll::ID = 0;
 INITIALIZE_PASS_BEGIN(LoopUnroll, "loop-unroll", "Unroll loops", false, false)
-INITIALIZE_AG_DEPENDENCY(TargetTransformInfo)
-INITIALIZE_PASS_DEPENDENCY(LoopInfo)
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(FunctionTargetTransformInfo)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
 INITIALIZE_PASS_DEPENDENCY(LCSSA)
 INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
@@ -200,11 +206,15 @@ Pass *llvm::createSimpleLoopUnrollPass() {
 /// ApproximateLoopSize - Approximate the size of the loop.
 static unsigned ApproximateLoopSize(const Loop *L, unsigned &NumCalls,
                                     bool &NotDuplicatable,
-                                    const TargetTransformInfo &TTI) {
+                                    const TargetTransformInfo &TTI,
+                                    AssumptionCache *AC) {
+  SmallPtrSet<const Value *, 32> EphValues;
+  CodeMetrics::collectEphemeralValues(L, AC, EphValues);
+
   CodeMetrics Metrics;
   for (Loop::block_iterator I = L->block_begin(), E = L->block_end();
        I != E; ++I)
-    Metrics.analyzeBasicBlock(*I, TTI);
+    Metrics.analyzeBasicBlock(*I, TTI, EphValues);
   NumCalls = Metrics.NumInlineCandidates;
   NotDuplicatable = Metrics.notDuplicatable;
 
@@ -212,8 +222,11 @@ static unsigned ApproximateLoopSize(const Loop *L, unsigned &NumCalls,
 
   // Don't allow an estimate of size zero.  This would allows unrolling of loops
   // with huge iteration counts, which is a compile time problem even if it's
-  // not a problem for code quality.
-  if (LoopSize == 0) LoopSize = 1;
+  // not a problem for code quality. Also, the code using this size may assume
+  // that each loop has at least three instructions (likely a conditional
+  // branch, a comparison feeding that branch, and some kind of loop increment
+  // feeding that comparison instruction).
+  LoopSize = std::max(LoopSize, 3u);
 
   return LoopSize;
 }
@@ -221,48 +234,32 @@ static unsigned ApproximateLoopSize(const Loop *L, unsigned &NumCalls,
 // Returns the loop hint metadata node with the given name (for example,
 // "llvm.loop.unroll.count").  If no such metadata node exists, then nullptr is
 // returned.
-static const MDNode *GetUnrollMetadata(const Loop *L, StringRef Name) {
+static const MDNode *GetUnrollMetadataForLoop(const Loop *L, StringRef Name) {
   MDNode *LoopID = L->getLoopID();
   if (!LoopID)
     return nullptr;
-
-  // First operand should refer to the loop id itself.
-  assert(LoopID->getNumOperands() > 0 && "requires at least one operand");
-  assert(LoopID->getOperand(0) == LoopID && "invalid loop id");
-
-  for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) {
-    const MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i));
-    if (!MD)
-      continue;
-
-    const MDString *S = dyn_cast<MDString>(MD->getOperand(0));
-    if (!S)
-      continue;
-
-    if (Name.equals(S->getString()))
-      return MD;
-  }
-  return nullptr;
+  return GetUnrollMetadata(LoopID, Name);
 }
 
 // Returns true if the loop has an unroll(full) pragma.
 static bool HasUnrollFullPragma(const Loop *L) {
-  return GetUnrollMetadata(L, "llvm.loop.unroll.full");
+  return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.full");
 }
 
 // Returns true if the loop has an unroll(disable) pragma.
 static bool HasUnrollDisablePragma(const Loop *L) {
-  return GetUnrollMetadata(L, "llvm.loop.unroll.disable");
+  return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.disable");
 }
 
 // If loop has an unroll_count pragma return the (necessarily
 // positive) value from the pragma.  Otherwise return 0.
 static unsigned UnrollCountPragmaValue(const Loop *L) {
-  const MDNode *MD = GetUnrollMetadata(L, "llvm.loop.unroll.count");
+  const MDNode *MD = GetUnrollMetadataForLoop(L, "llvm.loop.unroll.count");
   if (MD) {
     assert(MD->getNumOperands() == 2 &&
            "Unroll count hint metadata should have two operands.");
-    unsigned Count = cast<ConstantInt>(MD->getOperand(1))->getZExtValue();
+    unsigned Count =
+        mdconst::extract<ConstantInt>(MD->getOperand(1))->getZExtValue();
     assert(Count >= 1 && "Unroll count must be positive.");
     return Count;
   }
@@ -278,9 +275,9 @@ static void SetLoopAlreadyUnrolled(Loop *L) {
   if (!LoopID) return;
 
   // First remove any existing loop unrolling metadata.
-  SmallVector<Value *, 4> Vals;
+  SmallVector<Metadata *, 4> MDs;
   // Reserve first location for self reference to the LoopID metadata node.
-  Vals.push_back(nullptr);
+  MDs.push_back(nullptr);
   for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) {
     bool IsUnrollMetadata = false;
     MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i));
@@ -288,17 +285,18 @@ static void SetLoopAlreadyUnrolled(Loop *L) {
       const MDString *S = dyn_cast<MDString>(MD->getOperand(0));
       IsUnrollMetadata = S && S->getString().startswith("llvm.loop.unroll.");
     }
-    if (!IsUnrollMetadata) Vals.push_back(LoopID->getOperand(i));
+    if (!IsUnrollMetadata)
+      MDs.push_back(LoopID->getOperand(i));
   }
 
   // Add unroll(disable) metadata to disable future unrolling.
   LLVMContext &Context = L->getHeader()->getContext();
-  SmallVector<Value *, 1> DisableOperands;
+  SmallVector<Metadata *, 1> DisableOperands;
   DisableOperands.push_back(MDString::get(Context, "llvm.loop.unroll.disable"));
   MDNode *DisableNode = MDNode::get(Context, DisableOperands);
-  Vals.push_back(DisableNode);
+  MDs.push_back(DisableNode);
 
-  MDNode *NewLoopID = MDNode::get(Context, Vals);
+  MDNode *NewLoopID = MDNode::get(Context, MDs);
   // Set operand 0 to refer to the loop id itself.
   NewLoopID->replaceOperandWith(0, NewLoopID);
   L->setLoopID(NewLoopID);
@@ -348,9 +346,15 @@ bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
   if (skipOptnoneFunction(L))
     return false;
 
-  LoopInfo *LI = &getAnalysis<LoopInfo>();
+  Function &F = *L->getHeader()->getParent();
+
+  LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
   ScalarEvolution *SE = &getAnalysis<ScalarEvolution>();
-  const TargetTransformInfo &TTI = getAnalysis<TargetTransformInfo>();
+  const TargetTransformInfo &TTI =
+      getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+  const FunctionTargetTransformInfo &FTTI =
+      getAnalysis<FunctionTargetTransformInfo>();
+  auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
 
   BasicBlock *Header = L->getHeader();
   DEBUG(dbgs() << "Loop Unroll: F[" << Header->getParent()->getName()
@@ -364,18 +368,20 @@ bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
   bool HasPragma = PragmaFullUnroll || PragmaCount > 0;
 
   TargetTransformInfo::UnrollingPreferences UP;
-  getUnrollingPreferences(L, TTI, UP);
+  getUnrollingPreferences(L, FTTI, UP);
 
   // Find trip count and trip multiple if count is not available
   unsigned TripCount = 0;
   unsigned TripMultiple = 1;
-  // Find "latch trip count". UnrollLoop assumes that control cannot exit
-  // via the loop latch on any iteration prior to TripCount. The loop may exit
-  // early via an earlier branch.
-  BasicBlock *LatchBlock = L->getLoopLatch();
-  if (LatchBlock) {
-    TripCount = SE->getSmallConstantTripCount(L, LatchBlock);
-    TripMultiple = SE->getSmallConstantTripMultiple(L, LatchBlock);
+  // If there are multiple exiting blocks but one of them is the latch, use the
+  // latch for the trip count estimation. Otherwise insist on a single exiting
+  // block for the trip count estimation.
+  BasicBlock *ExitingBlock = L->getLoopLatch();
+  if (!ExitingBlock || !L->isLoopExiting(ExitingBlock))
+    ExitingBlock = L->getExitingBlock();
+  if (ExitingBlock) {
+    TripCount = SE->getSmallConstantTripCount(L, ExitingBlock);
+    TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock);
   }
 
   // Select an initial unroll count.  This may be reduced later based
@@ -387,9 +393,13 @@ bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
   unsigned NumInlineCandidates;
   bool notDuplicatable;
   unsigned LoopSize =
-      ApproximateLoopSize(L, NumInlineCandidates, notDuplicatable, TTI);
+      ApproximateLoopSize(L, NumInlineCandidates, notDuplicatable, TTI, &AC);
   DEBUG(dbgs() << "  Loop Size = " << LoopSize << "\n");
-  uint64_t UnrolledSize = (uint64_t)LoopSize * Count;
+
+  // When computing the unrolled size, note that the conditional branch on the
+  // backedge and the comparison feeding it are not replicated like the rest of
+  // the loop body (which is why 2 is subtracted).
+  uint64_t UnrolledSize = (uint64_t)(LoopSize-2) * Count + 2;
   if (notDuplicatable) {
     DEBUG(dbgs() << "  Not unrolling loop which contains non-duplicatable"
                  << " instructions.\n");
@@ -434,7 +444,7 @@ bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
     }
     if (PartialThreshold != NoThreshold && UnrolledSize > PartialThreshold) {
       // Reduce unroll count to be modulo of TripCount for partial unrolling.
-      Count = PartialThreshold / LoopSize;
+      Count = (std::max(PartialThreshold, 3u)-2) / (LoopSize-2);
       while (Count != 0 && TripCount % Count != 0)
         Count--;
     }
@@ -448,7 +458,7 @@ bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
     // the original count which satisfies the threshold limit.
     while (Count != 0 && UnrolledSize > PartialThreshold) {
       Count >>= 1;
-      UnrolledSize = LoopSize * Count;
+      UnrolledSize = (LoopSize-2) * Count + 2;
     }
     if (Count > UP.MaxCount)
       Count = UP.MaxCount;
@@ -493,7 +503,8 @@ bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
   }
 
   // Unroll the loop.
-  if (!UnrollLoop(L, Count, TripCount, AllowRuntime, TripMultiple, LI, this, &LPM))
+  if (!UnrollLoop(L, Count, TripCount, AllowRuntime, TripMultiple, LI, this,
+                  &LPM, &AC))
     return false;
 
   return true;