#include "llvm/Pass.h"
#include "llvm/Support/BranchProbability.h"
#include <climits>
+#include <numeric>
namespace llvm {
// adjustment. Any edge weights used with the sum should be divided by Scale.
uint32_t getSumForBlock(const MachineBasicBlock *MBB, uint32_t &Scale) const;
+ // Get sum of the block successors' weights, and force normalizing the
+ // successors' weights of MBB so that their sum fit within 32-bits.
+ uint32_t getSumForBlock(MachineBasicBlock *MBB) const;
+
// A 'Hot' edge is an edge which probability is >= 80%.
- bool isEdgeHot(MachineBasicBlock *Src, MachineBasicBlock *Dst) const;
+ bool isEdgeHot(const MachineBasicBlock *Src,
+ const MachineBasicBlock *Dst) const;
// Return a hot successor for the block BB or null if there isn't one.
// NB: This routine's complexity is linear on the number of successors.
// NB: This routine's complexity is linear on the number of successors of
// Src. Querying sequentially for each successor's probability is a quadratic
// query pattern.
- BranchProbability getEdgeProbability(MachineBasicBlock *Src,
- MachineBasicBlock *Dst) const;
+ BranchProbability getEdgeProbability(const MachineBasicBlock *Src,
+ const MachineBasicBlock *Dst) const;
// Print value between 0 (0% probability) and 1 (100% probability),
// however the value is never equal to 0, and can be 1 only iff SRC block
// has only one successor.
- raw_ostream &printEdgeProbability(raw_ostream &OS, MachineBasicBlock *Src,
- MachineBasicBlock *Dst) const;
+ raw_ostream &printEdgeProbability(raw_ostream &OS,
+ const MachineBasicBlock *Src,
+ const MachineBasicBlock *Dst) const;
+
+ // Normalize a list of weights by scaling them down so that the sum of them
+ // doesn't exceed UINT32_MAX. Return the scale.
+ template <class WeightList>
+ static uint32_t normalizeEdgeWeights(WeightList &Weights);
};
+template <class WeightList>
+uint32_t
+MachineBranchProbabilityInfo::normalizeEdgeWeights(WeightList &Weights) {
+ assert(Weights.size() < UINT32_MAX && "Too many weights in the list!");
+ // First we compute the sum with 64-bits of precision.
+ uint64_t Sum = std::accumulate(Weights.begin(), Weights.end(), uint64_t(0));
+
+ // If the computed sum fits in 32-bits, we're done.
+ if (Sum <= UINT32_MAX)
+ return 1;
+
+ // Otherwise, compute the scale necessary to cause the weights to fit, and
+ // re-sum with that scale applied.
+ assert((Sum / UINT32_MAX) < UINT32_MAX &&
+ "The sum of weights exceeds UINT32_MAX^2!");
+ uint32_t Scale = (Sum / UINT32_MAX) + 1;
+ for (auto &W : Weights)
+ W /= Scale;
+ return Scale;
+}
+
}