modifiable channel pipelines
authorJames Sedgwick <jsedgwick@fb.com>
Fri, 21 Nov 2014 19:30:14 +0000 (11:30 -0800)
committerDave Watson <davejwatson@fb.com>
Thu, 11 Dec 2014 15:58:31 +0000 (07:58 -0800)
Summary:
Basically the same interface as before, but you must specify the read and write types for the ends of pipeline. Implementation is cleaner as well; there's fewer levels of indirection

This dynamic casts shit all over the place and is less typesafe then the previous iteration, but I think with some carefully placed static_asserts, could be just as safe (in the case where you don't do any modification, anyway)

Right now you can only add to the front or back of the pipeline but the way it's set up you could add any number of mutations, including ones that are triggered by handlers. But this should (might?) be enough for Tunnel, which was the motivation.

Test Plan: basic test compiles, thrift2 diff still works with a one line change

Reviewed By: hans@fb.com

Subscribers: trunkagent, fugalh, njormrod, folly-diffs@, bmatheny

FB internal diff: D1661169

Tasks: 5002299

Signature: t1:1661169:1416521727:1f126279796c0b09d1905b9f7dbc48a9e5540271

folly/experimental/wangle/channel/ChannelHandler.h
folly/experimental/wangle/channel/ChannelHandlerContext.h [new file with mode: 0644]
folly/experimental/wangle/channel/ChannelPipeline.h
folly/experimental/wangle/channel/ChannelTest.cpp

index 3ef7ae50514ca4c113734a3f8a0297dbadeb8021..27a324470631af6b084a3a24b668c27296599187 100644 (file)
@@ -138,7 +138,6 @@ class ChannelHandlerPtr : public ChannelHandler<
     }
   }
 
