From e886504902fe5479f9478235bd2307c7053a154c Mon Sep 17 00:00:00 2001 From: Mike Curtiss Date: Mon, 13 May 2013 22:05:06 -0700 Subject: [PATCH] gen::sample Summary: Take a random sample of size N from a range. Clients can also pass in a custom random number generator. Test Plan: Added test and benchmark. Reviewed By: tjackson@fb.com FB internal diff: D811260 --- folly/experimental/Gen-inl.h | 89 +++++++++++++++++++++++- folly/experimental/Gen.h | 4 ++ folly/experimental/test/GenBenchmark.cpp | 11 +++ folly/experimental/test/GenTest.cpp | 40 +++++++++++ 4 files changed, 142 insertions(+), 2 deletions(-) diff --git a/folly/experimental/Gen-inl.h b/folly/experimental/Gen-inl.h index 73c10816..6a90d266 100644 --- a/folly/experimental/Gen-inl.h +++ b/folly/experimental/Gen-inl.h @@ -156,8 +156,10 @@ class GenImpl : public FBounded { /** * apply() - Send all values produced by this generator to given - * handler until it returns false. Returns true if the false iff the handler - * returned false. + * handler until the handler returns false. Returns true until the handler + * returns false. GOTCHA: It should return true even if it completes (without + * the handler returning false), as 'Chain' uses the return value of apply + * to determine if it should process the second object in its chain. */ template bool apply(Handler&& handler) const; @@ -773,6 +775,84 @@ class Take : public Operator { } }; +/** + * Sample - For taking a random sample of N elements from a sequence + * (without replacement). + */ +template +class Sample : public Operator> { + size_t count_; + Random rng_; + public: + explicit Sample(size_t count, Random rng) + : count_(count), rng_(std::move(rng)) {} + + template::type> + class Generator : + public GenImpl> { + static_assert(!Source::infinite, "Cannot sample infinite source!"); + // It's too easy to bite ourselves if random generator is only 16-bit + static_assert(Random::max() >= std::numeric_limits::max() - 1, + "Random number generator must support big values"); + Source source_; + size_t count_; + mutable Rand rng_; + public: + explicit Generator(Source source, size_t count, Random rng) + : source_(std::move(source)) , count_(count), rng_(std::move(rng)) {} + + template + bool apply(Handler&& handler) const { + if (count_ == 0) { return false; } + std::vector v; + v.reserve(count_); + // use reservoir sampling to give each source value an equal chance + // of appearing in our output. + size_t n = 1; + source_.foreach([&](Value value) -> void { + if (v.size() < count_) { + v.push_back(std::forward(value)); + } else { + // alternatively, we could create a std::uniform_int_distribution + // instead of using modulus, but benchmarks show this has + // substantial overhead. + size_t index = rng_() % n; + if (index < v.size()) { + v[index] = std::forward(value); + } + } + ++n; + }); + + // output is unsorted! + for (auto& val: v) { + if (!handler(std::move(val))) { + return false; + } + } + return true; + } + }; + + template> + Gen compose(GenImpl&& source) const { + return Gen(std::move(source.self()), count_, rng_); + } + + template> + Gen compose(const GenImpl& source) const { + return Gen(source.self(), count_, rng_); + } +}; + /** * Skip - For skipping N items from the beginning of a source generator. * @@ -1706,6 +1786,11 @@ inline detail::Take take(size_t count) { return detail::Take(count); } +template +inline detail::Sample sample(size_t count, Random rng = Random()) { + return detail::Sample(count, std::move(rng)); +} + inline detail::Skip skip(size_t count) { return detail::Skip(count); } diff --git a/folly/experimental/Gen.h b/folly/experimental/Gen.h index 0946c0f3..130cc773 100644 --- a/folly/experimental/Gen.h +++ b/folly/experimental/Gen.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -210,6 +211,9 @@ class Until; class Take; +template +class Sample; + class Skip; template diff --git a/folly/experimental/test/GenBenchmark.cpp b/folly/experimental/test/GenBenchmark.cpp index d9949117..cb33e000 100644 --- a/folly/experimental/test/GenBenchmark.cpp +++ b/folly/experimental/test/GenBenchmark.cpp @@ -280,6 +280,17 @@ BENCHMARK_RELATIVE(Composed_GenRegular, iters) { BENCHMARK_DRAW_LINE() +BENCHMARK(Sample, iters) { + size_t s = 0; + while (iters--) { + auto sampler = seq(1, 10 * 1000 * 1000) | sample(1000); + s += (sampler | sum); + } + folly::doNotOptimizeAway(s); +} + +BENCHMARK_DRAW_LINE() + namespace { const char* const kLine = "The quick brown fox jumped over the lazy dog.\n"; diff --git a/folly/experimental/test/GenTest.cpp b/folly/experimental/test/GenTest.cpp index d5a05ea5..33dc7cd4 100644 --- a/folly/experimental/test/GenTest.cpp +++ b/folly/experimental/test/GenTest.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include "folly/experimental/Gen.h" @@ -173,6 +174,40 @@ TEST(Gen, Take) { EXPECT_EQ(expected, actual); } +TEST(Gen, Sample) { + std::mt19937 rnd(42); + + auto sampler = + seq(1, 100) + | sample(50, rnd); + std::unordered_map hits; + const int kNumIters = 80; + for (int i = 0; i < kNumIters; i++) { + auto vec = sampler | as>(); + EXPECT_EQ(vec.size(), 50); + auto uniq = fromConst(vec) | as>(); + EXPECT_EQ(uniq.size(), vec.size()); // sampling without replacement + for (auto v: vec) { + ++hits[v]; + } + } + + // In 80 separate samples of our range, we should have seen every value + // at least once and no value all 80 times. (The odds of either of those + // events is 1/2^80). + EXPECT_EQ(hits.size(), 100); + for (auto hit: hits) { + EXPECT_GT(hit.second, 0); + EXPECT_LT(hit.second, kNumIters); + } + + auto small = + seq(1, 5) + | sample(10); + EXPECT_EQ((small | sum), 15); + EXPECT_EQ((small | take(3) | count), 3); +} + TEST(Gen, Skip) { auto gen = seq(1, 1000) @@ -672,6 +707,11 @@ TEST(Gen, DynamicObject) { EXPECT_EQ(dynamic(6), from(obj.items()) | get<1>() | sum); } +TEST(Gen, Collect) { + auto s = from({7, 6, 5, 4, 3}) | as>(); + EXPECT_EQ(s.size(), 5); +} + TEST(StringGen, EmptySplit) { auto collect = eachTo() | as(); { -- 2.34.1