[Support] Add saturating multiply-add support function
[oota-llvm.git] / lib / ProfileData / InstrProf.cpp
index df3f8fade3b2f806db28751490a7b42ab6da4bb2..d6777639abe7ab2adf873d97fb6976adc027e75f 100644 (file)
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ProfileData/InstrProf.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalVariable.h"
@@ -167,44 +168,51 @@ GlobalVariable *createPGOFuncNameVar(Function &F, StringRef FuncName) {
 int collectPGOFuncNameStrings(const std::vector<std::string> &NameStrs,
                               bool doCompression, std::string &Result) {
   uint8_t Header[16], *P = Header;
-  std::string UncompressedNameStrings;
+  std::string UncompressedNameStrings =
+      join(NameStrs.begin(), NameStrs.end(), StringRef(" "));
 
-  for (auto NameStr : NameStrs) {
-    UncompressedNameStrings += NameStr;
-    UncompressedNameStrings.append(" ");
-  }
   unsigned EncLen = encodeULEB128(UncompressedNameStrings.length(), P);
   P += EncLen;
-  if (!doCompression) {
-    EncLen = encodeULEB128(0, P);
+
+  auto WriteStringToResult = [&](size_t CompressedLen,
+                                 const std::string &InputStr) {
+    EncLen = encodeULEB128(CompressedLen, P);
     P += EncLen;
-    Result.append(reinterpret_cast<char *>(&Header[0]), P - &Header[0]);
-    Result += UncompressedNameStrings;
+    char *HeaderStr = reinterpret_cast<char *>(&Header[0]);
+    unsigned HeaderLen = P - &Header[0];
+    Result.append(HeaderStr, HeaderLen);
+    Result += InputStr;
     return 0;
-  }
+  };
+
+  if (!doCompression)
+    return WriteStringToResult(0, UncompressedNameStrings);
+
   SmallVector<char, 128> CompressedNameStrings;
   zlib::Status Success =
       zlib::compress(StringRef(UncompressedNameStrings), CompressedNameStrings,
                      zlib::BestSizeCompression);
-  assert(Success == zlib::StatusOK);
+
   if (Success != zlib::StatusOK)
     return 1;
-  EncLen = encodeULEB128(CompressedNameStrings.size(), P);
-  P += EncLen;
-  Result.append(reinterpret_cast<char *>(&Header[0]), P - &Header[0]);
-  Result +=
-      std::string(CompressedNameStrings.data(), CompressedNameStrings.size());
-  return 0;
+
+  return WriteStringToResult(
+      CompressedNameStrings.size(),
+      std::string(CompressedNameStrings.data(), CompressedNameStrings.size()));
+}
+
+StringRef getPGOFuncNameInitializer(GlobalVariable *NameVar) {
+  auto *Arr = cast<ConstantDataArray>(NameVar->getInitializer());
+  StringRef NameStr =
+      Arr->isCString() ? Arr->getAsCString() : Arr->getAsString();
+  return NameStr;
 }
 
 int collectPGOFuncNameStrings(const std::vector<GlobalVariable *> &NameVars,
                               std::string &Result) {
   std::vector<std::string> NameStrs;
   for (auto *NameVar : NameVars) {
-    auto *Arr = cast<ConstantDataArray>(NameVar->getInitializer());
-    StringRef NameStr =
-        Arr->isCString() ? Arr->getAsCString() : Arr->getAsString();
-    NameStrs.push_back(NameStr.str());
+    NameStrs.push_back(getPGOFuncNameInitializer(NameVar));
   }
   return collectPGOFuncNameStrings(NameStrs, zlib::isAvailable(), Result);
 }
@@ -237,20 +245,10 @@ int readPGOFuncNameStrings(StringRef NameStrings, InstrProfSymtab &Symtab) {
       P += UncompressedSize;
     }
     // Now parse the name strings.
-    size_t NameStart = 0;
-    bool isLast = false;
-    do {
-      size_t NameStop = NameStrings.find(' ', NameStart);
-      if (NameStop == StringRef::npos)
-        return 1;
-      if (NameStop == NameStrings.size() - 1)
-        isLast = true;
-      StringRef Name = NameStrings.substr(NameStart, NameStop - NameStart);
+    SmallVector<StringRef, 0> Names;
+    NameStrings.split(Names, ' ');
+    for (StringRef &Name : Names)
       Symtab.addFuncName(Name);
-      if (isLast)
-        break;
-      NameStart = NameStop + 1;
-    } while (true);
 
     while (P < EndP && *P == 0)
       P++;
@@ -259,9 +257,8 @@ int readPGOFuncNameStrings(StringRef NameStrings, InstrProfSymtab &Symtab) {
   return 0;
 }
 
-instrprof_error
-InstrProfValueSiteRecord::mergeValueData(InstrProfValueSiteRecord &Input,
-                                         uint64_t Weight) {
+instrprof_error InstrProfValueSiteRecord::merge(InstrProfValueSiteRecord &Input,
+                                                uint64_t Weight) {
   this->sortByTargetValues();
   Input.sortByTargetValues();
   auto I = ValueData.begin();
@@ -272,14 +269,8 @@ InstrProfValueSiteRecord::mergeValueData(InstrProfValueSiteRecord &Input,
     while (I != IE && I->Value < J->Value)
       ++I;
     if (I != IE && I->Value == J->Value) {
-      uint64_t JCount = J->Count;
       bool Overflowed;
-      if (Weight > 1) {
-        JCount = SaturatingMultiply(JCount, Weight, &Overflowed);
-        if (Overflowed)
-          Result = instrprof_error::counter_overflow;
-      }
-      I->Count = SaturatingAdd(I->Count, JCount, &Overflowed);
+      I->Count = SaturatingMultiplyAdd(J->Count, Weight, I->Count, &Overflowed);
       if (Overflowed)
         Result = instrprof_error::counter_overflow;
       ++I;
@@ -290,6 +281,17 @@ InstrProfValueSiteRecord::mergeValueData(InstrProfValueSiteRecord &Input,
   return Result;
 }
 
+instrprof_error InstrProfValueSiteRecord::scale(uint64_t Weight) {
+  instrprof_error Result = instrprof_error::success;
+  for (auto I = ValueData.begin(), IE = ValueData.end(); I != IE; ++I) {
+    bool Overflowed;
+    I->Count = SaturatingMultiply(I->Count, Weight, &Overflowed);
+    if (Overflowed)
+      Result = instrprof_error::counter_overflow;
+  }
+  return Result;
+}
+
 // Merge Value Profile data from Src record to this record for ValueKind.
 // Scale merged value counts by \p Weight.
 instrprof_error InstrProfRecord::mergeValueProfData(uint32_t ValueKind,
@@ -305,8 +307,7 @@ instrprof_error InstrProfRecord::mergeValueProfData(uint32_t ValueKind,
       Src.getValueSitesForKind(ValueKind);
   instrprof_error Result = instrprof_error::success;
   for (uint32_t I = 0; I < ThisNumValueSites; I++)
-    MergeResult(Result,
-                ThisSiteRecords[I].mergeValueData(OtherSiteRecords[I], Weight));
+    MergeResult(Result, ThisSiteRecords[I].merge(OtherSiteRecords[I], Weight));
   return Result;
 }
 
@@ -321,13 +322,8 @@ instrprof_error InstrProfRecord::merge(InstrProfRecord &Other,
 
   for (size_t I = 0, E = Other.Counts.size(); I < E; ++I) {
     bool Overflowed;
-    uint64_t OtherCount = Other.Counts[I];
-    if (Weight > 1) {
-      OtherCount = SaturatingMultiply(OtherCount, Weight, &Overflowed);
-      if (Overflowed)
-        Result = instrprof_error::counter_overflow;
-    }
-    Counts[I] = SaturatingAdd(Counts[I], OtherCount, &Overflowed);
+    Counts[I] =
+        SaturatingMultiplyAdd(Other.Counts[I], Weight, Counts[I], &Overflowed);
     if (Overflowed)
       Result = instrprof_error::counter_overflow;
   }
@@ -338,6 +334,32 @@ instrprof_error InstrProfRecord::merge(InstrProfRecord &Other,
   return Result;
 }
 
+instrprof_error InstrProfRecord::scaleValueProfData(uint32_t ValueKind,
+                                                    uint64_t Weight) {
+  uint32_t ThisNumValueSites = getNumValueSites(ValueKind);
+  std::vector<InstrProfValueSiteRecord> &ThisSiteRecords =
+      getValueSitesForKind(ValueKind);
+  instrprof_error Result = instrprof_error::success;
+  for (uint32_t I = 0; I < ThisNumValueSites; I++)
+    MergeResult(Result, ThisSiteRecords[I].scale(Weight));
+  return Result;
+}
+
+instrprof_error InstrProfRecord::scale(uint64_t Weight) {
+  instrprof_error Result = instrprof_error::success;
+  for (auto &Count : this->Counts) {
+    bool Overflowed;
+    Count = SaturatingMultiply(Count, Weight, &Overflowed);
+    if (Overflowed && Result == instrprof_error::success) {
+      Result = instrprof_error::counter_overflow;
+    }
+  }
+  for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind)
+    MergeResult(Result, scaleValueProfData(Kind, Weight));
+
+  return Result;
+}
+
 // Map indirect call target name hash to name string.
 uint64_t InstrProfRecord::remapValue(uint64_t Value, uint32_t ValueKind,
                                      ValueMapType *ValueMap) {