Destroy promise/future callback functions before waking waiters
[folly.git] / folly / futures / Future-inl.h
index 0e6c7e31c40f8e32f1205530844429a84044f203..e1e34956a6aa3cffc96d75072a80aa9b60e038cb 100644 (file)
@@ -48,7 +48,76 @@ typedef folly::Baton<> FutureBatonType;
 }
 
 namespace detail {
-  std::shared_ptr<Timekeeper> getTimekeeperSingleton();
+std::shared_ptr<Timekeeper> getTimekeeperSingleton();
+
+//  Guarantees that the stored functor is destructed before the stored promise
+//  may be fulfilled. Assumes the stored functor to be noexcept-destructible.
+template <typename T, typename F>
+class CoreCallbackState {
+ public:
+  template <typename FF>
+  CoreCallbackState(Promise<T>&& promise, FF&& func) noexcept(
+      noexcept(F(std::declval<FF>())))
+      : func_(std::forward<FF>(func)), promise_(std::move(promise)) {
+    assert(before_barrier());
+  }
+
+  CoreCallbackState(CoreCallbackState&& that) noexcept(
+      noexcept(F(std::declval<F>()))) {
+    if (that.before_barrier()) {
+      new (&func_) F(std::move(that.func_));
+      promise_ = that.stealPromise();
+    }
+  }
+
+  CoreCallbackState& operator=(CoreCallbackState&&) = delete;
+
+  ~CoreCallbackState() {
+    if (before_barrier()) {
+      stealPromise();
+    }
+  }
+
+  template <typename... Args>
+  auto invoke(Args&&... args) noexcept(
+      noexcept(std::declval<F&&>()(std::declval<Args&&>()...))) {
+    assert(before_barrier());
+    return std::move(func_)(std::forward<Args>(args)...);
+  }
+
+  void setTry(Try<T>&& t) {
+    stealPromise().setTry(std::move(t));
+  }
+
+  void setException(exception_wrapper&& ew) {
+    stealPromise().setException(std::move(ew));
+  }
+
+  Promise<T> stealPromise() noexcept {
+    assert(before_barrier());
+    func_.~F();
+    return std::move(promise_);
+  }
+
+ private:
+  bool before_barrier() const noexcept {
+    return !promise_.isFulfilled();
+  }
+
+  union {
+    F func_;
+  };
+  Promise<T> promise_{detail::EmptyConstruct{}};
+};
+
+template <typename T, typename F>
+inline auto makeCoreCallbackState(Promise<T>&& p, F&& f) noexcept(
+    noexcept(CoreCallbackState<T, _t<std::decay<F>>>(
+        std::declval<Promise<T>&&>(),
+        std::declval<F&&>()))) {
+  return CoreCallbackState<T, _t<std::decay<F>>>(
+      std::move(p), std::forward<F>(f));
+}
 }
 
 template <class T>
