(Wangle) Implement collect* using mapSetCallback and shared_ptrs
authorHannes Roth <hannesr@fb.com>
Fri, 8 May 2015 23:34:32 +0000 (16:34 -0700)
committerPraveen Kumar Ramakrishnan <praveenr@fb.com>
Tue, 12 May 2015 00:02:48 +0000 (17:02 -0700)
Summary:
I figured it would make sense to implement all the collect* functions using a shared_ptr<Context>, instead of doing our manual reference counting and all that. Fulfilling the promise in the destructor seemed like the icing on the cake. Also saves some line of code.

Test Plan: Run all the tests.

Reviewed By: hans@fb.com

Subscribers: folly-diffs@, jsedgwick, yfeldblum, chalfant

FB internal diff: D2015320

Signature: t1:2015320:1431106133:ac3001b3696fc75230afe70908ed349102b02a45

folly/futures/Future-inl.h
folly/futures/Future.cpp
folly/futures/detail/Core.h

index 33b758ea2db126692601b62850a9a5ca9a2a035d..62f480072e68fae901219dc54b96c40b064b28d7 100644 (file)
@@ -531,22 +531,31 @@ inline Future<void> via(Executor* executor) {
   return makeFuture().via(executor);
 }
 
-// when (variadic)
+// mapSetCallback calls func(i, Try<T>) when every future completes
+
+template <class T, class InputIterator, class F>
+void mapSetCallback(InputIterator first, InputIterator last, F func) {
+  for (size_t i = 0; first != last; ++first, ++i) {
+    first->setCallback_([func, i](Try<T>&& t) {
+      func(i, std::move(t));
+    });
+  }
+}
+
+// collectAll (variadic)
 
 template <typename... Fs>
 typename detail::VariadicContext<
   typename std::decay<Fs>::type::value_type...>::type
 collectAll(Fs&&... fs) {
-  auto ctx =
-    new detail::VariadicContext<typename std::decay<Fs>::type::value_type...>();
-  ctx->total = sizeof...(fs);
-  auto f_saved = ctx->p.getFuture();
+  auto ctx = std::make_shared<detail::VariadicContext<
+    typename std::decay<Fs>::type::value_type...>>();
   detail::collectAllVariadicHelper(ctx,
     std::forward<typename std::decay<Fs>::type>(fs)...);
-  return f_saved;
+  return ctx->p.getFuture();
 }
 
-// when (iterator)
+// collectAll (iterator)
 
 template <class InputIterator>
 Future<
