gen::sample
authorMike Curtiss <mcurtiss@fb.com>
Tue, 14 May 2013 05:05:06 +0000 (22:05 -0700)
committerSara Golemon <sgolemon@fb.com>
Mon, 20 May 2013 18:01:27 +0000 (11:01 -0700)
Summary:
Take a random sample of size N from a range.  Clients
can also pass in a custom random number generator.

Test Plan: Added test and benchmark.

Reviewed By: tjackson@fb.com

FB internal diff: D811260

folly/experimental/Gen-inl.h
folly/experimental/Gen.h
folly/experimental/test/GenBenchmark.cpp
folly/experimental/test/GenTest.cpp

index 73c1081686f1eda19e279f49597a7f8e8766a750..6a90d26651ef8385cd07865642c01bc6149e221c 100644 (file)
@@ -156,8 +156,10 @@ class GenImpl : public FBounded<Self> {
 
   /**
    * apply() - Send all values produced by this generator to given
-   * handler until it returns false. Returns true if the false iff the handler
-   * returned false.
+   * handler until the handler returns false. Returns true until the handler
+   * returns false. GOTCHA: It should return true even if it completes (without
+   * the handler returning false), as 'Chain' uses the return value of apply
+   * to determine if it should process the second object in its chain.
    */
   template<class Handler>
   bool apply(Handler&& handler) const;
@@ -773,6 +775,84 @@ class Take : public Operator<Take> {
   }
 };
 