@@ -160,13 +229,13 @@ Future<T>::thenImplementation(F&& func, detail::argResult<isTry, F, Args...>) {
      in the destruction of the Future used to create it.
      */
   setCallback_(
-      [ func = std::forward<F>(func), pm = std::move(p) ](Try<T> && t) mutable {
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
         if (!isTry && t.hasException()) {
-          pm.setException(std::move(t.exception()));
+          state.setException(std::move(t.exception()));
         } else {
-          pm.setWith([&]() {
-            return std::move(func)(t.template get<isTry, Args>()...);
-          });
+          state.setTry(makeTryWith(
+              [&] { return state.invoke(t.template get<isTry, Args>()...); }));
         }
       });
 
@@ -191,30 +260,31 @@ Future<T>::thenImplementation(F&& func, detail::argResult<isTry, F, Args...>) {
   auto f = p.getFuture();
   f.core_->setExecutorNoLock(getExecutor());
 
-  setCallback_([ func = std::forward<F>(func), pm = std::move(p) ](
-      Try<T> && t) mutable {
-    auto ew = [&] {
-      if (!isTry && t.hasException()) {
-        return std::move(t.exception());
-      } else {
-        try {
-          auto f2 = std::move(func)(t.template get<isTry, Args>()...);
-          // that didn't throw, now we can steal p
-          f2.setCallback_([p = std::move(pm)](Try<B> && b) mutable {
-            p.setTry(std::move(b));
-          });
-          return exception_wrapper();
-        } catch (const std::exception& e) {
-          return exception_wrapper(std::current_exception(), e);
-        } catch (...) {
-          return exception_wrapper(std::current_exception());
+  setCallback_(
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
+        auto ew = [&] {
+          if (!isTry && t.hasException()) {
+            return std::move(t.exception());
+          } else {
+            try {
+              auto f2 = state.invoke(t.template get<isTry, Args>()...);
+              // that didn't throw, now we can steal p
+              f2.setCallback_([p = state.stealPromise()](Try<B> && b) mutable {
+                p.setTry(std::move(b));
+              });
+              return exception_wrapper();
+            } catch (const std::exception& e) {
+              return exception_wrapper(std::current_exception(), e);
+            } catch (...) {
+              return exception_wrapper(std::current_exception());
+            }
+          }
+        }();
+        if (ew) {
+          state.setException(std::move(ew));
         }
-      }
-    }();
-    if (ew) {
-      pm.setException(std::move(ew));
-    }
-  });
+      });
 
   return f;
 }
@@ -266,11 +336,12 @@ Future<T>::onError(F&& func) {
   auto f = p.getFuture();
 
   setCallback_(
-      [ func = std::forward<F>(func), pm = std::move(p) ](Try<T> && t) mutable {
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
         if (!t.template withException<Exn>([&](Exn& e) {
-              pm.setWith([&] { return std::move(func)(e); });
+              state.setTry(makeTryWith([&] { return state.invoke(e); }));
             })) {
-          pm.setTry(std::move(t));
+          state.setTry(std::move(t));
         }
       });
 
@@ -293,29 +364,29 @@ Future<T>::onError(F&& func) {
   Promise<T> p;
   auto f = p.getFuture();
 
-  setCallback_([ pm = std::move(p), func = std::forward<F>(func) ](
-      Try<T> && t) mutable {
-    if (!t.template withException<Exn>([&](Exn& e) {
-          auto ew = [&] {
-            try {
-              auto f2 = std::move(func)(e);
-              f2.setCallback_([pm = std::move(pm)](Try<T> && t2) mutable {
-                pm.setTry(std::move(t2));
-              });
-              return exception_wrapper();
-            } catch (const std::exception& e2) {
-              return exception_wrapper(std::current_exception(), e2);
-            } catch (...) {
-              return exception_wrapper(std::current_exception());
-            }
-          }();
-          if (ew) {
-            pm.setException(std::move(ew));
-          }
-        })) {
-      pm.setTry(std::move(t));
-    }
-  });
+  setCallback_(
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
+        if (!t.template withException<Exn>([&](Exn& e) {
+              auto ew = [&] {
+                try {
+                  auto f2 = state.invoke(e);
+                  f2.setCallback_([p = state.stealPromise()](
+                      Try<T> && t2) mutable { p.setTry(std::move(t2)); });
+                  return exception_wrapper();
+                } catch (const std::exception& e2) {
+                  return exception_wrapper(std::current_exception(), e2);
+                } catch (...) {
+                  return exception_wrapper(std::current_exception());
+                }
+              }();
+              if (ew) {
+                state.setException(std::move(ew));
+              }
+            })) {
+          state.setTry(std::move(t));
+        }
+      });
 
   return f;
 }
@@ -349,13 +420,14 @@ Future<T>::onError(F&& func) {
   Promise<T> p;
   auto f = p.getFuture();
   setCallback_(
-      [ pm = std::move(p), func = std::forward<F>(func) ](Try<T> t) mutable {
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> t) mutable {
         if (t.hasException()) {
           auto ew = [&] {
             try {
-              auto f2 = std::move(func)(std::move(t.exception()));
-              f2.setCallback_([pm = std::move(pm)](Try<T> t2) mutable {
-                pm.setTry(std::move(t2));
+              auto f2 = state.invoke(std::move(t.exception()));
+              f2.setCallback_([p = state.stealPromise()](Try<T> t2) mutable {
+                p.setTry(std::move(t2));
               });
               return exception_wrapper();
             } catch (const std::exception& e2) {
@@ -365,10 +437,10 @@ Future<T>::onError(F&& func) {
             }
           }();
           if (ew) {
-            pm.setException(std::move(ew));
+            state.setException(std::move(ew));
           }
         } else {
-          pm.setTry(std::move(t));
+          state.setTry(std::move(t));
         }
       });
 
@@ -390,11 +462,13 @@ Future<T>::onError(F&& func) {
   Promise<T> p;
   auto f = p.getFuture();
   setCallback_(
-      [ pm = std::move(p), func = std::forward<F>(func) ](Try<T> t) mutable {
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> t) mutable {
         if (t.hasException()) {
-          pm.setWith([&] { return std::move(func)(std::move(t.exception())); });
+          state.setTry(makeTryWith(
+              [&] { return state.invoke(std::move(t.exception())); }));
         } else {
-          pm.setTry(std::move(t));
+          state.setTry(std::move(t));
         }
       });