Update the branch weight metadata in JumpThreading pass.
authorCong Hou <congh@google.com>
Wed, 14 Oct 2015 23:14:17 +0000 (23:14 +0000)
committerCong Hou <congh@google.com>
Wed, 14 Oct 2015 23:14:17 +0000 (23:14 +0000)
Currently in JumpThreading pass, the branch weight metadata is not updated after CFG modification. Consider the jump threading on PredBB, BB, and SuccBB. After jump threading, the weight on BB->SuccBB should be adjusted as some of it is contributed by the edge PredBB->BB, which doesn't exist anymore. This patch tries to update the edge weight in metadata on BB->SuccBB by scaling it by 1 - Freq(PredBB->BB) / Freq(BB->SuccBB).

This is the third attempt to submit this patch, while the first two led to failures in some FDO tests. After investigation, it is the edge weight normalization that caused those failures. In this patch the edge weight normalization is fixed so that there is no zero weight in the output and the sum of all weights can fit in 32-bit integer. Several unit tests are added.

Differential revision: http://reviews.llvm.org/D10979

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@250345 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/BlockFrequencyInfo.h
include/llvm/Analysis/BlockFrequencyInfoImpl.h
include/llvm/Support/BranchProbability.h
lib/Analysis/BlockFrequencyInfo.cpp
lib/Analysis/BlockFrequencyInfoImpl.cpp
lib/Transforms/Scalar/JumpThreading.cpp
test/Transforms/JumpThreading/update-edge-weight.ll [new file with mode: 0644]
unittests/Support/BranchProbabilityTest.cpp

index 2b6b16a3fb9b35fd903039667eecb0048293af0e..6f2a2b52276967a250935b86a7eb24188aad89e3 100644 (file)
@@ -45,6 +45,9 @@ public:
   /// floating points.
   BlockFrequency getBlockFreq(const BasicBlock *BB) const;
 
+  // Set the frequency of the given basic block.
+  void setBlockFreq(const BasicBlock *BB, uint64_t Freq);
+
   /// calculate - compute block frequency info for the given function.
   void calculate(const Function &F, const BranchProbabilityInfo &BPI,
                  const LoopInfo &LI);
index 9519709b845de189400bf35fe13f49693adaf696..d7379b8fbeaa8cf14e55bec82593a01086af3efe 100644 (file)
@@ -477,6 +477,8 @@ public:
 
   BlockFrequency getBlockFreq(const BlockNode &Node) const;
 
+  void setBlockFreq(const BlockNode &Node, uint64_t Freq);
+
   raw_ostream &printBlockFreq(raw_ostream &OS, const BlockNode &Node) const;
   raw_ostream &printBlockFreq(raw_ostream &OS,
                               const BlockFrequency &Freq) const;
@@ -913,6 +915,7 @@ public:
   BlockFrequency getBlockFreq(const BlockT *BB) const {
     return BlockFrequencyInfoImplBase::getBlockFreq(getNode(BB));
   }
+  void setBlockFreq(const BlockT *BB, uint64_t Freq);
   Scaled64 getFloatingBlockFreq(const BlockT *BB) const {
     return BlockFrequencyInfoImplBase::getFloatingBlockFreq(getNode(BB));
   }
@@ -965,6 +968,21 @@ void BlockFrequencyInfoImpl<BT>::calculate(const FunctionT &F,
   finalizeMetrics();
 }
 
