Destroy promise/future callback functions before waking waiters
authorYedidya Feldblum <yfeldblum@fb.com>
Thu, 1 Jun 2017 05:41:16 +0000 (22:41 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 1 Jun 2017 05:51:28 +0000 (22:51 -0700)
Summary:
Code may pass a callback which captures an object with a destructor which mutates through a stored reference, triggering heap-use-after-free or stack-use-after-scope.

```lang=c++
void performDataRace() {
  auto number = std::make_unique<int>(0);
  auto guard = folly::makeGuard([&number] { *number = 1; });
  folly::via(getSomeExecutor(), [guard = std::move(guard)]() mutable {}).wait();
  // data race - we may wake and destruct number before guard is destructed on the
  // executor thread, which is both stack-use-after-scope and heap-use-after-free!
}
```

We can avoid this condition by always destructing the provided functor before setting any result on the promise.

Retry at {D4982969}.

Reviewed By: andriigrynenko

Differential Revision: D5058750

fbshipit-source-id: 4d1d878b4889e5e6474941187f03de5fa84d3061

folly/futures/Future-inl.h
folly/futures/Promise-inl.h
folly/futures/Promise.h
folly/futures/test/CallbackLifetimeTest.cpp [new file with mode: 0644]
folly/test/Makefile.am

index 4089a67..427be4d 100644 (file)
@@ -48,7 +48,76 @@ typedef folly::Baton<> FutureBatonType;
 }
 
 namespace detail {
-  std::shared_ptr<Timekeeper> getTimekeeperSingleton();
+std::shared_ptr<Timekeeper> getTimekeeperSingleton();
+
+//  Guarantees that the stored functor is destructed before the stored promise
+//  may be fulfilled. Assumes the stored functor to be noexcept-destructible.
+template <typename T, typename F>
+class CoreCallbackState {
+ public:
+  template <typename FF>
+  CoreCallbackState(Promise<T>&& promise, FF&& func) noexcept(
+      noexcept(F(std::declval<FF>())))
+      : func_(std::forward<FF>(func)), promise_(std::move(promise)) {
+    assert(before_barrier());
+  }
+
+  CoreCallbackState(CoreCallbackState&& that) noexcept(
+      noexcept(F(std::declval<F>()))) {
+    if (that.before_barrier()) {
+      new (&func_) F(std::move(that.func_));
+      promise_ = that.stealPromise();
+    }
+  }
+
+  CoreCallbackState& operator=(CoreCallbackState&&) = delete;
+
+  ~CoreCallbackState() {
+    if (before_barrier()) {
+      stealPromise();
+    }
+  }
+
+  template <typename... Args>
+  auto invoke(Args&&... args) noexcept(
+      noexcept(std::declval<F&&>()(std::declval<Args&&>()...))) {
+    assert(before_barrier());
+    return std::move(func_)(std::forward<Args>(args)...);
+  }
+
+  void setTry(Try<T>&& t) {
+    stealPromise().setTry(std::move(t));
+  }
+
+  void setException(exception_wrapper&& ew) {
+    stealPromise().setException(std::move(ew));
+  }
+
+  Promise<T> stealPromise() noexcept {
+    assert(before_barrier());
+    func_.~F();
+    return std::move(promise_);
+  }
+
+ private:
+  bool before_barrier() const noexcept {
+    return !promise_.isFulfilled();
+  }
+
+  union {
+    F func_;
+  };
+  Promise<T> promise_{detail::EmptyConstruct{}};
+};
+
+template <typename T, typename F>
+inline auto makeCoreCallbackState(Promise<T>&& p, F&& f) noexcept(
+    noexcept(CoreCallbackState<T, _t<std::decay<F>>>(
+        std::declval<Promise<T>&&>(),
+        std::declval<F&&>()))) {
+  return CoreCallbackState<T, _t<std::decay<F>>>(
+      std::move(p), std::forward<F>(f));
+}
 }
 
 template <class T>
