InstrProf: Add unit tests for the profile reader and writer
[oota-llvm.git] / lib / ProfileData / InstrProfWriter.cpp
index 19b6890b36bb7565f84446a6df1cc4db706f569b..2c72efe3faa92fcd79667e4fef38afdb449f8b07 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ProfileData/InstrProfWriter.h"
+#include "InstrProfIndexed.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/EndianStream.h"
 #include "llvm/Support/OnDiskHashTable.h"
 
-#include "InstrProfIndexed.h"
-
 using namespace llvm;
 
 namespace {
@@ -27,8 +26,8 @@ public:
   typedef StringRef key_type;
   typedef StringRef key_type_ref;
 
-  typedef InstrProfWriter::CounterData data_type;
-  typedef const InstrProfWriter::CounterData &data_type_ref;
+  typedef const InstrProfWriter::CounterData *const data_type;
+  typedef const InstrProfWriter::CounterData *const data_type_ref;
 
   typedef uint64_t hash_value_type;
   typedef uint64_t offset_type;
@@ -45,7 +44,9 @@ public:
     offset_type N = K.size();
     LE.write<offset_type>(N);
 
-    offset_type M = (1 + V.Counts.size()) * sizeof(uint64_t);
+    offset_type M = 0;
+    for (const auto &Counts : *V)
+      M += (2 + Counts.second.size()) * sizeof(uint64_t);
     LE.write<offset_type>(M);
 
     return std::make_pair(N, M);
@@ -59,51 +60,58 @@ public:
                        offset_type) {
     using namespace llvm::support;
     endian::Writer<little> LE(Out);
-    LE.write<uint64_t>(V.Hash);
-    for (uint64_t I : V.Counts)
-      LE.write<uint64_t>(I);
+
+    for (const auto &Counts : *V) {
+      LE.write<uint64_t>(Counts.first);
+      LE.write<uint64_t>(Counts.second.size());
+      for (uint64_t I : Counts.second)
+        LE.write<uint64_t>(I);
+    }
   }
 };
 }
 
-error_code InstrProfWriter::addFunctionCounts(StringRef FunctionName,
-                                              uint64_t FunctionHash,
-                                              ArrayRef<uint64_t> Counters) {
-  auto Where = FunctionData.find(FunctionName);
-  if (Where == FunctionData.end()) {
-    // If this is the first time we've seen this function, just add it.
-    auto &Data = FunctionData[FunctionName];
-    Data.Hash = FunctionHash;
-    Data.Counts = Counters;
-    return instrprof_error::success;;
+std::error_code
+InstrProfWriter::addFunctionCounts(StringRef FunctionName,
+                                   uint64_t FunctionHash,
+                                   ArrayRef<uint64_t> Counters) {
+  auto &CounterData = FunctionData[FunctionName];
+
+  auto Where = CounterData.find(FunctionHash);
+  if (Where == CounterData.end()) {
+    // We've never seen a function with this name and hash, add it.
+    CounterData[FunctionHash] = Counters;
+    // We keep track of the max function count as we go for simplicity.
+    if (Counters[0] > MaxFunctionCount)
+      MaxFunctionCount = Counters[0];
+    return instrprof_error::success;
   }
 
-  auto &Data = Where->getValue();
-  // We can only add to existing functions if they match, so we check the hash
-  // and number of counters.
-  if (Data.Hash != FunctionHash)
-    return instrprof_error::hash_mismatch;
-  if (Data.Counts.size() != Counters.size())
+  // We're updating a function we've seen before.
+  auto &FoundCounters = Where->second;
+  // If the number of counters doesn't match we either have bad data or a hash
+  // collision.
+  if (FoundCounters.size() != Counters.size())
     return instrprof_error::count_mismatch;
-  // These match, add up the counters.
+
   for (size_t I = 0, E = Counters.size(); I < E; ++I) {
-    if (Data.Counts[I] + Counters[I] < Data.Counts[I])
+    if (FoundCounters[I] + Counters[I] < FoundCounters[I])
       return instrprof_error::counter_overflow;
-    Data.Counts[I] += Counters[I];
+    FoundCounters[I] += Counters[I];
   }
+  // We keep track of the max function count as we go for simplicity.
+  if (FoundCounters[0] > MaxFunctionCount)
+    MaxFunctionCount = FoundCounters[0];
+
   return instrprof_error::success;
 }
 
-void InstrProfWriter::write(raw_fd_ostream &OS) {
+std::pair<uint64_t, uint64_t> InstrProfWriter::writeImpl(raw_ostream &OS) {
   OnDiskChainedHashTableGenerator<InstrProfRecordTrait> Generator;
-  uint64_t MaxFunctionCount = 0;
 
   // Populate the hash table generator.
-  for (const auto &I : FunctionData) {
-    Generator.insert(I.getKey(), I.getValue());
-    if (I.getValue().Counts[0] > MaxFunctionCount)
-      MaxFunctionCount = I.getValue().Counts[0];
-  }
+  for (const auto &I : FunctionData)
+    Generator.insert(I.getKey(), &I.getValue());
 
   using namespace llvm::support;
   endian::Writer<little> LE(OS);
@@ -120,7 +128,31 @@ void InstrProfWriter::write(raw_fd_ostream &OS) {
   // Write the hash table.
   uint64_t HashTableStart = Generator.Emit(OS);
 
+  return std::make_pair(HashTableStartLoc, HashTableStart);
+}
+
+void InstrProfWriter::write(raw_fd_ostream &OS) {
+  // Write the hash table.
+  auto TableStart = writeImpl(OS);
+
   // Go back and fill in the hash table start.
-  OS.seek(HashTableStartLoc);
-  LE.write<uint64_t>(HashTableStart);
+  using namespace support;
+  OS.seek(TableStart.first);
+  endian::Writer<little>(OS).write<uint64_t>(TableStart.second);
+}
+
+std::string InstrProfWriter::writeString() {
+  std::string Result;
+  llvm::raw_string_ostream OS(Result);
+  // Write the hash table.
+  auto TableStart = writeImpl(OS);
+  OS.flush();
+
+  // Go back and fill in the hash table start.
+  using namespace support;
+  uint64_t Bytes = endian::byte_swap<uint64_t, little>(TableStart.second);
+  Result.replace(TableStart.first, sizeof(uint64_t), (const char *)&Bytes,
+                 sizeof(uint64_t));
+
+  return Result;
 }