copy wangle back into folly
[folly.git] / folly / wangle / ssl / test / SSLCacheTest.cpp
1 /*
2  *  Copyright (c) 2015, Facebook, Inc.
3  *  All rights reserved.
4  *
5  *  This source code is licensed under the BSD-style license found in the
6  *  LICENSE file in the root directory of this source tree. An additional grant
7  *  of patent rights can be found in the PATENTS file in the same directory.
8  *
9  */
10 #include <folly/Portability.h>
11 #include <folly/io/async/EventBase.h>
12 #include <gflags/gflags.h>
13 #include <iostream>
14 #include <thread>
15 #include <folly/io/async/AsyncSSLSocket.h>
16 #include <folly/io/async/AsyncSocket.h>
17 #include <vector>
18
19 using namespace std;
20 using namespace folly;
21
22 DEFINE_int32(clients, 1, "Number of simulated SSL clients");
23 DEFINE_int32(threads, 1, "Number of threads to spread clients across");
24 DEFINE_int32(requests, 2, "Total number of requests per client");
25 DEFINE_int32(port, 9423, "Server port");
26 DEFINE_bool(sticky, false, "A given client sends all reqs to one "
27             "(random) server");
28 DEFINE_bool(global, false, "All clients in a thread use the same SSL session");
29 DEFINE_bool(handshakes, false, "Force 100% handshakes");
30
31 string f_servers[10];
32 int f_num_servers = 0;
33 int tnum = 0;
34
35 class ClientRunner {
36  public:
37
38   ClientRunner(): reqs(0), hits(0), miss(0), num(tnum++) {}
39   void run();
40
41   int reqs;
42   int hits;
43   int miss;
44   int num;
45 };
46
47 class SSLCacheClient : public AsyncSocket::ConnectCallback,
48                        public AsyncSSLSocket::HandshakeCB
49 {
50 private:
51   EventBase* eventBase_;
52   int currReq_;
53   int serverIdx_;
54   AsyncSocket* socket_;
55   AsyncSSLSocket* sslSocket_;
56   SSL_SESSION* session_;
57   SSL_SESSION **pSess_;
58   std::shared_ptr<SSLContext> ctx_;
59   ClientRunner* cr_;
60
61 public:
62   SSLCacheClient(EventBase* eventBase, SSL_SESSION **pSess, ClientRunner* cr);
63   ~SSLCacheClient() {
64     if (session_ && !FLAGS_global)
65       SSL_SESSION_free(session_);
66     if (socket_ != nullptr) {
67       if (sslSocket_ != nullptr) {
68         sslSocket_->destroy();
69         sslSocket_ = nullptr;
70       }
71       socket_->destroy();
72       socket_ = nullptr;
73     }
74   };
75
76   void start();
77
78   virtual void connectSuccess() noexcept;
79
80   virtual void connectErr(const AsyncSocketException& ex)
81     noexcept ;
82
83   virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept;
84
85   virtual void handshakeErr(
86     AsyncSSLSocket* sock,
87     const AsyncSocketException& ex) noexcept;
88
89 };
90
91 int
92 main(int argc, char* argv[])
93 {
94   gflags::SetUsageMessage(std::string("\n\n"
95 "usage: sslcachetest [options] -c <clients> -t <threads> servers\n"
96 ));
97   gflags::ParseCommandLineFlags(&argc, &argv, true);
98   int reqs = 0;
99   int hits = 0;
100   int miss = 0;
101   struct timeval start;
102   struct timeval end;
103   struct timeval result;
104
105   srand((unsigned int)time(nullptr));
106
107   for (int i = 1; i < argc; i++) {
108     f_servers[f_num_servers++] = argv[i];
109   }
110   if (f_num_servers == 0) {
111     cout << "require at least one server\n";
112     return 1;
113   }
114
115   gettimeofday(&start, nullptr);
116   if (FLAGS_threads == 1) {
117     ClientRunner r;
118     r.run();
119     gettimeofday(&end, nullptr);
120     reqs = r.reqs;
121     hits = r.hits;
122     miss = r.miss;
123   }
124   else {
125     std::vector<ClientRunner> clients;
126     std::vector<std::thread> threads;
127     for (int t = 0; t < FLAGS_threads; t++) {
128       threads.emplace_back([&] {
129           clients[t].run();
130         });
131     }
132     for (auto& thr: threads) {
133       thr.join();
134     }
135     gettimeofday(&end, nullptr);
136
137     for (const auto& client: clients) {
138       reqs += client.reqs;
139       hits += client.hits;
140       miss += client.miss;
141     }
142   }
143
144   timersub(&end, &start, &result);
145
146   cout << "Requests: " << reqs << endl;
147   cout << "Handshakes: " << miss << endl;
148   cout << "Resumes: " << hits << endl;
149   cout << "Runtime(ms): " << result.tv_sec << "." << result.tv_usec / 1000 <<
150     endl;
151
152   cout << "ops/sec: " << (reqs * 1.0) /
153     ((double)result.tv_sec * 1.0 + (double)result.tv_usec / 1000000.0) << endl;
154
155   return 0;
156 }
157
158 void
159 ClientRunner::run()
160 {
161   EventBase eb;
162   std::list<SSLCacheClient *> clients;
163   SSL_SESSION* session = nullptr;
164
165   for (int i = 0; i < FLAGS_clients; i++) {
166     SSLCacheClient* c = new SSLCacheClient(&eb, &session, this);
167     c->start();
168     clients.push_back(c);
169   }
170
171   eb.loop();
172
173   for (auto it = clients.begin(); it != clients.end(); it++) {
174     delete* it;
175   }
176
177   reqs += hits + miss;
178 }
179
180 SSLCacheClient::SSLCacheClient(EventBase* eb,
181                                SSL_SESSION **pSess,
182                                ClientRunner* cr)
183     : eventBase_(eb),
184       currReq_(0),
185       serverIdx_(0),
186       socket_(nullptr),
187       sslSocket_(nullptr),
188       session_(nullptr),
189       pSess_(pSess),
190       cr_(cr)
191 {
192   ctx_.reset(new SSLContext());
193   ctx_->setOptions(SSL_OP_NO_TICKET);
194 }
195
196 void
197 SSLCacheClient::start()
198 {
199   if (currReq_ >= FLAGS_requests) {
200     cout << "+";
201     return;
202   }
203
204   if (currReq_ == 0 || !FLAGS_sticky) {
205     serverIdx_ = rand() % f_num_servers;
206   }
207   if (socket_ != nullptr) {
208     if (sslSocket_ != nullptr) {
209       sslSocket_->destroy();
210       sslSocket_ = nullptr;
211     }
212     socket_->destroy();
213     socket_ = nullptr;
214   }
215   socket_ = new AsyncSocket(eventBase_);
216   socket_->connect(this, f_servers[serverIdx_], (uint16_t)FLAGS_port);
217 }
218
219 void
220 SSLCacheClient::connectSuccess() noexcept
221 {
222   sslSocket_ = new AsyncSSLSocket(ctx_, eventBase_, socket_->detachFd(),
223                                    false);
224
225   if (!FLAGS_handshakes) {
226     if (session_ != nullptr)
227       sslSocket_->setSSLSession(session_);
228     else if (FLAGS_global && pSess_ != nullptr)
229       sslSocket_->setSSLSession(*pSess_);
230   }
231   sslSocket_->sslConn(this);
232 }
233
234 void
235 SSLCacheClient::connectErr(const AsyncSocketException& ex)
236   noexcept
237 {
238   cout << "connectError: " << ex.what() << endl;
239 }
240
241 void
242 SSLCacheClient::handshakeSuc(AsyncSSLSocket* socket) noexcept
243 {
244   if (sslSocket_->getSSLSessionReused()) {
245     cr_->hits++;
246   } else {
247     cr_->miss++;
248     if (session_ != nullptr) {
249       SSL_SESSION_free(session_);
250     }
251     session_ = sslSocket_->getSSLSession();
252     if (FLAGS_global && pSess_ != nullptr && *pSess_ == nullptr) {
253       *pSess_ = session_;
254     }
255   }
256   if ( ((cr_->hits + cr_->miss) % 100) == ((100 / FLAGS_threads) * cr_->num)) {
257     cout << ".";
258     cout.flush();
259   }
260   sslSocket_->closeNow();
261   currReq_++;
262   this->start();
263 }
264
265 void
266 SSLCacheClient::handshakeErr(
267   AsyncSSLSocket* sock,
268   const AsyncSocketException& ex)
269   noexcept
270 {
271   cout << "handshakeError: " << ex.what() << endl;
272 }