24f805185d845b61767479a07075b48628124a50
[folly.git] / folly / wangle / ssl / test / SSLCacheTest.cpp
1 /*
2  *  Copyright (c) 2015, Facebook, Inc.
3  *  All rights reserved.
4  *
5  *  This source code is licensed under the BSD-style license found in the
6  *  LICENSE file in the root directory of this source tree. An additional grant
7  *  of patent rights can be found in the PATENTS file in the same directory.
8  *
9  */
10 #include <folly/Portability.h>
11 #include <folly/io/async/EventBase.h>
12 #include <gflags/gflags.h>
13 #include <iostream>
14 #include <thread>
15 #include <folly/io/async/AsyncSSLSocket.h>
16 #include <folly/io/async/AsyncSocket.h>
17 #include <vector>
18
19 using namespace std;
20 using namespace folly;
21
22 DEFINE_int32(clients, 1, "Number of simulated SSL clients");
23 DEFINE_int32(threads, 1, "Number of threads to spread clients across");
24 DEFINE_int32(requests, 2, "Total number of requests per client");
25 DEFINE_int32(port, 9423, "Server port");
26 DEFINE_bool(sticky, false, "A given client sends all reqs to one "
27             "(random) server");
28 DEFINE_bool(global, false, "All clients in a thread use the same SSL session");
29 DEFINE_bool(handshakes, false, "Force 100% handshakes");
30
31 string f_servers[10];
32 int f_num_servers = 0;
33 int tnum = 0;
34
35 class ClientRunner {
36  public:
37
38   ClientRunner(): reqs(0), hits(0), miss(0), num(tnum++) {}
39   void run();
40
41   int reqs;
42   int hits;
43   int miss;
44   int num;
45 };
46
47 class SSLCacheClient : public AsyncSocket::ConnectCallback,
48                        public AsyncSSLSocket::HandshakeCB
49 {
50 private:
51   EventBase* eventBase_;
52   int currReq_;
53   int serverIdx_;
54   AsyncSocket* socket_;
55   AsyncSSLSocket* sslSocket_;
56   SSL_SESSION* session_;
57   SSL_SESSION **pSess_;
58   std::shared_ptr<SSLContext> ctx_;
59   ClientRunner* cr_;
60
61 public:
62   SSLCacheClient(EventBase* eventBase, SSL_SESSION **pSess, ClientRunner* cr);
63   ~SSLCacheClient() override {
64     if (session_ && !FLAGS_global)
65       SSL_SESSION_free(session_);
66     if (socket_ != nullptr) {
67       if (sslSocket_ != nullptr) {
68         sslSocket_->destroy();
69         sslSocket_ = nullptr;
70       }
71       socket_->destroy();
72       socket_ = nullptr;
73     }
74   };
75
76   void start();
77
78   void connectSuccess() noexcept override;
79
80   void connectErr(const AsyncSocketException& ex) noexcept override;
81
82   void handshakeSuc(AsyncSSLSocket* sock) noexcept override;
83
84   void handshakeErr(AsyncSSLSocket* sock,
85                     const AsyncSocketException& ex) noexcept override;
86 };
87
88 int
89 main(int argc, char* argv[])
90 {
91   gflags::SetUsageMessage(std::string("\n\n"
92 "usage: sslcachetest [options] -c <clients> -t <threads> servers\n"
93 ));
94   gflags::ParseCommandLineFlags(&argc, &argv, true);
95   int reqs = 0;
96   int hits = 0;
97   int miss = 0;
98   struct timeval start;
99   struct timeval end;
100   struct timeval result;
101
102   srand((unsigned int)time(nullptr));
103
104   for (int i = 1; i < argc; i++) {
105     f_servers[f_num_servers++] = argv[i];
106   }
107   if (f_num_servers == 0) {
108     cout << "require at least one server\n";
109     return 1;
110   }
111
112   gettimeofday(&start, nullptr);
113   if (FLAGS_threads == 1) {
114     ClientRunner r;
115     r.run();
116     gettimeofday(&end, nullptr);
117     reqs = r.reqs;
118     hits = r.hits;
119     miss = r.miss;
120   }
121   else {
122     std::vector<ClientRunner> clients;
123     std::vector<std::thread> threads;
124     for (int t = 0; t < FLAGS_threads; t++) {
125       threads.emplace_back([&] {
126           clients[t].run();
127         });
128     }
129     for (auto& thr: threads) {
130       thr.join();
131     }
132     gettimeofday(&end, nullptr);
133
134     for (const auto& client: clients) {
135       reqs += client.reqs;
136       hits += client.hits;
137       miss += client.miss;
138     }
139   }
140
141   timersub(&end, &start, &result);
142
143   cout << "Requests: " << reqs << endl;
144   cout << "Handshakes: " << miss << endl;
145   cout << "Resumes: " << hits << endl;
146   cout << "Runtime(ms): " << result.tv_sec << "." << result.tv_usec / 1000 <<
147     endl;
148
149   cout << "ops/sec: " << (reqs * 1.0) /
150     ((double)result.tv_sec * 1.0 + (double)result.tv_usec / 1000000.0) << endl;
151
152   return 0;
153 }
154
155 void
156 ClientRunner::run()
157 {
158   EventBase eb;
159   std::list<SSLCacheClient *> clients;
160   SSL_SESSION* session = nullptr;
161
162   for (int i = 0; i < FLAGS_clients; i++) {
163     SSLCacheClient* c = new SSLCacheClient(&eb, &session, this);
164     c->start();
165     clients.push_back(c);
166   }
167
168   eb.loop();
169
170   for (auto it = clients.begin(); it != clients.end(); it++) {
171     delete* it;
172   }
173
174   reqs += hits + miss;
175 }
176
177 SSLCacheClient::SSLCacheClient(EventBase* eb,
178                                SSL_SESSION **pSess,
179                                ClientRunner* cr)
180     : eventBase_(eb),
181       currReq_(0),
182       serverIdx_(0),
183       socket_(nullptr),
184       sslSocket_(nullptr),
185       session_(nullptr),
186       pSess_(pSess),
187       cr_(cr)
188 {
189   ctx_.reset(new SSLContext());
190   ctx_->setOptions(SSL_OP_NO_TICKET);
191 }
192
193 void
194 SSLCacheClient::start()
195 {
196   if (currReq_ >= FLAGS_requests) {
197     cout << "+";
198     return;
199   }
200
201   if (currReq_ == 0 || !FLAGS_sticky) {
202     serverIdx_ = rand() % f_num_servers;
203   }
204   if (socket_ != nullptr) {
205     if (sslSocket_ != nullptr) {
206       sslSocket_->destroy();
207       sslSocket_ = nullptr;
208     }
209     socket_->destroy();
210     socket_ = nullptr;
211   }
212   socket_ = new AsyncSocket(eventBase_);
213   socket_->connect(this, f_servers[serverIdx_], (uint16_t)FLAGS_port);
214 }
215
216 void
217 SSLCacheClient::connectSuccess() noexcept
218 {
219   sslSocket_ = new AsyncSSLSocket(ctx_, eventBase_, socket_->detachFd(),
220                                    false);
221
222   if (!FLAGS_handshakes) {
223     if (session_ != nullptr)
224       sslSocket_->setSSLSession(session_);
225     else if (FLAGS_global && pSess_ != nullptr)
226       sslSocket_->setSSLSession(*pSess_);
227   }
228   sslSocket_->sslConn(this);
229 }
230
231 void
232 SSLCacheClient::connectErr(const AsyncSocketException& ex)
233   noexcept
234 {
235   cout << "connectError: " << ex.what() << endl;
236 }
237
238 void
239 SSLCacheClient::handshakeSuc(AsyncSSLSocket* socket) noexcept
240 {
241   if (sslSocket_->getSSLSessionReused()) {
242     cr_->hits++;
243   } else {
244     cr_->miss++;
245     if (session_ != nullptr) {
246       SSL_SESSION_free(session_);
247     }
248     session_ = sslSocket_->getSSLSession();
249     if (FLAGS_global && pSess_ != nullptr && *pSess_ == nullptr) {
250       *pSess_ = session_;
251     }
252   }
253   if ( ((cr_->hits + cr_->miss) % 100) == ((100 / FLAGS_threads) * cr_->num)) {
254     cout << ".";
255     cout.flush();
256   }
257   sslSocket_->closeNow();
258   currReq_++;
259   this->start();
260 }
261
262 void
263 SSLCacheClient::handshakeErr(
264   AsyncSSLSocket* sock,
265   const AsyncSocketException& ex)
266   noexcept
267 {
268   cout << "handshakeError: " << ex.what() << endl;
269 }