Future<Unit> wangle fixup
[folly.git] / folly / wangle / channel / HandlerContext-inl.h
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #pragma once
18
19 namespace folly { namespace wangle {
20
21 class PipelineContext {
22  public:
23   virtual ~PipelineContext() = default;
24
25   virtual void attachPipeline() = 0;
26   virtual void detachPipeline() = 0;
27
28   template <class H, class HandlerContext>
29   void attachContext(H* handler, HandlerContext* ctx) {
30     if (++handler->attachCount_ == 1) {
31       handler->ctx_ = ctx;
32     } else {
33       handler->ctx_ = nullptr;
34     }
35   }
36
37   virtual void setNextIn(PipelineContext* ctx) = 0;
38   virtual void setNextOut(PipelineContext* ctx) = 0;
39
40   virtual HandlerDir getDirection() = 0;
41 };
42
43 template <class In>
44 class InboundLink {
45  public:
46   virtual ~InboundLink() = default;
47   virtual void read(In msg) = 0;
48   virtual void readEOF() = 0;
49   virtual void readException(exception_wrapper e) = 0;
50   virtual void transportActive() = 0;
51   virtual void transportInactive() = 0;
52 };
53
54 template <class Out>
55 class OutboundLink {
56  public:
57   virtual ~OutboundLink() = default;
58   virtual Future<Unit> write(Out msg) = 0;
59   virtual Future<Unit> close() = 0;
60 };
61
62 template <class H, class Context>
63 class ContextImplBase : public PipelineContext {
64  public:
65   ~ContextImplBase() = default;
66
67   H* getHandler() {
68     return handler_.get();
69   }
70
71   void initialize(PipelineBase* pipeline, std::shared_ptr<H> handler) {
72     pipeline_ = pipeline;
73     handler_ = std::move(handler);
74   }
75
76   // PipelineContext overrides
77   void attachPipeline() override {
78     if (!attached_) {
79       this->attachContext(handler_.get(), impl_);
80       handler_->attachPipeline(impl_);
81       attached_ = true;
82     }
83   }
84
85   void detachPipeline() override {
86     handler_->detachPipeline(impl_);
87     attached_ = false;
88   }
89
90   void setNextIn(PipelineContext* ctx) override {
91     if (!ctx) {
92       nextIn_ = nullptr;
93       return;
94     }
95     auto nextIn = dynamic_cast<InboundLink<typename H::rout>*>(ctx);
96     if (nextIn) {
97       nextIn_ = nextIn;
98     } else {
99       throw std::invalid_argument("inbound type mismatch");
100     }
101   }
102
103   void setNextOut(PipelineContext* ctx) override {
104     if (!ctx) {
105       nextOut_ = nullptr;
106       return;
107     }
108     auto nextOut = dynamic_cast<OutboundLink<typename H::wout>*>(ctx);
109     if (nextOut) {
110       nextOut_ = nextOut;
111     } else {
112       throw std::invalid_argument("outbound type mismatch");
113     }
114   }
115
116   HandlerDir getDirection() override {
117     return H::dir;
118   }
119
120  protected:
121   Context* impl_;
122   PipelineBase* pipeline_;
123   std::shared_ptr<H> handler_;
124   InboundLink<typename H::rout>* nextIn_{nullptr};
125   OutboundLink<typename H::wout>* nextOut_{nullptr};
126
127  private:
128   bool attached_{false};
129   using DestructorGuard = typename DelayedDestruction::DestructorGuard;
130 };
131
132 template <class H>
133 class ContextImpl
134   : public HandlerContext<typename H::rout,
135                           typename H::wout>,
136     public InboundLink<typename H::rin>,
137     public OutboundLink<typename H::win>,
138     public ContextImplBase<H, HandlerContext<typename H::rout,
139                                              typename H::wout>> {
140  public:
141   typedef typename H::rin Rin;
142   typedef typename H::rout Rout;
143   typedef typename H::win Win;
144   typedef typename H::wout Wout;
145   static const HandlerDir dir = HandlerDir::BOTH;
146
147   explicit ContextImpl(PipelineBase* pipeline, std::shared_ptr<H> handler) {
148     this->impl_ = this;
149     this->initialize(pipeline, std::move(handler));
150   }
151
152   // For StaticPipeline
153   ContextImpl() {
154     this->impl_ = this;
155   }
156
157   ~ContextImpl() = default;
158
159   // HandlerContext overrides
160   void fireRead(Rout msg) override {
161     DestructorGuard dg(this->pipeline_);
162     if (this->nextIn_) {
163       this->nextIn_->read(std::forward<Rout>(msg));
164     } else {
165       LOG(WARNING) << "read reached end of pipeline";
166     }
167   }
168
169   void fireReadEOF() override {
170     DestructorGuard dg(this->pipeline_);
171     if (this->nextIn_) {
172       this->nextIn_->readEOF();
173     } else {
174       LOG(WARNING) << "readEOF reached end of pipeline";
175     }
176   }
177
178   void fireReadException(exception_wrapper e) override {
179     DestructorGuard dg(this->pipeline_);
180     if (this->nextIn_) {
181       this->nextIn_->readException(std::move(e));
182     } else {
183       LOG(WARNING) << "readException reached end of pipeline";
184     }
185   }
186
187   void fireTransportActive() override {
188     DestructorGuard dg(this->pipeline_);
189     if (this->nextIn_) {
190       this->nextIn_->transportActive();
191     }
192   }
193
194   void fireTransportInactive() override {
195     DestructorGuard dg(this->pipeline_);
196     if (this->nextIn_) {
197       this->nextIn_->transportInactive();
198     }
199   }
200
201   Future<Unit> fireWrite(Wout msg) override {
202     DestructorGuard dg(this->pipeline_);
203     if (this->nextOut_) {
204       return this->nextOut_->write(std::forward<Wout>(msg));
205     } else {
206       LOG(WARNING) << "write reached end of pipeline";
207       return makeFuture();
208     }
209   }
210
211   Future<Unit> fireClose() override {
212     DestructorGuard dg(this->pipeline_);
213     if (this->nextOut_) {
214       return this->nextOut_->close();
215     } else {
216       LOG(WARNING) << "close reached end of pipeline";
217       return makeFuture();
218     }
219   }
220
221   PipelineBase* getPipeline() override {
222     return this->pipeline_;
223   }
224
225   void setWriteFlags(WriteFlags flags) override {
226     this->pipeline_->setWriteFlags(flags);
227   }
228
229   WriteFlags getWriteFlags() override {
230     return this->pipeline_->getWriteFlags();
231   }
232
233   void setReadBufferSettings(
234       uint64_t minAvailable,
235       uint64_t allocationSize) override {
236     this->pipeline_->setReadBufferSettings(minAvailable, allocationSize);
237   }
238
239   std::pair<uint64_t, uint64_t> getReadBufferSettings() override {
240     return this->pipeline_->getReadBufferSettings();
241   }
242
243   // InboundLink overrides
244   void read(Rin msg) override {
245     DestructorGuard dg(this->pipeline_);
246     this->handler_->read(this, std::forward<Rin>(msg));
247   }
248
249   void readEOF() override {
250     DestructorGuard dg(this->pipeline_);
251     this->handler_->readEOF(this);
252   }
253
254   void readException(exception_wrapper e) override {
255     DestructorGuard dg(this->pipeline_);
256     this->handler_->readException(this, std::move(e));
257   }
258
259   void transportActive() override {
260     DestructorGuard dg(this->pipeline_);
261     this->handler_->transportActive(this);
262   }
263
264   void transportInactive() override {
265     DestructorGuard dg(this->pipeline_);
266     this->handler_->transportInactive(this);
267   }
268
269   // OutboundLink overrides
270   Future<Unit> write(Win msg) override {
271     DestructorGuard dg(this->pipeline_);
272     return this->handler_->write(this, std::forward<Win>(msg));
273   }
274
275   Future<Unit> close() override {
276     DestructorGuard dg(this->pipeline_);
277     return this->handler_->close(this);
278   }
279
280  private:
281   using DestructorGuard = typename DelayedDestruction::DestructorGuard;
282 };
283
284 template <class H>
285 class InboundContextImpl
286   : public InboundHandlerContext<typename H::rout>,
287     public InboundLink<typename H::rin>,
288     public ContextImplBase<H, InboundHandlerContext<typename H::rout>> {
289  public:
290   typedef typename H::rin Rin;
291   typedef typename H::rout Rout;
292   typedef typename H::win Win;
293   typedef typename H::wout Wout;
294   static const HandlerDir dir = HandlerDir::IN;
295
296   explicit InboundContextImpl(
297       PipelineBase* pipeline,
298       std::shared_ptr<H> handler) {
299     this->impl_ = this;
300     this->initialize(pipeline, std::move(handler));
301   }
302
303   // For StaticPipeline
304   InboundContextImpl() {
305     this->impl_ = this;
306   }
307
308   ~InboundContextImpl() = default;
309
310   // InboundHandlerContext overrides
311   void fireRead(Rout msg) override {
312     DestructorGuard dg(this->pipeline_);
313     if (this->nextIn_) {
314       this->nextIn_->read(std::forward<Rout>(msg));
315     } else {
316       LOG(WARNING) << "read reached end of pipeline";
317     }
318   }
319
320   void fireReadEOF() override {
321     DestructorGuard dg(this->pipeline_);
322     if (this->nextIn_) {
323       this->nextIn_->readEOF();
324     } else {
325       LOG(WARNING) << "readEOF reached end of pipeline";
326     }
327   }
328
329   void fireReadException(exception_wrapper e) override {
330     DestructorGuard dg(this->pipeline_);
331     if (this->nextIn_) {
332       this->nextIn_->readException(std::move(e));
333     } else {
334       LOG(WARNING) << "readException reached end of pipeline";
335     }
336   }
337
338   void fireTransportActive() override {
339     DestructorGuard dg(this->pipeline_);
340     if (this->nextIn_) {
341       this->nextIn_->transportActive();
342     }
343   }
344
345   void fireTransportInactive() override {
346     DestructorGuard dg(this->pipeline_);
347     if (this->nextIn_) {
348       this->nextIn_->transportInactive();
349     }
350   }
351
352   PipelineBase* getPipeline() override {
353     return this->pipeline_;
354   }
355
356   // InboundLink overrides
357   void read(Rin msg) override {
358     DestructorGuard dg(this->pipeline_);
359     this->handler_->read(this, std::forward<Rin>(msg));
360   }
361
362   void readEOF() override {
363     DestructorGuard dg(this->pipeline_);
364     this->handler_->readEOF(this);
365   }
366
367   void readException(exception_wrapper e) override {
368     DestructorGuard dg(this->pipeline_);
369     this->handler_->readException(this, std::move(e));
370   }
371
372   void transportActive() override {
373     DestructorGuard dg(this->pipeline_);
374     this->handler_->transportActive(this);
375   }
376
377   void transportInactive() override {
378     DestructorGuard dg(this->pipeline_);
379     this->handler_->transportInactive(this);
380   }
381
382  private:
383   using DestructorGuard = typename DelayedDestruction::DestructorGuard;
384 };
385
386 template <class H>
387 class OutboundContextImpl
388   : public OutboundHandlerContext<typename H::wout>,
389     public OutboundLink<typename H::win>,
390     public ContextImplBase<H, OutboundHandlerContext<typename H::wout>> {
391  public:
392   typedef typename H::rin Rin;
393   typedef typename H::rout Rout;
394   typedef typename H::win Win;
395   typedef typename H::wout Wout;
396   static const HandlerDir dir = HandlerDir::OUT;
397
398   explicit OutboundContextImpl(
399       PipelineBase* pipeline,
400       std::shared_ptr<H> handler) {
401     this->impl_ = this;
402     this->initialize(pipeline, std::move(handler));
403   }
404
405   // For StaticPipeline
406   OutboundContextImpl() {
407     this->impl_ = this;
408   }
409
410   ~OutboundContextImpl() = default;
411
412   // OutboundHandlerContext overrides
413   Future<Unit> fireWrite(Wout msg) override {
414     DestructorGuard dg(this->pipeline_);
415     if (this->nextOut_) {
416       return this->nextOut_->write(std::forward<Wout>(msg));
417     } else {
418       LOG(WARNING) << "write reached end of pipeline";
419       return makeFuture();
420     }
421   }
422
423   Future<Unit> fireClose() override {
424     DestructorGuard dg(this->pipeline_);
425     if (this->nextOut_) {
426       return this->nextOut_->close();
427     } else {
428       LOG(WARNING) << "close reached end of pipeline";
429       return makeFuture();
430     }
431   }
432
433   PipelineBase* getPipeline() override {
434     return this->pipeline_;
435   }
436
437   // OutboundLink overrides
438   Future<Unit> write(Win msg) override {
439     DestructorGuard dg(this->pipeline_);
440     return this->handler_->write(this, std::forward<Win>(msg));
441   }
442
443   Future<Unit> close() override {
444     DestructorGuard dg(this->pipeline_);
445     return this->handler_->close(this);
446   }
447
448  private:
449   using DestructorGuard = typename DelayedDestruction::DestructorGuard;
450 };
451
452 template <class Handler>
453 struct ContextType {
454   typedef typename std::conditional<
455     Handler::dir == HandlerDir::BOTH,
456     ContextImpl<Handler>,
457     typename std::conditional<
458       Handler::dir == HandlerDir::IN,
459       InboundContextImpl<Handler>,
460       OutboundContextImpl<Handler>
461     >::type>::type
462   type;
463 };
464
465 }} // folly::wangle