[PGO] Handle and report overflow during profile merge for all types of data
[oota-llvm.git] / include / llvm / ProfileData / SampleProf.h
index 7607e24ec1c84714cc38c94b60af1c384ea52bf0..8df3fe80320930d38464e72cf4ed08e0fb6dc1d6 100644 (file)
@@ -38,13 +38,24 @@ enum class sampleprof_error {
   unrecognized_format,
   unsupported_writing_format,
   truncated_name_table,
-  not_implemented
+  not_implemented,
+  counter_overflow
 };
 
 inline std::error_code make_error_code(sampleprof_error E) {
   return std::error_code(static_cast<int>(E), sampleprof_category());
 }
 
+inline sampleprof_error MergeResult(sampleprof_error &Accumulator,
+                                    sampleprof_error Result) {
+  // Prefer first error encountered as later errors may be secondary effects of
+  // the initial problem.
+  if (Accumulator == sampleprof_error::success &&
+      Result != sampleprof_error::success)
+    Accumulator = Result;
+  return Accumulator;
+}
+
 } // end namespace llvm
 
 namespace std {
@@ -127,15 +138,18 @@ public:
   ///
   /// Sample counts accumulate using saturating arithmetic, to avoid wrapping
   /// around unsigned integers.
-  void addSamples(uint64_t S, uint64_t Weight = 1) {
-    // FIXME: Improve handling of counter overflow.
+  sampleprof_error addSamples(uint64_t S, uint64_t Weight = 1) {
     bool Overflowed;
     if (Weight > 1) {
       S = SaturatingMultiply(S, Weight, &Overflowed);
-      assert(!Overflowed && "Sample counter overflowed!");
+      if (Overflowed)
+        return sampleprof_error::counter_overflow;
     }
     NumSamples = SaturatingAdd(NumSamples, S, &Overflowed);
-    assert(!Overflowed && "Sample counter overflowed!");
+    if (Overflowed)
+      return sampleprof_error::counter_overflow;
+
+    return sampleprof_error::success;
   }
 
   /// Add called function \p F with samples \p S.
@@ -143,16 +157,20 @@ public:
   ///
   /// Sample counts accumulate using saturating arithmetic, to avoid wrapping
   /// around unsigned integers.
-  void addCalledTarget(StringRef F, uint64_t S, uint64_t Weight = 1) {
-    // FIXME: Improve handling of counter overflow.
+  sampleprof_error addCalledTarget(StringRef F, uint64_t S,
+                                   uint64_t Weight = 1) {
     uint64_t &TargetSamples = CallTargets[F];
     bool Overflowed;
     if (Weight > 1) {
       S = SaturatingMultiply(S, Weight, &Overflowed);
-      assert(!Overflowed && "Called target counter overflowed!");
+      if (Overflowed)
+        return sampleprof_error::counter_overflow;
     }
     TargetSamples = SaturatingAdd(TargetSamples, S, &Overflowed);
-    assert(!Overflowed && "Called target counter overflowed!");
+    if (Overflowed)
+      return sampleprof_error::counter_overflow;
+
+    return sampleprof_error::success;
   }
 
   /// Return true if this sample record contains function calls.
@@ -163,10 +181,12 @@ public:
 
   /// Merge the samples in \p Other into this record.
   /// Optionally scale sample counts by \p Weight.
-  void merge(const SampleRecord &Other, uint64_t Weight = 1) {
-    addSamples(Other.getSamples(), Weight);
-    for (const auto &I : Other.getCallTargets())
-      addCalledTarget(I.first(), I.second, Weight);
+  sampleprof_error merge(const SampleRecord &Other, uint64_t Weight = 1) {
+    sampleprof_error Result = addSamples(Other.getSamples(), Weight);
+    for (const auto &I : Other.getCallTargets()) {
+      MergeResult(Result, addCalledTarget(I.first(), I.second, Weight));
+    }
+    return Result;
   }
 
   void print(raw_ostream &OS, unsigned Indent) const;
@@ -193,35 +213,42 @@ public:
   FunctionSamples() : TotalSamples(0), TotalHeadSamples(0) {}
   void print(raw_ostream &OS = dbgs(), unsigned Indent = 0) const;
   void dump() const;
-  void addTotalSamples(uint64_t Num, uint64_t Weight = 1) {
-    // FIXME: Improve handling of counter overflow.
+  sampleprof_error addTotalSamples(uint64_t Num, uint64_t Weight = 1) {
     bool Overflowed;
     if (Weight > 1) {
       Num = SaturatingMultiply(Num, Weight, &Overflowed);
-      assert(!Overflowed && "Total samples counter overflowed!");
+      if (Overflowed)
+        return sampleprof_error::counter_overflow;
     }
     TotalSamples = SaturatingAdd(TotalSamples, Num, &Overflowed);
-    assert(!Overflowed && "Total samples counter overflowed!");
+    if (Overflowed)
+      return sampleprof_error::counter_overflow;
+
+    return sampleprof_error::success;
   }
-  void addHeadSamples(uint64_t Num, uint64_t Weight = 1) {
-    // FIXME: Improve handling of counter overflow.
+  sampleprof_error addHeadSamples(uint64_t Num, uint64_t Weight = 1) {
     bool Overflowed;
     if (Weight > 1) {
       Num = SaturatingMultiply(Num, Weight, &Overflowed);
-      assert(!Overflowed && "Total head samples counter overflowed!");
+      if (Overflowed)
+        return sampleprof_error::counter_overflow;
     }
     TotalHeadSamples = SaturatingAdd(TotalHeadSamples, Num, &Overflowed);
-    assert(!Overflowed && "Total head samples counter overflowed!");
+    if (Overflowed)
+      return sampleprof_error::counter_overflow;
+
+    return sampleprof_error::success;
   }
-  void addBodySamples(uint32_t LineOffset, uint32_t Discriminator, uint64_t Num,
-                      uint64_t Weight = 1) {
-    BodySamples[LineLocation(LineOffset, Discriminator)].addSamples(Num,
-                                                                    Weight);
+  sampleprof_error addBodySamples(uint32_t LineOffset, uint32_t Discriminator,
+                                  uint64_t Num, uint64_t Weight = 1) {
+    return BodySamples[LineLocation(LineOffset, Discriminator)].addSamples(
+        Num, Weight);
   }
-  void addCalledTargetSamples(uint32_t LineOffset, uint32_t Discriminator,
-                              std::string FName, uint64_t Num,
-                              uint64_t Weight = 1) {
-    BodySamples[LineLocation(LineOffset, Discriminator)].addCalledTarget(
+  sampleprof_error addCalledTargetSamples(uint32_t LineOffset,
+                                          uint32_t Discriminator,
+                                          std::string FName, uint64_t Num,
+                                          uint64_t Weight = 1) {
+    return BodySamples[LineLocation(LineOffset, Discriminator)].addCalledTarget(
         FName, Num, Weight);
   }
 
@@ -272,19 +299,21 @@ public:
 
   /// Merge the samples in \p Other into this one.
   /// Optionally scale samples by \p Weight.
-  void merge(const FunctionSamples &Other, uint64_t Weight = 1) {
-    addTotalSamples(Other.getTotalSamples(), Weight);
-    addHeadSamples(Other.getHeadSamples(), Weight);
+  sampleprof_error merge(const FunctionSamples &Other, uint64_t Weight = 1) {
+    sampleprof_error Result = sampleprof_error::success;
+    MergeResult(Result, addTotalSamples(Other.getTotalSamples(), Weight));
+    MergeResult(Result, addHeadSamples(Other.getHeadSamples(), Weight));
     for (const auto &I : Other.getBodySamples()) {
       const LineLocation &Loc = I.first;
       const SampleRecord &Rec = I.second;
-      BodySamples[Loc].merge(Rec, Weight);
+      MergeResult(Result, BodySamples[Loc].merge(Rec, Weight));
     }
     for (const auto &I : Other.getCallsiteSamples()) {
       const CallsiteLocation &Loc = I.first;
       const FunctionSamples &Rec = I.second;
-      functionSamplesAt(Loc).merge(Rec, Weight);
+      MergeResult(Result, functionSamplesAt(Loc).merge(Rec, Weight));
     }
+    return Result;
   }
 
 private: