X-Git-Url: http://plrg.eecs.uci.edu/git/?p=oota-llvm.git;a=blobdiff_plain;f=include%2Fllvm%2FSupport%2FBranchProbability.h;h=f204777066e9be5b0dfd2b3eff63822884e512b7;hp=98b2a3d7d9a2460c26ad964cf239d957faf3c81f;hb=a99158bbd8d2f9c4822f2e61173807e2039f61b3;hpb=73e2613bcae85f944720761d6a770cf55908ce3c diff --git a/include/llvm/Support/BranchProbability.h b/include/llvm/Support/BranchProbability.h index 98b2a3d7d9a..f204777066e 100644 --- a/include/llvm/Support/BranchProbability.h +++ b/include/llvm/Support/BranchProbability.h @@ -15,7 +15,10 @@ #define LLVM_SUPPORT_BRANCHPROBABILITY_H #include "llvm/Support/DataTypes.h" +#include #include +#include +#include namespace llvm { @@ -53,6 +56,11 @@ public: template static void normalizeProbabilities(ProbabilityList &Probs); + // Normalize a list of weights by scaling them down so that the sum of them + // doesn't exceed UINT32_MAX. + template + static void normalizeEdgeWeights(WeightListIter Begin, WeightListIter End); + uint32_t getNumerator() const { return N; } static uint32_t getDenominator() { return D; } @@ -135,6 +143,49 @@ void BranchProbability::normalizeProbabilities(ProbabilityList &Probs) { Prob.N = (Prob.N * uint64_t(D) + Sum / 2) / Sum; } +template +void BranchProbability::normalizeEdgeWeights(WeightListIter Begin, + WeightListIter End) { + // First we compute the sum with 64-bits of precision. + uint64_t Sum = std::accumulate(Begin, End, uint64_t(0)); + + if (Sum > UINT32_MAX) { + // Compute the scale necessary to cause the weights to fit, and re-sum with + // that scale applied. + assert(Sum / UINT32_MAX < UINT32_MAX && + "The sum of weights exceeds UINT32_MAX^2!"); + uint32_t Scale = Sum / UINT32_MAX + 1; + for (auto I = Begin; I != End; ++I) + *I /= Scale; + Sum = std::accumulate(Begin, End, uint64_t(0)); + } + + // Eliminate zero weights. + auto ZeroWeightNum = std::count(Begin, End, 0u); + if (ZeroWeightNum > 0) { + // If all weights are zeros, replace them by 1. + if (Sum == 0) + std::fill(Begin, End, 1u); + else { + // We are converting zeros into ones, and here we need to make sure that + // after this the sum won't exceed UINT32_MAX. + if (Sum + ZeroWeightNum > UINT32_MAX) { + for (auto I = Begin; I != End; ++I) + *I /= 2; + ZeroWeightNum = std::count(Begin, End, 0u); + Sum = std::accumulate(Begin, End, uint64_t(0)); + } + // Scale up non-zero weights and turn zero weights into ones. + uint64_t ScalingFactor = (UINT32_MAX - ZeroWeightNum) / Sum; + assert(ScalingFactor >= 1); + if (ScalingFactor > 1) + for (auto I = Begin; I != End; ++I) + *I *= ScalingFactor; + std::replace(Begin, End, 0u, 1u); + } + } +} + } #endif