+template <class BT>
+void BlockFrequencyInfoImpl<BT>::setBlockFreq(const BlockT *BB, uint64_t Freq) {
+  if (Nodes.count(BB))
+    BlockFrequencyInfoImplBase::setBlockFreq(getNode(BB), Freq);
+  else {
+    // If BB is a newly added block after BFI is done, we need to create a new
+    // BlockNode for it assigned with a new index. The index can be determined
+    // by the size of Freqs.
+    BlockNode NewNode(Freqs.size());
+    Nodes[BB] = NewNode;
+    Freqs.emplace_back();
+    BlockFrequencyInfoImplBase::setBlockFreq(NewNode, Freq);
+  }
+}
+
 template <class BT> void BlockFrequencyInfoImpl<BT>::initializeRPOT() {
   const BlockT *Entry = &F->front();
   RPOT.reserve(F->size());
index 98b2a3d7d9a2460c26ad964cf239d957faf3c81f..f204777066e9be5b0dfd2b3eff63822884e512b7 100644 (file)
 #define LLVM_SUPPORT_BRANCHPROBABILITY_H
 
 #include "llvm/Support/DataTypes.h"
+#include <algorithm>
 #include <cassert>
+#include <climits>
+#include <numeric>
 
 namespace llvm {
 
@@ -53,6 +56,11 @@ public:
   template <class ProbabilityList>
   static void normalizeProbabilities(ProbabilityList &Probs);
 
+  // Normalize a list of weights by scaling them down so that the sum of them
+  // doesn't exceed UINT32_MAX.
+  template <class WeightListIter>
+  static void normalizeEdgeWeights(WeightListIter Begin, WeightListIter End);
+
   uint32_t getNumerator() const { return N; }
   static uint32_t getDenominator() { return D; }
 
@@ -135,6 +143,49 @@ void BranchProbability::normalizeProbabilities(ProbabilityList &Probs) {
     Prob.N = (Prob.N * uint64_t(D) + Sum / 2) / Sum;
 }
 
+template <class WeightListIter>
+void BranchProbability::normalizeEdgeWeights(WeightListIter Begin,
+                                             WeightListIter End) {
+  // First we compute the sum with 64-bits of precision.
+  uint64_t Sum = std::accumulate(Begin, End, uint64_t(0));
+
+  if (Sum > UINT32_MAX) {
+    // Compute the scale necessary to cause the weights to fit, and re-sum with
+    // that scale applied.
+    assert(Sum / UINT32_MAX < UINT32_MAX &&
+           "The sum of weights exceeds UINT32_MAX^2!");
+    uint32_t Scale = Sum / UINT32_MAX + 1;
+    for (auto I = Begin; I != End; ++I)
+      *I /= Scale;
+    Sum = std::accumulate(Begin, End, uint64_t(0));
+  }
+
+  // Eliminate zero weights.
+  auto ZeroWeightNum = std::count(Begin, End, 0u);
+  if (ZeroWeightNum > 0) {
+    // If all weights are zeros, replace them by 1.
+    if (Sum == 0)
+      std::fill(Begin, End, 1u);
+    else {
+      // We are converting zeros into ones, and here we need to make sure that
+      // after this the sum won't exceed UINT32_MAX.
+      if (Sum + ZeroWeightNum > UINT32_MAX) {
+        for (auto I = Begin; I != End; ++I)
+          *I /= 2;
+        ZeroWeightNum = std::count(Begin, End, 0u);
+        Sum = std::accumulate(Begin, End, uint64_t(0));
+      }
+      // Scale up non-zero weights and turn zero weights into ones.
+      uint64_t ScalingFactor = (UINT32_MAX - ZeroWeightNum) / Sum;
+      assert(ScalingFactor >= 1);
+      if (ScalingFactor > 1)
+        for (auto I = Begin; I != End; ++I)
+          *I *= ScalingFactor;
+      std::replace(Begin, End, 0u, 1u);
+    }
+  }
+}
+
 }
 
 #endif
index ac4ee8f11e0a0e5189fb921716ba13b61d99a81c..90b7a339a0fe2d13b37ed0420c74565fe72ac63c 100644 (file)
@@ -129,6 +129,12 @@ BlockFrequency BlockFrequencyInfo::getBlockFreq(const BasicBlock *BB) const {
   return BFI ? BFI->getBlockFreq(BB) : 0;
 }
 
+void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB,
+                                      uint64_t Freq) {
+  assert(BFI && "Expected analysis to be available");
+  BFI->setBlockFreq(BB, Freq);
+}
+
 /// Pop up a ghostview window with the current block frequency propagation
 /// rendered using dot.
 void BlockFrequencyInfo::view() const {
index 903a263a65fee860def7cd346f5c107524fecbf0..48e23af2690a729f5e862a5d5caf1d0a78b5e5fc 100644 (file)
@@ -530,6 +530,13 @@ BlockFrequencyInfoImplBase::getFloatingBlockFreq(const BlockNode &Node) const {
   return Freqs[Node.Index].Scaled;
 }
 
+void BlockFrequencyInfoImplBase::setBlockFreq(const BlockNode &Node,
+                                              uint64_t Freq) {
+  assert(Node.isValid() && "Expected valid node");
+  assert(Node.Index < Freqs.size() && "Expected legal index");
+  Freqs[Node.Index].Integer = Freq;
+}
+
 std::string
 BlockFrequencyInfoImplBase::getBlockName(const BlockNode &Node) const {
   return std::string();
index 2440a76224b212d2e93f035f85d96f3bc6f9952c..805f4cdcfcc6f421aa002087421e55911c61e59e 100644 (file)
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/CFG.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/LazyValueInfo.h"
 #include "llvm/Analysis/Loads.h"
+#include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/ValueHandle.h"
 #include "llvm/Pass.h"
@@ -37,6 +42,8 @@
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/SSAUpdater.h"
+#include <algorithm>
+#include <memory>
 using namespace llvm;
 
 #define DEBUG_TYPE "jump-threading"
@@ -81,6 +88,9 @@ namespace {
   class JumpThreading : public FunctionPass {
     TargetLibraryInfo *TLI;
     LazyValueInfo *LVI;
+    std::unique_ptr<BlockFrequencyInfo> BFI;
+    std::unique_ptr<BranchProbabilityInfo> BPI;
+    bool HasProfileData;
 #ifdef NDEBUG
     SmallPtrSet<BasicBlock*, 16> LoopHeaders;
 #else
@@ -119,6 +129,11 @@ namespace {
       AU.addRequired<TargetLibraryInfoWrapperPass>();
     }
 
+    void releaseMemory() override {
+      BFI.reset();
+      BPI.reset();
+    }
+
     void FindLoopHeaders(Function &F);
     bool ProcessBlock(BasicBlock *BB);
     bool ThreadEdge(BasicBlock *BB, const SmallVectorImpl<BasicBlock*> &PredBBs,
@@ -139,6 +154,12 @@ namespace {
 
     bool SimplifyPartiallyRedundantLoad(LoadInst *LI);
     bool TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB);
+
+  private:
+    BasicBlock *SplitBlockPreds(BasicBlock *BB, ArrayRef<BasicBlock *> Preds,
+                                const char *Suffix);
+    void UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB, BasicBlock *BB,
+                                      BasicBlock *NewBB, BasicBlock *SuccBB);
   };
 }
 
@@ -162,6 +183,16 @@ bool JumpThreading::runOnFunction(Function &F) {
   DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n");
   TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
   LVI = &getAnalysis<LazyValueInfo>();
+  BFI.reset();
+  BPI.reset();
+  // When profile data is available, we need to update edge weights after
+  // successful jump threading, which requires both BPI and BFI being available.
+  HasProfileData = F.getEntryCount().hasValue();
+  if (HasProfileData) {
+    LoopInfo LI{DominatorTree(F)};
+    BPI.reset(new BranchProbabilityInfo(F, LI));
+    BFI.reset(new BlockFrequencyInfo(F, *BPI, LI));
+  }
 
   // Remove unreachable blocks from function as they may result in infinite
   // loop. We do threading if we found something profitable. Jump threading a
@@ -977,8 +1008,7 @@ bool JumpThreading::SimplifyPartiallyRedundantLoad(LoadInst *LI) {
     }
 
     // Split them out to their own block.
-    UnavailablePred =
-      SplitBlockPredecessors(LoadBB, PredsToSplit, "thread-pre-split");
+    UnavailablePred = SplitBlockPreds(LoadBB, PredsToSplit, "thread-pre-split");
   }
 
   // If the value isn't available in all predecessors, then there will be
@@ -1403,7 +1433,7 @@ bool JumpThreading::ThreadEdge(BasicBlock *BB,
   else {
     DEBUG(dbgs() << "  Factoring out " << PredBBs.size()
           << " common predecessors.\n");
-    PredBB = SplitBlockPredecessors(BB, PredBBs, ".thr_comm");
+    PredBB = SplitBlockPreds(BB, PredBBs, ".thr_comm");
   }
 
   // And finally, do it!
@@ -1424,6 +1454,13 @@ bool JumpThreading::ThreadEdge(BasicBlock *BB,
                                          BB->getParent(), BB);
   NewBB->moveAfter(PredBB);
 
+  // Set the block frequency of NewBB.
+  if (HasProfileData) {
+    auto NewBBFreq =
+        BFI->getBlockFreq(PredBB) * BPI->getEdgeProbability(PredBB, BB);
+    BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency());
+  }
+
   BasicBlock::iterator BI = BB->begin();
   for (; PHINode *PN = dyn_cast<PHINode>(BI); ++BI)
     ValueMapping[PN] = PN->getIncomingValueForBlock(PredBB);
@@ -1447,7 +1484,7 @@ bool JumpThreading::ThreadEdge(BasicBlock *BB,
 
   // We didn't copy the terminator from BB over to NewBB, because there is now
   // an unconditional jump to SuccBB.  Insert the unconditional jump.
-  BranchInst *NewBI =BranchInst::Create(SuccBB, NewBB);
+  BranchInst *NewBI = BranchInst::Create(SuccBB, NewBB);
   NewBI->setDebugLoc(BB->getTerminator()->getDebugLoc());
 
   // Check to see if SuccBB has PHI nodes. If so, we need to add entries to the
@@ -1508,11 +1545,85 @@ bool JumpThreading::ThreadEdge(BasicBlock *BB,
   // frequently happens because of phi translation.
   SimplifyInstructionsInBlock(NewBB, TLI);
 
+  // Update the edge weight from BB to SuccBB, which should be less than before.
+  UpdateBlockFreqAndEdgeWeight(PredBB, BB, NewBB, SuccBB);
+
   // Threaded an edge!
   ++NumThreads;
   return true;
 }
 
+/// Create a new basic block that will be the predecessor of BB and successor of
+/// all blocks in Preds. When profile data is availble, update the frequency of
+/// this new block.
+BasicBlock *JumpThreading::SplitBlockPreds(BasicBlock *BB,
+                                           ArrayRef<BasicBlock *> Preds,
+                                           const char *Suffix) {
+  // Collect the frequencies of all predecessors of BB, which will be used to
+  // update the edge weight on BB->SuccBB.
+  BlockFrequency PredBBFreq(0);
+  if (HasProfileData)
+    for (auto Pred : Preds)
+      PredBBFreq += BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, BB);
+
+  BasicBlock *PredBB = SplitBlockPredecessors(BB, Preds, Suffix);
+
+  // Set the block frequency of the newly created PredBB, which is the sum of
+  // frequencies of Preds.
+  if (HasProfileData)
+    BFI->setBlockFreq(PredBB, PredBBFreq.getFrequency());
+  return PredBB;
+}
+
+/// Update the block frequency of BB and branch weight and the metadata on the
+/// edge BB->SuccBB. This is done by scaling the weight of BB->SuccBB by 1 -
+/// Freq(PredBB->BB) / Freq(BB->SuccBB).
+void JumpThreading::UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
+                                                 BasicBlock *BB,
+                                                 BasicBlock *NewBB,
+                                                 BasicBlock *SuccBB) {
+  if (!HasProfileData)
+    return;
+
+  assert(BFI && BPI && "BFI & BPI should have been created here");
+
+  // As the edge from PredBB to BB is deleted, we have to update the block
+  // frequency of BB.
+  auto BBOrigFreq = BFI->getBlockFreq(BB);
+  auto NewBBFreq = BFI->getBlockFreq(NewBB);
+  auto BB2SuccBBFreq = BBOrigFreq * BPI->getEdgeProbability(BB, SuccBB);
+  auto BBNewFreq = BBOrigFreq - NewBBFreq;
+  BFI->setBlockFreq(BB, BBNewFreq.getFrequency());
+
+  // Collect updated outgoing edges' frequencies from BB and use them to update
+  // edge weights.
+  SmallVector<uint64_t, 4> BBSuccFreq;
+  for (auto I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
+    auto SuccFreq = (*I == SuccBB)
+                        ? BB2SuccBBFreq - NewBBFreq
+                        : BBOrigFreq * BPI->getEdgeProbability(BB, *I);
+    BBSuccFreq.push_back(SuccFreq.getFrequency());
+  }
+
+  // Normalize edge weights in Weights64 so that the sum of them can fit in
+  BranchProbability::normalizeEdgeWeights(BBSuccFreq.begin(), BBSuccFreq.end());
+
+  SmallVector<uint32_t, 4> Weights;
+  for (auto Freq : BBSuccFreq)
+    Weights.push_back(static_cast<uint32_t>(Freq));
+
+  // Update edge weights in BPI.
+  for (int I = 0, E = Weights.size(); I < E; I++)
+    BPI->setEdgeWeight(BB, I, Weights[I]);
+
+  if (Weights.size() >= 2) {
+    auto TI = BB->getTerminator();
+    TI->setMetadata(
+        LLVMContext::MD_prof,
+        MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights));
+  }
+}
+
 /// DuplicateCondBranchOnPHIIntoPred - PredBB contains an unconditional branch
 /// to BB which contains an i1 PHI node and a conditional branch on that PHI.
 /// If we can duplicate the contents of BB up into PredBB do so now, this
@@ -1546,7 +1657,7 @@ bool JumpThreading::DuplicateCondBranchOnPHIIntoPred(BasicBlock *BB,
   else {
     DEBUG(dbgs() << "  Factoring out " << PredBBs.size()
           << " common predecessors.\n");
-    PredBB = SplitBlockPredecessors(BB, PredBBs, ".thr_comm");
+    PredBB = SplitBlockPreds(BB, PredBBs, ".thr_comm");
   }
 
   // Okay, we decided to do this!  Clone all the instructions in BB onto the end
diff --git a/test/Transforms/JumpThreading/update-edge-weight.ll b/test/Transforms/JumpThreading/update-edge-weight.ll
new file mode 100644 (file)
index 0000000..b5c5d01
--- /dev/null
@@ -0,0 +1,43 @@
+; RUN: opt -S -jump-threading %s | FileCheck %s
+
+; Test if edge weights are properly updated after jump threading.
+
+; CHECK: !2 = !{!"branch_weights", i32 22, i32 7}
+
+define void @foo(i32 %n) !prof !0 {
+entry:
+  %cmp = icmp sgt i32 %n, 10
+  br i1 %cmp, label %if.then.1, label %if.else.1, !prof !1
+
+if.then.1:
+  tail call void @a()
+  br label %if.cond
+
+if.else.1:
+  tail call void @b()
+  br label %if.cond
+
+if.cond:
+  %cmp1 = icmp sgt i32 %n, 5
+  br i1 %cmp1, label %if.then.2, label %if.else.2, !prof !2
+
+if.then.2:
+  tail call void @c()
+  br label %if.end
+
+if.else.2:
+  tail call void @d()
+  br label %if.end
+
+if.end:
+  ret void
+}
+
+declare void @a()
+declare void @b()
+declare void @c()
+declare void @d()
+
+!0 = !{!"function_entry_count", i64 1}
+!1 = !{!"branch_weights", i32 10, i32 5}
+!2 = !{!"branch_weights", i32 10, i32 1}
index 87a250919475aad3b32c91e024d9fcddc5429b43..37a5c3f0dc877a1b740aaeb3e3a7a22019dc488e 100644 (file)
@@ -287,4 +287,45 @@ TEST(BranchProbabilityTest, scaleBruteForce) {
   }
 }
 
