Adds writer test case for RCU
[folly.git] / folly / io / async / test / BlockingSocket.h
1 /*
2  * Copyright 2015-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/Optional.h>
19 #include <folly/io/async/AsyncSSLSocket.h>
20 #include <folly/io/async/AsyncSocket.h>
21 #include <folly/io/async/SSLContext.h>
22
23 class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
24                        public folly::AsyncTransportWrapper::ReadCallback,
25                        public folly::AsyncTransportWrapper::WriteCallback {
26  public:
27   explicit BlockingSocket(int fd)
28       : sock_(new folly::AsyncSocket(&eventBase_, fd)) {}
29
30   BlockingSocket(
31       folly::SocketAddress address,
32       std::shared_ptr<folly::SSLContext> sslContext)
33       : sock_(
34             sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_)
35                        : new folly::AsyncSocket(&eventBase_)),
36         address_(address) {}
37
38   explicit BlockingSocket(folly::AsyncSocket::UniquePtr socket)
39       : sock_(std::move(socket)) {
40     sock_->attachEventBase(&eventBase_);
41   }
42
43   void enableTFO() {
44     sock_->enableTFO();
45   }
46
47   void setAddress(folly::SocketAddress address) {
48     address_ = address;
49   }
50
51   void open(
52       std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) {
53     sock_->connect(this, address_, timeout.count());
54     eventBase_.loop();
55     if (err_.hasValue()) {
56       throw err_.value();
57     }
58   }
59
60   void close() {
61     sock_->close();
62   }
63   void closeWithReset() {
64     sock_->closeWithReset();
65   }
66
67   int32_t write(uint8_t const* buf, size_t len) {
68     sock_->write(this, buf, len);
69     eventBase_.loop();
70     if (err_.hasValue()) {
71       throw err_.value();
72     }
73     return len;
74   }
75
76   void flush() {}
77
78   int32_t readAll(uint8_t* buf, size_t len) {
79     return readHelper(buf, len, true);
80   }
81
82   int32_t read(uint8_t* buf, size_t len) {
83     return readHelper(buf, len, false);
84   }
85
86   int getSocketFD() const {
87     return sock_->getFd();
88   }
89
90   folly::AsyncSocket* getSocket() {
91     return sock_.get();
92   }
93
94   folly::AsyncSSLSocket* getSSLSocket() {
95     return dynamic_cast<folly::AsyncSSLSocket*>(sock_.get());
96   }
97
98  private:
99   folly::EventBase eventBase_;
100   folly::AsyncSocket::UniquePtr sock_;
101   folly::Optional<folly::AsyncSocketException> err_;
102   uint8_t* readBuf_{nullptr};
103   size_t readLen_{0};
104   folly::SocketAddress address_;
105
106   void connectSuccess() noexcept override {}
107   void connectErr(const folly::AsyncSocketException& ex) noexcept override {
108     err_ = ex;
109   }
110   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
111     *bufReturn = readBuf_;
112     *lenReturn = readLen_;
113   }
114   void readDataAvailable(size_t len) noexcept override {
115     readBuf_ += len;
116     readLen_ -= len;
117     if (readLen_ == 0) {
118       sock_->setReadCB(nullptr);
119     }
120   }
121   void readEOF() noexcept override {}
122   void readErr(const folly::AsyncSocketException& ex) noexcept override {
123     err_ = ex;
124   }
125   void writeSuccess() noexcept override {}
126   void writeErr(
127       size_t /* bytesWritten */,
128       const folly::AsyncSocketException& ex) noexcept override {
129     err_ = ex;
130   }
131
132   int32_t readHelper(uint8_t* buf, size_t len, bool all) {
133     if (!sock_->good()) {
134       return 0;
135     }
136
137     readBuf_ = buf;
138     readLen_ = len;
139     sock_->setReadCB(this);
140     while (!err_ && sock_->good() && readLen_ > 0) {
141       eventBase_.loopOnce();
142       if (!all) {
143         break;
144       }
145     }
146     sock_->setReadCB(nullptr);
147     if (err_.hasValue()) {
148       throw err_.value();
149     }
150     if (all && readLen_ > 0) {
151       throw folly::AsyncSocketException(
152           folly::AsyncSocketException::UNKNOWN, "eof");
153     }
154     return len - readLen_;
155   }
156 };