Expected coroutines support
authorYedidya Feldblum <yfeldblum@fb.com>
Thu, 26 Oct 2017 03:27:41 +0000 (20:27 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Oct 2017 03:39:00 +0000 (20:39 -0700)
Summary:
[Folly] `Expected` coroutines support.

Copied from `Optional` coroutines support.

Reviewed By: ericniebler, Orvid

Differential Revision: D5923792

fbshipit-source-id: 8661012c65762a0e540a4af2fd2fc237a8cb87a1

folly/Expected.h
folly/test/ExpectedCoroutinesTest.cpp [new file with mode: 0644]

index c2d4e42d07784fbcd5739198bc843e1a159fa186..c9874eba239531796a956a43afaed509bb784ee6 100644 (file)
 #include <type_traits>
 #include <utility>
 
+#include <glog/logging.h>
+
 #include <folly/Likely.h>
+#include <folly/Optional.h>
 #include <folly/Portability.h>
 #include <folly/Preprocessor.h>
 #include <folly/Traits.h>
@@ -93,6 +96,10 @@ using ExpectedErrorType =
 
 // Details...
 namespace expected_detail {
+
+template <typename Value, typename Error>
+struct PromiseReturn;
+
 #ifdef _MSC_VER
 // MSVC 2015 can't handle the StrictConjunction, so we have
 // to use std::conjunction instead.
@@ -1034,6 +1041,12 @@ class Expected final : expected_detail::ExpectedStorage<Value, Error> {
     return *this;
   }
 
+  // Used only when an Expected is used with coroutines on MSVC
+  /* implicit */ Expected(const expected_detail::PromiseReturn<Value, Error>& p)
+      : Expected{} {
+    p.promise_->value_ = this;
+  }
+
   template <class... Ts FOLLY_REQUIRES_TRAILING(
       std::is_constructible<Value, Ts&&...>::value)>
   void emplace(Ts&&... ts) {
@@ -1413,3 +1426,99 @@ bool operator>(const Value& other, const Expected<Value, Error>&) = delete;
 
 #undef FOLLY_REQUIRES
 #undef FOLLY_REQUIRES_TRAILING
+
+// Enable the use of folly::Expected with `co_await`
+// Inspired by https://github.com/toby-allsopp/coroutine_monad
+#if FOLLY_HAS_COROUTINES
+#include <experimental/coroutine>
+
+namespace folly {
+namespace expected_detail {
+template <typename Value, typename Error>
+struct Promise;
+
+template <typename Value, typename Error>
+struct PromiseReturn {
+  Optional<Expected<Value, Error>> storage_;
+  Promise<Value, Error>* promise_;
+  /* implicit */ PromiseReturn(Promise<Value, Error>& promise) noexcept
+      : promise_(&promise) {
+    promise_->value_ = &storage_;
+  }
+  PromiseReturn(PromiseReturn&& that) noexcept
+      : PromiseReturn{*that.promise_} {}
+  ~PromiseReturn() {}
+  /* implicit */ operator Expected<Value, Error>() & {
+    return std::move(*storage_);
+  }
+};
+
+template <typename Value, typename Error>
+struct Promise {
+  Optional<Expected<Value, Error>>* value_ = nullptr;
+  Promise() = default;
+  Promise(Promise const&) = delete;
+  // This should work regardless of whether the compiler generates:
+  //    folly::Expected<Value, Error> retobj{ p.get_return_object(); } // MSVC
+  // or:
+  //    auto retobj = p.get_return_object(); // clang
+  PromiseReturn<Value, Error> get_return_object() noexcept {
+    return *this;
+  }
+  std::experimental::suspend_never initial_suspend() const noexcept {
+    return {};
+  }
+  std::experimental::suspend_never final_suspend() const {
+    return {};
+  }
+  template <typename U>
+  void return_value(U&& u) {
+    value_->emplace(static_cast<U&&>(u));
+  }
+  void unhandled_exception() {
+    // Technically, throwing from unhandled_exception is underspecified:
+    // https://github.com/GorNishanov/CoroutineWording/issues/17
+    throw;
+  }
+};
+
+template <typename Value, typename Error>
+struct Awaitable {
+  Expected<Value, Error> o_;
+
+  explicit Awaitable(Expected<Value, Error> o) : o_(std::move(o)) {}
+
+  bool await_ready() const noexcept {
+    return o_.hasValue();
+  }
+  Value await_resume() {
+    return std::move(o_.value());
+  }
+
+  // Explicitly only allow suspension into a Promise
+  template <typename U>
+  void await_suspend(std::experimental::coroutine_handle<Promise<U, Error>> h) {
+    *h.promise().value_ = makeUnexpected(std::move(o_.error()));
+    // Abort the rest of the coroutine. resume() is not going to be called
+    h.destroy();
+  }
+};
+} // namespace expected_detail
+
+template <typename Value, typename Error>
+expected_detail::Awaitable<Value, Error>
+/* implicit */ operator co_await(Expected<Value, Error> o) {
+  return expected_detail::Awaitable<Value, Error>{std::move(o)};
+}
+} // namespace folly
+
+// This makes folly::Optional<Value> useable as a coroutine return type..
+FOLLY_NAMESPACE_STD_BEGIN
+namespace experimental {
+template <typename Value, typename Error, typename... Args>
+struct coroutine_traits<folly::Expected<Value, Error>, Args...> {
+  using promise_type = folly::expected_detail::Promise<Value, Error>;
+};
+} // namespace experimental
+FOLLY_NAMESPACE_STD_END
+#endif // FOLLY_HAS_COROUTINES
diff --git a/folly/test/ExpectedCoroutinesTest.cpp b/folly/test/ExpectedCoroutinesTest.cpp
new file mode 100644 (file)
index 0000000..e1d8717
--- /dev/null
@@ -0,0 +1,140 @@
+/*
+ * Copyright 2017 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <folly/Expected.h>
+#include <folly/Portability.h>
+#include <folly/ScopeGuard.h>
+#include <folly/portability/GTest.h>
+
+using namespace folly;
+
+namespace {
+
+struct Exn {};
+
+// not default-constructible, thereby preventing Expected<T, Err> from being
+// default-constructible, forcing our implementation to handle such cases
+class Err {
+ private:
+  enum class Type { Bad, Badder, Baddest };
+
+  Type type_;
+
+  constexpr Err(Type type) : type_(type) {}
+
+ public:
+  Err(Err const&) = default;
+  Err(Err&&) = default;
+  Err& operator=(Err const&) = default;
+  Err& operator=(Err&&) = default;
+
+  friend bool operator==(Err a, Err b) {
+    return a.type_ == b.type_;
+  }
+  friend bool operator!=(Err a, Err b) {
+    return a.type_ != b.type_;
+  }
+
+  static constexpr Err bad() {
+    return Type::Bad;
+  }
+  static constexpr Err badder() {
+    return Type::Badder;
+  }
+  static constexpr Err baddest() {
+    return Type::Baddest;
+  }
+};
+
+Expected<int, Err> f1() {
+  return 7;
+}
+
+Expected<double, Err> f2(int x) {
+  return 2.0 * x;
+}
+
+// move-only type
+Expected<std::unique_ptr<int>, Err> f3(int x, double y) {
+  return std::make_unique<int>(int(x + y));
+}
+
+// error result
+Expected<int, Err> f4(int, double, Err err) {
+  return makeUnexpected(err);
+}
+
+// exception
+Expected<int, Err> throws() {
+  throw Exn{};
+}
+
+} // namespace
+
+#if FOLLY_HAS_COROUTINES
+
+TEST(Expected, CoroutineSuccess) {
+  auto r0 = []() -> Expected<int, Err> {
+    auto x = co_await f1();
+    EXPECT_EQ(7, x);
+    auto y = co_await f2(x);
+    EXPECT_EQ(2.0 * 7, y);
+    auto z = co_await f3(x, y);
+    EXPECT_EQ(int(2.0 * 7 + 7), *z);
+    co_return* z;
+  }();
+  EXPECT_TRUE(r0.hasValue());
+  EXPECT_EQ(21, *r0);
+}
+
+TEST(Expected, CoroutineFailure) {
+  auto r1 = []() -> Expected<int, Err> {
+    auto x = co_await f1();
+    auto y = co_await f2(x);
+    auto z = co_await f4(x, y, Err::badder());
+    ADD_FAILURE();
+    co_return z;
+  }();
+  EXPECT_TRUE(r1.hasError());
+  EXPECT_EQ(Err::badder(), r1.error());
+}
+
+TEST(Expected, CoroutineException) {
+  EXPECT_THROW(
+      ([]() -> Expected<int, Err> {
+        auto x = co_await throws();
+        ADD_FAILURE();
+        co_return x;
+      }()),
+      Exn);
+}
+
+// this test makes sure that the coroutine is destroyed properly
+TEST(Expected, CoroutineCleanedUp) {
+  int count_dest = 0;
+  auto r = [&]() -> Expected<int, Err> {
+    SCOPE_EXIT {
+      ++count_dest;
+    };
+    auto x = co_await Expected<int, Err>(makeUnexpected(Err::badder()));
+    ADD_FAILURE() << "Should not be resuming";
+    co_return x;
+  }();
+  EXPECT_FALSE(r.hasValue());
+  EXPECT_EQ(1, count_dest);
+}
+
+#endif