@@ -556,155 +565,87 @@ collectAll(InputIterator first, InputIterator last) {
   typedef
     typename std::iterator_traits<InputIterator>::value_type::value_type T;
 
-  if (first >= last) {
-    return makeFuture(std::vector<Try<T>>());
-  }
-  size_t n = std::distance(first, last);
-
-  auto ctx = new detail::WhenAllContext<T>();
-
-  ctx->results.resize(n);
-
-  auto f_saved = ctx->p.getFuture();
-
-  for (size_t i = 0; first != last; ++first, ++i) {
-     assert(i < n);
-     auto& f = *first;
-     f.setCallback_([ctx, i, n](Try<T> t) {
-       ctx->results[i] = std::move(t);
-       if (++ctx->count == n) {
-         ctx->p.setValue(std::move(ctx->results));
-         delete ctx;
-       }
-     });
-  }
+  struct CollectAllContext {
+    CollectAllContext(int n) : results(n) {}
+    ~CollectAllContext() {
+      p.setValue(std::move(results));
+    }
+    Promise<std::vector<Try<T>>> p;
+    std::vector<Try<T>> results;
+  };
 
-  return f_saved;
+  auto ctx = std::make_shared<CollectAllContext>(std::distance(first, last));
+  mapSetCallback<T>(first, last, [ctx](size_t i, Try<T>&& t) {
+    ctx->results[i] = std::move(t);
+  });
+  return ctx->p.getFuture();
 }
 
 namespace detail {
 
-template <class, class, typename = void> struct CollectContextHelper;
-
-template <class T, class VecT>
-struct CollectContextHelper<T, VecT,
-    typename std::enable_if<std::is_same<T, VecT>::value>::type> {
-  static inline std::vector<T>&& getResults(std::vector<VecT>& results) {
-    return std::move(results);
-  }
-};
-
-template <class T, class VecT>
-struct CollectContextHelper<T, VecT,
-    typename std::enable_if<!std::is_same<T, VecT>::value>::type> {
-  static inline std::vector<T> getResults(std::vector<VecT>& results) {
-    std::vector<T> finalResults;
-    finalResults.reserve(results.size());
-    for (auto& opt : results) {
-      finalResults.push_back(std::move(opt.value()));
-    }
-    return finalResults;
-  }
-};
-
 template <typename T>
 struct CollectContext {
-
-  typedef typename std::conditional<
-    std::is_default_constructible<T>::value,
-    T,
-    Optional<T>
-   >::type VecT;
-
-  explicit CollectContext(int n) : count(0), success_count(0), threw(false) {
-    results.resize(n);
-  }
-
-  Promise<std::vector<T>> p;
-  std::vector<VecT> results;
-  std::atomic<size_t> count, success_count;
-  std::atomic_bool threw;
-
-  typedef std::vector<T> result_type;
-
-  static inline Future<std::vector<T>> makeEmptyFuture() {
-    return makeFuture(std::vector<T>());
-  }
-
-  inline void setValue() {
-    p.setValue(CollectContextHelper<T, VecT>::getResults(results));
+  struct Nothing { explicit Nothing(int n) {} };
+
+  using Result = typename std::conditional<
+    std::is_void<T>::value,
+    void,
+    std::vector<T>>::type;
+
+  using InternalResult = typename std::conditional<
+    std::is_void<T>::value,
+    Nothing,
+    std::vector<Optional<T>>>::type;
+
+  explicit CollectContext(int n) : result(n) {}
+  ~CollectContext() {
+    if (!threw.exchange(true)) {
+      // map Optional<T> -> T
+      std::vector<T> finalResult;
+      finalResult.reserve(result.size());
+      std::transform(result.begin(), result.end(),
+                     std::back_inserter(finalResult),
+                     [](Optional<T>& o) { return std::move(o.value()); });
+      p.setValue(std::move(finalResult));
+    }
   }
-
-  inline void addResult(int i, Try<T>& t) {
-    results[i] = std::move(t.value());
+  inline void setPartialResult(size_t i, Try<T>& t) {
+    result[i] = std::move(t.value());
   }
+  Promise<Result> p;
+  InternalResult result;
+  std::atomic<bool> threw;
 };
 
-template <>
-struct CollectContext<void> {
-
-  explicit CollectContext(int n) : count(0), success_count(0), threw(false) {}
+// Specialize for void (implementations in Future.cpp)
 
-  Promise<void> p;
-  std::atomic<size_t> count, success_count;
-  std::atomic_bool threw;
-
-  typedef void result_type;
-
-  static inline Future<void> makeEmptyFuture() {
-    return makeFuture();
-  }
-
-  inline void setValue() {
-    p.setValue();
-  }
+template <>
+CollectContext<void>::~CollectContext();
 
-  inline void addResult(int i, Try<void>& t) {
-    // do nothing
-  }
-};
+template <>
+void CollectContext<void>::setPartialResult(size_t i, Try<void>& t);
 
-} // detail
+}
 
 template <class InputIterator>
 Future<typename detail::CollectContext<
-  typename std::iterator_traits<InputIterator>::value_type::value_type
->::result_type>
+  typename std::iterator_traits<InputIterator>::value_type::value_type>::Result>
 collect(InputIterator first, InputIterator last) {
   typedef
     typename std::iterator_traits<InputIterator>::value_type::value_type T;
 
-  if (first >= last) {
-    return detail::CollectContext<T>::makeEmptyFuture();
-  }
-
-  size_t n = std::distance(first, last);
-  auto ctx = new detail::CollectContext<T>(n);
-  auto f_saved = ctx->p.getFuture();
-
-  for (size_t i = 0; first != last; ++first, ++i) {
-     assert(i < n);
-     auto& f = *first;
-     f.setCallback_([ctx, i, n](Try<T> t) {
-
-       if (t.hasException()) {
-         if (!ctx->threw.exchange(true)) {
-           ctx->p.setException(std::move(t.exception()));
-         }
-       } else if (!ctx->threw) {
-         ctx->addResult(i, t);
-         if (++ctx->success_count == n) {
-           ctx->setValue();
-         }
-       }
-
-       if (++ctx->count == n) {
-         delete ctx;
+  auto ctx = std::make_shared<detail::CollectContext<T>>(
+    std::distance(first, last));
+  mapSetCallback<T>(first, last, [ctx](size_t i, Try<T>&& t) {
+    if (t.hasException()) {
+       if (!ctx->threw.exchange(true)) {
+         ctx->p.setException(std::move(t.exception()));
        }
-     });
-  }
-
-  return f_saved;
+     } else if (!ctx->threw) {
+       ctx->setPartialResult(i, t);
+     }
+  });
+  return ctx->p.getFuture();
 }
 
 template <class InputIterator>
@@ -712,25 +653,24 @@ Future<
   std::pair<size_t,
             Try<
               typename
-              std::iterator_traits<InputIterator>::value_type::value_type> > >
+              std::iterator_traits<InputIterator>::value_type::value_type>>>
 collectAny(InputIterator first, InputIterator last) {
   typedef
     typename std::iterator_traits<InputIterator>::value_type::value_type T;
 
-  auto ctx = new detail::WhenAnyContext<T>(std::distance(first, last));
-  auto f_saved = ctx->p.getFuture();
-
-  for (size_t i = 0; first != last; first++, i++) {
-    auto& f = *first;
-    f.setCallback_([i, ctx](Try<T>&& t) {
-      if (!ctx->done.exchange(true)) {
-        ctx->p.setValue(std::make_pair(i, std::move(t)));
-      }
-      ctx->decref();
-    });
-  }
+  struct CollectAnyContext {
+    CollectAnyContext(size_t n) : done(false) {};
+    Promise<std::pair<size_t, Try<T>>> p;
+    std::atomic<bool> done;
+  };
 
-  return f_saved;
+  auto ctx = std::make_shared<CollectAnyContext>(std::distance(first, last));
+  mapSetCallback<T>(first, last, [ctx](size_t i, Try<T>&& t) {
+    if (!ctx->done.exchange(true)) {
+      ctx->p.setValue(std::make_pair(i, std::move(t)));
+    }
+  });
+  return ctx->p.getFuture();
 }
 
 template <class InputIterator>
@@ -741,38 +681,29 @@ collectN(InputIterator first, InputIterator last, size_t n) {
     std::iterator_traits<InputIterator>::value_type::value_type T;
   typedef std::vector<std::pair<size_t, Try<T>>> V;
 
-  struct ctx_t {
+  struct CollectNContext {
     V v;
-    size_t completed;
+    std::atomic<size_t> completed = {0};
     Promise<V> p;
   };
-  auto ctx = std::make_shared<ctx_t>();
-  ctx->completed = 0;
-
-  // for each completed Future, increase count and add to vector, until we
-  // have n completed futures at which point we fulfill our Promise with the
-  // vector
-  auto it = first;
-  size_t i = 0;
-  while (it != last) {
-    it->then([ctx, n, i](Try<T>&& t) {
-      auto& v = ctx->v;
+  auto ctx = std::make_shared<CollectNContext>();
+
+  if (std::distance(first, last) < n) {
+    ctx->p.setException(std::runtime_error("Not enough futures"));
+  } else {
+    // for each completed Future, increase count and add to vector, until we
+    // have n completed futures at which point we fulfil our Promise with the
+    // vector
+    mapSetCallback<T>(first, last, [ctx, n](size_t i, Try<T>&& t) {
       auto c = ++ctx->completed;
       if (c <= n) {
         assert(ctx->v.size() < n);
-        v.push_back(std::make_pair(i, std::move(t)));
+        ctx->v.push_back(std::make_pair(i, std::move(t)));
         if (c == n) {
-          ctx->p.setTry(Try<V>(std::move(v)));
+          ctx->p.setTry(Try<V>(std::move(ctx->v)));
         }
       }
     });
-
-    it++;
-    i++;
-  }
-
-  if (i < n) {
-    ctx->p.setException(std::runtime_error("Not enough futures"));
   }
 
   return ctx->p.getFuture();
index 0f6dc3d15128a1c1e56fc68f146900afc0e5347b..78f33d2608426dcb8b7ed31463d38852ef6f94ea 100644 (file)
@@ -39,3 +39,19 @@ Future<void> sleep(Duration dur, Timekeeper* tk) {
 }
 
 }}
+
+namespace folly { namespace detail {
+
+template <>
+CollectContext<void>::~CollectContext() {
+  if (!threw.exchange(true)) {
+    p.setValue();
+  }
+}
+
+template <>
+void CollectContext<void>::setPartialResult(size_t i, Try<void>& t) {
+  // Nothing to do for void
+}
+
+}}
index 7e23dd7c665f13b14a341ca4d1faeb9ec80c9752..65e2cb1d76d141f0dbc9f345fda92169dd7be02a 100644 (file)
@@ -319,59 +319,33 @@ class Core {
 
 template <typename... Ts>
 struct VariadicContext {
-  VariadicContext() : total(0), count(0) {}
-  Promise<std::tuple<Try<Ts>... > > p;
+  VariadicContext() {}
+  ~VariadicContext() {
+    p.setValue(std::move(results));
+  }
+  Promise<std::tuple<Try<Ts>... >> p;
   std::tuple<Try<Ts>... > results;
-  size_t total;
-  std::atomic<size_t> count;
   typedef Future<std::tuple<Try<Ts>...>> type;
 };
 
 template <typename... Ts, typename THead, typename... Fs>
 typename std::enable_if<sizeof...(Fs) == 0, void>::type
-collectAllVariadicHelper(VariadicContext<Ts...> *ctx, THead&& head, Fs&&... tail) {
+collectAllVariadicHelper(std::shared_ptr<VariadicContext<Ts...>> ctx,
+                         THead&& head, Fs&&... tail) {
   head.setCallback_([ctx](Try<typename THead::value_type>&& t) {
     std::get<sizeof...(Ts) - sizeof...(Fs) - 1>(ctx->results) = std::move(t);
-    if (++ctx->count == ctx->total) {
-      ctx->p.setValue(std::move(ctx->results));
-      delete ctx;
-    }
   });
 }
 
 template <typename... Ts, typename THead, typename... Fs>
 typename std::enable_if<sizeof...(Fs) != 0, void>::type
-collectAllVariadicHelper(VariadicContext<Ts...> *ctx, THead&& head, Fs&&... tail) {
+collectAllVariadicHelper(std::shared_ptr<VariadicContext<Ts...>> ctx,
+                         THead&& head, Fs&&... tail) {
   head.setCallback_([ctx](Try<typename THead::value_type>&& t) {
     std::get<sizeof...(Ts) - sizeof...(Fs) - 1>(ctx->results) = std::move(t);
-    if (++ctx->count == ctx->total) {
-      ctx->p.setValue(std::move(ctx->results));
-      delete ctx;
-    }
   });
   // template tail-recursion
   collectAllVariadicHelper(ctx, std::forward<Fs>(tail)...);
 }
 
-template <typename T>
-struct WhenAllContext {
-  WhenAllContext() : count(0) {}
-  Promise<std::vector<Try<T> > > p;
-  std::vector<Try<T> > results;
-  std::atomic<size_t> count;
-};
-
-template <typename T>
-struct WhenAnyContext {
-  explicit WhenAnyContext(size_t n) : done(false), ref_count(n) {};
-  Promise<std::pair<size_t, Try<T>>> p;
-  std::atomic<bool> done;
-  std::atomic<size_t> ref_count;
-  void decref() {
-    if (--ref_count == 0) {
-      delete this;
-    }
-  }
-};
-
 }} // folly::detail