Random number generator utilities for folly
authorBen Maurer <bmaurer@fb.com>
Tue, 25 Feb 2014 02:37:24 +0000 (18:37 -0800)
committerDave Watson <davejwatson@fb.com>
Fri, 28 Feb 2014 22:01:19 +0000 (14:01 -0800)
Summary:
In looking at how people were using common/base/Random, I noticed
a number of issues with our current usage of random number generators

1) People would simply declare a RandomInt32 without seeding it. This
results in a predictable seed
2) We initialize a Mersenne Twister RNG from a single int32. This
causes us to have a more predictable starting sequence of numbers
3) People aren't consistently using thread-local RNGs
4) Many of the APIs lack consistency. For example random32 takes a
max parameter that is exclusive while uniformRandom32 uses inclusive
boundries

I'm hoping a better API can fix this. thread_prng implements the Generator
concept with result_type = int32. It isn't actually a random number generator,
but it uses a thread local to point to a real generator. An advantage
of this is that it can be used in existing APIs but that it doesn't expose
the implementation of the RNG via the header file. One thing that's a bit
weird about it is that if you copy the object across threads it could
cause an error.

The Random class provides utilities that take any type of random number
generator. This has the advantage of allowing the user to pass a RNG
meant for testing or a secure RNG based on /dev/random. Another advnatage
is if you're woried about the performance of TLS lookups, you can
cache a local thread_prng which memoizes the TLS lookup.

If available, we use a SIMD optimized MT API

Some open questions:

1) What functions should be in random
2) Should the default RNG be a 64 or 32 bit based RNG

Test Plan: Benchmark runs

Reviewed By: simpkins@fb.com

FB internal diff: D1181864

folly/Random.cpp
folly/Random.h
folly/test/RandomTest.cpp

index 2448645..3ba3411 100644 (file)
 #include <atomic>
 #include <unistd.h>
 #include <sys/time.h>
+#include <random>
+#include <array>
+
+#if __GNUC_PREREQ(4, 8)
+#include <ext/random>
+#define USE_SIMD_PRNG
+#endif
 
 namespace folly {
 
@@ -39,4 +46,39 @@ uint32_t randomNumberSeed() {
        + kPrime3 * static_cast<uint32_t>(tv.tv_usec);
 }
 
+
+folly::ThreadLocalPtr<ThreadLocalPRNG::LocalInstancePRNG>
+ThreadLocalPRNG::localInstance;
+
+class ThreadLocalPRNG::LocalInstancePRNG {
+#ifdef USE_SIMD_PRNG
+  typedef  __gnu_cxx::sfmt19937 RNG;
+#else
+  typedef std::mt19937 RNG;
+#endif
+
+  static RNG makeRng() {
+    std::array<int, RNG::state_size> seed_data;
+    std::random_device r;
+    std::generate_n(seed_data.data(), seed_data.size(), std::ref(r));
+    std::seed_seq seq(std::begin(seed_data), std::end(seed_data));
+    return RNG(seq);
+  }
+
+ public:
+  LocalInstancePRNG() : rng(std::move(makeRng())) {}
+
+  RNG rng;
+};
+
+ThreadLocalPRNG::LocalInstancePRNG* ThreadLocalPRNG::initLocal() {
+  auto ret = new LocalInstancePRNG;
+  localInstance.reset(ret);
+  return ret;
+}
+
+uint32_t ThreadLocalPRNG::getImpl(LocalInstancePRNG* local) {
+  return local->rng();
+}
+
 }
index daa8e1e..f587418 100644 (file)
@@ -18,6 +18,7 @@
 #define FOLLY_BASE_RANDOM_H_
 
 #include <stdint.h>
