Stop abusing errno
[folly.git] / folly / io / async / test / BlockingSocket.h
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17
18 #include <folly/Optional.h>
19 #include <folly/io/async/SSLContext.h>
20 #include <folly/io/async/AsyncSocket.h>
21 #include <folly/io/async/AsyncSSLSocket.h>
22
23 class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
24                        public folly::AsyncTransportWrapper::ReadCallback,
25                        public folly::AsyncTransportWrapper::WriteCallback
26 {
27  public:
28   explicit BlockingSocket(int fd)
29     : sock_(new folly::AsyncSocket(&eventBase_, fd)) {
30   }
31
32   BlockingSocket(folly::SocketAddress address,
33                  std::shared_ptr<folly::SSLContext> sslContext)
34     : sock_(sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_) :
35             new folly::AsyncSocket(&eventBase_)),
36     address_(address) {}
37
38   explicit BlockingSocket(folly::AsyncSocket::UniquePtr socket)
39       : sock_(std::move(socket)) {
40     sock_->attachEventBase(&eventBase_);
41   }
42
43   void open() {
44     sock_->connect(this, address_);
45     eventBase_.loop();
46     if (err_.hasValue()) {
47       throw err_.value();
48     }
49   }
50   void close() {
51     sock_->close();
52   }
53   void closeWithReset() { sock_->closeWithReset(); }
54
55   int32_t write(uint8_t const* buf, size_t len) {
56     sock_->write(this, buf, len);
57     eventBase_.loop();
58     if (err_.hasValue()) {
59       throw err_.value();
60     }
61     return len;
62   }
63
64   void flush() {}
65
66   int32_t readAll(uint8_t *buf, size_t len) {
67     return readHelper(buf, len, true);
68   }
69
70   int32_t read(uint8_t *buf, size_t len) {
71     return readHelper(buf, len, false);
72   }
73
74   int getSocketFD() const {
75     return sock_->getFd();
76   }
77
78  private:
79   folly::EventBase eventBase_;
80   folly::AsyncSocket::UniquePtr sock_;
81   folly::Optional<folly::AsyncSocketException> err_;
82   uint8_t *readBuf_{nullptr};
83   size_t readLen_{0};
84   folly::SocketAddress address_;
85
86   void connectSuccess() noexcept override {}
87   void connectErr(const folly::AsyncSocketException& ex) noexcept override {
88     err_ = ex;
89   }
90   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
91     *bufReturn = readBuf_;
92     *lenReturn = readLen_;
93   }
94   void readDataAvailable(size_t len) noexcept override {
95     readBuf_ += len;
96     readLen_ -= len;
97     if (readLen_ == 0) {
98       sock_->setReadCB(nullptr);
99     }
100   }
101   void readEOF() noexcept override {
102   }
103   void readErr(const folly::AsyncSocketException& ex) noexcept override {
104     err_ = ex;
105   }
106   void writeSuccess() noexcept override {}
107   void writeErr(size_t /* bytesWritten */,
108                 const folly::AsyncSocketException& ex) noexcept override {
109     err_ = ex;
110   }
111
112   int32_t readHelper(uint8_t *buf, size_t len, bool all) {
113     readBuf_ = buf;
114     readLen_ = len;
115     sock_->setReadCB(this);
116     while (!err_ && sock_->good() && readLen_ > 0) {
117       eventBase_.loop();
118       if (!all) {
119         break;
120       }
121     }
122     sock_->setReadCB(nullptr);
123     if (err_.hasValue()) {
124       throw err_.value();
125     }
126     if (all && readLen_ > 0) {
127       throw folly::AsyncSocketException(folly::AsyncSocketException::UNKNOWN,
128                                         "eof");
129     }
130     return len - readLen_;
131   }
132 };