@@ -160,13 +229,13 @@ Future<T>::thenImplementation(F&& func, detail::argResult<isTry, F, Args...>) {
      in the destruction of the Future used to create it.
      */
   setCallback_(
-      [ func = std::forward<F>(func), pm = std::move(p) ](Try<T> && t) mutable {
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
         if (!isTry && t.hasException()) {
-          pm.setException(std::move(t.exception()));
+          state.setException(std::move(t.exception()));
         } else {
-          pm.setWith([&]() {
-            return std::move(func)(t.template get<isTry, Args>()...);
-          });
+          state.setTry(makeTryWith(
+              [&] { return state.invoke(t.template get<isTry, Args>()...); }));
         }
       });
 
@@ -191,30 +260,31 @@ Future<T>::thenImplementation(F&& func, detail::argResult<isTry, F, Args...>) {
   auto f = p.getFuture();
   f.core_->setExecutorNoLock(getExecutor());
 
-  setCallback_([ func = std::forward<F>(func), pm = std::move(p) ](
-      Try<T> && t) mutable {
-    auto ew = [&] {
-      if (!isTry && t.hasException()) {
-        return std::move(t.exception());
-      } else {
-        try {
-          auto f2 = std::move(func)(t.template get<isTry, Args>()...);
-          // that didn't throw, now we can steal p
-          f2.setCallback_([p = std::move(pm)](Try<B> && b) mutable {
-            p.setTry(std::move(b));
-          });
-          return exception_wrapper();
-        } catch (const std::exception& e) {
-          return exception_wrapper(std::current_exception(), e);
-        } catch (...) {
-          return exception_wrapper(std::current_exception());
+  setCallback_(
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
+        auto ew = [&] {
+          if (!isTry && t.hasException()) {
+            return std::move(t.exception());
+          } else {
+            try {
+              auto f2 = state.invoke(t.template get<isTry, Args>()...);
+              // that didn't throw, now we can steal p
+              f2.setCallback_([p = state.stealPromise()](Try<B> && b) mutable {
+                p.setTry(std::move(b));
+              });
+              return exception_wrapper();
+            } catch (const std::exception& e) {
+              return exception_wrapper(std::current_exception(), e);
+            } catch (...) {
+              return exception_wrapper(std::current_exception());
+            }
+          }
+        }();
+        if (ew) {
+          state.setException(std::move(ew));
         }
-      }
-    }();
-    if (ew) {
-      pm.setException(std::move(ew));
-    }
-  });
+      });
 
   return f;
 }
@@ -266,11 +336,12 @@ Future<T>::onError(F&& func) {
   auto f = p.getFuture();
 
   setCallback_(
-      [ func = std::forward<F>(func), pm = std::move(p) ](Try<T> && t) mutable {
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
         if (!t.template withException<Exn>([&](Exn& e) {
-              pm.setWith([&] { return std::move(func)(e); });
+              state.setTry(makeTryWith([&] { return state.invoke(e); }));
             })) {
-          pm.setTry(std::move(t));
+          state.setTry(std::move(t));
         }
       });
 
@@ -293,29 +364,29 @@ Future<T>::onError(F&& func) {
   Promise<T> p;
   auto f = p.getFuture();
 
-  setCallback_([ pm = std::move(p), func = std::forward<F>(func) ](
-      Try<T> && t) mutable {
-    if (!t.template withException<Exn>([&](Exn& e) {
-          auto ew = [&] {
-            try {
-              auto f2 = std::move(func)(e);
-              f2.setCallback_([pm = std::move(pm)](Try<T> && t2) mutable {
-                pm.setTry(std::move(t2));
-              });
-              return exception_wrapper();
-            } catch (const std::exception& e2) {
-              return exception_wrapper(std::current_exception(), e2);
-            } catch (...) {
-              return exception_wrapper(std::current_exception());
-            }
-          }();
-          if (ew) {
-            pm.setException(std::move(ew));
-          }
-        })) {
-      pm.setTry(std::move(t));
-    }
-  });
+  setCallback_(
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> && t) mutable {
+        if (!t.template withException<Exn>([&](Exn& e) {
+              auto ew = [&] {
+                try {
+                  auto f2 = state.invoke(e);
+                  f2.setCallback_([p = state.stealPromise()](
+                      Try<T> && t2) mutable { p.setTry(std::move(t2)); });
+                  return exception_wrapper();
+                } catch (const std::exception& e2) {
+                  return exception_wrapper(std::current_exception(), e2);
+                } catch (...) {
+                  return exception_wrapper(std::current_exception());
+                }
+              }();
+              if (ew) {
+                state.setException(std::move(ew));
+              }
+            })) {
+          state.setTry(std::move(t));
+        }
+      });
 
   return f;
 }
@@ -349,13 +420,14 @@ Future<T>::onError(F&& func) {
   Promise<T> p;
   auto f = p.getFuture();
   setCallback_(
-      [ pm = std::move(p), func = std::forward<F>(func) ](Try<T> t) mutable {
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> t) mutable {
         if (t.hasException()) {
           auto ew = [&] {
             try {
-              auto f2 = std::move(func)(std::move(t.exception()));
-              f2.setCallback_([pm = std::move(pm)](Try<T> t2) mutable {
-                pm.setTry(std::move(t2));
+              auto f2 = state.invoke(std::move(t.exception()));
+              f2.setCallback_([p = state.stealPromise()](Try<T> t2) mutable {
+                p.setTry(std::move(t2));
               });
               return exception_wrapper();
             } catch (const std::exception& e2) {
@@ -365,10 +437,10 @@ Future<T>::onError(F&& func) {
             }
           }();
           if (ew) {
-            pm.setException(std::move(ew));
+            state.setException(std::move(ew));
           }
         } else {
-          pm.setTry(std::move(t));
+          state.setTry(std::move(t));
         }
       });
 
@@ -390,11 +462,13 @@ Future<T>::onError(F&& func) {
   Promise<T> p;
   auto f = p.getFuture();
   setCallback_(
-      [ pm = std::move(p), func = std::forward<F>(func) ](Try<T> t) mutable {
+      [state = detail::makeCoreCallbackState(
+           std::move(p), std::forward<F>(func))](Try<T> t) mutable {
         if (t.hasException()) {
-          pm.setWith([&] { return std::move(func)(std::move(t.exception())); });
+          state.setTry(makeTryWith(
+              [&] { return state.invoke(std::move(t.exception())); }));
         } else {
-          pm.setTry(std::move(t));
+          state.setTry(std::move(t));
         }
       });
 
index c55d34c..61d4d93 100644 (file)
@@ -59,6 +59,10 @@ void Promise<T>::throwIfRetrieved() {
   }
 }
 
+template <class T>
+Promise<T>::Promise(detail::EmptyConstruct) noexcept
+    : retrieved_(false), core_(nullptr) {}
+
 template <class T>
 Promise<T>::~Promise() {
   detach();
index a2793e2..7511912 100644 (file)
@@ -25,6 +25,12 @@ namespace folly {
 // forward declaration
 template <class T> class Future;
 
+namespace detail {
+struct EmptyConstruct {};
+template <typename T, typename F>
+class CoreCallbackState;
+}
+
 template <class T>
 class Promise {
  public:
@@ -98,6 +104,8 @@ class Promise {
  private:
   typedef typename Future<T>::corePtr corePtr;
   template <class> friend class Future;
+  template <class, class>
+  friend class detail::CoreCallbackState;
 
   // Whether the Future has been retrieved (a one-time operation).
   bool retrieved_;
@@ -105,6 +113,8 @@ class Promise {
   // shared core state object
   corePtr core_;
 
+  explicit Promise(detail::EmptyConstruct) noexcept;
+
   void throwIfFulfilled();
   void throwIfRetrieved();
   void detach();
diff --git a/folly/futures/test/CallbackLifetimeTest.cpp b/folly/futures/test/CallbackLifetimeTest.cpp
new file mode 100644 (file)
index 0000000..98fa9fd
--- /dev/null
@@ -0,0 +1,207 @@
+/*
+ * Copyright 2017 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/futures/Future.h>
+
+#include <thread>
+
+#include <folly/futures/test/TestExecutor.h>
+#include <folly/portability/GTest.h>
+
+using namespace folly;
+
+namespace {
+
+/***
+ *  The basic premise is to check that the callback passed to then or onError
+ *  is destructed before wait returns on the resulting future.
+ *
+ *  The approach is to use callbacks where the destructor sleeps 500ms and then
+ *  mutates a counter allocated on the caller stack. The caller checks the
+ *  counter immediately after calling wait. Were the callback not destructed
+ *  before wait returns, then we would very likely see an unchanged counter just
+ *  after wait returns. But if, as we expect, the callback were destructed
+ *  before wait returns, then we must be guaranteed to see a mutated counter
+ *  just after wait returns.
+ *
+ *  Note that the failure condition is not strictly guaranteed under load. :(
+ */
+class CallbackLifetimeTest : public testing::Test {
+ public:
+  using CounterPtr = std::unique_ptr<size_t>;
+
+  static bool kRaiseWillThrow() {
+    return true;
+  }
+  static constexpr auto kDelay() {
+    return std::chrono::milliseconds(500);
+  }
+
+  auto mkC() {
+    return std::make_unique<size_t>(0);
+  }
+  auto mkCGuard(CounterPtr& ptr) {
+    return makeGuard([&] {
+      /* sleep override */ std::this_thread::sleep_for(kDelay());
+      ++*ptr;
+    });
+  }
+
+  static void raise() {
+    if (kRaiseWillThrow()) { // to avoid marking [[noreturn]]
+      throw std::runtime_error("raise");
+    }
+  }
+  static Future<Unit> raiseFut() {
+    raise();
+    return makeFuture();
+  }
+
+  TestExecutor executor{2}; // need at least 2 threads for internal futures
+};
+}
+
+TEST_F(CallbackLifetimeTest, thenReturnsValue) {
+  auto c = mkC();
+  via(&executor).then([_ = mkCGuard(c)]{}).wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, thenReturnsValueThrows) {
+  auto c = mkC();
+  via(&executor).then([_ = mkCGuard(c)] { raise(); }).wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, thenReturnsFuture) {
+  auto c = mkC();
+  via(&executor).then([_ = mkCGuard(c)] { return makeFuture(); }).wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, thenReturnsFutureThrows) {
+  auto c = mkC();
+  via(&executor).then([_ = mkCGuard(c)] { return raiseFut(); }).wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsValueMatch) {
+  auto c = mkC();
+  via(&executor)
+      .then(raise)
+      .onError([_ = mkCGuard(c)](std::exception&){})
+      .wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsValueMatchThrows) {
+  auto c = mkC();
+  via(&executor)
+      .then(raise)
+      .onError([_ = mkCGuard(c)](std::exception&) { raise(); })
+      .wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsValueWrong) {
+  auto c = mkC();
+  via(&executor)
+      .then(raise)
+      .onError([_ = mkCGuard(c)](std::logic_error&){})
+      .wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsValueWrongThrows) {
+  auto c = mkC();
+  via(&executor)
+      .then(raise)
+      .onError([_ = mkCGuard(c)](std::logic_error&) { raise(); })
+      .wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsFutureMatch) {
+  auto c = mkC();
+  via(&executor)
+      .then(raise)
+      .onError([_ = mkCGuard(c)](std::exception&) { return makeFuture(); })
+      .wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsFutureMatchThrows) {
+  auto c = mkC();
+  via(&executor)
+      .then(raise)
+      .onError([_ = mkCGuard(c)](std::exception&) { return raiseFut(); })
+      .wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsFutureWrong) {
+  auto c = mkC();
+  via(&executor)
+      .then(raise)
+      .onError([_ = mkCGuard(c)](std::logic_error&) { return makeFuture(); })
+      .wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, onErrorTakesExnReturnsFutureWrongThrows) {
+  auto c = mkC();
+  via(&executor)
+      .then(raise)
+      .onError([_ = mkCGuard(c)](std::logic_error&) { return raiseFut(); })
+      .wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, onErrorTakesWrapReturnsValue) {
+  auto c = mkC();
+  via(&executor)
+      .then(raise)
+      .onError([_ = mkCGuard(c)](exception_wrapper &&){})
+      .wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, onErrorTakesWrapReturnsValueThrows) {
+  auto c = mkC();
+  via(&executor)
+      .then(raise)
+      .onError([_ = mkCGuard(c)](exception_wrapper &&) { raise(); })
+      .wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, onErrorTakesWrapReturnsFuture) {
+  auto c = mkC();
+  via(&executor)
+      .then(raise)
+      .onError([_ = mkCGuard(c)](exception_wrapper &&) { return makeFuture(); })
+      .wait();
+  EXPECT_EQ(1, *c);
+}
+
+TEST_F(CallbackLifetimeTest, onErrorTakesWrapReturnsFutureThrows) {
+  auto c = mkC();
+  via(&executor)
+      .then(raise)
+      .onError([_ = mkCGuard(c)](exception_wrapper &&) { return raiseFut(); })
+      .wait();
+  EXPECT_EQ(1, *c);
+}
index d3d12b5..60fdea5 100644 (file)
@@ -262,6 +262,7 @@ unit_test_LDADD = libfollytestmain.la
 TESTS += unit_test
 
 futures_test_SOURCES = \
+    ../futures/test/CallbackLifetimeTest.cpp \
     ../futures/test/CollectTest.cpp \
     ../futures/test/ContextTest.cpp \
     ../futures/test/CoreTest.cpp \