gen::sample
[folly.git] / folly / experimental / Gen-inl.h
index 73c1081686f1eda19e279f49597a7f8e8766a750..6a90d26651ef8385cd07865642c01bc6149e221c 100644 (file)
@@ -156,8 +156,10 @@ class GenImpl : public FBounded<Self> {
 
   /**
    * apply() - Send all values produced by this generator to given
-   * handler until it returns false. Returns true if the false iff the handler
-   * returned false.
+   * handler until the handler returns false. Returns true until the handler
+   * returns false. GOTCHA: It should return true even if it completes (without
+   * the handler returning false), as 'Chain' uses the return value of apply
+   * to determine if it should process the second object in its chain.
    */
   template<class Handler>
   bool apply(Handler&& handler) const;
@@ -773,6 +775,84 @@ class Take : public Operator<Take> {
   }
 };
 
+/**
+ * Sample - For taking a random sample of N elements from a sequence
+ * (without replacement).
+ */
+template<class Random>
+class Sample : public Operator<Sample<Random>> {
+  size_t count_;
+  Random rng_;
+ public:
+  explicit Sample(size_t count, Random rng)
+    : count_(count), rng_(std::move(rng)) {}
+
+  template<class Value,
+           class Source,
+           class Rand,
+           class StorageType = typename std::decay<Value>::type>
+  class Generator :
+          public GenImpl<StorageType&&,
+                         Generator<Value, Source, Rand, StorageType>> {
+    static_assert(!Source::infinite, "Cannot sample infinite source!");
+    // It's too easy to bite ourselves if random generator is only 16-bit
+    static_assert(Random::max() >= std::numeric_limits<int32_t>::max() - 1,
+                  "Random number generator must support big values");
+    Source source_;
+    size_t count_;
+    mutable Rand rng_;
+  public:
+    explicit Generator(Source source, size_t count, Random rng)
+      : source_(std::move(source)) , count_(count), rng_(std::move(rng)) {}
+
+    template<class Handler>
+    bool apply(Handler&& handler) const {
+      if (count_ == 0) { return false; }
+      std::vector<StorageType> v;
+      v.reserve(count_);
+      // use reservoir sampling to give each source value an equal chance
+      // of appearing in our output.
+      size_t n = 1;
+      source_.foreach([&](Value value) -> void {
+          if (v.size() < count_) {
+            v.push_back(std::forward<Value>(value));
+          } else {
+            // alternatively, we could create a std::uniform_int_distribution
+            // instead of using modulus, but benchmarks show this has
+            // substantial overhead.
+            size_t index = rng_() % n;
+            if (index < v.size()) {
+              v[index] = std::forward<Value>(value);
+            }
+          }
+          ++n;
+        });
+
+      // output is unsorted!
+      for (auto& val: v) {
+        if (!handler(std::move(val))) {
+          return false;
+        }
+      }
+      return true;
+    }
+  };
+
+  template<class Source,
+           class Value,
+           class Gen = Generator<Value, Source, Random>>
+  Gen compose(GenImpl<Value, Source>&& source) const {
+    return Gen(std::move(source.self()), count_, rng_);
+  }
+
+  template<class Source,
+           class Value,
+           class Gen = Generator<Value, Source, Random>>
+  Gen compose(const GenImpl<Value, Source>& source) const {
+    return Gen(source.self(), count_, rng_);
+  }
+};
+
 /**
  * Skip - For skipping N items from the beginning of a source generator.
  *
@@ -1706,6 +1786,11 @@ inline detail::Take take(size_t count) {
   return detail::Take(count);
 }
 
+template<class Random = std::default_random_engine>
+inline detail::Sample<Random> sample(size_t count, Random rng = Random()) {
+  return detail::Sample<Random>(count, std::move(rng));
+}
+
 inline detail::Skip skip(size_t count) {
   return detail::Skip(count);
 }