Batch
authorKittipat Virochsiri <kittipat@fb.com>
Fri, 2 Aug 2013 20:25:35 +0000 (13:25 -0700)
committerSara Golemon <sgolemon@fb.com>
Wed, 28 Aug 2013 21:30:11 +0000 (14:30 -0700)
Summary: Convert stream of elements to stream of fixed-sized vectors.

Test Plan: unit tests

Reviewed By: tjackson@fb.com

FB internal diff: D912290

folly/experimental/Gen-inl.h
folly/experimental/Gen.h
folly/experimental/test/GenTest.cpp

index 9a7d568f8d9ae9be102dc195c3724e8cec134af3..cc8fd5c59213b16afdd541fb97effff855db6275 100644 (file)
@@ -1153,6 +1153,83 @@ class Distinct : public Operator<Distinct<Selector>> {
   }
 };
 
+/**
+ * Batch - For producing fixed-size batches of each value from a source.
+ *
+ * This type is usually used through the 'batch' helper function:
+ *
+ *   auto batchSums
+ *     = seq(1, 10)
+ *     | batch(3)
+ *     | map([](const std::vector<int>& batch) {
+ *         return from(batch) | sum;
+ *       })
+ *     | as<vector>();
+ */
+class Batch : public Operator<Batch> {
+  size_t batchSize_;
+ public:
+  explicit Batch(size_t batchSize)
+    : batchSize_(batchSize) {
+    if (batchSize_ == 0) {
+      throw std::invalid_argument("Batch size must be non-zero!");
+    }
+  }
+
+  template<class Value,
+           class Source,
+           class StorageType = typename std::decay<Value>::type,
+           class VectorType = std::vector<StorageType>>
+  class Generator :
+      public GenImpl<VectorType&,
+                     Generator<Value, Source, StorageType, VectorType>> {
+    Source source_;
+    size_t batchSize_;
+  public:
+    explicit Generator(Source source, size_t batchSize)
+      : source_(std::move(source))
+      , batchSize_(batchSize) {}
+
+    template<class Handler>
+    bool apply(Handler&& handler) const {
+      VectorType batch_;
+      batch_.reserve(batchSize_);
+      bool shouldContinue = source_.apply([&](Value value) -> bool {
+          batch_.push_back(std::forward<Value>(value));
+          if (batch_.size() == batchSize_) {
+            bool needMore = handler(batch_);
+            batch_.clear();
+            return needMore;
+          }
+          // Always need more if the handler is not called.
+          return true;
+        });
+      // Flush everything, if and only if `handler` hasn't returned false.
+      if (shouldContinue && !batch_.empty()) {
+        shouldContinue = handler(batch_);
+        batch_.clear();
+      }
+      return shouldContinue;
+    }
+
+    static constexpr bool infinite = Source::infinite;
+  };
+
+  template<class Source,
+           class Value,
+           class Gen = Generator<Value, Source>>
+  Gen compose(GenImpl<Value, Source>&& source) const {
+    return Gen(std::move(source.self()), batchSize_);
+  }
+
+  template<class Source,
+           class Value,
+           class Gen = Generator<Value, Source>>
+  Gen compose(const GenImpl<Value, Source>& source) const {
+    return Gen(source.self(), batchSize_);
+  }
+};
+
 /**
  * Composed - For building up a pipeline of operations to perform, absent any
  * particular source generator. Useful for building up custom pipelines.
@@ -1994,6 +2071,10 @@ inline detail::Skip skip(size_t count) {
   return detail::Skip(count);
 }
 
+inline detail::Batch batch(size_t batchSize) {
+  return detail::Batch(batchSize);
+}
+
 }} //folly::gen
 
 #pragma GCC diagnostic pop
index 31c1cedc219a88b0415b1e14015161ee145f25fe..fa1ec1b2981e73128c850931080c7eee5d50f554 100644 (file)
@@ -310,6 +310,8 @@ class RangeConcat;
 
 class Cycle;
 
+class Batch;
+
 /*
  * Sinks
  */
index 9ab1aa57a8f561323cd8448022dc586348b1accc..a07265552f277c7acfdc0be84a3d1cba90685ef9 100644 (file)
@@ -1283,6 +1283,32 @@ TEST(Gen, Guard) {
                runtime_error);
 }
 
+TEST(Gen, Batch) {
+  EXPECT_EQ((vector<vector<int>> { {1} }),
+            seq(1, 1) | batch(5) | as<vector>());
+  EXPECT_EQ((vector<vector<int>> { {1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11} }),
+            seq(1, 11) | batch(3) | as<vector>());
+  EXPECT_THROW(seq(1, 1) | batch(0) | as<vector>(),
+               std::invalid_argument);
+}
+
+TEST(Gen, BatchMove) {
+  auto expected = vector<vector<int>>{ {0, 1}, {2, 3}, {4} };
+  auto actual =
+      seq(0, 4)
+    | mapped([](int i) { return std::unique_ptr<int>(new int(i)); })
+    | batch(2)
+    | mapped([](std::vector<std::unique_ptr<int>>& pVector) {
+        std::vector<int> iVector;
+        for (const auto& p : pVector) {
+          iVector.push_back(*p);
+        };
+        return iVector;
+      })
+    | as<vector>();
+  EXPECT_EQ(expected, actual);
+}
+
 int main(int argc, char *argv[]) {
   testing::InitGoogleTest(&argc, argv);
   google::ParseCommandLineFlags(&argc, &argv, true);