-
   void attachTransport(Context* ctx) override {
     ctx_ = ctx;
     if (handler_) {
diff --git a/folly/experimental/wangle/channel/ChannelHandlerContext.h b/folly/experimental/wangle/channel/ChannelHandlerContext.h
new file mode 100644 (file)
index 0000000..b0f1064
--- /dev/null
@@ -0,0 +1,251 @@
+/*
+ * Copyright 2014 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/wangle/Future.h>
+#include <folly/ExceptionWrapper.h>
+
+namespace folly { namespace wangle {
+
+template <class In, class Out>
+class ChannelHandlerContext {
+ public:
+  virtual ~ChannelHandlerContext() {}
+
+  virtual void fireRead(In msg) = 0;
+  virtual void fireReadEOF() = 0;
+  virtual void fireReadException(exception_wrapper e) = 0;
+
+  virtual Future<void> fireWrite(Out msg) = 0;
+  virtual Future<void> fireClose() = 0;
+
+  virtual std::shared_ptr<AsyncTransport> getTransport() = 0;
+
+  virtual void setWriteFlags(WriteFlags flags) = 0;
+  virtual WriteFlags getWriteFlags() = 0;
+
+  virtual void setReadBufferSettings(
+      uint64_t minAvailable,
+      uint64_t allocationSize) = 0;
+  virtual std::pair<uint64_t, uint64_t> getReadBufferSettings() = 0;
+
+  /* TODO
+  template <class H>
+  virtual void addHandlerBefore(H&&) {}
+  template <class H>
+  virtual void addHandlerAfter(H&&) {}
+  template <class H>
+  virtual void replaceHandler(H&&) {}
+  virtual void removeHandler() {}
+  */
+};
+
+class PipelineContext {
+ public:
+  virtual ~PipelineContext() {}
+
+  virtual void attachTransport() = 0;
+  virtual void detachTransport() = 0;
+
+  void link(PipelineContext* other) {
+    setNextIn(other);
+    other->setNextOut(this);
+  }
+
+ protected:
+  virtual void setNextIn(PipelineContext* ctx) = 0;
+  virtual void setNextOut(PipelineContext* ctx) = 0;
+};
+
+template <class In>
+class InboundChannelHandlerContext {
+ public:
+  virtual ~InboundChannelHandlerContext() {}
+  virtual void read(In msg) = 0;
+  virtual void readEOF() = 0;
+  virtual void readException(exception_wrapper e) = 0;
+};
+
+template <class Out>
+class OutboundChannelHandlerContext {
+ public:
+  virtual ~OutboundChannelHandlerContext() {}
+  virtual Future<void> write(Out msg) = 0;
+  virtual Future<void> close() = 0;
+};
+
+template <class P, class H>
+class ContextImpl : public ChannelHandlerContext<typename H::rout,
+                                                 typename H::wout>,
+                    public InboundChannelHandlerContext<typename H::rin>,
+                    public OutboundChannelHandlerContext<typename H::win>,
+                    public PipelineContext {
+ public:
+  typedef typename H::rin Rin;
+  typedef typename H::rout Rout;
+  typedef typename H::win Win;
+  typedef typename H::wout Wout;
+
+  template <class HandlerArg>
+  explicit ContextImpl(P* pipeline, HandlerArg&& handlerArg)
+    : pipeline_(pipeline),
+      handler_(std::forward<HandlerArg>(handlerArg)) {
+    handler_.attachPipeline(this);
+  }
+
+  ~ContextImpl() {
+    handler_.detachPipeline(this);
+  }
+
+  H* getHandler() {
+    return &handler_;
+  }
+
+  // PipelineContext overrides
+  void setNextIn(PipelineContext* ctx) override {
+    auto nextIn = dynamic_cast<InboundChannelHandlerContext<Rout>*>(ctx);
+    if (nextIn) {
+      nextIn_ = nextIn;
+    } else {
+      throw std::invalid_argument("wrong type in setNextIn");
+    }
+  }
+
+  void setNextOut(PipelineContext* ctx) override {
+    auto nextOut = dynamic_cast<OutboundChannelHandlerContext<Wout>*>(ctx);
+    if (nextOut) {
+      nextOut_ = nextOut;
+    } else {
+      throw std::invalid_argument("wrong type in setNextOut");
+    }
+  }
+
+  void attachTransport() override {
+    typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+    handler_.attachTransport(this);
+  }
+
+  void detachTransport() override {
+    typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+    handler_.detachTransport(this);
+  }
+
+  // ChannelHandlerContext overrides
+  void fireRead(Rout msg) override {
+    typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+    if (nextIn_) {
+      nextIn_->read(std::forward<Rout>(msg));
+    } else {
+      LOG(WARNING) << "read reached end of pipeline";
+    }
+  }
+
+  void fireReadEOF() override {
+    typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+    if (nextIn_) {
+      nextIn_->readEOF();
+    } else {
+      LOG(WARNING) << "readEOF reached end of pipeline";
+    }
+  }
+
+  void fireReadException(exception_wrapper e) override {
+    typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+    if (nextIn_) {
+      nextIn_->readException(std::move(e));
+    } else {
+      LOG(WARNING) << "readException reached end of pipeline";
+    }
+  }
+
+  Future<void> fireWrite(Wout msg) override {
+    typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+    if (nextOut_) {
+      return nextOut_->write(std::forward<Wout>(msg));
+    } else {
+      LOG(WARNING) << "write reached end of pipeline";
+      return makeFuture();
+    }
+  }
+
+  Future<void> fireClose() override {
+    typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+    if (nextOut_) {
+      return nextOut_->close();
+    } else {
+      LOG(WARNING) << "close reached end of pipeline";
+      return makeFuture();
+    }
+  }
+
+  std::shared_ptr<AsyncTransport> getTransport() override {
+    return pipeline_->getTransport();
+  }
+
+  void setWriteFlags(WriteFlags flags) override {
+    pipeline_->setWriteFlags(flags);
+  }
+
+  WriteFlags getWriteFlags() override {
+    return pipeline_->getWriteFlags();
+  }
+
+  void setReadBufferSettings(
+      uint64_t minAvailable,
+      uint64_t allocationSize) override {
+    pipeline_->setReadBufferSettings(minAvailable, allocationSize);
+  }
+
+  std::pair<uint64_t, uint64_t> getReadBufferSettings() override {
+    return pipeline_->getReadBufferSettings();
+  }
+
+  // InboundChannelHandlerContext overrides
+  void read(Rin msg) override {
+    typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+    handler_.read(this, std::forward<Rin>(msg));
+  }
+
+  void readEOF() override {
+    typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+    handler_.readEOF(this);
+  }
+
+  void readException(exception_wrapper e) override {
+    typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+    handler_.readException(this, std::move(e));
+  }
+
+  // OutboundChannelHandlerContext overrides
+  Future<void> write(Win msg) override {
+    typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+    return handler_.write(this, std::forward<Win>(msg));
+  }
+
+  Future<void> close() override {
+    typename P::DestructorGuard dg(static_cast<DelayedDestruction*>(pipeline_));
+    return handler_.close(this);
+  }
+
+ private:
+  P* pipeline_;
+  H handler_;
+  InboundChannelHandlerContext<Rout>* nextIn_{nullptr};
+  OutboundChannelHandlerContext<Wout>* nextOut_{nullptr};
+};
+
+}}
index de80a856cf0fac290362a23fb41a71a641741f03..7af4adb994cc528bb89968c0b739d174eaa79a12 100644 (file)
 
 #pragma once
 