+#include "folly/ThreadLocal.h"
 
 namespace folly {
 
@@ -26,6 +27,136 @@ namespace folly {
  */
 uint32_t randomNumberSeed();
 
+class Random;
+
+/**
+ * A PRNG with one instance per thread. This PRNG uses a mersenne twister random
+ * number generator and is seeded from /dev/urandom. It should not be used for
+ * anything which requires security, only for statistical randomness.
+ *
+ * An instance of this class represents the current threads PRNG. This means
+ * copying an instance of this class across threads will result in corruption
+ *
+ * Most users will use the Random class which implicitly creates this class.
+ * However, if you are worried about performance, you can memoize the TLS
+ * lookups that get the per thread state by manually using this class:
+ *
+ * ThreadLocalPRNG rng = Random::threadLocalPRNG()
+ * for (...) {
+ *   Random::rand32(rng);
+ * }
+ */
+class ThreadLocalPRNG {
+ public:
+  typedef uint32_t result_type;
+
+  uint32_t operator()() {
+    // Using a static method allows the compiler to avoid allocating stack space
+    // for this class.
+    return getImpl(local_);
+  }
+
+  static constexpr result_type min() {
+    return std::numeric_limits<result_type>::min();
+  }
+  static constexpr result_type max() {
+    return std::numeric_limits<result_type>::max();
+  }
+  friend class Random;
+
+  ThreadLocalPRNG() {
+    local_ = localInstance.get();
+    if (!local_) {
+      local_ = initLocal();
+    }
+  }
+
+ private:
+  class LocalInstancePRNG;
+  static LocalInstancePRNG* initLocal();
+  static folly::ThreadLocalPtr<ThreadLocalPRNG::LocalInstancePRNG>
+    localInstance;
+
+  static result_type getImpl(LocalInstancePRNG* local);
+  LocalInstancePRNG* local_;
+};
+
+
+
+class Random {
+
+ private:
+  template<class RNG>
+  using ValidRNG = typename std::enable_if<
+   std::is_unsigned<typename std::result_of<RNG&()>::type>::value,
+   RNG>::type;
+
+ public:
+
+  /**
+   * Returns a random uint32_t
+   */
+  template<class RNG = ThreadLocalPRNG>
+  static uint32_t rand32(ValidRNG<RNG>  rrng = RNG()) {
+    uint32_t r = rrng.operator()();
+    return r;
+  }
+
+  /**
+   * Returns a random uint32_t in [0, max). If max == 0, returns 0.
+   */
+  template<class RNG = ThreadLocalPRNG>
+  static uint32_t rand32(uint32_t max, ValidRNG<RNG> rng = RNG()) {
+    if (max == 0) {
+      return 0;
+    }
+
+    return std::uniform_int_distribution<uint32_t>(0, max - 1)(rng);
+  }
+
+  /**
+   * Returns a random uint64_t
+   */
+  template<class RNG = ThreadLocalPRNG>
+  static uint64_t rand64(ValidRNG<RNG> rng = RNG()) {
+    return ((uint64_t) rng() << 32) | rng();
+  }
+
+  /**
+   * Returns a random uint64_t in [0, max). If max == 0, returns 0.
+   */
+  template<class RNG = ThreadLocalPRNG>
+  static uint64_t rand64(uint64_t max, ValidRNG<RNG> rng = RNG()) {
+    if (max == 0) {
+      return 0;
+    }
+
+    return std::uniform_int_distribution<uint64_t>(0, max - 1)(rng);
+  }
+
+  /**
+   * Returns true 1/n of the time. If n == 0, always returns false
+   */
+  template<class RNG = ThreadLocalPRNG>
+  static bool oneIn(uint32_t n, ValidRNG<RNG> rng = RNG()) {
+    if (n == 0) {
+      return false;
+    }
+
+    return rand32(n, rng) == 0;
+  }
+
+  /**
+   * Returns a double in [0, 1)
+   */
+  template<class RNG = ThreadLocalPRNG>
+  static double randDouble01(ValidRNG<RNG> rng = RNG()) {
+    return std::generate_canonical<double, std::numeric_limits<double>::digits>
+      (rng);
+  }
+
+};
+
 }
 
 #endif
index b23228a..09aff3c 100644 (file)
@@ -15,6 +15,9 @@
  */
 
 #include "folly/Random.h"
+#include "folly/Range.h"
+#include "folly/Benchmark.h"
+#include "folly/Foreach.h"
 
 #include <glog/logging.h>
 #include <gtest/gtest.h>
@@ -22,6 +25,7 @@
 #include <algorithm>
 #include <thread>
 #include <vector>
+#include <random>
 
 using namespace folly;
 
@@ -50,3 +54,54 @@ TEST(Random, MultiThreaded) {
     EXPECT_LT(seeds[i], seeds[i+1]);
   }
 }
+
+BENCHMARK(minstdrand, n) {
+  BenchmarkSuspender braces;
+  std::random_device rd;
+  std::minstd_rand rng(rd());
+
+  braces.dismiss();
+
+  FOR_EACH_RANGE (i, 0, n) {
+    doNotOptimizeAway(rng());
+  }
+}
+
+BENCHMARK(mt19937, n) {
+  BenchmarkSuspender braces;
+  std::random_device rd;
+  std::mt19937 rng(rd());
+
+  braces.dismiss();
+
+  FOR_EACH_RANGE (i, 0, n) {
+    doNotOptimizeAway(rng());
+  }
+}
+
+BENCHMARK(threadprng, n) {
+  BenchmarkSuspender braces;
+  ThreadLocalPRNG tprng;
+  tprng();
+
+  braces.dismiss();
+
+  FOR_EACH_RANGE (i, 0, n) {
+    doNotOptimizeAway(tprng());
+  }
+}
+
+BENCHMARK(RandomDouble) { doNotOptimizeAway(Random::randDouble01()); }
+BENCHMARK(Random32) { doNotOptimizeAway(Random::rand32()); }
+BENCHMARK(Random32Num) { doNotOptimizeAway(Random::rand32(100)); }
+BENCHMARK(Random64) { doNotOptimizeAway(Random::rand64()); }
+BENCHMARK(Random64Num) { doNotOptimizeAway(Random::rand64(100ul << 32)); }
+BENCHMARK(Random64OneIn) { doNotOptimizeAway(Random::oneIn(100)); }
+
+int main(int argc, char** argv) {
+  google::ParseCommandLineFlags(&argc, &argv, true);
+
+  runBenchmarks();
+
+  return 0;
+}