copy wangle back into folly
[folly.git] / folly / wangle / acceptor / Acceptor.cpp
1 /*
2  *  Copyright (c) 2015, Facebook, Inc.
3  *  All rights reserved.
4  *
5  *  This source code is licensed under the BSD-style license found in the
6  *  LICENSE file in the root directory of this source tree. An additional grant
7  *  of patent rights can be found in the PATENTS file in the same directory.
8  *
9  */
10 #include <folly/wangle/acceptor/Acceptor.h>
11
12 #include <folly/wangle/acceptor/ManagedConnection.h>
13 #include <folly/wangle/ssl/SSLContextManager.h>
14
15 #include <boost/cast.hpp>
16 #include <fcntl.h>
17 #include <folly/ScopeGuard.h>
18 #include <folly/io/async/EventBase.h>
19 #include <fstream>
20 #include <sys/socket.h>
21 #include <sys/types.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <gflags/gflags.h>
25 #include <unistd.h>
26
27 using folly::wangle::ConnectionManager;
28 using folly::wangle::ManagedConnection;
29 using std::chrono::microseconds;
30 using std::chrono::milliseconds;
31 using std::filebuf;
32 using std::ifstream;
33 using std::ios;
34 using std::shared_ptr;
35 using std::string;
36
37 namespace folly {
38
39 #ifndef NO_LIB_GFLAGS
40 DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before "
41              "closing idle conns");
42 #else
43 const int32_t FLAGS_shutdown_idle_grace_ms = 5000;
44 #endif
45
46 static const std::string empty_string;
47 std::atomic<uint64_t> Acceptor::totalNumPendingSSLConns_{0};
48
49 /**
50  * Lightweight wrapper class to keep track of a newly
51  * accepted connection during SSL handshaking.
52  */
53 class AcceptorHandshakeHelper :
54       public AsyncSSLSocket::HandshakeCB,
55       public ManagedConnection {
56  public:
57   AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket,
58                           Acceptor* acceptor,
59                           const SocketAddress& clientAddr,
60                           std::chrono::steady_clock::time_point acceptTime,
61                           TransportInfo& tinfo)
62     : socket_(std::move(socket)), acceptor_(acceptor),
63       acceptTime_(acceptTime), clientAddr_(clientAddr),
64       tinfo_(tinfo) {
65     acceptor_->downstreamConnectionManager_->addConnection(this, true);
66     if(acceptor_->parseClientHello_)  {
67       socket_->enableClientHelloParsing();
68     }
69     socket_->sslAccept(this);
70   }
71
72   virtual void timeoutExpired() noexcept override {
73     VLOG(4) << "SSL handshake timeout expired";
74     sslError_ = SSLErrorEnum::TIMEOUT;
75     dropConnection();
76   }
77   virtual void describe(std::ostream& os) const override {
78     os << "pending handshake on " << clientAddr_;
79   }
80   virtual bool isBusy() const override {
81     return true;
82   }
83   virtual void notifyPendingShutdown() override {}
84   virtual void closeWhenIdle() override {}
85
86   virtual void dropConnection() override {
87     VLOG(10) << "Dropping in progress handshake for " << clientAddr_;
88     socket_->closeNow();
89   }
90   virtual void dumpConnectionState(uint8_t loglevel) override {
91   }
92
93  private:
94   // AsyncSSLSocket::HandshakeCallback API
95   virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
96
97     const unsigned char* nextProto = nullptr;
98     unsigned nextProtoLength = 0;
99     sock->getSelectedNextProtocol(&nextProto, &nextProtoLength);
100     if (VLOG_IS_ON(3)) {
101       if (nextProto) {
102         VLOG(3) << "Client selected next protocol " <<
103             string((const char*)nextProto, nextProtoLength);
104       } else {
105         VLOG(3) << "Client did not select a next protocol";
106       }
107     }
108
109     // fill in SSL-related fields from TransportInfo
110     // the other fields like RTT are filled in the Acceptor
111     tinfo_.ssl = true;
112     tinfo_.acceptTime = acceptTime_;
113     tinfo_.sslSetupTime = std::chrono::duration_cast<std::chrono::milliseconds>(
114       std::chrono::steady_clock::now() - acceptTime_
115     );
116     tinfo_.sslSetupBytesRead = sock->getRawBytesReceived();
117     tinfo_.sslSetupBytesWritten = sock->getRawBytesWritten();
118     tinfo_.sslServerName = sock->getSSLServerName() ?
119       std::make_shared<std::string>(sock->getSSLServerName()) : nullptr;
120     tinfo_.sslCipher = sock->getNegotiatedCipherName() ?
121       std::make_shared<std::string>(sock->getNegotiatedCipherName()) : nullptr;
122     tinfo_.sslVersion = sock->getSSLVersion();
123     tinfo_.sslCertSize = sock->getSSLCertSize();
124     tinfo_.sslResume = SSLUtil::getResumeState(sock);
125     tinfo_.sslClientCiphers = std::make_shared<std::string>();
126     sock->getSSLClientCiphers(*tinfo_.sslClientCiphers);
127     tinfo_.sslServerCiphers = std::make_shared<std::string>();
128     sock->getSSLServerCiphers(*tinfo_.sslServerCiphers);
129     tinfo_.sslClientComprMethods =
130         std::make_shared<std::string>(sock->getSSLClientComprMethods());
131     tinfo_.sslClientExts =
132         std::make_shared<std::string>(sock->getSSLClientExts());
133     tinfo_.sslNextProtocol = std::make_shared<std::string>();
134     tinfo_.sslNextProtocol->assign(reinterpret_cast<const char*>(nextProto),
135                                   nextProtoLength);
136
137     acceptor_->updateSSLStats(
138       sock,
139       tinfo_.sslSetupTime,
140       SSLErrorEnum::NO_ERROR
141     );
142     acceptor_->downstreamConnectionManager_->removeConnection(this);
143     acceptor_->sslConnectionReady(std::move(socket_), clientAddr_,
144         nextProto ? string((const char*)nextProto, nextProtoLength) :
145                                   empty_string, tinfo_);
146     delete this;
147   }
148
149   virtual void handshakeErr(AsyncSSLSocket* sock,
150                             const AsyncSocketException& ex) noexcept override {
151     auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
152     VLOG(3) << "SSL handshake error after " << elapsedTime.count() <<
153         " ms; " << sock->getRawBytesReceived() << " bytes received & " <<
154         sock->getRawBytesWritten() << " bytes sent: " <<
155         ex.what();
156     acceptor_->updateSSLStats(sock, elapsedTime, sslError_);
157     acceptor_->sslConnectionError();
158     delete this;
159   }
160
161   AsyncSSLSocket::UniquePtr socket_;
162   Acceptor* acceptor_;
163   std::chrono::steady_clock::time_point acceptTime_;
164   SocketAddress clientAddr_;
165   TransportInfo tinfo_;
166   SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR};
167 };
168
169 Acceptor::Acceptor(const ServerSocketConfig& accConfig) :
170   accConfig_(accConfig),
171   socketOptions_(accConfig.getSocketOptions()) {
172 }
173
174 void
175 Acceptor::init(AsyncServerSocket* serverSocket,
176                EventBase* eventBase) {
177   CHECK(nullptr == this->base_);
178
179   if (accConfig_.isSSL()) {
180     if (!sslCtxManager_) {
181       sslCtxManager_ = folly::make_unique<SSLContextManager>(
182         eventBase,
183         "vip_" + getName(),
184         accConfig_.strictSSL, nullptr);
185     }
186     for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) {
187       sslCtxManager_->addSSLContextConfig(
188         sslCtxConfig,
189         accConfig_.sslCacheOptions,
190         &accConfig_.initialTicketSeeds,
191         accConfig_.bindAddress,
192         cacheProvider_);
193       parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled;
194     }
195
196     CHECK(sslCtxManager_->getDefaultSSLCtx());
197   }
198
199   base_ = eventBase;
200   state_ = State::kRunning;
201   downstreamConnectionManager_ = ConnectionManager::makeUnique(
202     eventBase, accConfig_.connectionIdleTimeout, this);
203
204   if (serverSocket) {
205     serverSocket->addAcceptCallback(this, eventBase);
206
207     for (auto& fd : serverSocket->getSockets()) {
208       if (fd < 0) {
209         continue;
210       }
211       for (const auto& opt: socketOptions_) {
212         opt.first.apply(fd, opt.second);
213       }
214     }
215   }
216 }
217
218 Acceptor::~Acceptor(void) {
219 }
220
221 void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) {
222   sslCtxManager_->addSSLContextConfig(sslCtxConfig,
223                                       accConfig_.sslCacheOptions,
224                                       &accConfig_.initialTicketSeeds,
225                                       accConfig_.bindAddress,
226                                       cacheProvider_);
227 }
228
229 void
230 Acceptor::drainAllConnections() {
231   if (downstreamConnectionManager_) {
232     downstreamConnectionManager_->initiateGracefulShutdown(
233       std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms));
234   }
235 }
236
237 void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from,
238                        IConnectionCounter* counter) {
239   loadShedConfig_ = from;
240   connectionCounter_ = counter;
241 }
242
243 bool Acceptor::canAccept(const SocketAddress& address) {
244   if (!connectionCounter_) {
245     return true;
246   }
247
248   uint64_t maxConnections = connectionCounter_->getMaxConnections();
249   if (maxConnections == 0) {
250     return true;
251   }
252
253   uint64_t currentConnections = connectionCounter_->getNumConnections();
254   if (currentConnections < maxConnections) {
255     return true;
256   }
257
258   if (loadShedConfig_.isWhitelisted(address)) {
259     return true;
260   }
261
262   // Take care of comparing connection count against max connections across
263   // all acceptors. Expensive since a lock must be taken to get the counter.
264   auto connectionCountForLoadShedding = getConnectionCountForLoadShedding();
265   if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) {
266     return true;
267   }
268
269   VLOG(4) << address.describe() << " not whitelisted";
270   return false;
271 }
272
273 void
274 Acceptor::connectionAccepted(
275     int fd, const SocketAddress& clientAddr) noexcept {
276   if (!canAccept(clientAddr)) {
277     close(fd);
278     return;
279   }
280   auto acceptTime = std::chrono::steady_clock::now();
281   for (const auto& opt: socketOptions_) {
282     opt.first.apply(fd, opt.second);
283   }
284
285   onDoneAcceptingConnection(fd, clientAddr, acceptTime);
286 }
287
288 void Acceptor::onDoneAcceptingConnection(
289     int fd,
290     const SocketAddress& clientAddr,
291     std::chrono::steady_clock::time_point acceptTime) noexcept {
292   TransportInfo tinfo;
293   processEstablishedConnection(fd, clientAddr, acceptTime, tinfo);
294 }
295
296 void
297 Acceptor::processEstablishedConnection(
298     int fd,
299     const SocketAddress& clientAddr,
300     std::chrono::steady_clock::time_point acceptTime,
301     TransportInfo& tinfo) noexcept {
302   if (accConfig_.isSSL()) {
303     CHECK(sslCtxManager_);
304     AsyncSSLSocket::UniquePtr sslSock(
305       makeNewAsyncSSLSocket(
306         sslCtxManager_->getDefaultSSLCtx(), base_, fd));
307     ++numPendingSSLConns_;
308     ++totalNumPendingSSLConns_;
309     if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) {
310       VLOG(2) << "dropped SSL handshake on " << accConfig_.name <<
311         " too many handshakes in progress";
312       updateSSLStats(sslSock.get(), std::chrono::milliseconds(0),
313                      SSLErrorEnum::DROPPED);
314       sslConnectionError();
315       return;
316     }
317     new AcceptorHandshakeHelper(
318       std::move(sslSock),
319       this,
320       clientAddr,
321       acceptTime,
322       tinfo
323     );
324   } else {
325     tinfo.ssl = false;
326     tinfo.acceptTime = acceptTime;
327     AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd));
328     connectionReady(std::move(sock), clientAddr, empty_string, tinfo);
329   }
330 }
331
332 void
333 Acceptor::connectionReady(
334     AsyncSocket::UniquePtr sock,
335     const SocketAddress& clientAddr,
336     const string& nextProtocolName,
337     TransportInfo& tinfo) {
338   // Limit the number of reads from the socket per poll loop iteration,
339   // both to keep memory usage under control and to prevent one fast-
340   // writing client from starving other connections.
341   sock->setMaxReadsPerEvent(16);
342   tinfo.initWithSocket(sock.get());
343   onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo);
344 }
345
346 void
347 Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock,
348                              const SocketAddress& clientAddr,
349                              const string& nextProtocol,
350                              TransportInfo& tinfo) {
351   CHECK(numPendingSSLConns_ > 0);
352   connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo);
353   --numPendingSSLConns_;
354   --totalNumPendingSSLConns_;
355   if (state_ == State::kDraining) {
356     checkDrained();
357   }
358 }
359
360 void
361 Acceptor::sslConnectionError() {
362   CHECK(numPendingSSLConns_ > 0);
363   --numPendingSSLConns_;
364   --totalNumPendingSSLConns_;
365   if (state_ == State::kDraining) {
366     checkDrained();
367   }
368 }
369
370 void
371 Acceptor::acceptError(const std::exception& ex) noexcept {
372   // An error occurred.
373   // The most likely error is out of FDs.  AsyncServerSocket will back off
374   // briefly if we are out of FDs, then continue accepting later.
375   // Just log a message here.
376   LOG(ERROR) << "error accepting on acceptor socket: " << ex.what();
377 }
378
379 void
380 Acceptor::acceptStopped() noexcept {
381   VLOG(3) << "Acceptor " << this << " acceptStopped()";
382   // Drain the open client connections
383   drainAllConnections();
384
385   // If we haven't yet finished draining, begin doing so by marking ourselves
386   // as in the draining state. We must be sure to hit checkDrained() here, as
387   // if we're completely idle, we can should consider ourself drained
388   // immediately (as there is no outstanding work to complete to cause us to
389   // re-evaluate this).
390   if (state_ != State::kDone) {
391     state_ = State::kDraining;
392     checkDrained();
393   }
394 }
395
396 void
397 Acceptor::onEmpty(const ConnectionManager& cm) {
398   VLOG(3) << "Acceptor=" << this << " onEmpty()";
399   if (state_ == State::kDraining) {
400     checkDrained();
401   }
402 }
403
404 void
405 Acceptor::checkDrained() {
406   CHECK(state_ == State::kDraining);
407   if (forceShutdownInProgress_ ||
408       (downstreamConnectionManager_->getNumConnections() != 0) ||
409       (numPendingSSLConns_ != 0)) {
410     return;
411   }
412
413   VLOG(2) << "All connections drained from Acceptor=" << this << " in thread "
414           << base_;
415
416   downstreamConnectionManager_.reset();
417
418   state_ = State::kDone;
419
420   onConnectionsDrained();
421 }
422
423 milliseconds
424 Acceptor::getConnTimeout() const {
425   return accConfig_.connectionIdleTimeout;
426 }
427
428 void Acceptor::addConnection(ManagedConnection* conn) {
429   // Add the socket to the timeout manager so that it can be cleaned
430   // up after being left idle for a long time.
431   downstreamConnectionManager_->addConnection(conn, true);
432 }
433
434 void
435 Acceptor::forceStop() {
436   base_->runInEventBaseThread([&] { dropAllConnections(); });
437 }
438
439 void
440 Acceptor::dropAllConnections() {
441   if (downstreamConnectionManager_) {
442     VLOG(3) << "Dropping all connections from Acceptor=" << this <<
443       " in thread " << base_;
444     assert(base_->isInEventBaseThread());
445     forceShutdownInProgress_ = true;
446     downstreamConnectionManager_->dropAllConnections();
447     CHECK(downstreamConnectionManager_->getNumConnections() == 0);
448     downstreamConnectionManager_.reset();
449   }
450   CHECK(numPendingSSLConns_ == 0);
451
452   state_ = State::kDone;
453   onConnectionsDrained();
454 }
455
456 } // namespace