Let SelectionDAG start to use probability-based interface to add successors.
[oota-llvm.git] / include / llvm / Support / BranchProbability.h
1 //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Definition of BranchProbability shared by IR and Machine Instructions.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
15 #define LLVM_SUPPORT_BRANCHPROBABILITY_H
16
17 #include "llvm/Support/DataTypes.h"
18 #include <algorithm>
19 #include <cassert>
20 #include <climits>
21 #include <numeric>
22
23 namespace llvm {
24
25 class raw_ostream;
26
27 // This class represents Branch Probability as a non-negative fraction that is
28 // no greater than 1. It uses a fixed-point-like implementation, in which the
29 // denominator is always a constant value (here we use 1<<31 for maximum
30 // precision).
31 class BranchProbability {
32   // Numerator
33   uint32_t N;
34
35   // Denominator, which is a constant value.
36   static const uint32_t D = 1u << 31;
37   static const uint32_t UnknownN = UINT32_MAX;
38
39   // Construct a BranchProbability with only numerator assuming the denominator
40   // is 1<<31. For internal use only.
41   explicit BranchProbability(uint32_t n) : N(n) {}
42
43 public:
44   BranchProbability() : N(UnknownN) {}
45   BranchProbability(uint32_t Numerator, uint32_t Denominator);
46
47   bool isZero() const { return N == 0; }
48   bool isUnknown() const { return N == UnknownN; }
49
50   static BranchProbability getZero() { return BranchProbability(0); }
51   static BranchProbability getOne() { return BranchProbability(D); }
52   static BranchProbability getUnknown() { return BranchProbability(UnknownN); }
53   // Create a BranchProbability object with the given numerator and 1<<31
54   // as denominator.
55   static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); }
56
57   // Normalize given probabilties so that the sum of them becomes approximate
58   // one.
59   template <class ProbabilityIter>
60   static void normalizeProbabilities(ProbabilityIter Begin,
61                                      ProbabilityIter End);
62
63   // Normalize a list of weights by scaling them down so that the sum of them
64   // doesn't exceed UINT32_MAX.
65   template <class WeightListIter>
66   static void normalizeEdgeWeights(WeightListIter Begin, WeightListIter End);
67
68   uint32_t getNumerator() const { return N; }
69   static uint32_t getDenominator() { return D; }
70
71   // Return (1 - Probability).
72   BranchProbability getCompl() const { return BranchProbability(D - N); }
73
74   raw_ostream &print(raw_ostream &OS) const;
75
76   void dump() const;
77
78   /// \brief Scale a large integer.
79   ///
80   /// Scales \c Num.  Guarantees full precision.  Returns the floor of the
81   /// result.
82   ///
83   /// \return \c Num times \c this.
84   uint64_t scale(uint64_t Num) const;
85
86   /// \brief Scale a large integer by the inverse.
87   ///
88   /// Scales \c Num by the inverse of \c this.  Guarantees full precision.
89   /// Returns the floor of the result.
90   ///
91   /// \return \c Num divided by \c this.
92   uint64_t scaleByInverse(uint64_t Num) const;
93
94   BranchProbability &operator+=(BranchProbability RHS) {
95     assert(N != UnknownN && RHS.N != UnknownN &&
96            "Unknown probability cannot participate in arithmetics.");
97     // Saturate the result in case of overflow.
98     N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N;
99     return *this;
100   }
101
102   BranchProbability &operator-=(BranchProbability RHS) {
103     assert(N != UnknownN && RHS.N != UnknownN &&
104            "Unknown probability cannot participate in arithmetics.");
105     // Saturate the result in case of underflow.
106     N = N < RHS.N ? 0 : N - RHS.N;
107     return *this;
108   }
109
110   BranchProbability &operator*=(BranchProbability RHS) {
111     assert(N != UnknownN && RHS.N != UnknownN &&
112            "Unknown probability cannot participate in arithmetics.");
113     N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D;
114     return *this;
115   }
116
117   BranchProbability operator+(BranchProbability RHS) const {
118     BranchProbability Prob(*this);
119     return Prob += RHS;
120   }
121
122   BranchProbability operator-(BranchProbability RHS) const {
123     BranchProbability Prob(*this);
124     return Prob -= RHS;
125   }
126
127   BranchProbability operator*(BranchProbability RHS) const {
128     BranchProbability Prob(*this);
129     return Prob *= RHS;
130   }
131
132   bool operator==(BranchProbability RHS) const { return N == RHS.N; }
133   bool operator!=(BranchProbability RHS) const { return !(*this == RHS); }
134   bool operator<(BranchProbability RHS) const { return N < RHS.N; }
135   bool operator>(BranchProbability RHS) const { return RHS < *this; }
136   bool operator<=(BranchProbability RHS) const { return !(RHS < *this); }
137   bool operator>=(BranchProbability RHS) const { return !(*this < RHS); }
138 };
139
140 inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) {
141   return Prob.print(OS);
142 }
143
144 inline BranchProbability operator/(BranchProbability LHS, uint32_t RHS) {
145   assert(LHS != BranchProbability::getUnknown() &&
146          "Unknown probability cannot participate in arithmetics.");
147   return BranchProbability::getRaw(LHS.getNumerator() / RHS);
148 }
149
150 template <class ProbabilityIter>
151 void BranchProbability::normalizeProbabilities(ProbabilityIter Begin,
152                                                ProbabilityIter End) {
153   if (Begin == End)
154     return;
155
156   auto UnknownProbCount =
157       std::count(Begin, End, BranchProbability::getUnknown());
158   assert((UnknownProbCount == 0 ||
159           UnknownProbCount == std::distance(Begin, End)) &&
160          "Cannot normalize probabilities with known and unknown ones.");
161   (void)UnknownProbCount;
162
163   uint64_t Sum = std::accumulate(
164       Begin, End, uint64_t(0),
165       [](uint64_t S, const BranchProbability &BP) { return S + BP.N; });
166
167   if (Sum == 0) {
168     BranchProbability BP(1, std::distance(Begin, End));
169     std::fill(Begin, End, BP);
170     return;
171   }
172
173   for (auto I = Begin; I != End; ++I)
174     I->N = (I->N * uint64_t(D) + Sum / 2) / Sum;
175 }
176
177 template <class WeightListIter>
178 void BranchProbability::normalizeEdgeWeights(WeightListIter Begin,
179                                              WeightListIter End) {
180   // First we compute the sum with 64-bits of precision.
181   uint64_t Sum = std::accumulate(Begin, End, uint64_t(0));
182
183   if (Sum > UINT32_MAX) {
184     // Compute the scale necessary to cause the weights to fit, and re-sum with
185     // that scale applied.
186     assert(Sum / UINT32_MAX < UINT32_MAX &&
187            "The sum of weights exceeds UINT32_MAX^2!");
188     uint32_t Scale = Sum / UINT32_MAX + 1;
189     for (auto I = Begin; I != End; ++I)
190       *I /= Scale;
191     Sum = std::accumulate(Begin, End, uint64_t(0));
192   }
193
194   // Eliminate zero weights.
195   auto ZeroWeightNum = std::count(Begin, End, 0u);
196   if (ZeroWeightNum > 0) {
197     // If all weights are zeros, replace them by 1.
198     if (Sum == 0)
199       std::fill(Begin, End, 1u);
200     else {
201       // We are converting zeros into ones, and here we need to make sure that
202       // after this the sum won't exceed UINT32_MAX.
203       if (Sum + ZeroWeightNum > UINT32_MAX) {
204         for (auto I = Begin; I != End; ++I)
205           *I /= 2;
206         ZeroWeightNum = std::count(Begin, End, 0u);
207         Sum = std::accumulate(Begin, End, uint64_t(0));
208       }
209       // Scale up non-zero weights and turn zero weights into ones.
210       uint64_t ScalingFactor = (UINT32_MAX - ZeroWeightNum) / Sum;
211       assert(ScalingFactor >= 1);
212       if (ScalingFactor > 1)
213         for (auto I = Begin; I != End; ++I)
214           *I *= ScalingFactor;
215       std::replace(Begin, End, 0u, 1u);
216     }
217   }
218 }
219
220 }
221
222 #endif