[Support] Add saturating multiply-add support function
[oota-llvm.git] / include / llvm / ProfileData / SampleProf.h
index f62f79064c4e654f11f9f7454625e8bdfcbd08ae..6c39cf9458dcb4f39e86303cef6b03b40f5fc4be 100644 (file)
@@ -38,13 +38,24 @@ enum class sampleprof_error {
   unrecognized_format,
   unsupported_writing_format,
   truncated_name_table,
-  not_implemented
+  not_implemented,
+  counter_overflow
 };
 
 inline std::error_code make_error_code(sampleprof_error E) {
   return std::error_code(static_cast<int>(E), sampleprof_category());
 }
 
+inline sampleprof_error MergeResult(sampleprof_error &Accumulator,
+                                    sampleprof_error Result) {
+  // Prefer first error encountered as later errors may be secondary effects of
+  // the initial problem.
+  if (Accumulator == sampleprof_error::success &&
+      Result != sampleprof_error::success)
+    Accumulator = Result;
+  return Accumulator;
+}
+
 } // end namespace llvm
 
 namespace std {
@@ -123,18 +134,30 @@ public:
   SampleRecord() : NumSamples(0), CallTargets() {}
 
   /// Increment the number of samples for this record by \p S.
+  /// Optionally scale sample count \p S by \p Weight.
   ///
   /// Sample counts accumulate using saturating arithmetic, to avoid wrapping
   /// around unsigned integers.
-  void addSamples(uint64_t S) { NumSamples = SaturatingAdd(NumSamples, S); }
+  sampleprof_error addSamples(uint64_t S, uint64_t Weight = 1) {
+    bool Overflowed;
+    NumSamples = SaturatingMultiplyAdd(S, Weight, NumSamples, &Overflowed);
+    return Overflowed ? sampleprof_error::counter_overflow
+                      : sampleprof_error::success;
+  }
 
   /// Add called function \p F with samples \p S.
+  /// Optionally scale sample count \p S by \p Weight.
   ///
   /// Sample counts accumulate using saturating arithmetic, to avoid wrapping
   /// around unsigned integers.
-  void addCalledTarget(StringRef F, uint64_t S) {
+  sampleprof_error addCalledTarget(StringRef F, uint64_t S,
+                                   uint64_t Weight = 1) {
     uint64_t &TargetSamples = CallTargets[F];
-    TargetSamples = SaturatingAdd(TargetSamples, S);
+    bool Overflowed;
+    TargetSamples =
+        SaturatingMultiplyAdd(S, Weight, TargetSamples, &Overflowed);
+    return Overflowed ? sampleprof_error::counter_overflow
+                      : sampleprof_error::success;
   }
 
   /// Return true if this sample record contains function calls.
@@ -144,10 +167,13 @@ public:
   const CallTargetMap &getCallTargets() const { return CallTargets; }
 
   /// Merge the samples in \p Other into this record.
-  void merge(const SampleRecord &Other) {
-    addSamples(Other.getSamples());
-    for (const auto &I : Other.getCallTargets())
-      addCalledTarget(I.first(), I.second);
+  /// Optionally scale sample counts by \p Weight.
+  sampleprof_error merge(const SampleRecord &Other, uint64_t Weight = 1) {
+    sampleprof_error Result = addSamples(Other.getSamples(), Weight);
+    for (const auto &I : Other.getCallTargets()) {
+      MergeResult(Result, addCalledTarget(I.first(), I.second, Weight));
+    }
+    return Result;
   }
 
   void print(raw_ostream &OS, unsigned Indent) const;
@@ -174,16 +200,31 @@ public:
   FunctionSamples() : TotalSamples(0), TotalHeadSamples(0) {}
   void print(raw_ostream &OS = dbgs(), unsigned Indent = 0) const;
   void dump() const;
-  void addTotalSamples(uint64_t Num) { TotalSamples += Num; }
-  void addHeadSamples(uint64_t Num) { TotalHeadSamples += Num; }
-  void addBodySamples(uint32_t LineOffset, uint32_t Discriminator,
-                      uint64_t Num) {
-    BodySamples[LineLocation(LineOffset, Discriminator)].addSamples(Num);
+  sampleprof_error addTotalSamples(uint64_t Num, uint64_t Weight = 1) {
+    bool Overflowed;
+    TotalSamples =
+        SaturatingMultiplyAdd(Num, Weight, TotalSamples, &Overflowed);
+    return Overflowed ? sampleprof_error::counter_overflow
+                      : sampleprof_error::success;
+  }
+  sampleprof_error addHeadSamples(uint64_t Num, uint64_t Weight = 1) {
+    bool Overflowed;
+    TotalHeadSamples =
+        SaturatingMultiplyAdd(Num, Weight, TotalHeadSamples, &Overflowed);
+    return Overflowed ? sampleprof_error::counter_overflow
+                      : sampleprof_error::success;
+  }
+  sampleprof_error addBodySamples(uint32_t LineOffset, uint32_t Discriminator,
+                                  uint64_t Num, uint64_t Weight = 1) {
+    return BodySamples[LineLocation(LineOffset, Discriminator)].addSamples(
+        Num, Weight);
   }
-  void addCalledTargetSamples(uint32_t LineOffset, uint32_t Discriminator,
-                              std::string FName, uint64_t Num) {
-    BodySamples[LineLocation(LineOffset, Discriminator)].addCalledTarget(FName,
-                                                                         Num);
+  sampleprof_error addCalledTargetSamples(uint32_t LineOffset,
+                                          uint32_t Discriminator,
+                                          std::string FName, uint64_t Num,
+                                          uint64_t Weight = 1) {
+    return BodySamples[LineLocation(LineOffset, Discriminator)].addCalledTarget(
+        FName, Num, Weight);
   }
 
   /// Return the number of samples collected at the given location.
@@ -232,19 +273,22 @@ public:
   }
 
   /// Merge the samples in \p Other into this one.
-  void merge(const FunctionSamples &Other) {
-    addTotalSamples(Other.getTotalSamples());
-    addHeadSamples(Other.getHeadSamples());
+  /// Optionally scale samples by \p Weight.
+  sampleprof_error merge(const FunctionSamples &Other, uint64_t Weight = 1) {
+    sampleprof_error Result = sampleprof_error::success;
+    MergeResult(Result, addTotalSamples(Other.getTotalSamples(), Weight));
+    MergeResult(Result, addHeadSamples(Other.getHeadSamples(), Weight));
     for (const auto &I : Other.getBodySamples()) {
       const LineLocation &Loc = I.first;
       const SampleRecord &Rec = I.second;
-      BodySamples[Loc].merge(Rec);
+      MergeResult(Result, BodySamples[Loc].merge(Rec, Weight));
     }
     for (const auto &I : Other.getCallsiteSamples()) {
       const CallsiteLocation &Loc = I.first;
       const FunctionSamples &Rec = I.second;
-      functionSamplesAt(Loc).merge(Rec);
+      MergeResult(Result, functionSamplesAt(Loc).merge(Rec, Weight));
     }
+    return Result;
   }
 
 private: