1 //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Definition of BranchProbability shared by IR and Machine Instructions.
11 //
12 //===----------------------------------------------------------------------===//
14 #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
15 #define LLVM_SUPPORT_BRANCHPROBABILITY_H
17 #include "llvm/Support/DataTypes.h"
18 #include <algorithm>
19 #include <cassert>
20 #include <climits>
21 #include <numeric>
23 namespace llvm {
25 class raw_ostream;
27 // This class represents Branch Probability as a non-negative fraction that is
28 // no greater than 1. It uses a fixed-point-like implementation, in which the
29 // denominator is always a constant value (here we use 1<<31 for maximum
30 // precision).
31 class BranchProbability {
32   // Numerator
33   uint32_t N;
35   // Denominator, which is a constant value.
36   static const uint32_t D = 1u << 31;
37   static const uint32_t UnknownN = UINT32_MAX;
39   // Construct a BranchProbability with only numerator assuming the denominator
40   // is 1<<31. For internal use only.
41   explicit BranchProbability(uint32_t n) : N(n) {}
43 public:
44   BranchProbability() : N(0) {}
45   BranchProbability(uint32_t Numerator, uint32_t Denominator);
47   bool isZero() const { return N == 0; }
48   bool isUnknown() const { return N == UnknownN; }
50   static BranchProbability getZero() { return BranchProbability(0); }
51   static BranchProbability getOne() { return BranchProbability(D); }
52   static BranchProbability getUnknown() { return BranchProbability(UnknownN); }
53   // Create a BranchProbability object with the given numerator and 1<<31
54   // as denominator.
55   static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); }
57   // Normalize given probabilties so that the sum of them becomes approximate
58   // one.
59   template <class ProbabilityIter>
60   static void normalizeProbabilities(ProbabilityIter Begin,
61                                      ProbabilityIter End);
63   // Normalize a list of weights by scaling them down so that the sum of them
64   // doesn't exceed UINT32_MAX.
65   template <class WeightListIter>
66   static void normalizeEdgeWeights(WeightListIter Begin, WeightListIter End);
68   uint32_t getNumerator() const { return N; }
69   static uint32_t getDenominator() { return D; }
71   // Return (1 - Probability).
72   BranchProbability getCompl() const { return BranchProbability(D - N); }
74   raw_ostream &print(raw_ostream &OS) const;
76   void dump() const;
78   /// \brief Scale a large integer.
79   ///
80   /// Scales \c Num.  Guarantees full precision.  Returns the floor of the
81   /// result.
82   ///
83   /// \return \c Num times \c this.
84   uint64_t scale(uint64_t Num) const;
86   /// \brief Scale a large integer by the inverse.
87   ///
88   /// Scales \c Num by the inverse of \c this.  Guarantees full precision.
89   /// Returns the floor of the result.
90   ///
91   /// \return \c Num divided by \c this.
92   uint64_t scaleByInverse(uint64_t Num) const;
94   BranchProbability &operator+=(BranchProbability RHS) {
95     // Saturate the result in case of overflow.
96     N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N;
97     return *this;
98   }
100   BranchProbability &operator-=(BranchProbability RHS) {
101     // Saturate the result in case of underflow.
102     N = N < RHS.N ? 0 : N - RHS.N;
103     return *this;
104   }
106   BranchProbability &operator*=(BranchProbability RHS) {
107     N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D;
108     return *this;
109   }
111   BranchProbability operator+(BranchProbability RHS) const {
112     BranchProbability Prob(*this);
113     return Prob += RHS;
114   }
116   BranchProbability operator-(BranchProbability RHS) const {
117     BranchProbability Prob(*this);
118     return Prob -= RHS;
119   }
121   BranchProbability operator*(BranchProbability RHS) const {
122     BranchProbability Prob(*this);
123     return Prob *= RHS;
124   }
126   bool operator==(BranchProbability RHS) const { return N == RHS.N; }
127   bool operator!=(BranchProbability RHS) const { return !(*this == RHS); }
128   bool operator<(BranchProbability RHS) const { return N < RHS.N; }
129   bool operator>(BranchProbability RHS) const { return RHS < *this; }
130   bool operator<=(BranchProbability RHS) const { return !(RHS < *this); }
131   bool operator>=(BranchProbability RHS) const { return !(*this < RHS); }
132 };
134 inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) {
135   return Prob.print(OS);
136 }
138 inline BranchProbability operator/(BranchProbability LHS, uint32_t RHS) {
139   return BranchProbability::getRaw(LHS.getNumerator() / RHS);
140 }
142 template <class ProbabilityIter>
143 void BranchProbability::normalizeProbabilities(ProbabilityIter Begin,
144                                                ProbabilityIter End) {
145   if (Begin == End)
146     return;
148   uint64_t Sum = 0;
149   for (auto I = Begin; I != End; ++I)
150     Sum += I->N;
151   assert(Sum > 0);
152   for (auto I = Begin; I != End; ++I)
153     I->N = (I->N * uint64_t(D) + Sum / 2) / Sum;
154 }
156 template <class WeightListIter>
157 void BranchProbability::normalizeEdgeWeights(WeightListIter Begin,
158                                              WeightListIter End) {
159   // First we compute the sum with 64-bits of precision.
160   uint64_t Sum = std::accumulate(Begin, End, uint64_t(0));
162   if (Sum > UINT32_MAX) {
163     // Compute the scale necessary to cause the weights to fit, and re-sum with
164     // that scale applied.
165     assert(Sum / UINT32_MAX < UINT32_MAX &&
166            "The sum of weights exceeds UINT32_MAX^2!");
167     uint32_t Scale = Sum / UINT32_MAX + 1;
168     for (auto I = Begin; I != End; ++I)
169       *I /= Scale;
170     Sum = std::accumulate(Begin, End, uint64_t(0));
171   }
173   // Eliminate zero weights.
174   auto ZeroWeightNum = std::count(Begin, End, 0u);
175   if (ZeroWeightNum > 0) {
176     // If all weights are zeros, replace them by 1.
177     if (Sum == 0)
178       std::fill(Begin, End, 1u);
179     else {
180       // We are converting zeros into ones, and here we need to make sure that
181       // after this the sum won't exceed UINT32_MAX.
182       if (Sum + ZeroWeightNum > UINT32_MAX) {
183         for (auto I = Begin; I != End; ++I)
184           *I /= 2;
185         ZeroWeightNum = std::count(Begin, End, 0u);
186         Sum = std::accumulate(Begin, End, uint64_t(0));
187       }
188       // Scale up non-zero weights and turn zero weights into ones.
189       uint64_t ScalingFactor = (UINT32_MAX - ZeroWeightNum) / Sum;
190       assert(ScalingFactor >= 1);
191       if (ScalingFactor > 1)
192         for (auto I = Begin; I != End; ++I)
193           *I *= ScalingFactor;
194       std::replace(Begin, End, 0u, 1u);
195     }
196   }
197 }
199 }
201 #endif