2 * Copyright (c) 2015, Facebook, Inc.
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree. An additional grant
7 * of patent rights can be found in the PATENTS file in the same directory.
10 #include <folly/wangle/acceptor/Acceptor.h>
12 #include <folly/wangle/acceptor/ManagedConnection.h>
13 #include <folly/wangle/ssl/SSLContextManager.h>
15 #include <boost/cast.hpp>
17 #include <folly/ScopeGuard.h>
18 #include <folly/io/async/EventBase.h>
20 #include <sys/socket.h>
21 #include <sys/types.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <gflags/gflags.h>
27 using folly::wangle::ConnectionManager;
28 using folly::wangle::ManagedConnection;
29 using std::chrono::microseconds;
30 using std::chrono::milliseconds;
34 using std::shared_ptr;
40 DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before "
41 "closing idle conns");
43 const int32_t FLAGS_shutdown_idle_grace_ms = 5000;
46 static const std::string empty_string;
47 std::atomic<uint64_t> Acceptor::totalNumPendingSSLConns_{0};
50 * Lightweight wrapper class to keep track of a newly
51 * accepted connection during SSL handshaking.
53 class AcceptorHandshakeHelper :
54 public AsyncSSLSocket::HandshakeCB,
55 public ManagedConnection {
57 AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket,
59 const SocketAddress& clientAddr,
60 std::chrono::steady_clock::time_point acceptTime,
62 : socket_(std::move(socket)), acceptor_(acceptor),
63 acceptTime_(acceptTime), clientAddr_(clientAddr),
65 acceptor_->downstreamConnectionManager_->addConnection(this, true);
66 if(acceptor_->parseClientHello_) {
67 socket_->enableClientHelloParsing();
69 socket_->sslAccept(this);
72 virtual void timeoutExpired() noexcept {
73 VLOG(4) << "SSL handshake timeout expired";
74 sslError_ = SSLErrorEnum::TIMEOUT;
77 virtual void describe(std::ostream& os) const {
78 os << "pending handshake on " << clientAddr_;
80 virtual bool isBusy() const {
83 virtual void notifyPendingShutdown() {}
84 virtual void closeWhenIdle() {}
86 virtual void dropConnection() {
87 VLOG(10) << "Dropping in progress handshake for " << clientAddr_;
90 virtual void dumpConnectionState(uint8_t loglevel) {
94 // AsyncSSLSocket::HandshakeCallback API
95 virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept {
97 const unsigned char* nextProto = nullptr;
98 unsigned nextProtoLength = 0;
99 sock->getSelectedNextProtocol(&nextProto, &nextProtoLength);
102 VLOG(3) << "Client selected next protocol " <<
103 string((const char*)nextProto, nextProtoLength);
105 VLOG(3) << "Client did not select a next protocol";
109 // fill in SSL-related fields from TransportInfo
110 // the other fields like RTT are filled in the Acceptor
112 tinfo_.acceptTime = acceptTime_;
113 tinfo_.sslSetupTime = std::chrono::duration_cast<std::chrono::milliseconds>(
114 std::chrono::steady_clock::now() - acceptTime_
116 tinfo_.sslSetupBytesRead = sock->getRawBytesReceived();
117 tinfo_.sslSetupBytesWritten = sock->getRawBytesWritten();
118 tinfo_.sslServerName = sock->getSSLServerName() ?
119 std::make_shared<std::string>(sock->getSSLServerName()) : nullptr;
120 tinfo_.sslCipher = sock->getNegotiatedCipherName() ?
121 std::make_shared<std::string>(sock->getNegotiatedCipherName()) : nullptr;
122 tinfo_.sslVersion = sock->getSSLVersion();
123 tinfo_.sslCertSize = sock->getSSLCertSize();
124 tinfo_.sslResume = SSLUtil::getResumeState(sock);
125 tinfo_.sslClientCiphers = std::make_shared<std::string>();
126 sock->getSSLClientCiphers(*tinfo_.sslClientCiphers);
127 tinfo_.sslServerCiphers = std::make_shared<std::string>();
128 sock->getSSLServerCiphers(*tinfo_.sslServerCiphers);
129 tinfo_.sslClientComprMethods =
130 std::make_shared<std::string>(sock->getSSLClientComprMethods());
131 tinfo_.sslClientExts =
132 std::make_shared<std::string>(sock->getSSLClientExts());
133 tinfo_.sslNextProtocol = std::make_shared<std::string>();
134 tinfo_.sslNextProtocol->assign(reinterpret_cast<const char*>(nextProto),
137 acceptor_->updateSSLStats(
140 SSLErrorEnum::NO_ERROR
142 acceptor_->downstreamConnectionManager_->removeConnection(this);
143 acceptor_->sslConnectionReady(std::move(socket_), clientAddr_,
144 nextProto ? string((const char*)nextProto, nextProtoLength) :
145 empty_string, tinfo_);
149 virtual void handshakeErr(AsyncSSLSocket* sock,
150 const AsyncSocketException& ex) noexcept {
151 auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
152 VLOG(3) << "SSL handshake error after " << elapsedTime.count() <<
153 " ms; " << sock->getRawBytesReceived() << " bytes received & " <<
154 sock->getRawBytesWritten() << " bytes sent: " <<
156 acceptor_->updateSSLStats(sock, elapsedTime, sslError_);
157 acceptor_->sslConnectionError();
161 AsyncSSLSocket::UniquePtr socket_;
163 std::chrono::steady_clock::time_point acceptTime_;
164 SocketAddress clientAddr_;
165 TransportInfo tinfo_;
166 SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR};
169 Acceptor::Acceptor(const ServerSocketConfig& accConfig) :
170 accConfig_(accConfig),
171 socketOptions_(accConfig.getSocketOptions()) {
175 Acceptor::init(AsyncServerSocket* serverSocket,
176 EventBase* eventBase) {
177 CHECK(nullptr == this->base_);
179 if (accConfig_.isSSL()) {
180 if (!sslCtxManager_) {
181 sslCtxManager_ = folly::make_unique<SSLContextManager>(
184 accConfig_.strictSSL, nullptr);
186 for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) {
187 sslCtxManager_->addSSLContextConfig(
189 accConfig_.sslCacheOptions,
190 &accConfig_.initialTicketSeeds,
191 accConfig_.bindAddress,
193 parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled;
196 CHECK(sslCtxManager_->getDefaultSSLCtx());
200 state_ = State::kRunning;
201 downstreamConnectionManager_ = ConnectionManager::makeUnique(
202 eventBase, accConfig_.connectionIdleTimeout, this);
205 serverSocket->addAcceptCallback(this, eventBase);
207 for (auto& fd : serverSocket->getSockets()) {
211 for (const auto& opt: socketOptions_) {
212 opt.first.apply(fd, opt.second);
218 Acceptor::~Acceptor(void) {
221 void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) {
222 sslCtxManager_->addSSLContextConfig(sslCtxConfig,
223 accConfig_.sslCacheOptions,
224 &accConfig_.initialTicketSeeds,
225 accConfig_.bindAddress,
230 Acceptor::drainAllConnections() {
231 if (downstreamConnectionManager_) {
232 downstreamConnectionManager_->initiateGracefulShutdown(
233 std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms));
237 void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from,
238 IConnectionCounter* counter) {
239 loadShedConfig_ = from;
240 connectionCounter_ = counter;
243 bool Acceptor::canAccept(const SocketAddress& address) {
244 if (!connectionCounter_) {
248 uint64_t maxConnections = connectionCounter_->getMaxConnections();
249 if (maxConnections == 0) {
253 uint64_t currentConnections = connectionCounter_->getNumConnections();
254 if (currentConnections < maxConnections) {
258 if (loadShedConfig_.isWhitelisted(address)) {
262 // Take care of comparing connection count against max connections across
263 // all acceptors. Expensive since a lock must be taken to get the counter.
264 auto connectionCountForLoadShedding = getConnectionCountForLoadShedding();
265 if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) {
269 VLOG(4) << address.describe() << " not whitelisted";
274 Acceptor::connectionAccepted(
275 int fd, const SocketAddress& clientAddr) noexcept {
276 if (!canAccept(clientAddr)) {
280 auto acceptTime = std::chrono::steady_clock::now();
281 for (const auto& opt: socketOptions_) {
282 opt.first.apply(fd, opt.second);
285 onDoneAcceptingConnection(fd, clientAddr, acceptTime);
288 void Acceptor::onDoneAcceptingConnection(
290 const SocketAddress& clientAddr,
291 std::chrono::steady_clock::time_point acceptTime) noexcept {
293 processEstablishedConnection(fd, clientAddr, acceptTime, tinfo);
297 Acceptor::processEstablishedConnection(
299 const SocketAddress& clientAddr,
300 std::chrono::steady_clock::time_point acceptTime,
301 TransportInfo& tinfo) noexcept {
302 if (accConfig_.isSSL()) {
303 CHECK(sslCtxManager_);
304 AsyncSSLSocket::UniquePtr sslSock(
305 makeNewAsyncSSLSocket(
306 sslCtxManager_->getDefaultSSLCtx(), base_, fd));
307 ++numPendingSSLConns_;
308 ++totalNumPendingSSLConns_;
309 if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) {
310 VLOG(2) << "dropped SSL handshake on " << accConfig_.name <<
311 " too many handshakes in progress";
312 updateSSLStats(sslSock.get(), std::chrono::milliseconds(0),
313 SSLErrorEnum::DROPPED);
314 sslConnectionError();
317 new AcceptorHandshakeHelper(
326 tinfo.acceptTime = acceptTime;
327 AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd));
328 connectionReady(std::move(sock), clientAddr, empty_string, tinfo);
333 Acceptor::connectionReady(
334 AsyncSocket::UniquePtr sock,
335 const SocketAddress& clientAddr,
336 const string& nextProtocolName,
337 TransportInfo& tinfo) {
338 // Limit the number of reads from the socket per poll loop iteration,
339 // both to keep memory usage under control and to prevent one fast-
340 // writing client from starving other connections.
341 sock->setMaxReadsPerEvent(16);
342 tinfo.initWithSocket(sock.get());
343 onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo);
347 Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock,
348 const SocketAddress& clientAddr,
349 const string& nextProtocol,
350 TransportInfo& tinfo) {
351 CHECK(numPendingSSLConns_ > 0);
352 connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo);
353 --numPendingSSLConns_;
354 --totalNumPendingSSLConns_;
355 if (state_ == State::kDraining) {
361 Acceptor::sslConnectionError() {
362 CHECK(numPendingSSLConns_ > 0);
363 --numPendingSSLConns_;
364 --totalNumPendingSSLConns_;
365 if (state_ == State::kDraining) {
371 Acceptor::acceptError(const std::exception& ex) noexcept {
372 // An error occurred.
373 // The most likely error is out of FDs. AsyncServerSocket will back off
374 // briefly if we are out of FDs, then continue accepting later.
375 // Just log a message here.
376 LOG(ERROR) << "error accepting on acceptor socket: " << ex.what();
380 Acceptor::acceptStopped() noexcept {
381 VLOG(3) << "Acceptor " << this << " acceptStopped()";
382 // Drain the open client connections
383 drainAllConnections();
385 // If we haven't yet finished draining, begin doing so by marking ourselves
386 // as in the draining state. We must be sure to hit checkDrained() here, as
387 // if we're completely idle, we can should consider ourself drained
388 // immediately (as there is no outstanding work to complete to cause us to
389 // re-evaluate this).
390 if (state_ != State::kDone) {
391 state_ = State::kDraining;
397 Acceptor::onEmpty(const ConnectionManager& cm) {
398 VLOG(3) << "Acceptor=" << this << " onEmpty()";
399 if (state_ == State::kDraining) {
405 Acceptor::checkDrained() {
406 CHECK(state_ == State::kDraining);
407 if (forceShutdownInProgress_ ||
408 (downstreamConnectionManager_->getNumConnections() != 0) ||
409 (numPendingSSLConns_ != 0)) {
413 VLOG(2) << "All connections drained from Acceptor=" << this << " in thread "
416 downstreamConnectionManager_.reset();
418 state_ = State::kDone;
420 onConnectionsDrained();
424 Acceptor::getConnTimeout() const {
425 return accConfig_.connectionIdleTimeout;
428 void Acceptor::addConnection(ManagedConnection* conn) {
429 // Add the socket to the timeout manager so that it can be cleaned
430 // up after being left idle for a long time.
431 downstreamConnectionManager_->addConnection(conn, true);
435 Acceptor::forceStop() {
436 base_->runInEventBaseThread([&] { dropAllConnections(); });
440 Acceptor::dropAllConnections() {
441 if (downstreamConnectionManager_) {
442 VLOG(3) << "Dropping all connections from Acceptor=" << this <<
443 " in thread " << base_;
444 assert(base_->isInEventBaseThread());
445 forceShutdownInProgress_ = true;
446 downstreamConnectionManager_->dropAllConnections();
447 CHECK(downstreamConnectionManager_->getNumConnections() == 0);
448 downstreamConnectionManager_.reset();
450 CHECK(numPendingSSLConns_ == 0);
452 state_ = State::kDone;
453 onConnectionsDrained();