Add evb change callback to SSL Socket
[folly.git] / folly / io / async / test / AsyncSSLSocketTest2.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/test/AsyncSSLSocketTest.h>
17
18 #include <pthread.h>
19
20 #include <folly/futures/Promise.h>
21 #include <folly/io/async/AsyncSSLSocket.h>
22 #include <folly/io/async/EventBase.h>
23 #include <folly/io/async/SSLContext.h>
24 #include <folly/io/async/ScopedEventBaseThread.h>
25 #include <folly/portability/GTest.h>
26
27 using std::string;
28 using std::vector;
29 using std::min;
30 using std::cerr;
31 using std::endl;
32 using std::list;
33
34 namespace folly {
35
36 struct EvbAndContext {
37   EvbAndContext() {
38     ctx_.reset(new SSLContext());
39     ctx_->setOptions(SSL_OP_NO_TICKET);
40     ctx_->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
41   }
42
43   std::shared_ptr<AsyncSSLSocket> createSocket() {
44     return AsyncSSLSocket::newSocket(ctx_, getEventBase());
45   }
46
47   EventBase* getEventBase() {
48     return evb_.getEventBase();
49   }
50
51   void attach(AsyncSSLSocket& socket) {
52     socket.attachEventBase(getEventBase());
53     socket.attachSSLContext(ctx_);
54   }
55
56   folly::ScopedEventBaseThread evb_;
57   std::shared_ptr<SSLContext> ctx_;
58 };
59
60 class AttachDetachClient : public AsyncSocket::ConnectCallback,
61                            public AsyncTransportWrapper::WriteCallback,
62                            public AsyncTransportWrapper::ReadCallback {
63  private:
64   // two threads here - we'll create the socket in one, connect
65   // in the other, and then read/write in the initial one
66   EvbAndContext t1_;
67   EvbAndContext t2_;
68   std::shared_ptr<AsyncSSLSocket> sslSocket_;
69   folly::SocketAddress address_;
70   char buf_[128];
71   char readbuf_[128];
72   uint32_t bytesRead_;
73   // promise to fulfill when done
74   folly::Promise<bool> promise_;
75
76   void detach() {
77     sslSocket_->detachEventBase();
78     sslSocket_->detachSSLContext();
79   }
80
81  public:
82   explicit AttachDetachClient(const folly::SocketAddress& address)
83       : address_(address), bytesRead_(0) {}
84
85   Future<bool> getFuture() {
86     return promise_.getFuture();
87   }
88
89   void connect() {
90     // create in one and then move to another
91     auto t1Evb = t1_.getEventBase();
92     t1Evb->runInEventBaseThread([this] {
93       sslSocket_ = t1_.createSocket();
94       // ensure we can detach and reattach the context before connecting
95       for (int i = 0; i < 1000; ++i) {
96         sslSocket_->detachSSLContext();
97         sslSocket_->attachSSLContext(t1_.ctx_);
98       }
99       // detach from t1 and connect in t2
100       detach();
101       auto t2Evb = t2_.getEventBase();
102       t2Evb->runInEventBaseThread([this] {
103         t2_.attach(*sslSocket_);
104         sslSocket_->connect(this, address_);
105       });
106     });
107   }
108
109   void connectSuccess() noexcept override {
110     auto t2Evb = t2_.getEventBase();
111     EXPECT_TRUE(t2Evb->isInEventBaseThread());
112     cerr << "client SSL socket connected" << endl;
113     for (int i = 0; i < 1000; ++i) {
114       sslSocket_->detachSSLContext();
115       sslSocket_->attachSSLContext(t2_.ctx_);
116     }
117
118     // detach from t2 and then read/write in t1
119     t2Evb->runInEventBaseThread([this] {
120       detach();
121       auto t1Evb = t1_.getEventBase();
122       t1Evb->runInEventBaseThread([this] {
123         t1_.attach(*sslSocket_);
124         sslSocket_->write(this, buf_, sizeof(buf_));
125         sslSocket_->setReadCB(this);
126         memset(readbuf_, 'b', sizeof(readbuf_));
127         bytesRead_ = 0;
128       });
129     });
130   }
131
132   void connectErr(const AsyncSocketException& ex) noexcept override
133   {
134     cerr << "AttachDetachClient::connectError: " << ex.what() << endl;
135     sslSocket_.reset();
136   }
137
138   void writeSuccess() noexcept override {
139     cerr << "client write success" << endl;
140   }
141
142   void writeErr(size_t /* bytesWritten */,
143                 const AsyncSocketException& ex) noexcept override {
144     cerr << "client writeError: " << ex.what() << endl;
145   }
146
147   void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
148     *bufReturn = readbuf_ + bytesRead_;
149     *lenReturn = sizeof(readbuf_) - bytesRead_;
150   }
151   void readEOF() noexcept override {
152     cerr << "client readEOF" << endl;
153   }
154
155   void readErr(const AsyncSocketException& ex) noexcept override {
156     cerr << "client readError: " << ex.what() << endl;
157     promise_.setException(ex);
158   }
159
160   void readDataAvailable(size_t len) noexcept override {
161     EXPECT_TRUE(t1_.getEventBase()->isInEventBaseThread());
162     EXPECT_EQ(sslSocket_->getEventBase(), t1_.getEventBase());
163     cerr << "client read data: " << len << endl;
164     bytesRead_ += len;
165     if (len == sizeof(buf_)) {
166       EXPECT_EQ(memcmp(buf_, readbuf_, bytesRead_), 0);
167       sslSocket_->closeNow();
168       sslSocket_.reset();
169       promise_.setValue(true);
170     }
171   }
172 };
173
174 /**
175  * Test passing contexts between threads
176  */
177 TEST(AsyncSSLSocketTest2, AttachDetachSSLContext) {
178   // Start listening on a local port
179   WriteCallbackBase writeCallback;
180   ReadCallback readCallback(&writeCallback);
181   HandshakeCallback handshakeCallback(&readCallback);
182   SSLServerAcceptCallbackDelay acceptCallback(&handshakeCallback);
183   TestSSLServer server(&acceptCallback);
184
185   std::shared_ptr<AttachDetachClient> client(
186       new AttachDetachClient(server.getAddress()));
187
188   auto f = client->getFuture();
189   client->connect();
190   EXPECT_TRUE(f.within(std::chrono::seconds(3)).get());
191 }
192
193 }  // folly
194
195 int main(int argc, char *argv[]) {
196 #ifdef SIGPIPE
197   signal(SIGPIPE, SIG_IGN);
198 #endif
199   folly::SSLContext::setSSLLockTypes({
200       {CRYPTO_LOCK_EVP_PKEY, folly::SSLContext::LOCK_NONE},
201       {CRYPTO_LOCK_SSL_SESSION, folly::SSLContext::LOCK_SPINLOCK},
202       {CRYPTO_LOCK_SSL_CTX, folly::SSLContext::LOCK_NONE}});
203   testing::InitGoogleTest(&argc, argv);
204   gflags::ParseCommandLineFlags(&argc, &argv, true);
205   return RUN_ALL_TESTS();
206 }