Update the branch weight metadata in JumpThreading pass.
[oota-llvm.git] / include / llvm / Support / BranchProbability.h
1 //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Definition of BranchProbability shared by IR and Machine Instructions.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
15 #define LLVM_SUPPORT_BRANCHPROBABILITY_H
16
17 #include "llvm/Support/DataTypes.h"
18 #include <algorithm>
19 #include <cassert>
20 #include <climits>
21 #include <numeric>
22
23 namespace llvm {
24
25 class raw_ostream;
26
27 // This class represents Branch Probability as a non-negative fraction that is
28 // no greater than 1. It uses a fixed-point-like implementation, in which the
29 // denominator is always a constant value (here we use 1<<31 for maximum
30 // precision).
31 class BranchProbability {
32   // Numerator
33   uint32_t N;
34
35   // Denominator, which is a constant value.
36   static const uint32_t D = 1u << 31;
37
38   // Construct a BranchProbability with only numerator assuming the denominator
39   // is 1<<31. For internal use only.
40   explicit BranchProbability(uint32_t n) : N(n) {}
41
42 public:
43   BranchProbability() : N(0) {}
44   BranchProbability(uint32_t Numerator, uint32_t Denominator);
45
46   bool isZero() const { return N == 0; }
47
48   static BranchProbability getZero() { return BranchProbability(0); }
49   static BranchProbability getOne() { return BranchProbability(D); }
50   // Create a BranchProbability object with the given numerator and 1<<31
51   // as denominator.
52   static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); }
53
54   // Normalize given probabilties so that the sum of them becomes approximate
55   // one.
56   template <class ProbabilityList>
57   static void normalizeProbabilities(ProbabilityList &Probs);
58
59   // Normalize a list of weights by scaling them down so that the sum of them
60   // doesn't exceed UINT32_MAX.
61   template <class WeightListIter>
62   static void normalizeEdgeWeights(WeightListIter Begin, WeightListIter End);
63
64   uint32_t getNumerator() const { return N; }
65   static uint32_t getDenominator() { return D; }
66
67   // Return (1 - Probability).
68   BranchProbability getCompl() const { return BranchProbability(D - N); }
69
70   raw_ostream &print(raw_ostream &OS) const;
71
72   void dump() const;
73
74   /// \brief Scale a large integer.
75   ///
76   /// Scales \c Num.  Guarantees full precision.  Returns the floor of the
77   /// result.
78   ///
79   /// \return \c Num times \c this.
80   uint64_t scale(uint64_t Num) const;
81
82   /// \brief Scale a large integer by the inverse.
83   ///
84   /// Scales \c Num by the inverse of \c this.  Guarantees full precision.
85   /// Returns the floor of the result.
86   ///
87   /// \return \c Num divided by \c this.
88   uint64_t scaleByInverse(uint64_t Num) const;
89
90   BranchProbability &operator+=(BranchProbability RHS) {
91     assert(N <= D - RHS.N &&
92            "The sum of branch probabilities should not exceed one!");
93     N += RHS.N;
94     return *this;
95   }
96
97   BranchProbability &operator-=(BranchProbability RHS) {
98     assert(N >= RHS.N &&
99            "Can only subtract a smaller probability from a larger one!");
100     N -= RHS.N;
101     return *this;
102   }
103
104   BranchProbability &operator*=(BranchProbability RHS) {
105     N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D;
106     return *this;
107   }
108
109   BranchProbability operator+(BranchProbability RHS) const {
110     BranchProbability Prob(*this);
111     return Prob += RHS;
112   }
113
114   BranchProbability operator-(BranchProbability RHS) const {
115     BranchProbability Prob(*this);
116     return Prob -= RHS;
117   }
118
119   BranchProbability operator*(BranchProbability RHS) const {
120     BranchProbability Prob(*this);
121     return Prob *= RHS;
122   }
123
124   bool operator==(BranchProbability RHS) const { return N == RHS.N; }
125   bool operator!=(BranchProbability RHS) const { return !(*this == RHS); }
126   bool operator<(BranchProbability RHS) const { return N < RHS.N; }
127   bool operator>(BranchProbability RHS) const { return RHS < *this; }
128   bool operator<=(BranchProbability RHS) const { return !(RHS < *this); }
129   bool operator>=(BranchProbability RHS) const { return !(*this < RHS); }
130 };
131
132 inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) {
133   return Prob.print(OS);
134 }
135
136 template <class ProbabilityList>
137 void BranchProbability::normalizeProbabilities(ProbabilityList &Probs) {
138   uint64_t Sum = 0;
139   for (auto Prob : Probs)
140     Sum += Prob.N;
141   assert(Sum > 0);
142   for (auto &Prob : Probs)
143     Prob.N = (Prob.N * uint64_t(D) + Sum / 2) / Sum;
144 }
145
146 template <class WeightListIter>
147 void BranchProbability::normalizeEdgeWeights(WeightListIter Begin,
148                                              WeightListIter End) {
149   // First we compute the sum with 64-bits of precision.
150   uint64_t Sum = std::accumulate(Begin, End, uint64_t(0));
151
152   if (Sum > UINT32_MAX) {
153     // Compute the scale necessary to cause the weights to fit, and re-sum with
154     // that scale applied.
155     assert(Sum / UINT32_MAX < UINT32_MAX &&
156            "The sum of weights exceeds UINT32_MAX^2!");
157     uint32_t Scale = Sum / UINT32_MAX + 1;
158     for (auto I = Begin; I != End; ++I)
159       *I /= Scale;
160     Sum = std::accumulate(Begin, End, uint64_t(0));
161   }
162
163   // Eliminate zero weights.
164   auto ZeroWeightNum = std::count(Begin, End, 0u);
165   if (ZeroWeightNum > 0) {
166     // If all weights are zeros, replace them by 1.
167     if (Sum == 0)
168       std::fill(Begin, End, 1u);
169     else {
170       // We are converting zeros into ones, and here we need to make sure that
171       // after this the sum won't exceed UINT32_MAX.
172       if (Sum + ZeroWeightNum > UINT32_MAX) {
173         for (auto I = Begin; I != End; ++I)
174           *I /= 2;
175         ZeroWeightNum = std::count(Begin, End, 0u);
176         Sum = std::accumulate(Begin, End, uint64_t(0));
177       }
178       // Scale up non-zero weights and turn zero weights into ones.
179       uint64_t ScalingFactor = (UINT32_MAX - ZeroWeightNum) / Sum;
180       assert(ScalingFactor >= 1);
181       if (ScalingFactor > 1)
182         for (auto I = Begin; I != End; ++I)
183           *I *= ScalingFactor;
184       std::replace(Begin, End, 0u, 1u);
185     }
186   }
187 }
188
189 }
190
191 #endif