Fix RequestContext held too long issue in EventBase
[folly.git] / folly / io / async / test / AsyncSSLSocketTest2.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <folly/futures/Promise.h>
19 #include <folly/init/Init.h>
20 #include <folly/io/async/AsyncSSLSocket.h>
21 #include <folly/io/async/EventBase.h>
22 #include <folly/io/async/SSLContext.h>
23 #include <folly/io/async/ScopedEventBaseThread.h>
24 #include <folly/portability/GTest.h>
25 #include <folly/portability/PThread.h>
26
27 using std::string;
28 using std::vector;
29 using std::min;
30 using std::cerr;
31 using std::endl;
32 using std::list;
33
34 namespace folly {
35
36 struct EvbAndContext {
37   EvbAndContext() {
38     ctx_.reset(new SSLContext());
39     ctx_->setOptions(SSL_OP_NO_TICKET);
40     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
41   }
42
43   std::shared_ptr<AsyncSSLSocket> createSocket() {
44     return AsyncSSLSocket::newSocket(ctx_, getEventBase());
45   }
46
47   EventBase* getEventBase() {
48     return evb_.getEventBase();
49   }
50
51   void attach(AsyncSSLSocket& socket) {
52     socket.attachEventBase(getEventBase());
53     socket.attachSSLContext(ctx_);
54   }
55
56   folly::ScopedEventBaseThread evb_;
57   std::shared_ptr<SSLContext> ctx_;
58 };
59
60 class AttachDetachClient : public AsyncSocket::ConnectCallback,
61                            public AsyncTransportWrapper::WriteCallback,
62                            public AsyncTransportWrapper::ReadCallback {
63  private:
64   // two threads here - we'll create the socket in one, connect
65   // in the other, and then read/write in the initial one
66   EvbAndContext t1_;
67   EvbAndContext t2_;
68   std::shared_ptr<AsyncSSLSocket> sslSocket_;
69   folly::SocketAddress address_;
70   char buf_[128];
71   char readbuf_[128];
72   uint32_t bytesRead_;
73   // promise to fulfill when done
74   folly::Promise<bool> promise_;
75
76   void detach() {
77     sslSocket_->detachEventBase();
78     sslSocket_->detachSSLContext();
79   }
80
81  public:
82   explicit AttachDetachClient(const folly::SocketAddress& address)
83       : address_(address), bytesRead_(0) {}
84
85   Future<bool> getFuture() {
86     return promise_.getFuture();
87   }
88
89   void connect() {
90     // create in one and then move to another
91     auto t1Evb = t1_.getEventBase();
92     t1Evb->runInEventBaseThread([this] {
93       sslSocket_ = t1_.createSocket();
94       // ensure we can detach and reattach the context before connecting
95       for (int i = 0; i < 1000; ++i) {
96         sslSocket_->detachSSLContext();
97         sslSocket_->attachSSLContext(t1_.ctx_);
98       }
99       // detach from t1 and connect in t2
100       detach();
101       auto t2Evb = t2_.getEventBase();
102       t2Evb->runInEventBaseThread([this] {
103         t2_.attach(*sslSocket_);
104         sslSocket_->connect(this, address_);
105       });
106     });
107   }
108
109   void connectSuccess() noexcept override {
110     auto t2Evb = t2_.getEventBase();
111     EXPECT_TRUE(t2Evb->isInEventBaseThread());
112     cerr << "client SSL socket connected" << endl;
113     for (int i = 0; i < 1000; ++i) {
114       sslSocket_->detachSSLContext();
115       sslSocket_->attachSSLContext(t2_.ctx_);
116     }
117
118     // detach from t2 and then read/write in t1
119     t2Evb->runInEventBaseThread([this] {
120       detach();
121       auto t1Evb = t1_.getEventBase();
122       t1Evb->runInEventBaseThread([this] {
123         t1_.attach(*sslSocket_);
124         sslSocket_->write(this, buf_, sizeof(buf_));
125         sslSocket_->setReadCB(this);
126         memset(readbuf_, 'b', sizeof(readbuf_));
127         bytesRead_ = 0;
128       });
129     });
130   }
131
132   void connectErr(const AsyncSocketException& ex) noexcept override
133   {
134     cerr << "AttachDetachClient::connectError: " << ex.what() << endl;
135     sslSocket_.reset();
136   }
137
138   void writeSuccess() noexcept override {
139     cerr << "client write success" << endl;
140   }
141
142   void writeErr(size_t /* bytesWritten */,
143                 const AsyncSocketException& ex) noexcept override {
144     cerr << "client writeError: " << ex.what() << endl;
145   }
146
147   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
148     *bufReturn = readbuf_ + bytesRead_;
149     *lenReturn = sizeof(readbuf_) - bytesRead_;
150   }
151   void readEOF() noexcept override {
152     cerr << "client readEOF" << endl;
153   }
154
155   void readErr(const AsyncSocketException& ex) noexcept override {
156     cerr << "client readError: " << ex.what() << endl;
157     promise_.setException(ex);
158   }
159
160   void readDataAvailable(size_t len) noexcept override {
161     EXPECT_TRUE(t1_.getEventBase()->isInEventBaseThread());
162     EXPECT_EQ(sslSocket_->getEventBase(), t1_.getEventBase());
163     cerr << "client read data: " << len << endl;
164     bytesRead_ += len;
165     if (len == sizeof(buf_)) {
166       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
167       sslSocket_->closeNow();
168       sslSocket_.reset();
169       promise_.setValue(true);
170     }
171   }
172 };
173
174 /**
175  * Test passing contexts between threads
176  */
177 TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
178   // Start listening on a local port
179   WriteCallbackBase writeCallback;
180   ReadCallback readCallback(&writeCallback);
181   HandshakeCallback handshakeCallback(&readCallback);
182   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
183   TestSSLServer server(&acceptCallback);
184
185   std::shared_ptr<AttachDetachClient> client(
186       new AttachDetachClient(server.getAddress()));
187
188   auto f = client->getFuture();
189   client->connect();
190   EXPECT_TRUE(f.within(std::chrono::seconds(3)).get());
191 }
192
193 class ConnectClient : public AsyncSocket::ConnectCallback {
194  public:
195   ConnectClient() = default;
196
197   Future<bool> getFuture() {
198     return promise_.getFuture();
199   }
200
201   void connect(const folly::SocketAddress& addr) {
202     t1_.getEventBase()->runInEventBaseThread([&] {
203       socket_ = t1_.createSocket();
204       socket_->connect(this, addr);
205     });
206   }
207
208   void connectSuccess() noexcept override {
209     socket_.reset();
210     promise_.setValue(true);
211   }
212
213   void connectErr(const AsyncSocketException& /* ex */) noexcept override {
214     socket_.reset();
215     promise_.setValue(false);
216   }
217
218   void setCtx(std::shared_ptr<SSLContext> ctx) {
219     t1_.ctx_ = ctx;
220   }
221
222  private:
223   EvbAndContext t1_;
224   // promise to fulfill when done with a value of true if connect succeeded
225   folly::Promise<bool> promise_;
226   std::shared_ptr<AsyncSSLSocket> socket_;
227 };
228
229 class NoopReadCallback : public ReadCallbackBase {
230  public:
231   NoopReadCallback() : ReadCallbackBase(nullptr) {
232     state = STATE_SUCCEEDED;
233   }
234
235   void getReadBuffer(void** buf, size_t* lenReturn) override {
236     *buf = &buffer_;
237     *lenReturn = 1;
238   }
239   void readDataAvailable(size_t) noexcept override {}
240
241   uint8_t buffer_{0};
242 };
243
244 TEST(AsyncSSLSocketTest2, TestTLS12DefaultClient) {
245   // Start listening on a local port
246   NoopReadCallback readCallback;
247   HandshakeCallback handshakeCallback(&readCallback);
248   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
249   auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
250   TestSSLServer server(&acceptCallback, ctx);
251   server.loadTestCerts();
252
253   // create a default client
254   auto c1 = std::make_unique<ConnectClient>();
255   auto f1 = c1->getFuture();
256   c1->connect(server.getAddress());
257   EXPECT_TRUE(f1.within(std::chrono::seconds(3)).get());
258 }
259
260 TEST(AsyncSSLSocketTest2, TestTLS12BadClient) {
261   // Start listening on a local port
262   NoopReadCallback readCallback;
263   HandshakeCallback handshakeCallback(
264       &readCallback, HandshakeCallback::EXPECT_ERROR);
265   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
266   auto ctx = std::make_shared<SSLContext>(SSLContext::TLSv1_2);
267   TestSSLServer server(&acceptCallback, ctx);
268   server.loadTestCerts();
269
270   // create a client that doesn't speak TLS 1.2
271   auto c2 = std::make_unique<ConnectClient>();
272   auto clientCtx = std::make_shared<SSLContext>();
273   clientCtx->setOptions(SSL_OP_NO_TLSv1_2);
274   c2->setCtx(clientCtx);
275   auto f2 = c2->getFuture();
276   c2->connect(server.getAddress());
277   EXPECT_FALSE(f2.within(std::chrono::seconds(3)).get());
278 }
279
280 } // namespace folly
281
282 int main(int argc, char *argv[]) {
283 #ifdef SIGPIPE
284   signal(SIGPIPE, SIG_IGN);
285 #endif
286   testing::InitGoogleTest(&argc, argv);
287   folly::init(&argc, &argv);
288   return RUN_ALL_TESTS();
289 }