move wangle/futures to futures
[folly.git] / folly / wangle / channel / ChannelPipeline.h
1 /*
2  * Copyright 2014 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #pragma once
18
19 #include <folly/wangle/channel/ChannelHandlerContext.h>
20 #include <folly/futures/Future.h>
21 #include <folly/io/async/AsyncTransport.h>
22 #include <folly/io/async/DelayedDestruction.h>
23 #include <folly/ExceptionWrapper.h>
24 #include <folly/Memory.h>
25 #include <glog/logging.h>
26
27 namespace folly { namespace wangle {
28
29 /*
30  * R is the inbound type, i.e. inbound calls start with pipeline.read(R)
31  * W is the outbound type, i.e. outbound calls start with pipeline.write(W)
32  */
33 template <class R, class W, class... Handlers>
34 class ChannelPipeline;
35
36 template <class R, class W>
37 class ChannelPipeline<R, W> : public DelayedDestruction {
38  public:
39   ChannelPipeline() {}
40   ~ChannelPipeline() {}
41
42   std::shared_ptr<AsyncTransport> getTransport() {
43     return transport_;
44   }
45
46   void setWriteFlags(WriteFlags flags) {
47     writeFlags_ = flags;
48   }
49
50   WriteFlags getWriteFlags() {
51     return writeFlags_;
52   }
53
54   void setReadBufferSettings(uint64_t minAvailable, uint64_t allocationSize) {
55     readBufferSettings_ = std::make_pair(minAvailable, allocationSize);
56   }
57
58   std::pair<uint64_t, uint64_t> getReadBufferSettings() {
59     return readBufferSettings_;
60   }
61
62   void read(R msg) {
63     front_->read(std::forward<R>(msg));
64   }
65
66   void readEOF() {
67     front_->readEOF();
68   }
69
70   void readException(exception_wrapper e) {
71     front_->readException(std::move(e));
72   }
73
74   Future<void> write(W msg) {
75     return back_->write(std::forward<W>(msg));
76   }
77
78   Future<void> close() {
79     return back_->close();
80   }
81
82   template <class H>
83   ChannelPipeline& addBack(H&& handler) {
84     ctxs_.push_back(folly::make_unique<ContextImpl<ChannelPipeline, H>>(
85         this, std::forward<H>(handler)));
86     return *this;
87   }
88
89   template <class H>
90   ChannelPipeline& addFront(H&& handler) {
91     ctxs_.insert(
92         ctxs_.begin(),
93         folly::make_unique<ContextImpl<ChannelPipeline, H>>(
94             this,
95             std::forward<H>(handler)));
96     return *this;
97   }
98
99   template <class H>
100   H* getHandler(int i) {
101     auto ctx = dynamic_cast<ContextImpl<ChannelPipeline, H>*>(ctxs_[i].get());
102     CHECK(ctx);
103     return ctx->getHandler();
104   }
105
106   void finalize() {
107     finalizeHelper();
108     InboundChannelHandlerContext<R>* front;
109     front_ = dynamic_cast<InboundChannelHandlerContext<R>*>(
110         ctxs_.front().get());
111     if (!front_) {
112       throw std::invalid_argument("wrong type for first handler");
113     }
114   }
115
116  protected:
117   explicit ChannelPipeline(bool shouldFinalize) {
118     CHECK(!shouldFinalize);
119   }
120
121   void finalizeHelper() {
122     if (ctxs_.empty()) {
123       return;
124     }
125
126     for (size_t i = 0; i < ctxs_.size() - 1; i++) {
127       ctxs_[i]->link(ctxs_[i+1].get());
128     }
129
130     back_ = dynamic_cast<OutboundChannelHandlerContext<W>*>(ctxs_.back().get());
131     if (!back_) {
132       throw std::invalid_argument("wrong type for last handler");
133     }
134   }
135
136   PipelineContext* getLocalFront() {
137     return ctxs_.empty() ? nullptr : ctxs_.front().get();
138   }
139
140   static const bool is_end{true};
141
142   std::shared_ptr<AsyncTransport> transport_;
143   WriteFlags writeFlags_{WriteFlags::NONE};
144   std::pair<uint64_t, uint64_t> readBufferSettings_{2048, 2048};
145
146   void attachPipeline() {}
147
148   void attachTransport(
149       std::shared_ptr<AsyncTransport> transport) {
150     transport_ = std::move(transport);
151   }
152
153   void detachTransport() {
154     transport_ = nullptr;
155   }
156
157   OutboundChannelHandlerContext<W>* back_{nullptr};
158
159  private:
160   InboundChannelHandlerContext<R>* front_{nullptr};
161   std::vector<std::unique_ptr<PipelineContext>> ctxs_;
162 };
163
164 template <class R, class W, class Handler, class... Handlers>
165 class ChannelPipeline<R, W, Handler, Handlers...>
166   : public ChannelPipeline<R, W, Handlers...> {
167  protected:
168   template <class HandlerArg, class... HandlersArgs>
169   ChannelPipeline(
170       bool shouldFinalize,
171       HandlerArg&& handlerArg,
172       HandlersArgs&&... handlersArgs)
173     : ChannelPipeline<R, W, Handlers...>(
174           false,
175           std::forward<HandlersArgs>(handlersArgs)...),
176           ctx_(this, std::forward<HandlerArg>(handlerArg)) {
177     if (shouldFinalize) {
178       finalize();
179     }
180   }
181
182  public:
183   template <class... HandlersArgs>
184   explicit ChannelPipeline(HandlersArgs&&... handlersArgs)
185     : ChannelPipeline(true, std::forward<HandlersArgs>(handlersArgs)...) {}
186
187   ~ChannelPipeline() {}
188
189   void destroy() override { }
190
191   void read(R msg) {
192     typename ChannelPipeline<R, W>::DestructorGuard dg(
193         static_cast<DelayedDestruction*>(this));
194     front_->read(std::forward<R>(msg));
195   }
196
197   void readEOF() {
198     typename ChannelPipeline<R, W>::DestructorGuard dg(
199         static_cast<DelayedDestruction*>(this));
200     front_->readEOF();
201   }
202
203   void readException(exception_wrapper e) {
204     typename ChannelPipeline<R, W>::DestructorGuard dg(
205         static_cast<DelayedDestruction*>(this));
206     front_->readException(std::move(e));
207   }
208
209   Future<void> write(W msg) {
210     typename ChannelPipeline<R, W>::DestructorGuard dg(
211         static_cast<DelayedDestruction*>(this));
212     return back_->write(std::forward<W>(msg));
213   }
214
215   Future<void> close() {
216     typename ChannelPipeline<R, W>::DestructorGuard dg(
217         static_cast<DelayedDestruction*>(this));
218     return back_->close();
219   }
220
221   void attachTransport(
222       std::shared_ptr<AsyncTransport> transport) {
223     typename ChannelPipeline<R, W>::DestructorGuard dg(
224         static_cast<DelayedDestruction*>(this));
225     CHECK((!ChannelPipeline<R, W>::transport_));
226     ChannelPipeline<R, W, Handlers...>::attachTransport(std::move(transport));
227     forEachCtx([&](PipelineContext* ctx){
228       ctx->attachTransport();
229     });
230   }
231
232   void detachTransport() {
233     typename ChannelPipeline<R, W>::DestructorGuard dg(
234         static_cast<DelayedDestruction*>(this));
235     ChannelPipeline<R, W, Handlers...>::detachTransport();
236     forEachCtx([&](PipelineContext* ctx){
237       ctx->detachTransport();
238     });
239   }
240
241   std::shared_ptr<AsyncTransport> getTransport() {
242     return ChannelPipeline<R, W>::transport_;
243   }
244
245   template <class H>
246   ChannelPipeline& addBack(H&& handler) {
247     ChannelPipeline<R, W>::addBack(std::move(handler));
248     return *this;
249   }
250
251   template <class H>
252   ChannelPipeline& addFront(H&& handler) {
253     ctxs_.insert(
254         ctxs_.begin(),
255         folly::make_unique<ContextImpl<ChannelPipeline, H>>(
256             this,
257             std::move(handler)));
258     return *this;
259   }
260
261   template <class H>
262   H* getHandler(size_t i) {
263     if (i > ctxs_.size()) {
264       return ChannelPipeline<R, W, Handlers...>::template getHandler<H>(
265           i - (ctxs_.size() + 1));
266     } else {
267       auto pctx = (i == ctxs_.size()) ? &ctx_ : ctxs_[i].get();
268       auto ctx = dynamic_cast<ContextImpl<ChannelPipeline, H>*>(pctx);
269       return ctx->getHandler();
270     }
271   }
272
273   void finalize() {
274     finalizeHelper();
275     auto ctx = ctxs_.empty() ? &ctx_ : ctxs_.front().get();
276     front_ = dynamic_cast<InboundChannelHandlerContext<R>*>(ctx);
277     if (!front_) {
278       throw std::invalid_argument("wrong type for first handler");
279     }
280   }
281
282  protected:
283   void finalizeHelper() {
284     ChannelPipeline<R, W, Handlers...>::finalizeHelper();
285     back_ = ChannelPipeline<R, W, Handlers...>::back_;
286     if (!back_) {
287       auto is_end = ChannelPipeline<R, W, Handlers...>::is_end;
288       CHECK(is_end);
289       back_ = dynamic_cast<OutboundChannelHandlerContext<W>*>(&ctx_);
290       if (!back_) {
291         throw std::invalid_argument("wrong type for last handler");
292       }
293     }
294
295     if (!ctxs_.empty()) {
296       for (size_t i = 0; i < ctxs_.size() - 1; i++) {
297         ctxs_[i]->link(ctxs_[i+1].get());
298       }
299       ctxs_.back()->link(&ctx_);
300     }
301
302     auto nextFront = ChannelPipeline<R, W, Handlers...>::getLocalFront();
303     if (nextFront) {
304       ctx_.link(nextFront);
305     }
306   }
307
308   PipelineContext* getLocalFront() {
309     return ctxs_.empty() ? &ctx_ : ctxs_.front().get();
310   }
311
312   static const bool is_end{false};
313   InboundChannelHandlerContext<R>* front_{nullptr};
314   OutboundChannelHandlerContext<W>* back_{nullptr};
315
316  private:
317   template <class F>
318   void forEachCtx(const F& func) {
319     for (auto& ctx : ctxs_) {
320       func(ctx.get());
321     }
322     func(&ctx_);
323   }
324
325   ContextImpl<ChannelPipeline, Handler> ctx_;
326   std::vector<std::unique_ptr<PipelineContext>> ctxs_;
327 };
328
329 }}
330
331 namespace folly {
332
333 class AsyncSocket;
334
335 template <typename Pipeline>
336 class PipelineFactory {
337  public:
338   virtual Pipeline* newPipeline(std::shared_ptr<AsyncSocket>) = 0;
339   virtual ~PipelineFactory() {}
340 };
341
342 }