[llvm-profdata] Add support for weighted merge of profile data
[oota-llvm.git] / include / llvm / ProfileData / InstrProf.h
index 95648511910298c5e73e27cc04d05f6babd6703c..e1ed2e9ce48c995b2cc562cc44931c45dbcfee80 100644 (file)
@@ -218,7 +218,8 @@ struct InstrProfValueSiteRecord {
   }
 
   /// Merge data from another InstrProfValueSiteRecord
-  void mergeValueData(InstrProfValueSiteRecord &Input) {
+  /// Optionally scale merged counts by \p Weight.
+  void mergeValueData(InstrProfValueSiteRecord &Input, uint64_t Weight = 1) {
     this->sortByTargetValues();
     Input.sortByTargetValues();
     auto I = ValueData.begin();
@@ -228,7 +229,11 @@ struct InstrProfValueSiteRecord {
       while (I != IE && I->Value < J->Value)
         ++I;
       if (I != IE && I->Value == J->Value) {
-        I->Count = SaturatingAdd(I->Count, J->Count);
+        // TODO: Check for counter overflow and return error if it occurs.
+        uint64_t JCount = J->Count;
+        if (Weight > 1)
+          JCount = SaturatingMultiply(JCount, Weight);
+        I->Count = SaturatingAdd(I->Count, JCount);
         ++I;
         continue;
       }
@@ -274,7 +279,8 @@ struct InstrProfRecord {
                            ValueMapType *HashKeys);
 
   /// Merge the counts in \p Other into this one.
-  inline instrprof_error merge(InstrProfRecord &Other);
+  /// Optionally scale merged counts by \p Weight.
+  inline instrprof_error merge(InstrProfRecord &Other, uint64_t Weight = 1);
 
   /// Used by InstrProfWriter: update the value strings to commoned strings in
   /// the writer instance.
@@ -326,7 +332,9 @@ private:
   }
 
   // Merge Value Profile data from Src record to this record for ValueKind.
-  instrprof_error mergeValueProfData(uint32_t ValueKind, InstrProfRecord &Src) {
+  // Scale merged value counts by \p Weight.
+  instrprof_error mergeValueProfData(uint32_t ValueKind, InstrProfRecord &Src,
+                                     uint64_t Weight) {
     uint32_t ThisNumValueSites = getNumValueSites(ValueKind);
     uint32_t OtherNumValueSites = Src.getNumValueSites(ValueKind);
     if (ThisNumValueSites != OtherNumValueSites)
@@ -336,7 +344,7 @@ private:
     std::vector<InstrProfValueSiteRecord> &OtherSiteRecords =
         Src.getValueSitesForKind(ValueKind);
     for (uint32_t I = 0; I < ThisNumValueSites; I++)
-      ThisSiteRecords[I].mergeValueData(OtherSiteRecords[I]);
+      ThisSiteRecords[I].mergeValueData(OtherSiteRecords[I], Weight);
     return instrprof_error::success;
   }
 };
@@ -422,7 +430,8 @@ void InstrProfRecord::updateStrings(InstrProfStringTable *StrTab) {
       VData.Value = (uint64_t)StrTab->insertString((const char *)VData.Value);
 }
 
-instrprof_error InstrProfRecord::merge(InstrProfRecord &Other) {
+instrprof_error InstrProfRecord::merge(InstrProfRecord &Other,
+                                       uint64_t Weight) {
   // If the number of counters doesn't match we either have bad data
   // or a hash collision.
   if (Counts.size() != Other.Counts.size())
@@ -432,13 +441,19 @@ instrprof_error InstrProfRecord::merge(InstrProfRecord &Other) {
 
   for (size_t I = 0, E = Other.Counts.size(); I < E; ++I) {
     bool ResultOverflowed;
-    Counts[I] = SaturatingAdd(Counts[I], Other.Counts[I], ResultOverflowed);
+    uint64_t OtherCount = Other.Counts[I];
+    if (Weight > 1) {
+      OtherCount = SaturatingMultiply(OtherCount, Weight, ResultOverflowed);
+      if (ResultOverflowed)
+        Result = instrprof_error::counter_overflow;
+    }
+    Counts[I] = SaturatingAdd(Counts[I], OtherCount, ResultOverflowed);
     if (ResultOverflowed)
       Result = instrprof_error::counter_overflow;
   }
 
   for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) {
-    instrprof_error MergeValueResult = mergeValueProfData(Kind, Other);
+    instrprof_error MergeValueResult = mergeValueProfData(Kind, Other, Weight);
     if (MergeValueResult != instrprof_error::success)
       Result = MergeValueResult;
   }