Update the branch weight metadata in JumpThreading pass.
[oota-llvm.git] / include / llvm / CodeGen / MachineBranchProbabilityInfo.h
index 058ab32f3aa993683b0edaf6ee2cdcd7949bfe13..26f0d99373871e0ec791934c68535379f6289c5b 100644 (file)
@@ -17,6 +17,7 @@
 #include "llvm/CodeGen/MachineBasicBlock.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/BranchProbability.h"
+#include <algorithm>
 #include <climits>
 #include <numeric>
 
@@ -83,8 +84,39 @@ public:
   raw_ostream &printEdgeProbability(raw_ostream &OS,
                                     const MachineBasicBlock *Src,
                                     const MachineBasicBlock *Dst) const;
+
+  // Normalize a list of weights by scaling them down so that the sum of them
+  // doesn't exceed UINT32_MAX. Return the scale.
+  template <class WeightListIter>
+  static uint32_t normalizeEdgeWeights(WeightListIter Begin,
+                                       WeightListIter End);
 };
 
+template <class WeightListIter>
+uint32_t
+MachineBranchProbabilityInfo::normalizeEdgeWeights(WeightListIter Begin,
+                                                   WeightListIter End) {
+  // First we compute the sum with 64-bits of precision.
+  uint64_t Sum = std::accumulate(Begin, End, uint64_t(0));
+
+  // If Sum is zero, set all weights to 1.
+  if (Sum == 0)
+    std::fill(Begin, End, uint64_t(1));
+
+  // If the computed sum fits in 32-bits, we're done.
+  if (Sum <= UINT32_MAX)
+    return 1;
+
+  // Otherwise, compute the scale necessary to cause the weights to fit, and
+  // re-sum with that scale applied.
+  assert((Sum / UINT32_MAX) < UINT32_MAX &&
+         "The sum of weights exceeds UINT32_MAX^2!");
+  uint32_t Scale = (Sum / UINT32_MAX) + 1;
+  for (auto I = Begin; I != End; ++I)
+    *I /= Scale;
+  return Scale;
+}
+
 }