[PGO] MST based PGO instrumentation infrastructure
[oota-llvm.git] / lib / Transforms / Instrumentation / PGOInstrumentation.cpp
1 //===- PGOInstru.cpp - PGO Instrumentation --------===//
2 //
3 //                      The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements PGO instrumentation using a minimum spanning tree based
11 // on the  following paper.
12 //   [1] Donald E. Knuth, Francis R. Stevenson. Optimal measurement of points
13 //   for program frequency counts. BIT Numerical Mathematics 1973, Volume 13,
14 //   Issue 3, pp 313-322
15 // The idea of the algorithm based on the fact that for each node (except for
16 // the entry and exit), the sum of incoming edge counts equals the sum of
17 // outgoing edge counts. The count of edge on spanning tree can be derived from
18 // those edges not on the spanning tree. Knuth proves this method instruments
19 // the minimum number of edges.
20 //
21 // The minimal spanning tree here is actually a maximum weight tree -- on-tree
22 // edges have higher frequencies (most likely to execute). The idea is to
23 // instrument those less frequently executed edges which speeds up the
24 // instrumented binaries.
25 //
26 // This file contains two passes:
27 // (1) Pass PGOInstrumentationGen which instruments the IR to generate edge
28 // count profile, and
29 // (2) Pass PGOInstrumentationUse which reads the edge count profile and
30 // annotates the branch weight.
31 // These two passes are mutually exclusive, and they are called at the same
32 // compilation point (so they see the same IR). For PGOInstrumentationGen,
33 // the real work is done instrumentOneFunc(). For PGOInstrumentationUse, the
34 // real work in done in class PGOUseFunc and the profile is opened in module
35 // level and passed to each PGOUseFunc instance.
36 // The shared code for PGOInstrumentationGen and PGOInstrumentationUse is put
37 // in class FuncPGOInstrumentation.
38 //
39 // Class PGOEdge represents a CFG edge and some auxiliary information. Class
40 // BBInfo contains auxiliary information for a BB. These two classes are used
41 // in PGOGenFunc. Class PGOUseEdge and UseBBInfo are the derived class of
42 // PGOEdge and BBInfo, respectively. They contains extra data structure used
43 // in populating profile counters.
44 // The MST implementation is in Class CFGMST.
45 //
46 //===----------------------------------------------------------------------===//
47
48 #include "llvm/Transforms/Instrumentation.h"
49 #include "llvm/ADT/Statistic.h"
50 #include "llvm/ADT/DenseMap.h"
51 #include "llvm/ADT/STLExtras.h"
52 #include "llvm/IR/InstIterator.h"
53 #include "llvm/IR/Instructions.h"
54 #include "llvm/IR/IntrinsicInst.h"
55 #include "llvm/IR/IRBuilder.h"
56 #include "llvm/IR/MDBuilder.h"
57 #include "llvm/IR/DiagnosticInfo.h"
58 #include "llvm/Pass.h"
59 #include "llvm/IR/Module.h"
60 #include "llvm/Support/Debug.h"
61 #include "llvm/Support/BranchProbability.h"
62 #include "llvm/Support/JamCRC.h"
63 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
64 #include "llvm/ProfileData/InstrProfReader.h"
65 #include "llvm/Analysis/CFG.h"
66 #include "llvm/Analysis/BranchProbabilityInfo.h"
67 #include "llvm/Analysis/BlockFrequencyInfo.h"
68 #include <string>
69 #include <vector>
70 #include <utility>
71 #include "CFGMST.h"
72
73 using namespace llvm;
74
75 #define DEBUG_TYPE "pgo-instr"
76
77 STATISTIC(NumOfPGOInstrument, "Number of edges instrumented.");
78 STATISTIC(NumOfPGOEdge, "Number of edges.");
79 STATISTIC(NumOfPGOBB, "Number of basic-blocks.");
80 STATISTIC(NumOfPGOSplit, "Number of critical edge splits.");
81 STATISTIC(NumOfPGOFunc, "Number of functions having valid profile counts.");
82 STATISTIC(NumOfPGOMismatch, "Number of functions having mismatch profile.");
83 STATISTIC(NumOfPGOMissing, "Number of functions without profile.");
84
85 static cl::opt<std::string>
86     PGOProfileFile("pgo-profile-file", cl::init(""), cl::Hidden,
87                    cl::value_desc("filename"),
88                    cl::desc("Specify the path of profile data file"));
89
90 namespace {
91 class PGOInstrumentationGen : public ModulePass {
92 public:
93   static char ID;
94
95   PGOInstrumentationGen() : ModulePass(ID) {
96     initializePGOInstrumentationGenPass(*PassRegistry::getPassRegistry());
97   }
98
99   const char *getPassName() const override {
100     return "PGOInstrumentationGenPass";
101   }
102
103 private:
104   bool runOnModule(Module &M) override;
105
106   void getAnalysisUsage(AnalysisUsage &AU) const override {
107     AU.addRequired<BlockFrequencyInfoWrapperPass>();
108     AU.addRequired<BranchProbabilityInfoWrapperPass>();
109   }
110 };
111
112 class PGOInstrumentationUse : public ModulePass {
113 public:
114   static char ID;
115
116   // Provide the profile filename as the parameter.
117   PGOInstrumentationUse(StringRef Filename = StringRef(""))
118       : ModulePass(ID), ProfileFileName(Filename) {
119     if (!PGOProfileFile.empty())
120       ProfileFileName = StringRef(PGOProfileFile);
121     initializePGOInstrumentationUsePass(*PassRegistry::getPassRegistry());
122   }
123
124   const char *getPassName() const override {
125     return "PGOInstrumentationUsePass";
126   }
127
128 private:
129   StringRef ProfileFileName;
130   std::unique_ptr<IndexedInstrProfReader> PGOReader;
131   bool runOnModule(Module &M) override;
132
133   void getAnalysisUsage(AnalysisUsage &AU) const override {
134     AU.addRequired<BlockFrequencyInfoWrapperPass>();
135     AU.addRequired<BranchProbabilityInfoWrapperPass>();
136   }
137 };
138 } // end anonymous namespace
139
140 char PGOInstrumentationGen::ID = 0;
141 INITIALIZE_PASS_BEGIN(PGOInstrumentationGen, "pgo-instr-gen",
142                       "PGO instrumentation.", false, false)
143 INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
144 INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
145 INITIALIZE_PASS_END(PGOInstrumentationGen, "pgo-instr-gen",
146                     "PGO instrumentation.", false, false)
147
148 ModulePass *llvm::createPGOInstrumentationGenPass() {
149   return new PGOInstrumentationGen();
150 }
151
152 char PGOInstrumentationUse::ID = 0;
153 INITIALIZE_PASS_BEGIN(PGOInstrumentationUse, "pgo-instr-use",
154                       "Read PGO instrumentation profile.", false, false)
155 INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
156 INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
157 INITIALIZE_PASS_END(PGOInstrumentationUse, "pgo-instr-use",
158                     "Read PGO instrumentation profile.", false, false)
159
160 ModulePass *llvm::createPGOInstrumentationUsePass(StringRef Filename) {
161   return new PGOInstrumentationUse(Filename);
162 }
163
164 namespace {
165 /// \brief An MST based instrumentation for PGO
166 ///
167 /// Implements a Minimum Spanning Tree (MST) based instrumentation for PGO
168 /// in the function level.
169 //
170 // This class implements the CFG edges. Note the CFG can be a multi-graph.
171 struct PGOEdge {
172   const BasicBlock *SrcBB;
173   const BasicBlock *DestBB;
174   uint64_t Weight;
175   bool InMST;
176   bool Removed;
177   bool IsCritical;
178   PGOEdge(const BasicBlock *Src, const BasicBlock *Dest, unsigned W = 1)
179       : SrcBB(Src), DestBB(Dest), Weight(W), InMST(false), Removed(false),
180         IsCritical(false) {}
181   // Return the information string of an edge.
182   const std::string infoString() const {
183     std::string Str = (Removed ? "-" : " ");
184     Str += (InMST ? " " : "*");
185     Str += (IsCritical ? "c" : " ");
186     Str += "  W=" + std::to_string(Weight);
187     return Str;
188   }
189 };
190
191 // This class stores the auxiliary information for each BB.
192 struct BBInfo {
193   BBInfo *Group;
194   uint32_t Index;
195   uint32_t Rank;
196
197   BBInfo(unsigned IX) : Group(this), Index(IX), Rank(0) {}
198
199   // Return the information string of this object.
200   const std::string infoString() const {
201     return "Index=" + std::to_string(Index);
202   }
203 };
204
205 // This class implements the CFG edges. Note the CFG can be a multi-graph.
206 template <class Edge, class BBInfo> class FuncPGOInstrumentation {
207 private:
208   Function &F;
209   void computeCFGHash();
210
211 public:
212   std::string FuncName;
213   GlobalVariable *FuncNameVar;
214   // CFG hash value for this function.
215   uint64_t FunctionHash;
216
217   // The Minimum Spanning Tree of function CFG.
218   CFGMST<Edge, BBInfo> MST;
219
220   // Give an edge, find the BB that will be instrumented.
221   // Return nullptr if there is no BB to be instrumented.
222   BasicBlock *getInstrBB(Edge *E);
223
224   // Return the auxiliary BB information.
225   BBInfo &getBBInfo(const BasicBlock *BB) const { return MST.getBBInfo(BB); }
226
227   // Dump edges and BB information.
228   void dumpInfo(std::string Str = "") const {
229     std::string Message = "Dump Function " + FuncName + " Hash: " +
230                           std::to_string(FunctionHash) + "\t" + Str;
231     MST.dumpEdges(dbgs(), Message);
232   }
233
234   FuncPGOInstrumentation(Function &Func, bool CreateGlobalVar = false,
235                          BranchProbabilityInfo *BPI_ = nullptr,
236                          BlockFrequencyInfo *BFI_ = nullptr)
237       : F(Func), FunctionHash(0), MST(F, BPI_, BFI_) {
238     FuncName = getPGOFuncName(F);
239     computeCFGHash();
240     DEBUG(dumpInfo("after CFGMST"));
241
242     NumOfPGOBB += MST.BBInfos.size();
243     for (auto &Ei : MST.AllEdges) {
244       if (Ei->Removed)
245         continue;
246       NumOfPGOEdge++;
247       if (!Ei->InMST)
248         NumOfPGOInstrument++;
249     }
250
251     if (CreateGlobalVar)
252       FuncNameVar = createPGOFuncNameVar(F, FuncName);
253   };
254 };
255
256 // Compute Hash value for the CFG: the lower 32 bits are CRC32 of the index
257 // value of each BB in the CFG. The higher 32 bits record the number of edges.
258 template <class Edge, class BBInfo>
259 void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() {
260   std::vector<char> Indexes;
261   JamCRC JC;
262   for (auto &BB : F) {
263     const TerminatorInst *TI = BB.getTerminator();
264     for (unsigned s = 0, e = TI->getNumSuccessors(); s != e; ++s) {
265       BasicBlock *Succ = TI->getSuccessor(s);
266       uint32_t Index = getBBInfo(Succ).Index;
267       for (int i = 0; i < sizeof(uint32_t) / sizeof(char); i++)
268         Indexes.push_back((char)(Index >> (i * sizeof(char))));
269     }
270   }
271   JC.update(Indexes);
272   FunctionHash = MST.AllEdges.size() << 32 | JC.getCRC();
273 }
274
275 template <class Edge, class BBInfo>
276 BasicBlock *FuncPGOInstrumentation<Edge, BBInfo>::getInstrBB(Edge *E) {
277   if (E->InMST || E->Removed)
278     return nullptr;
279
280   BasicBlock *SrcBB = const_cast<BasicBlock *>(E->SrcBB);
281   BasicBlock *DestBB = const_cast<BasicBlock *>(E->DestBB);
282   // For a fake edge, instrument the real BB.
283   if (SrcBB == nullptr)
284     return DestBB;
285   if (DestBB == nullptr)
286     return SrcBB;
287
288   // Instrument the SrcBB if it has a single successor,
289   // otherwise, the DestBB if this is not a critical edge.
290   TerminatorInst *TI = SrcBB->getTerminator();
291   if (TI->getNumSuccessors() <= 1)
292     return SrcBB;
293   if (!E->IsCritical)
294     return DestBB;
295
296   // For a critical edge, we have to split. Instrument the newly
297   // created BB.
298   NumOfPGOSplit++;
299   DEBUG(dbgs() << "Split critical edge: " << getBBInfo(SrcBB).Index << " --> "
300                << getBBInfo(DestBB).Index << "\n");
301   unsigned SuccNum = GetSuccessorNumber(SrcBB, DestBB);
302   BasicBlock *InstrBB = SplitCriticalEdge(TI, SuccNum);
303   assert(InstrBB && "Critical edge is not split");
304
305   E->Removed = true;
306   return InstrBB;
307 }
308
309 // Visit all edge and instrument the edges not in MST.
310 // Critical edges will be split.
311 static void instrumentOneFunc(Function &F, Module *M,
312                               BranchProbabilityInfo *BPI,
313                               BlockFrequencyInfo *BFI) {
314   unsigned NumCounters = 0;
315   FuncPGOInstrumentation<PGOEdge, BBInfo> FuncInfo(F, true, BPI, BFI);
316   for (auto &Ei : FuncInfo.MST.AllEdges) {
317     if (!Ei->InMST && !Ei->Removed)
318       NumCounters++;
319   }
320
321   uint32_t j = 0;
322   for (auto &Ei : FuncInfo.MST.AllEdges) {
323     BasicBlock *InstrBB = FuncInfo.getInstrBB(Ei.get());
324     if (!InstrBB)
325       continue;
326
327     IRBuilder<> Builder(InstrBB, InstrBB->getFirstInsertionPt());
328     assert(Builder.GetInsertPoint() != InstrBB->end() &&
329            "Cannot get the Instrumentation point");
330     auto *I8PtrTy = Type::getInt8PtrTy(M->getContext());
331     Builder.CreateCall(
332         Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment),
333         {llvm::ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy),
334          Builder.getInt64(FuncInfo.FunctionHash), Builder.getInt32(NumCounters),
335          Builder.getInt32(j++)});
336   }
337 }
338
339 struct PGOUseEdge : public PGOEdge {
340   bool CountValid;
341   uint64_t CountValue;
342   PGOUseEdge(const BasicBlock *Src, const BasicBlock *Dest, unsigned W = 1)
343       : PGOEdge(Src, Dest, W), CountValid(false), CountValue(0) {}
344
345   // Set edge count value
346   void setEdgeCount(uint64_t Value) {
347     CountValue = Value;
348     CountValid = true;
349   }
350
351   // Return the information string for this object.
352   const std::string infoString() const {
353     if (!CountValid)
354       return PGOEdge::infoString();
355     return PGOEdge::infoString() + "  Count=" + std::to_string(CountValue);
356   }
357 };
358
359 typedef SmallVector<PGOUseEdge *, 2> DirectEdges;
360
361 // This class stores the auxiliary information for each BB.
362 struct UseBBInfo : public BBInfo {
363   uint64_t CountValue;
364   bool CountValid;
365   int32_t UnknownCountInEdge;
366   int32_t UnknownCountOutEdge;
367   DirectEdges InEdges;
368   DirectEdges OutEdges;
369   UseBBInfo(unsigned IX)
370       : BBInfo(IX), CountValue(0), CountValid(false), UnknownCountInEdge(0),
371         UnknownCountOutEdge(0) {}
372   UseBBInfo(unsigned IX, uint64_t C)
373       : BBInfo(IX), CountValue(C), CountValid(true), UnknownCountInEdge(0),
374         UnknownCountOutEdge(0) {}
375
376   // Set the profile count value for this BB.
377   void setBBInfoCount(uint64_t Value) {
378     CountValue = Value;
379     CountValid = true;
380   }
381
382   // Return the information string of this object.
383   const std::string infoString() const {
384     if (!CountValid)
385       return BBInfo::infoString();
386     return BBInfo::infoString() + "  Count=" + std::to_string(CountValue);
387   }
388 };
389
390 // Sum up the count values for all the edges.
391 static uint64_t sumEdgeCount(const ArrayRef<PGOUseEdge *> Edges) {
392   uint64_t Total = 0;
393   for (auto &Ei : Edges) {
394     if (Ei->Removed)
395       continue;
396     Total += Ei->CountValue;
397   }
398   return Total;
399 }
400
401 class PGOUseFunc {
402 private:
403   Function &F;
404   Module *M;
405   // This member stores the shared information with class PGOGenFunc.
406   FuncPGOInstrumentation<PGOUseEdge, UseBBInfo> FuncInfo;
407
408   // Return the auxiliary BB information.
409   UseBBInfo &getBBInfo(const BasicBlock *BB) const {
410     return FuncInfo.getBBInfo(BB);
411   }
412
413   // The maximum count value in the profile. This is only used in PGO use
414   // compilation.
415   uint64_t ProgramMaxCount;
416
417   // Find the Instrumented BB and set the value.
418   void setInstrumentedCounts(const std::vector<uint64_t> &CountFromProfile);
419
420   // Set the edge counter value for the unknown edge -- there should be only
421   // one unknown edge.
422   void setEdgeCount(DirectEdges &Edges, uint64_t Value);
423
424   // Return FuncName string;
425   const std::string getFuncName() const { return FuncInfo.FuncName; }
426
427   // Set the hot/cold inline hints based on the count values.
428   void applyFunctionAttributes(uint64_t EntryCount, uint64_t MaxCount) {
429     if (ProgramMaxCount == 0)
430       return;
431     // Threshold of the hot functions.
432     const BranchProbability HotFunctionThreshold(1, 100);
433     // Threshold of the cold functions.
434     const BranchProbability ColdFunctionThreshold(2, 10000);
435     if (EntryCount >= HotFunctionThreshold.scale(ProgramMaxCount))
436       F.addFnAttr(llvm::Attribute::InlineHint);
437     else if (MaxCount <= ColdFunctionThreshold.scale(ProgramMaxCount))
438       F.addFnAttr(llvm::Attribute::Cold);
439   }
440
441 public:
442   PGOUseFunc(Function &Func, Module *Modu,
443              BranchProbabilityInfo *BPI_ = nullptr,
444              BlockFrequencyInfo *BFI_ = nullptr)
445       : F(Func), M(Modu), FuncInfo(Func, false, BPI_, BFI_) {}
446
447   // Read counts for the instrumented BB from profile.
448   bool readCounters(IndexedInstrProfReader *PGOReader);
449
450   // Populate the counts for all BBs.
451   void populateCounters();
452
453   // Set the branch weights based on the count values.
454   void setBranchWeights();
455 };
456
457 // Visit all the edges and assign the count value for the instrumented
458 // edges and the BB.
459 void PGOUseFunc::setInstrumentedCounts(
460     const std::vector<uint64_t> &CountFromProfile) {
461
462   // Use a worklist as we will update the vector during the iteration.
463   std::vector<PGOUseEdge *> WorkList;
464   for (auto &Ei : FuncInfo.MST.AllEdges)
465     WorkList.push_back(Ei.get());
466
467   uint32_t j = 0;
468   for (auto &Ei : WorkList) {
469     BasicBlock *InstrBB = FuncInfo.getInstrBB(Ei);
470     if (!InstrBB)
471       continue;
472     uint64_t CountValue = CountFromProfile[j++];
473     if (!Ei->Removed) {
474       getBBInfo(InstrBB).setBBInfoCount(CountValue);
475       Ei->setEdgeCount(CountValue);
476       continue;
477     }
478
479     // Need to add two new edges.
480     BasicBlock *SrcBB = const_cast<BasicBlock *>(Ei->SrcBB);
481     BasicBlock *DestBB = const_cast<BasicBlock *>(Ei->DestBB);
482     // Add new edge of SrcBB->InstrBB.
483     PGOUseEdge &NewEdge = FuncInfo.MST.addEdge(SrcBB, InstrBB, 0);
484     NewEdge.setEdgeCount(CountValue);
485     // Add new edge of InstrBB->DestBB.
486     PGOUseEdge &NewEdge1 = FuncInfo.MST.addEdge(InstrBB, DestBB, 0);
487     NewEdge1.setEdgeCount(CountValue);
488     NewEdge1.InMST = true;
489     getBBInfo(InstrBB).setBBInfoCount(CountValue);
490   }
491 }
492
493 // Set the count value for the unknown edges. There should be one and only one
494 // unknown edge in Edges vector.
495 void PGOUseFunc::setEdgeCount(DirectEdges &Edges, uint64_t Value) {
496   for (auto &Ei : Edges) {
497     if (Ei->CountValid)
498       continue;
499     Ei->setEdgeCount(Value);
500
501     getBBInfo(Ei->SrcBB).UnknownCountOutEdge--;
502     getBBInfo(Ei->DestBB).UnknownCountInEdge--;
503     return;
504   }
505   llvm_unreachable("Cannot find the unknown count edge");
506 }
507
508 // Read the profile from ProfileFileName and assign the value to the
509 // instrumented BB and the edges. This function also updates ProgramMaxCount.
510 // Return true if the profile are successfully read, and false on errors.
511 bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader) {
512   auto &Ctx = M->getContext();
513   ErrorOr<InstrProfRecord> Result =
514       PGOReader->getInstrProfRecord(FuncInfo.FuncName, FuncInfo.FunctionHash);
515   if (std::error_code EC = Result.getError()) {
516     if (EC == instrprof_error::unknown_function)
517       NumOfPGOMissing++;
518     else if (EC == instrprof_error::hash_mismatch ||
519              EC == llvm::instrprof_error::malformed)
520       NumOfPGOMismatch++;
521
522     std::string Msg = EC.message() + std::string(" ") + F.getName().str();
523     Ctx.diagnose(
524         DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning));
525     return false;
526   }
527   std::vector<uint64_t> &CountFromProfile = Result.get().Counts;
528
529   NumOfPGOFunc++;
530   DEBUG(dbgs() << CountFromProfile.size() << " counts\n");
531   uint64_t ValueSum = 0;
532   for (unsigned i = 0, e = CountFromProfile.size(); i < e; i++) {
533     DEBUG(dbgs() << "  " << i << ": " << CountFromProfile[i] << "\n");
534     ValueSum += CountFromProfile[i];
535   }
536
537   DEBUG(dbgs() << "SUM =  " << ValueSum << "\n");
538
539   getBBInfo(nullptr).UnknownCountOutEdge = 2;
540   getBBInfo(nullptr).UnknownCountInEdge = 2;
541
542   setInstrumentedCounts(CountFromProfile);
543   ProgramMaxCount = PGOReader->getMaximumFunctionCount();
544   return true;
545 }
546
547 // Populate the counters from instrumented BBs to all BBs.
548 // In the end of this operation, all BBs should have a valid count value.
549 void PGOUseFunc::populateCounters() {
550   // First set up Count variable for all BBs.
551   for (auto &Ei : FuncInfo.MST.AllEdges) {
552     if (Ei->Removed)
553       continue;
554
555     const BasicBlock *SrcBB = Ei->SrcBB;
556     const BasicBlock *DestBB = Ei->DestBB;
557     UseBBInfo &SrcInfo = getBBInfo(SrcBB);
558     UseBBInfo &DestInfo = getBBInfo(DestBB);
559     SrcInfo.OutEdges.push_back(Ei.get());
560     DestInfo.InEdges.push_back(Ei.get());
561     SrcInfo.UnknownCountOutEdge++;
562     DestInfo.UnknownCountInEdge++;
563
564     if (!Ei->CountValid)
565       continue;
566     DestInfo.UnknownCountInEdge--;
567     SrcInfo.UnknownCountOutEdge--;
568   }
569
570   bool Changes = true;
571   unsigned NumPasses = 0;
572   while (Changes) {
573     NumPasses++;
574     Changes = false;
575
576     // For efficient traversal, it's better to start from the end as most
577     // of the instrumented edges are at the end.
578     for (auto &BB : reverse(F)) {
579       UseBBInfo &Count = getBBInfo(&BB);
580       if (!Count.CountValid) {
581         if (Count.UnknownCountOutEdge == 0) {
582           Count.CountValue = sumEdgeCount(Count.OutEdges);
583           Count.CountValid = true;
584           Changes = true;
585         } else if (Count.UnknownCountInEdge == 0) {
586           Count.CountValue = sumEdgeCount(Count.InEdges);
587           Count.CountValid = true;
588           Changes = true;
589         }
590       }
591       if (Count.CountValid) {
592         if (Count.UnknownCountOutEdge == 1) {
593           uint64_t Total = Count.CountValue - sumEdgeCount(Count.OutEdges);
594           setEdgeCount(Count.OutEdges, Total);
595           Changes = true;
596         }
597         if (Count.UnknownCountInEdge == 1) {
598           uint64_t Total = Count.CountValue - sumEdgeCount(Count.InEdges);
599           setEdgeCount(Count.InEdges, Total);
600           Changes = true;
601         }
602       }
603     }
604   }
605
606   DEBUG(dbgs() << "Populate counts in " << NumPasses << " passes.\n");
607   // Assert every BB has a valid counter.
608   uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue;
609   uint64_t FuncMaxCount = FuncEntryCount;
610   for (auto &BB : F) {
611     assert(getBBInfo(&BB).CountValid && "BB count is not valid");
612     uint64_t Count = getBBInfo(&BB).CountValue;
613     if (Count > FuncMaxCount)
614       FuncMaxCount = Count;
615   }
616   applyFunctionAttributes(FuncEntryCount, FuncMaxCount);
617
618   DEBUG(FuncInfo.dumpInfo("after reading profile."));
619 }
620
621 // Assign the scaled count values to the BB with multiple out edges.
622 void PGOUseFunc::setBranchWeights() {
623   // Generate MD_prof metadata for every branch instruction.
624   DEBUG(dbgs() << "\nSetting branch weights.\n");
625   MDBuilder MDB(M->getContext());
626   for (auto &BB : F) {
627     TerminatorInst *TI = BB.getTerminator();
628     if (TI->getNumSuccessors() < 2)
629       continue;
630     if (!isa<BranchInst>(TI) && !isa<SwitchInst>(TI))
631       continue;
632     if (getBBInfo(&BB).CountValue == 0)
633       continue;
634
635     // We have a non-zero Branch BB.
636     const UseBBInfo &BBCountInfo = getBBInfo(&BB);
637     unsigned Size = BBCountInfo.OutEdges.size();
638     SmallVector<unsigned, 2> EdgeCounts(Size, 0);
639     uint64_t MaxCount = 0;
640     for (unsigned s = 0; s < Size; s++) {
641       const PGOUseEdge *E = BBCountInfo.OutEdges[s];
642       const BasicBlock *SrcBB = E->SrcBB;
643       const BasicBlock *DestBB = E->DestBB;
644       if (DestBB == 0)
645         continue;
646       unsigned SuccNum = GetSuccessorNumber(SrcBB, DestBB);
647       uint64_t EdgeCount = E->CountValue;
648       if (EdgeCount > MaxCount)
649         MaxCount = EdgeCount;
650       EdgeCounts[SuccNum] = EdgeCount;
651     }
652     assert(MaxCount > 0 && "Bad max count");
653     uint64_t Scale = calculateCountScale(MaxCount);
654     SmallVector<unsigned, 4> Weights;
655     for (const auto &ECI : EdgeCounts)
656       Weights.push_back(scaleBranchCount(ECI, Scale));
657
658     TI->setMetadata(llvm::LLVMContext::MD_prof,
659                     MDB.createBranchWeights(Weights));
660     DEBUG(dbgs() << "Weight is: "; for (const auto &W
661                                         : Weights) dbgs()
662                                    << W << " ";
663           dbgs() << "\n";);
664   }
665 }
666 } // end anonymous namespace
667
668 bool PGOInstrumentationGen::runOnModule(Module &M) {
669   for (auto &F : M) {
670     if (F.isDeclaration())
671       continue;
672     BranchProbabilityInfo *BPI =
673         &(getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI());
674     BlockFrequencyInfo *BFI =
675         &(getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI());
676     instrumentOneFunc(F, &M, BPI, BFI);
677   }
678   return true;
679 }
680
681 static void setPGOCountOnFunc(PGOUseFunc &Func,
682                               IndexedInstrProfReader *PGOReader) {
683   if (Func.readCounters(PGOReader)) {
684     Func.populateCounters();
685     Func.setBranchWeights();
686   }
687 }
688
689 bool PGOInstrumentationUse::runOnModule(Module &M) {
690   DEBUG(dbgs() << "Read in profile counters: ");
691   auto &Ctx = M.getContext();
692   // Read the counter array from file.
693   auto ReaderOrErr = IndexedInstrProfReader::create(ProfileFileName);
694   if (std::error_code EC = ReaderOrErr.getError()) {
695     Ctx.diagnose(
696         DiagnosticInfoPGOProfile(ProfileFileName.data(), EC.message()));
697     return false;
698   }
699
700   PGOReader = std::move(ReaderOrErr.get());
701   if (!PGOReader) {
702     Ctx.diagnose(DiagnosticInfoPGOProfile(ProfileFileName.data(),
703                                           "Cannot get PGOReader"));
704     return false;
705   }
706
707   for (auto &F : M) {
708     if (F.isDeclaration())
709       continue;
710     BranchProbabilityInfo *BPI =
711         &(getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI());
712     BlockFrequencyInfo *BFI =
713         &(getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI());
714     PGOUseFunc Func(F, &M, BPI, BFI);
715     setPGOCountOnFunc(Func, PGOReader.get());
716   }
717   return true;
718 }