+TEST(BranchProbabilityTest, NormalizeEdgeWeights) {
+  {
+    SmallVector<uint32_t, 2> Weights{0, 0};
+    BranchProbability::normalizeEdgeWeights(Weights.begin(), Weights.end());
+    EXPECT_EQ(1u, Weights[0]);
+    EXPECT_EQ(1u, Weights[1]);
+  }
+  {
+    SmallVector<uint32_t, 2> Weights{0, UINT32_MAX};
+    BranchProbability::normalizeEdgeWeights(Weights.begin(), Weights.end());
+    EXPECT_EQ(1u, Weights[0]);
+    EXPECT_EQ(UINT32_MAX - 1u, Weights[1]);
+  }
+  {
+    SmallVector<uint32_t, 2> Weights{1, UINT32_MAX};
+    BranchProbability::normalizeEdgeWeights(Weights.begin(), Weights.end());
+    EXPECT_EQ(1u, Weights[0]);
+    EXPECT_EQ(UINT32_MAX - 1u, Weights[1]);
+  }
+  {
+    SmallVector<uint32_t, 3> Weights{0, 0, UINT32_MAX};
+    BranchProbability::normalizeEdgeWeights(Weights.begin(), Weights.end());
+    EXPECT_EQ(1u, Weights[0]);
+    EXPECT_EQ(1u, Weights[1]);
+    EXPECT_EQ(UINT32_MAX / 2u, Weights[2]);
+  }
+  {
+    SmallVector<uint32_t, 2> Weights{UINT32_MAX, UINT32_MAX};
+    BranchProbability::normalizeEdgeWeights(Weights.begin(), Weights.end());
+    EXPECT_EQ(UINT32_MAX / 3u, Weights[0]);
+    EXPECT_EQ(UINT32_MAX / 3u, Weights[1]);
+  }
+  {
+    SmallVector<uint32_t, 3> Weights{UINT32_MAX, UINT32_MAX, UINT32_MAX};
+    BranchProbability::normalizeEdgeWeights(Weights.begin(), Weights.end());
+    EXPECT_EQ(UINT32_MAX / 4u, Weights[0]);
+    EXPECT_EQ(UINT32_MAX / 4u, Weights[1]);
+    EXPECT_EQ(UINT32_MAX / 4u, Weights[2]);
+  }
+}
+
 }