pass RNG by reference so state is updated on each call
authorChristopher Small <cas@fb.com>
Tue, 27 Dec 2016 00:45:14 +0000 (16:45 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 27 Dec 2016 00:47:55 +0000 (16:47 -0800)
Summary: folly::Random was taking the RNG by value (not reference) so it was not updating the RNG's state on each invocation -- so the RNG would not advance to the next value in the sequence.

Reviewed By: yfeldblum, nbronson

Differential Revision: D4362999

fbshipit-source-id: f93fc11911b92e230ac0cc2406151474d15f85af

folly/Random.h
folly/test/RandomTest.cpp

index 0e52772aa5b746970060ab07d7d2ae3ff83ca4dd..3cb553da66d5e9d9ce34d7c5f2be57b9fdfffc80 100644 (file)
@@ -133,25 +133,22 @@ class Random {
    * Returns a random uint32_t
    */
   static uint32_t rand32() {
-    ThreadLocalPRNG prng;
-    return rand32(prng);
+    return rand32(ThreadLocalPRNG());
   }
 
   /**
    * Returns a random uint32_t given a specific RNG
    */
   template <class RNG, class /* EnableIf */ = ValidRNG<RNG>>
-  static uint32_t rand32(RNG rng) {
-    uint32_t r = rng.operator()();
-    return r;
+  static uint32_t rand32(RNG&& rng) {
+    return rng();
   }
 
   /**
    * Returns a random uint32_t in [0, max). If max == 0, returns 0.
    */
   static uint32_t rand32(uint32_t max) {
-    ThreadLocalPRNG prng;
-    return rand32(max, prng);
+    return rand32(0, max, ThreadLocalPRNG());
   }
 
   /**
@@ -159,84 +156,123 @@ class Random {
    * If max == 0, returns 0.
    */
   template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
-  static uint32_t rand32(uint32_t max, RNG rng = RNG()) {
-    if (max == 0) {
-      return 0;
-    }
-
-    return std::uniform_int_distribution<uint32_t>(0, max - 1)(rng);
+  static uint32_t rand32(uint32_t max, RNG&& rng) {
+    return rand32(0, max, rng);
   }
 
   /**
    * Returns a random uint32_t in [min, max). If min == max, returns 0.
    */
+  static uint32_t rand32(uint32_t min, uint32_t max) {
+    return rand32(min, max, ThreadLocalPRNG());
+  }
+
+  /**
+   * Returns a random uint32_t in [min, max) given a specific RNG.
+   * If min == max, returns 0.
+   */
   template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
-  static uint32_t rand32(uint32_t min, uint32_t max, RNG rng = RNG()) {
+  static uint32_t rand32(uint32_t min, uint32_t max, RNG&& rng) {
     if (min == max) {
       return 0;
     }
-
     return std::uniform_int_distribution<uint32_t>(min, max - 1)(rng);
   }
 
+  /**
+   * Returns a random uint64_t
+   */
+  static uint64_t rand64() {
+    return rand64(ThreadLocalPRNG());
+  }
+
   /**
    * Returns a random uint64_t
    */
   template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
-  static uint64_t rand64(RNG rng = RNG()) {
-    return ((uint64_t) rng() << 32) | rng();
+  static uint64_t rand64(RNG&& rng) {
+    return ((uint64_t)rng() << 32) | rng();
+  }
+
+  /**
+   * Returns a random uint64_t in [0, max). If max == 0, returns 0.
+   */
+  static uint64_t rand64(uint64_t max) {
+    return rand64(0, max, ThreadLocalPRNG());
   }
 
   /**
    * Returns a random uint64_t in [0, max). If max == 0, returns 0.
    */
   template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
-  static uint64_t rand64(uint64_t max, RNG rng = RNG()) {
-    if (max == 0) {
-      return 0;
-    }
+  static uint64_t rand64(uint64_t max, RNG&& rng) {
+    return rand64(0, max, rng);
+  }
 
-    return std::uniform_int_distribution<uint64_t>(0, max - 1)(rng);
+  /**
+   * Returns a random uint64_t in [min, max). If min == max, returns 0.
+   */
+  static uint64_t rand64(uint64_t min, uint64_t max) {
+    return rand64(min, max, ThreadLocalPRNG());
   }
 
   /**
    * Returns a random uint64_t in [min, max). If min == max, returns 0.
    */
   template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
-  static uint64_t rand64(uint64_t min, uint64_t max, RNG rng = RNG()) {
+  static uint64_t rand64(uint64_t min, uint64_t max, RNG&& rng) {
     if (min == max) {
       return 0;
     }
-
     return std::uniform_int_distribution<uint64_t>(min, max - 1)(rng);
   }
 
+  /**
+   * Returns true 1/n of the time. If n == 0, always returns false
+   */
+  static bool oneIn(uint32_t n) {
+    return oneIn(n, ThreadLocalPRNG());
+  }
+
   /**
    * Returns true 1/n of the time. If n == 0, always returns false
    */
   template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
-  static bool oneIn(uint32_t n, ValidRNG<RNG> rng = RNG()) {
+  static bool oneIn(uint32_t n, RNG&& rng) {
     if (n == 0) {
       return false;
     }
+    return rand32(0, n, rng) == 0;
+  }
 
-    return rand32(n, rng) == 0;
+  /**
+   * Returns a double in [0, 1)
+   */
+  static double randDouble01() {
+    return randDouble01(ThreadLocalPRNG());
   }
 
   /**
    * Returns a double in [0, 1)
    */
   template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
-  static double randDouble01(RNG rng = RNG()) {
-    return std::generate_canonical<double, std::numeric_limits<double>::digits>
-      (rng);
+  static double randDouble01(RNG&& rng) {
+    return std::generate_canonical<double, std::numeric_limits<double>::digits>(
+        rng);
+  }
+
+  /**
+    * Returns a double in [min, max), if min == max, returns 0.
+    */
+  static double randDouble(double min, double max) {
+    return randDouble(min, max, ThreadLocalPRNG());
   }
 
   /**
     * Returns a double in [min, max), if min == max, returns 0.
     */
   template <class RNG = ThreadLocalPRNG, class /* EnableIf */ = ValidRNG<RNG>>
-  static double randDouble(double min, double max, RNG rng = RNG()) {
+  static double randDouble(double min, double max, RNG&& rng) {
     if (std::fabs(max - min) < std::numeric_limits<double>::epsilon()) {
       return 0;
     }
index 7d175b6e3bb2fe79ab7deed7ef1438e767009130..da906f69bc42f24cf2a49b86dd4f64c582597739 100644 (file)
@@ -22,6 +22,7 @@
 #include <thread>
 #include <vector>
 #include <random>
+#include <unordered_set>
 
 #include <folly/portability/GTest.h>
 
@@ -94,3 +95,46 @@ TEST(Random, MultiThreaded) {
     EXPECT_LT(seeds[i], seeds[i+1]);
   }
 }
+
+TEST(Random, sanity) {
+  // edge cases
+  EXPECT_EQ(folly::Random::rand32(0), 0);
+  EXPECT_EQ(folly::Random::rand32(12, 12), 0);
+  EXPECT_EQ(folly::Random::rand64(0), 0);
+  EXPECT_EQ(folly::Random::rand64(12, 12), 0);
+
+  // 32-bit repeatability, uniqueness
+  constexpr int kTestSize = 1000;
+  {
+    std::vector<uint32_t> vals;
+    folly::Random::DefaultGenerator rng;
+    rng.seed(0xdeadbeef);
+    for (int i = 0; i < kTestSize; ++i) {
+      vals.push_back(folly::Random::rand32(rng));
+    }
+    rng.seed(0xdeadbeef);
+    for (int i = 0; i < kTestSize; ++i) {
+      EXPECT_EQ(vals[i], folly::Random::rand32(rng));
+    }
+    EXPECT_EQ(
+        vals.size(),
+        std::unordered_set<uint32_t>(vals.begin(), vals.end()).size());
+  }
+
+  // 64-bit repeatability, uniqueness
+  {
+    std::vector<uint64_t> vals;
+    folly::Random::DefaultGenerator rng;
+    rng.seed(0xdeadbeef);
+    for (int i = 0; i < kTestSize; ++i) {
+      vals.push_back(folly::Random::rand64(rng));
+    }
+    rng.seed(0xdeadbeef);
+    for (int i = 0; i < kTestSize; ++i) {
+      EXPECT_EQ(vals[i], folly::Random::rand64(rng));
+    }
+    EXPECT_EQ(
+        vals.size(),
+        std::unordered_set<uint32_t>(vals.begin(), vals.end()).size());
+  }
+}