SamplePGO - Count sample records in embedded profiles when computing coverage.
[oota-llvm.git] / lib / Transforms / IPO / SampleProfile.cpp
index 51e95a5887a27e292aad45fb246ab19c6ddb515f..7c01a8672feac6dd9cfa3781b576984e44769bb6 100644 (file)
@@ -183,10 +183,11 @@ class SampleCoverageTracker {
 public:
   SampleCoverageTracker() : SampleCoverage() {}
 
-  void markSamplesUsed(const FunctionSamples *Samples, uint32_t LineOffset,
+  bool markSamplesUsed(const FunctionSamples *Samples, uint32_t LineOffset,
                        uint32_t Discriminator);
-  unsigned computeCoverage(const FunctionSamples *Samples) const;
-  unsigned getNumUsedSamples(const FunctionSamples *Samples) const;
+  unsigned computeCoverage(unsigned Used, unsigned Total) const;
+  unsigned countUsedSamples(const FunctionSamples *Samples) const;
+  unsigned countBodySamples(const FunctionSamples *Samples) const;
 
 private:
   typedef DenseMap<LineLocation, unsigned> BodySampleCoverageMap;
@@ -210,18 +211,35 @@ SampleCoverageTracker CoverageTracker;
 
 /// Mark as used the sample record for the given function samples at
 /// (LineOffset, Discriminator).
-void SampleCoverageTracker::markSamplesUsed(const FunctionSamples *Samples,
+///
+/// \returns true if this is the first time we mark the given record.
+bool SampleCoverageTracker::markSamplesUsed(const FunctionSamples *Samples,
                                             uint32_t LineOffset,
                                             uint32_t Discriminator) {
-  BodySampleCoverageMap &Coverage = SampleCoverage[Samples];
-  Coverage[LineLocation(LineOffset, Discriminator)]++;
+  LineLocation Loc(LineOffset, Discriminator);
+  unsigned &Count = SampleCoverage[Samples][Loc];
+  return ++Count == 1;
 }
 
 /// Return the number of sample records that were applied from this profile.
 unsigned
-SampleCoverageTracker::getNumUsedSamples(const FunctionSamples *Samples) const {
+SampleCoverageTracker::countUsedSamples(const FunctionSamples *Samples) const {
   auto I = SampleCoverage.find(Samples);
-  return (I != SampleCoverage.end()) ? I->second.size() : 0;
+  unsigned Count = (I != SampleCoverage.end()) ? I->second.size() : 0;
+  for (const auto &I : Samples->getCallsiteSamples())
+    Count += countUsedSamples(&I.second);
+  return Count;
+}
+
+/// Return the number of sample records in the body of this profile.
+///
+/// The count includes all the samples in inlined callees.
+unsigned
+SampleCoverageTracker::countBodySamples(const FunctionSamples *Samples) const {
+  unsigned Count = Samples->getBodySamples().size();
+  for (const auto &I : Samples->getCallsiteSamples())
+    Count += countBodySamples(&I.second);
+  return Count;
 }
 
 /// Return the fraction of sample records used in this profile.
@@ -229,13 +247,11 @@ SampleCoverageTracker::getNumUsedSamples(const FunctionSamples *Samples) const {
 /// The returned value is an unsigned integer in the range 0-100 indicating
 /// the percentage of sample records that were used while applying this
 /// profile to the associated function.
-unsigned
-SampleCoverageTracker::computeCoverage(const FunctionSamples *Samples) const {
-  uint32_t NumTotalRecords = Samples->getBodySamples().size();
-  uint32_t NumUsedRecords = getNumUsedSamples(Samples);
-  assert(NumUsedRecords <= NumTotalRecords &&
+unsigned SampleCoverageTracker::computeCoverage(unsigned Used,
+                                                unsigned Total) const {
+  assert(Used <= Total &&
          "number of used records cannot exceed the total number of records");
-  return NumTotalRecords > 0 ? NumUsedRecords * 100 / NumTotalRecords : 100;
+  return Total > 0 ? Used * 100 / Total : 100;
 }
 
 /// Clear all the per-function data used to load samples and propagate weights.
@@ -323,8 +339,15 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) const {
   uint32_t Discriminator = DIL->getDiscriminator();
   ErrorOr<uint64_t> R = FS->findSamplesAt(LineOffset, Discriminator);
   if (R) {
-    if (SampleProfileCoverage)
-      CoverageTracker.markSamplesUsed(FS, LineOffset, Discriminator);
+    bool FirstMark =
+        CoverageTracker.markSamplesUsed(FS, LineOffset, Discriminator);
+    if (FirstMark) {
+      const Function *F = Inst.getParent()->getParent();
+      LLVMContext &Ctx = F->getContext();
+      emitOptimizationRemark(Ctx, DEBUG_TYPE, *F, DLoc,
+                             Twine("Applied ") + Twine(*R) +
+                                 " samples from profile");
+    }
     DEBUG(dbgs() << "    " << Lineno << "." << DIL->getDiscriminator() << ":"
                  << Inst << " (line offset: " << Lineno - HeaderLineno << "."
                  << DIL->getDiscriminator() << " - weight: " << R.get()
@@ -377,20 +400,6 @@ bool SampleProfileLoader::computeBlockWeights(Function &F) {
     DEBUG(printBlockWeight(dbgs(), &BB));
   }
 
-  if (SampleProfileCoverage) {
-    unsigned Coverage = CoverageTracker.computeCoverage(Samples);
-    if (Coverage < SampleProfileCoverage) {
-      StringRef Filename = getDISubprogram(&F)->getFilename();
-      F.getContext().diagnose(DiagnosticInfoSampleProfile(
-          Filename.str().c_str(), getFunctionLoc(F),
-          Twine(CoverageTracker.getNumUsedSamples(Samples)) + " of " +
-              Twine(Samples->getBodySamples().size()) +
-              " available profile records (" + Twine(Coverage) +
-              "%) were applied",
-          DS_Warning));
-    }
-  }
-
   return Changed;
 }
 
@@ -994,6 +1003,21 @@ bool SampleProfileLoader::emitAnnotations(Function &F) {
     propagateWeights(F);
   }
 
+  // If coverage checking was requested, compute it now.
+  if (SampleProfileCoverage) {
+    unsigned Used = CoverageTracker.countUsedSamples(Samples);
+    unsigned Total = CoverageTracker.countBodySamples(Samples);
+    unsigned Coverage = CoverageTracker.computeCoverage(Used, Total);
+    if (Coverage < SampleProfileCoverage) {
+      StringRef Filename = getDISubprogram(&F)->getFilename();
+      F.getContext().diagnose(DiagnosticInfoSampleProfile(
+          Filename.str().c_str(), getFunctionLoc(F),
+          Twine(Used) + " of " + Twine(Total) + " available profile records (" +
+              Twine(Coverage) + "%) were applied",
+          DS_Warning));
+    }
+  }
+
   return Changed;
 }