+#include <folly/experimental/wangle/channel/ChannelHandlerContext.h>
 #include <folly/wangle/Future.h>
+#include <folly/io/async/AsyncTransport.h>
 #include <folly/io/async/DelayedDestruction.h>
 #include <folly/ExceptionWrapper.h>
+#include <folly/Memory.h>
 #include <glog/logging.h>
-#include <thrift/lib/cpp/async/TAsyncTransport.h>
 
 namespace folly { namespace wangle {
 
-template <class In, class Out>
-class ChannelHandlerContext {
- public:
-  virtual ~ChannelHandlerContext() {}
-
-  virtual void fireRead(In msg) = 0;
-  virtual void fireReadEOF() = 0;
-  virtual void fireReadException(exception_wrapper e) = 0;
-
-  virtual Future<void> fireWrite(Out msg) = 0;
-  virtual Future<void> fireClose() = 0;
-
-  virtual std::shared_ptr<AsyncTransport> getTransport() = 0;
-
-  virtual void setWriteFlags(WriteFlags flags) = 0;
-  virtual WriteFlags getWriteFlags() = 0;
-
-  virtual void setReadBufferSettings(
-      uint64_t minAvailable,
-      uint64_t allocationSize) = 0;
-  virtual std::pair<uint64_t, uint64_t> getReadBufferSettings() = 0;
-};
+template <class R, class W, class... Handlers>
+class ChannelPipeline;
 
-template <class Out>
-class OutboundChannelHandlerContext {
+template <class R, class W>
+class ChannelPipeline<R, W> : public DelayedDestruction {
  public:
-  virtual ~OutboundChannelHandlerContext() {}
-  virtual Future<void> write(Out msg) = 0;
-  virtual Future<void> close() = 0;
-};
+  ChannelPipeline() {}
+  ~ChannelPipeline() {}
 
-template <class... Handlers>
-class ChannelPipeline;
+  std::shared_ptr<AsyncTransport> getTransport() {
+    return transport_;
+  }
 
-template <>
-class ChannelPipeline<> : public DelayedDestruction {
- public:
   void setWriteFlags(WriteFlags flags) {
     writeFlags_ = flags;
   }
@@ -77,41 +55,87 @@ class ChannelPipeline<> : public DelayedDestruction {
     return readBufferSettings_;
   }
 
- protected:
-  static const bool is_end{true};
-  typedef void LastHandler;
-  typedef void OutboundContext;
-
-  std::shared_ptr<AsyncTransport> transport_;
-  WriteFlags writeFlags_{WriteFlags::NONE};
-  std::pair<uint64_t, uint64_t> readBufferSettings_{2048, 2048};
-
-  ~ChannelPipeline() {}
-
-  template <class T>
-  void read(T&& msg) {
-    LOG(FATAL) << "impossibru";
+  void read(R msg) {
+    front_->read(std::forward<R>(msg));
   }
 
   void readEOF() {
-    LOG(FATAL) << "impossibru";
+    front_->readEOF();
   }
 
   void readException(exception_wrapper e) {
-    LOG(FATAL) << "impossibru";
+    front_->readException(std::move(e));
   }
 
-  template <class T>
-  Future<void> write(T&& msg) {
-    LOG(FATAL) << "impossibru";
-    return makeFuture();
+  Future<void> write(W msg) {
+    return back_->write(std::forward<W>(msg));
   }
 
   Future<void> close() {
-    LOG(FATAL) << "impossibru";
-    return makeFuture();
+    return back_->close();
   }
 
+  template <class H>
+  ChannelPipeline& addBack(H&& handler) {
+    ctxs_.push_back(folly::make_unique<ContextImpl<ChannelPipeline, H>>(
+        this, std::forward<H>(handler)));
+    return *this;
+  }
+
+  template <class H>
+  ChannelPipeline& addFront(H&& handler) {
+    ctxs_.insert(0, folly::make_unique<ContextImpl<ChannelPipeline, H>>(
+        this, std::forward<H>(handler)));
+    return *this;
+  }
+
+  template <class H>
+  H* getHandler(int i) {
+    auto ctx = dynamic_cast<ContextImpl<ChannelPipeline, H>*>(ctxs_[i].get());
+    CHECK(ctx);
+    return ctx->getHandler();
+  }
+
+  void finalize() {
+    finalizeHelper();
+    InboundChannelHandlerContext<R>* front;
+    front_ = dynamic_cast<InboundChannelHandlerContext<R>*>(
+        ctxs_.front().get());
+    if (!front_) {
+      throw std::invalid_argument("wrong type for first handler");
+    }
+  }
+
+ protected:
+  explicit ChannelPipeline(bool shouldFinalize) {
+    CHECK(!shouldFinalize);
+  }
+
+  void finalizeHelper() {
+    if (ctxs_.empty()) {
+      return;
+    }
+
+    for (int i = 0; i < ctxs_.size() - 1; i++) {
+      ctxs_[i]->link(ctxs_[i+1].get());
+    }
+
+    back_ = dynamic_cast<OutboundChannelHandlerContext<W>*>(ctxs_.back().get());
+    if (!back_) {
+      throw std::invalid_argument("wrong type for last handler");
+    }
+  }
+
+  PipelineContext* getLocalFront() {
+    return ctxs_.empty() ? nullptr : ctxs_.front().get();
+  }
+
+  static const bool is_end{true};
+
+  std::shared_ptr<AsyncTransport> transport_;
+  WriteFlags writeFlags_{WriteFlags::NONE};
+  std::pair<uint64_t, uint64_t> readBufferSettings_{2048, 2048};
+
   void attachPipeline() {}
 
   void attachTransport(
@@ -123,234 +147,172 @@ class ChannelPipeline<> : public DelayedDestruction {
     transport_ = nullptr;
   }
 
-  template <class T>
-  void setOutboundContext(T ctx) {}
+  OutboundChannelHandlerContext<W>* back_{nullptr};
 
-  template <class H>
-  H* getHandler(size_t i) {
-    LOG(FATAL) << "impossibru";
-  }
+ private:
+  InboundChannelHandlerContext<R>* front_{nullptr};
+  std::vector<std::unique_ptr<PipelineContext>> ctxs_;
 };
 
-template <class Handler, class... Handlers>
-class ChannelPipeline<Handler, Handlers...>
-  : public ChannelPipeline<Handlers...> {
+template <class R, class W, class Handler, class... Handlers>
+class ChannelPipeline<R, W, Handler, Handlers...>
+  : public ChannelPipeline<R, W, Handlers...> {
  protected:
-  typedef typename std::conditional<
-      ChannelPipeline<Handlers...>::is_end,
-      Handler,
-      typename ChannelPipeline<Handlers...>::LastHandler>::type
-    LastHandler;
-
- public:
   template <class HandlerArg, class... HandlersArgs>
-  ChannelPipeline(HandlerArg&& handlerArg, HandlersArgs&&... handlersArgs)
-    : ChannelPipeline<Handlers...>(std::forward<HandlersArgs>(handlersArgs)...),
-      handler_(std::forward<HandlerArg>(handlerArg)),
-      ctx_(this) {
-    handler_.attachPipeline(&ctx_);
-    ChannelPipeline<Handlers...>::setOutboundContext(&ctx_);
+  ChannelPipeline(
+      bool shouldFinalize,
+      HandlerArg&& handlerArg,
+      HandlersArgs&&... handlersArgs)
+    : ChannelPipeline<R, W, Handlers...>(
+          false,
+          std::forward<HandlersArgs>(handlersArgs)...),
+          ctx_(this, std::forward<HandlerArg>(handlerArg)) {
+    if (shouldFinalize) {
+      finalize();
+    }
   }
+ public:
+  template <class... HandlersArgs>
+  explicit ChannelPipeline(HandlersArgs&&... handlersArgs)
+    : ChannelPipeline(true, std::forward<HandlersArgs>(handlersArgs)...) {}
 
   ~ChannelPipeline() {}
 
-  void destroy() override {
-    handler_.detachPipeline(&ctx_);
-  }
+  void destroy() override { }
 
-  void read(typename Handler::rin msg) {
-    ChannelPipeline<>::DestructorGuard dg(
+  void read(R msg) {
+    typename ChannelPipeline<R, W>::DestructorGuard dg(
         static_cast<DelayedDestruction*>(this));
-    handler_.read(&ctx_, std::forward<typename Handler::rin>(msg));
+    front_->read(std::forward<R>(msg));
   }
 
   void readEOF() {
-    ChannelPipeline<>::DestructorGuard dg(
+    typename ChannelPipeline<R, W>::DestructorGuard dg(
         static_cast<DelayedDestruction*>(this));
-    handler_.readEOF(&ctx_);
+    front_->readEOF();
   }
 
   void readException(exception_wrapper e) {
-    ChannelPipeline<>::DestructorGuard dg(
+    typename ChannelPipeline<R, W>::DestructorGuard dg(
         static_cast<DelayedDestruction*>(this));
-    handler_.readException(&ctx_, std::move(e));
+    front_->readEOF(std::move(e));
   }
 
-  Future<void> write(typename LastHandler::win msg) {
-    ChannelPipeline<>::DestructorGuard dg(
+  Future<void> write(W msg) {
+    typename ChannelPipeline<R, W>::DestructorGuard dg(
         static_cast<DelayedDestruction*>(this));
-    return ChannelPipeline<LastHandler>::writeHere(
-        std::forward<typename LastHandler::win>(msg));
+    return back_->write(std::forward<W>(msg));
   }
 
   Future<void> close() {
-    ChannelPipeline<>::DestructorGuard dg(
+    typename ChannelPipeline<R, W>::DestructorGuard dg(
         static_cast<DelayedDestruction*>(this));
-    return ChannelPipeline<LastHandler>::closeHere();
+    return back_->close();
   }
 
   void attachTransport(
       std::shared_ptr<AsyncTransport> transport) {
-    ChannelPipeline<>::DestructorGuard dg(
+    typename ChannelPipeline<R, W>::DestructorGuard dg(
         static_cast<DelayedDestruction*>(this));
-    CHECK(!ChannelPipeline<>::transport_);
-    ChannelPipeline<Handlers...>::attachTransport(std::move(transport));
-    handler_.attachTransport(&ctx_);
+    CHECK((!ChannelPipeline<R, W>::transport_));
+    ChannelPipeline<R, W, Handlers...>::attachTransport(std::move(transport));
+    forEachCtx([&](PipelineContext* ctx){
+      ctx->attachTransport();
+    });
   }
 
   void detachTransport() {
-    ChannelPipeline<>::DestructorGuard dg(
+    typename ChannelPipeline<R, W>::DestructorGuard dg(
         static_cast<DelayedDestruction*>(this));
-    ChannelPipeline<Handlers...>::detachTransport();
-    handler_.detachTransport(&ctx_);
+    ChannelPipeline<R, W, Handlers...>::detachTransport();
+    forEachCtx([&](PipelineContext* ctx){
+      ctx->detachTransport();
+    });
   }
 
   std::shared_ptr<AsyncTransport> getTransport() {
-    return ChannelPipeline<>::transport_;
+    return ChannelPipeline<R, W>::transport_;
   }
 
   template <class H>
-  H* getHandler(size_t i) {
-    if (i == 0) {
-      auto ptr = dynamic_cast<H*>(&handler_);
-      CHECK(ptr);
-      return ptr;
-    } else {
-      return ChannelPipeline<Handlers...>::template getHandler<H>(i-1);
-    }
-  }
-
- protected:
-  static const bool is_end{false};
-
-  typedef OutboundChannelHandlerContext<typename Handler::wout> OutboundContext;
-
-  void setOutboundContext(OutboundContext* ctx) {
-    outboundCtx_ = ctx;
-  }
-
-  Future<void> writeHere(typename Handler::win msg) {
-    return handler_.write(&ctx_, std::forward<typename Handler::win>(msg));
+  ChannelPipeline& addBack(H&& handler) {
+    ChannelPipeline<R, W>::addBack(std::move(handler));
+    return *this;
   }
 
-  Future<void> closeHere() {
-    return handler_.close(&ctx_);
+  template <class H>
+  ChannelPipeline& addFront(H&& handler) {
+    ctxs_.insert(0, folly::make_unique<ContextImpl<ChannelPipeline, H>>(
+        this, std::move(handler)));
+    return *this;
   }
 
- private:
-  class Context
-    : public ChannelHandlerContext<typename Handler::rout,
-                                   typename Handler::wout>,
-      public OutboundChannelHandlerContext<typename Handler::win> {
-   public:
-    explicit Context(ChannelPipeline* pipeline) : pipeline_(pipeline) {}
-    ChannelPipeline* pipeline_;
-
-    void fireRead(typename Handler::rout msg) override {
-      ChannelPipeline<>::DestructorGuard dg(pipeline_);
-      pipeline_->fireRead(std::forward<typename Handler::rout>(msg));
-    }
-
-    void fireReadEOF() override {
-      ChannelPipeline<>::DestructorGuard dg(pipeline_);
-      return pipeline_->fireReadEOF();
-    }
-
-    void fireReadException(exception_wrapper e) override {
-      ChannelPipeline<>::DestructorGuard dg(pipeline_);
-      return pipeline_->fireReadException(std::move(e));
-    }
-
-    Future<void> fireWrite(typename Handler::wout msg) override {
-      ChannelPipeline<>::DestructorGuard dg(pipeline_);
-      return pipeline_->fireWrite(std::forward<typename Handler::wout>(msg));
-    }
-
-    Future<void> write(typename Handler::win msg) override {
-      ChannelPipeline<>::DestructorGuard dg(pipeline_);
-      return pipeline_->writeHere(std::forward<typename Handler::win>(msg));
-    }
-
-    Future<void> fireClose() override {
-      ChannelPipeline<>::DestructorGuard dg(pipeline_);
-      return pipeline_->fireClose();
-    }
-
-    Future<void> close() override {
-      ChannelPipeline<>::DestructorGuard dg(pipeline_);
-      return pipeline_->closeHere();
-    }
-
-    std::shared_ptr<AsyncTransport> getTransport() override {
-      return pipeline_->transport_;
-    }
-
-    void setWriteFlags(WriteFlags flags) override {
-      pipeline_->setWriteFlags(flags);
-    }
-
-    WriteFlags getWriteFlags() override {
-      return pipeline_->getWriteFlags();
+  template <class H>
+  H* getHandler(size_t i) {
+    if (i > ctxs_.size()) {
+      return ChannelPipeline<R, W, Handlers...>::template getHandler<H>(
+          i - (ctxs_.size() + 1));
+    } else {
+      auto pctx = (i == ctxs_.size()) ? &ctx_ : ctxs_[i].get();
+      auto ctx = dynamic_cast<ContextImpl<ChannelPipeline, H>*>(pctx);
+      return ctx->getHandler();
     }
+  }
 
-    void setReadBufferSettings(
-        uint64_t minAvailable,
-        uint64_t allocationSize) override {
-      pipeline_->setReadBufferSettings(minAvailable, allocationSize);
+  void finalize() {
+    finalizeHelper();
+    auto ctx = ctxs_.empty() ? &ctx_ : ctxs_.front().get();
+    front_ = dynamic_cast<InboundChannelHandlerContext<R>*>(ctx);
+    if (!front_) {
+      throw std::invalid_argument("wrong type for first handler");
     }
+  }
 
-    std::pair<uint64_t, uint64_t> getReadBufferSettings() override {
-      return pipeline_->getReadBufferSettings();
+ protected:
+  void finalizeHelper() {
+    ChannelPipeline<R, W, Handlers...>::finalizeHelper();
+    back_ = ChannelPipeline<R, W, Handlers...>::back_;
+    if (!back_) {
+      auto is_end = ChannelPipeline<R, W, Handlers...>::is_end;
+      CHECK(is_end);
+      back_ = dynamic_cast<OutboundChannelHandlerContext<W>*>(&ctx_);
+      if (!back_) {
+        throw std::invalid_argument("wrong type for last handler");
+      }
     }
-  };
 
-  void fireRead(typename Handler::rout msg) {
-    if (!ChannelPipeline<Handlers...>::is_end) {
-      ChannelPipeline<Handlers...>::read(
-          std::forward<typename Handler::rout>(msg));
-    } else {
-      LOG(WARNING) << "read() reached end of pipeline";
+    if (!ctxs_.empty()) {
+      for (int i = 0; i < ctxs_.size() - 1; i++) {
+        ctxs_[i]->link(ctxs_[i+1].get());
+      }
+      ctxs_.back()->link(&ctx_);
     }
-  }
 
-  void fireReadEOF() {
-    if (!ChannelPipeline<Handlers...>::is_end) {
-      ChannelPipeline<Handlers...>::readEOF();
-    } else {
-      LOG(WARNING) << "readEOF() reached end of pipeline";
+    auto nextFront = ChannelPipeline<R, W, Handlers...>::getLocalFront();
+    if (nextFront) {
+      ctx_.link(nextFront);
     }
   }
 
-  void fireReadException(exception_wrapper e) {
-    if (!ChannelPipeline<Handlers...>::is_end) {
-      ChannelPipeline<Handlers...>::readException(std::move(e));
-    } else {
-      LOG(WARNING) << "readException() reached end of pipeline";
-    }
+  PipelineContext* getLocalFront() {
+    return ctxs_.empty() ? &ctx_ : ctxs_.front().get();
   }
 
-  Future<void> fireWrite(typename Handler::wout msg) {
-    if (outboundCtx_) {
-      return outboundCtx_->write(std::forward<typename Handler::wout>(msg));
-    } else {
-      LOG(WARNING) << "write() reached end of pipeline";
-      return makeFuture();
-    }
-  }
+  static const bool is_end{false};
+  InboundChannelHandlerContext<R>* front_{nullptr};
+  OutboundChannelHandlerContext<W>* back_{nullptr};
 
-  Future<void> fireClose() {
-    if (outboundCtx_) {
-      return outboundCtx_->close();
-    } else {
-      LOG(WARNING) << "close() reached end of pipeline";
-      return makeFuture();
+ private:
+  template <class F>
+  void forEachCtx(const F& func) {
+    for (auto& ctx : ctxs_) {
+      func(ctx.get());
     }
+    func(&ctx_);
   }
 
-  friend class Context;
-  Handler handler_;
-  Context ctx_;
-  OutboundContext* outboundCtx_{nullptr};
+  ContextImpl<ChannelPipeline, Handler> ctx_;
+  std::vector<std::unique_ptr<PipelineContext>> ctxs_;
 };
 
 }}
index e01515837a00f13f5fa02b8923dd2cafd815b730..6b7ec89744ad63a6eb1a72941ed6838ebc520cf4 100644 (file)
@@ -17,6 +17,7 @@
 #include <folly/experimental/wangle/channel/ChannelHandler.h>
 #include <folly/experimental/wangle/channel/ChannelPipeline.h>
 #include <folly/io/IOBufQueue.h>
+#include <folly/Memory.h>
 #include <folly/Conv.h>
 #include <gtest/gtest.h>
 
@@ -63,7 +64,7 @@ class EchoService : public ChannelHandlerAdapter<std::string> {
 };
 
 TEST(ChannelTest, PlzCompile) {
-  ChannelPipeline<
+  ChannelPipeline<IOBuf, IOBuf,
     BytesPassthrough,
     BytesPassthrough,
     // If this were useful it wouldn't be that hard
@@ -71,18 +72,34 @@ TEST(ChannelTest, PlzCompile) {
     BytesPassthrough>
   pipeline(BytesPassthrough(), BytesPassthrough(), BytesPassthrough);
 
-  ChannelPipeline<
+  ChannelPipeline<int, std::string,
     ChannelHandlerPtr<ToString>,
     KittyPrepender,
-    KittyPrepender,
-    EchoService>
+    KittyPrepender>
   kittyPipeline(
       std::make_shared<ToString>(),
       KittyPrepender{},
-      KittyPrepender{},
-      EchoService{});
+      KittyPrepender{});
+  kittyPipeline.addBack(KittyPrepender{});
+  kittyPipeline.addBack(EchoService{});
+  kittyPipeline.finalize();
   kittyPipeline.read(5);
 
   auto handler = kittyPipeline.getHandler<KittyPrepender>(2);
   CHECK(handler);
+
+  auto p = folly::make_unique<int>(42);
+  folly::Optional<std::unique_ptr<int>> foo{std::move(p)};
+}
+
+TEST(ChannelTest, PlzCompile2) {
+  EchoService echoService;
+  ChannelPipeline<int, std::string> pipeline;
+  pipeline
+    .addBack(ToString())
+    .addBack(KittyPrepender())
+    .addBack(KittyPrepender())
+    .addBack(ChannelHandlerPtr<EchoService, false>(&echoService))
+    .finalize();
+  pipeline.read(42);
 }