Modify the interface BranchProbability::normalizeProbabilities to let it accept a...
[oota-llvm.git] / include / llvm / Support / BranchProbability.h
1 //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Definition of BranchProbability shared by IR and Machine Instructions.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
15 #define LLVM_SUPPORT_BRANCHPROBABILITY_H
16
17 #include "llvm/Support/DataTypes.h"
18 #include <algorithm>
19 #include <cassert>
20 #include <climits>
21 #include <numeric>
22
23 namespace llvm {
24
25 class raw_ostream;
26
27 // This class represents Branch Probability as a non-negative fraction that is
28 // no greater than 1. It uses a fixed-point-like implementation, in which the
29 // denominator is always a constant value (here we use 1<<31 for maximum
30 // precision).
31 class BranchProbability {
32   // Numerator
33   uint32_t N;
34
35   // Denominator, which is a constant value.
36   static const uint32_t D = 1u << 31;
37   static const uint32_t UnknownN = UINT32_MAX;
38
39   // Construct a BranchProbability with only numerator assuming the denominator
40   // is 1<<31. For internal use only.
41   explicit BranchProbability(uint32_t n) : N(n) {}
42
43 public:
44   BranchProbability() : N(0) {}
45   BranchProbability(uint32_t Numerator, uint32_t Denominator);
46
47   bool isZero() const { return N == 0; }
48   bool isUnknown() const { return N == UnknownN; }
49
50   static BranchProbability getZero() { return BranchProbability(0); }
51   static BranchProbability getOne() { return BranchProbability(D); }
52   static BranchProbability getUnknown() { return BranchProbability(UnknownN); }
53   // Create a BranchProbability object with the given numerator and 1<<31
54   // as denominator.
55   static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); }
56
57   // Normalize given probabilties so that the sum of them becomes approximate
58   // one.
59   template <class ProbabilityIter>
60   static void normalizeProbabilities(ProbabilityIter Begin,
61                                      ProbabilityIter End);
62
63   // Normalize a list of weights by scaling them down so that the sum of them
64   // doesn't exceed UINT32_MAX.
65   template <class WeightListIter>
66   static void normalizeEdgeWeights(WeightListIter Begin, WeightListIter End);
67
68   uint32_t getNumerator() const { return N; }
69   static uint32_t getDenominator() { return D; }
70
71   // Return (1 - Probability).
72   BranchProbability getCompl() const { return BranchProbability(D - N); }
73
74   raw_ostream &print(raw_ostream &OS) const;
75
76   void dump() const;
77
78   /// \brief Scale a large integer.
79   ///
80   /// Scales \c Num.  Guarantees full precision.  Returns the floor of the
81   /// result.
82   ///
83   /// \return \c Num times \c this.
84   uint64_t scale(uint64_t Num) const;
85
86   /// \brief Scale a large integer by the inverse.
87   ///
88   /// Scales \c Num by the inverse of \c this.  Guarantees full precision.
89   /// Returns the floor of the result.
90   ///
91   /// \return \c Num divided by \c this.
92   uint64_t scaleByInverse(uint64_t Num) const;
93
94   BranchProbability &operator+=(BranchProbability RHS) {
95     assert(N <= D - RHS.N &&
96            "The sum of branch probabilities should not exceed one!");
97     N += RHS.N;
98     return *this;
99   }
100
101   BranchProbability &operator-=(BranchProbability RHS) {
102     assert(N >= RHS.N &&
103            "Can only subtract a smaller probability from a larger one!");
104     N -= RHS.N;
105     return *this;
106   }
107
108   BranchProbability &operator*=(BranchProbability RHS) {
109     N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D;
110     return *this;
111   }
112
113   BranchProbability operator+(BranchProbability RHS) const {
114     BranchProbability Prob(*this);
115     return Prob += RHS;
116   }
117
118   BranchProbability operator-(BranchProbability RHS) const {
119     BranchProbability Prob(*this);
120     return Prob -= RHS;
121   }
122
123   BranchProbability operator*(BranchProbability RHS) const {
124     BranchProbability Prob(*this);
125     return Prob *= RHS;
126   }
127
128   bool operator==(BranchProbability RHS) const { return N == RHS.N; }
129   bool operator!=(BranchProbability RHS) const { return !(*this == RHS); }
130   bool operator<(BranchProbability RHS) const { return N < RHS.N; }
131   bool operator>(BranchProbability RHS) const { return RHS < *this; }
132   bool operator<=(BranchProbability RHS) const { return !(RHS < *this); }
133   bool operator>=(BranchProbability RHS) const { return !(*this < RHS); }
134 };
135
136 inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) {
137   return Prob.print(OS);
138 }
139
140 inline BranchProbability operator/(BranchProbability LHS, uint32_t RHS) {
141   return BranchProbability::getRaw(LHS.getNumerator() / RHS);
142 }
143
144 template <class ProbabilityIter>
145 void BranchProbability::normalizeProbabilities(ProbabilityIter Begin,
146                                                ProbabilityIter End) {
147   if (Begin == End)
148     return;
149
150   uint64_t Sum = 0;
151   for (auto I = Begin; I != End; ++I)
152     Sum += I->N;
153   assert(Sum > 0);
154   for (auto I = Begin; I != End; ++I)
155     I->N = (I->N * uint64_t(D) + Sum / 2) / Sum;
156 }
157
158 template <class WeightListIter>
159 void BranchProbability::normalizeEdgeWeights(WeightListIter Begin,
160                                              WeightListIter End) {
161   // First we compute the sum with 64-bits of precision.
162   uint64_t Sum = std::accumulate(Begin, End, uint64_t(0));
163
164   if (Sum > UINT32_MAX) {
165     // Compute the scale necessary to cause the weights to fit, and re-sum with
166     // that scale applied.
167     assert(Sum / UINT32_MAX < UINT32_MAX &&
168            "The sum of weights exceeds UINT32_MAX^2!");
169     uint32_t Scale = Sum / UINT32_MAX + 1;
170     for (auto I = Begin; I != End; ++I)
171       *I /= Scale;
172     Sum = std::accumulate(Begin, End, uint64_t(0));
173   }
174
175   // Eliminate zero weights.
176   auto ZeroWeightNum = std::count(Begin, End, 0u);
177   if (ZeroWeightNum > 0) {
178     // If all weights are zeros, replace them by 1.
179     if (Sum == 0)
180       std::fill(Begin, End, 1u);
181     else {
182       // We are converting zeros into ones, and here we need to make sure that
183       // after this the sum won't exceed UINT32_MAX.
184       if (Sum + ZeroWeightNum > UINT32_MAX) {
185         for (auto I = Begin; I != End; ++I)
186           *I /= 2;
187         ZeroWeightNum = std::count(Begin, End, 0u);
188         Sum = std::accumulate(Begin, End, uint64_t(0));
189       }
190       // Scale up non-zero weights and turn zero weights into ones.
191       uint64_t ScalingFactor = (UINT32_MAX - ZeroWeightNum) / Sum;
192       assert(ScalingFactor >= 1);
193       if (ScalingFactor > 1)
194         for (auto I = Begin; I != End; ++I)
195           *I *= ScalingFactor;
196       std::replace(Begin, End, 0u, 1u);
197     }
198   }
199 }
200
201 }
202
203 #endif