+/**
+ * Sample - For taking a random sample of N elements from a sequence
+ * (without replacement).
+ */
+template<class Random>
+class Sample : public Operator<Sample<Random>> {
+  size_t count_;
+  Random rng_;
+ public:
+  explicit Sample(size_t count, Random rng)
+    : count_(count), rng_(std::move(rng)) {}
+
+  template<class Value,
+           class Source,
+           class Rand,
+           class StorageType = typename std::decay<Value>::type>
+  class Generator :
+          public GenImpl<StorageType&&,
+                         Generator<Value, Source, Rand, StorageType>> {
+    static_assert(!Source::infinite, "Cannot sample infinite source!");
+    // It's too easy to bite ourselves if random generator is only 16-bit
+    static_assert(Random::max() >= std::numeric_limits<int32_t>::max() - 1,
+                  "Random number generator must support big values");
+    Source source_;
+    size_t count_;
+    mutable Rand rng_;
+  public:
+    explicit Generator(Source source, size_t count, Random rng)
+      : source_(std::move(source)) , count_(count), rng_(std::move(rng)) {}
+
+    template<class Handler>
+    bool apply(Handler&& handler) const {
+      if (count_ == 0) { return false; }
+      std::vector<StorageType> v;
+      v.reserve(count_);
+      // use reservoir sampling to give each source value an equal chance
+      // of appearing in our output.
+      size_t n = 1;
+      source_.foreach([&](Value value) -> void {
+          if (v.size() < count_) {
+            v.push_back(std::forward<Value>(value));
+          } else {
+            // alternatively, we could create a std::uniform_int_distribution
+            // instead of using modulus, but benchmarks show this has
+            // substantial overhead.
+            size_t index = rng_() % n;
+            if (index < v.size()) {
+              v[index] = std::forward<Value>(value);
+            }
+          }
+          ++n;
+        });
+
+      // output is unsorted!
+      for (auto& val: v) {
+        if (!handler(std::move(val))) {
+          return false;
+        }
+      }
+      return true;
+    }
+  };
+
+  template<class Source,
+           class Value,
+           class Gen = Generator<Value, Source, Random>>
+  Gen compose(GenImpl<Value, Source>&& source) const {
+    return Gen(std::move(source.self()), count_, rng_);
+  }
+
+  template<class Source,
+           class Value,
+           class Gen = Generator<Value, Source, Random>>
+  Gen compose(const GenImpl<Value, Source>& source) const {
+    return Gen(source.self(), count_, rng_);
+  }
+};
+
 /**
  * Skip - For skipping N items from the beginning of a source generator.
  *
@@ -1706,6 +1786,11 @@ inline detail::Take take(size_t count) {
   return detail::Take(count);
 }
 
+template<class Random = std::default_random_engine>
+inline detail::Sample<Random> sample(size_t count, Random rng = Random()) {
+  return detail::Sample<Random>(count, std::move(rng));
+}
+
 inline detail::Skip skip(size_t count) {
   return detail::Skip(count);
 }
index 0946c0f311381822a4fba80989bcf8b42332046f..130cc7732c60b9a2cff336a5d79e64fa09994821 100644 (file)
@@ -21,6 +21,7 @@
 #include <type_traits>
 #include <utility>
 #include <algorithm>
+#include <random>
 #include <vector>
 #include <unordered_set>
 
@@ -210,6 +211,9 @@ class Until;
 
 class Take;
 
+template<class Rand>
+class Sample;
+
 class Skip;
 
 template<class Selector, class Comparer = Less>
index d9949117f21d1debdc903cea43638bd5917e9def..cb33e0004c22dbf68bfdb76cfcfe9d213e31c67a 100644 (file)
@@ -280,6 +280,17 @@ BENCHMARK_RELATIVE(Composed_GenRegular, iters) {
 
 BENCHMARK_DRAW_LINE()
 
+BENCHMARK(Sample, iters) {
+  size_t s = 0;
+  while (iters--) {
+    auto sampler = seq(1, 10 * 1000 * 1000) | sample(1000);
+    s += (sampler | sum);
+  }
+  folly::doNotOptimizeAway(s);
+}
+
+BENCHMARK_DRAW_LINE()
+
 namespace {
 
 const char* const kLine = "The quick brown fox jumped over the lazy dog.\n";
index d5a05ea5cebe0d2c309ac86b86b0862ee9e2ba19..33dc7cd476f97e8563b7c840d0758bc6bd25d660 100644 (file)
@@ -17,6 +17,7 @@
 #include <glog/logging.h>
 #include <gtest/gtest.h>
 #include <iostream>
+#include <random>
 #include <set>
 #include <vector>
 #include "folly/experimental/Gen.h"
@@ -173,6 +174,40 @@ TEST(Gen, Take) {
   EXPECT_EQ(expected, actual);
 }
 
+TEST(Gen, Sample) {
+  std::mt19937 rnd(42);
+
+  auto sampler =
+      seq(1, 100)
+    | sample(50, rnd);
+  std::unordered_map<int,int> hits;
+  const int kNumIters = 80;
+  for (int i = 0; i < kNumIters; i++) {
+    auto vec = sampler | as<vector<int>>();
+    EXPECT_EQ(vec.size(), 50);
+    auto uniq = fromConst(vec) | as<set<int>>();
+    EXPECT_EQ(uniq.size(), vec.size());  // sampling without replacement
+    for (auto v: vec) {
+      ++hits[v];
+    }
+  }
+
+  // In 80 separate samples of our range, we should have seen every value
+  // at least once and no value all 80 times. (The odds of either of those
+  // events is 1/2^80).
+  EXPECT_EQ(hits.size(), 100);
+  for (auto hit: hits) {
+    EXPECT_GT(hit.second, 0);
+    EXPECT_LT(hit.second, kNumIters);
+  }
+
+  auto small =
+      seq(1, 5)
+    | sample(10);
+  EXPECT_EQ((small | sum), 15);
+  EXPECT_EQ((small | take(3) | count), 3);
+}
+
 TEST(Gen, Skip) {
   auto gen =
       seq(1, 1000)
@@ -672,6 +707,11 @@ TEST(Gen, DynamicObject) {
   EXPECT_EQ(dynamic(6), from(obj.items()) | get<1>() | sum);
 }
 
+TEST(Gen, Collect) {
+  auto s = from({7, 6, 5, 4, 3}) | as<set<int>>();
+  EXPECT_EQ(s.size(), 5);
+}
+
 TEST(StringGen, EmptySplit) {
   auto collect = eachTo<std::string>() | as<vector>();
   {