f66b3412231cfc11ea6397299b45d8cc7c4f4668
[folly.git] / folly / wangle / acceptor / Acceptor.cpp
1 /*
2  *  Copyright (c) 2015, Facebook, Inc.
3  *  All rights reserved.
4  *
5  *  This source code is licensed under the BSD-style license found in the
6  *  LICENSE file in the root directory of this source tree. An additional grant
7  *  of patent rights can be found in the PATENTS file in the same directory.
8  *
9  */
10 #include <folly/wangle/acceptor/Acceptor.h>
11
12 #include <folly/wangle/acceptor/ManagedConnection.h>
13 #include <folly/wangle/ssl/SSLContextManager.h>
14
15 #include <boost/cast.hpp>
16 #include <fcntl.h>
17 #include <folly/ScopeGuard.h>
18 #include <folly/io/async/EventBase.h>
19 #include <fstream>
20 #include <sys/socket.h>
21 #include <sys/types.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <gflags/gflags.h>
25 #include <unistd.h>
26
27 using folly::wangle::ConnectionManager;
28 using folly::wangle::ManagedConnection;
29 using std::chrono::microseconds;
30 using std::chrono::milliseconds;
31 using std::filebuf;
32 using std::ifstream;
33 using std::ios;
34 using std::shared_ptr;
35 using std::string;
36
37 namespace folly {
38
39 #ifndef NO_LIB_GFLAGS
40 DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before "
41              "closing idle conns");
42 #else
43 const int32_t FLAGS_shutdown_idle_grace_ms = 5000;
44 #endif
45
46 static const std::string empty_string;
47 std::atomic<uint64_t> Acceptor::totalNumPendingSSLConns_{0};
48
49 /**
50  * Lightweight wrapper class to keep track of a newly
51  * accepted connection during SSL handshaking.
52  */
53 class AcceptorHandshakeHelper :
54       public AsyncSSLSocket::HandshakeCB,
55       public ManagedConnection {
56  public:
57   AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket,
58                           Acceptor* acceptor,
59                           const SocketAddress& clientAddr,
60                           std::chrono::steady_clock::time_point acceptTime,
61                           TransportInfo& tinfo)
62     : socket_(std::move(socket)), acceptor_(acceptor),
63       acceptTime_(acceptTime), clientAddr_(clientAddr),
64       tinfo_(tinfo) {
65     acceptor_->downstreamConnectionManager_->addConnection(this, true);
66     if(acceptor_->parseClientHello_)  {
67       socket_->enableClientHelloParsing();
68     }
69     socket_->sslAccept(this);
70   }
71
72   void timeoutExpired() noexcept override {
73     VLOG(4) << "SSL handshake timeout expired";
74     sslError_ = SSLErrorEnum::TIMEOUT;
75     dropConnection();
76   }
77   void describe(std::ostream& os) const override {
78     os << "pending handshake on " << clientAddr_;
79   }
80   bool isBusy() const override { return true; }
81   void notifyPendingShutdown() override {}
82   void closeWhenIdle() override {}
83
84   void dropConnection() override {
85     VLOG(10) << "Dropping in progress handshake for " << clientAddr_;
86     socket_->closeNow();
87   }
88   void dumpConnectionState(uint8_t loglevel) override {}
89
90  private:
91   // AsyncSSLSocket::HandshakeCallback API
92   void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
93
94     const unsigned char* nextProto = nullptr;
95     unsigned nextProtoLength = 0;
96     sock->getSelectedNextProtocol(&nextProto, &nextProtoLength);
97     if (VLOG_IS_ON(3)) {
98       if (nextProto) {
99         VLOG(3) << "Client selected next protocol " <<
100             string((const char*)nextProto, nextProtoLength);
101       } else {
102         VLOG(3) << "Client did not select a next protocol";
103       }
104     }
105
106     // fill in SSL-related fields from TransportInfo
107     // the other fields like RTT are filled in the Acceptor
108     tinfo_.ssl = true;
109     tinfo_.acceptTime = acceptTime_;
110     tinfo_.sslSetupTime = std::chrono::duration_cast<std::chrono::milliseconds>(
111       std::chrono::steady_clock::now() - acceptTime_
112     );
113     tinfo_.sslSetupBytesRead = sock->getRawBytesReceived();
114     tinfo_.sslSetupBytesWritten = sock->getRawBytesWritten();
115     tinfo_.sslServerName = sock->getSSLServerName() ?
116       std::make_shared<std::string>(sock->getSSLServerName()) : nullptr;
117     tinfo_.sslCipher = sock->getNegotiatedCipherName() ?
118       std::make_shared<std::string>(sock->getNegotiatedCipherName()) : nullptr;
119     tinfo_.sslVersion = sock->getSSLVersion();
120     tinfo_.sslCertSize = sock->getSSLCertSize();
121     tinfo_.sslResume = SSLUtil::getResumeState(sock);
122     tinfo_.sslClientCiphers = std::make_shared<std::string>();
123     sock->getSSLClientCiphers(*tinfo_.sslClientCiphers);
124     tinfo_.sslServerCiphers = std::make_shared<std::string>();
125     sock->getSSLServerCiphers(*tinfo_.sslServerCiphers);
126     tinfo_.sslClientComprMethods =
127         std::make_shared<std::string>(sock->getSSLClientComprMethods());
128     tinfo_.sslClientExts =
129         std::make_shared<std::string>(sock->getSSLClientExts());
130     tinfo_.sslNextProtocol = std::make_shared<std::string>();
131     tinfo_.sslNextProtocol->assign(reinterpret_cast<const char*>(nextProto),
132                                   nextProtoLength);
133
134     acceptor_->updateSSLStats(
135       sock,
136       tinfo_.sslSetupTime,
137       SSLErrorEnum::NO_ERROR
138     );
139     acceptor_->downstreamConnectionManager_->removeConnection(this);
140     acceptor_->sslConnectionReady(std::move(socket_), clientAddr_,
141         nextProto ? string((const char*)nextProto, nextProtoLength) :
142                                   empty_string, tinfo_);
143     delete this;
144   }
145
146   void handshakeErr(AsyncSSLSocket* sock,
147                     const AsyncSocketException& ex) noexcept override {
148     auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
149     VLOG(3) << "SSL handshake error after " << elapsedTime.count() <<
150         " ms; " << sock->getRawBytesReceived() << " bytes received & " <<
151         sock->getRawBytesWritten() << " bytes sent: " <<
152         ex.what();
153     acceptor_->updateSSLStats(sock, elapsedTime, sslError_);
154     acceptor_->sslConnectionError();
155     delete this;
156   }
157
158   AsyncSSLSocket::UniquePtr socket_;
159   Acceptor* acceptor_;
160   std::chrono::steady_clock::time_point acceptTime_;
161   SocketAddress clientAddr_;
162   TransportInfo tinfo_;
163   SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR};
164 };
165
166 Acceptor::Acceptor(const ServerSocketConfig& accConfig) :
167   accConfig_(accConfig),
168   socketOptions_(accConfig.getSocketOptions()) {
169 }
170
171 void
172 Acceptor::init(AsyncServerSocket* serverSocket,
173                EventBase* eventBase) {
174   CHECK(nullptr == this->base_);
175
176   if (accConfig_.isSSL()) {
177     if (!sslCtxManager_) {
178       sslCtxManager_ = folly::make_unique<SSLContextManager>(
179         eventBase,
180         "vip_" + getName(),
181         accConfig_.strictSSL, nullptr);
182     }
183     for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) {
184       sslCtxManager_->addSSLContextConfig(
185         sslCtxConfig,
186         accConfig_.sslCacheOptions,
187         &accConfig_.initialTicketSeeds,
188         accConfig_.bindAddress,
189         cacheProvider_);
190       parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled;
191     }
192
193     CHECK(sslCtxManager_->getDefaultSSLCtx());
194   }
195
196   base_ = eventBase;
197   state_ = State::kRunning;
198   downstreamConnectionManager_ = ConnectionManager::makeUnique(
199     eventBase, accConfig_.connectionIdleTimeout, this);
200
201   if (serverSocket) {
202     serverSocket->addAcceptCallback(this, eventBase);
203
204     for (auto& fd : serverSocket->getSockets()) {
205       if (fd < 0) {
206         continue;
207       }
208       for (const auto& opt: socketOptions_) {
209         opt.first.apply(fd, opt.second);
210       }
211     }
212   }
213 }
214
215 Acceptor::~Acceptor(void) {
216 }
217
218 void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) {
219   sslCtxManager_->addSSLContextConfig(sslCtxConfig,
220                                       accConfig_.sslCacheOptions,
221                                       &accConfig_.initialTicketSeeds,
222                                       accConfig_.bindAddress,
223                                       cacheProvider_);
224 }
225
226 void
227 Acceptor::drainAllConnections() {
228   if (downstreamConnectionManager_) {
229     downstreamConnectionManager_->initiateGracefulShutdown(
230       std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms));
231   }
232 }
233
234 void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from,
235                        IConnectionCounter* counter) {
236   loadShedConfig_ = from;
237   connectionCounter_ = counter;
238 }
239
240 bool Acceptor::canAccept(const SocketAddress& address) {
241   if (!connectionCounter_) {
242     return true;
243   }
244
245   uint64_t maxConnections = connectionCounter_->getMaxConnections();
246   if (maxConnections == 0) {
247     return true;
248   }
249
250   uint64_t currentConnections = connectionCounter_->getNumConnections();
251   if (currentConnections < maxConnections) {
252     return true;
253   }
254
255   if (loadShedConfig_.isWhitelisted(address)) {
256     return true;
257   }
258
259   // Take care of comparing connection count against max connections across
260   // all acceptors. Expensive since a lock must be taken to get the counter.
261   auto connectionCountForLoadShedding = getConnectionCountForLoadShedding();
262   if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) {
263     return true;
264   }
265
266   VLOG(4) << address.describe() << " not whitelisted";
267   return false;
268 }
269
270 void
271 Acceptor::connectionAccepted(
272     int fd, const SocketAddress& clientAddr) noexcept {
273   if (!canAccept(clientAddr)) {
274     close(fd);
275     return;
276   }
277   auto acceptTime = std::chrono::steady_clock::now();
278   for (const auto& opt: socketOptions_) {
279     opt.first.apply(fd, opt.second);
280   }
281
282   onDoneAcceptingConnection(fd, clientAddr, acceptTime);
283 }
284
285 void Acceptor::onDoneAcceptingConnection(
286     int fd,
287     const SocketAddress& clientAddr,
288     std::chrono::steady_clock::time_point acceptTime) noexcept {
289   TransportInfo tinfo;
290   processEstablishedConnection(fd, clientAddr, acceptTime, tinfo);
291 }
292
293 void
294 Acceptor::processEstablishedConnection(
295     int fd,
296     const SocketAddress& clientAddr,
297     std::chrono::steady_clock::time_point acceptTime,
298     TransportInfo& tinfo) noexcept {
299   if (accConfig_.isSSL()) {
300     CHECK(sslCtxManager_);
301     AsyncSSLSocket::UniquePtr sslSock(
302       makeNewAsyncSSLSocket(
303         sslCtxManager_->getDefaultSSLCtx(), base_, fd));
304     ++numPendingSSLConns_;
305     ++totalNumPendingSSLConns_;
306     if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) {
307       VLOG(2) << "dropped SSL handshake on " << accConfig_.name <<
308         " too many handshakes in progress";
309       updateSSLStats(sslSock.get(), std::chrono::milliseconds(0),
310                      SSLErrorEnum::DROPPED);
311       sslConnectionError();
312       return;
313     }
314     new AcceptorHandshakeHelper(
315       std::move(sslSock),
316       this,
317       clientAddr,
318       acceptTime,
319       tinfo
320     );
321   } else {
322     tinfo.ssl = false;
323     tinfo.acceptTime = acceptTime;
324     AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd));
325     connectionReady(std::move(sock), clientAddr, empty_string, tinfo);
326   }
327 }
328
329 void
330 Acceptor::connectionReady(
331     AsyncSocket::UniquePtr sock,
332     const SocketAddress& clientAddr,
333     const string& nextProtocolName,
334     TransportInfo& tinfo) {
335   // Limit the number of reads from the socket per poll loop iteration,
336   // both to keep memory usage under control and to prevent one fast-
337   // writing client from starving other connections.
338   sock->setMaxReadsPerEvent(16);
339   tinfo.initWithSocket(sock.get());
340   onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo);
341 }
342
343 void
344 Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock,
345                              const SocketAddress& clientAddr,
346                              const string& nextProtocol,
347                              TransportInfo& tinfo) {
348   CHECK(numPendingSSLConns_ > 0);
349   connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo);
350   --numPendingSSLConns_;
351   --totalNumPendingSSLConns_;
352   if (state_ == State::kDraining) {
353     checkDrained();
354   }
355 }
356
357 void
358 Acceptor::sslConnectionError() {
359   CHECK(numPendingSSLConns_ > 0);
360   --numPendingSSLConns_;
361   --totalNumPendingSSLConns_;
362   if (state_ == State::kDraining) {
363     checkDrained();
364   }
365 }
366
367 void
368 Acceptor::acceptError(const std::exception& ex) noexcept {
369   // An error occurred.
370   // The most likely error is out of FDs.  AsyncServerSocket will back off
371   // briefly if we are out of FDs, then continue accepting later.
372   // Just log a message here.
373   LOG(ERROR) << "error accepting on acceptor socket: " << ex.what();
374 }
375
376 void
377 Acceptor::acceptStopped() noexcept {
378   VLOG(3) << "Acceptor " << this << " acceptStopped()";
379   // Drain the open client connections
380   drainAllConnections();
381
382   // If we haven't yet finished draining, begin doing so by marking ourselves
383   // as in the draining state. We must be sure to hit checkDrained() here, as
384   // if we're completely idle, we can should consider ourself drained
385   // immediately (as there is no outstanding work to complete to cause us to
386   // re-evaluate this).
387   if (state_ != State::kDone) {
388     state_ = State::kDraining;
389     checkDrained();
390   }
391 }
392
393 void
394 Acceptor::onEmpty(const ConnectionManager& cm) {
395   VLOG(3) << "Acceptor=" << this << " onEmpty()";
396   if (state_ == State::kDraining) {
397     checkDrained();
398   }
399 }
400
401 void
402 Acceptor::checkDrained() {
403   CHECK(state_ == State::kDraining);
404   if (forceShutdownInProgress_ ||
405       (downstreamConnectionManager_->getNumConnections() != 0) ||
406       (numPendingSSLConns_ != 0)) {
407     return;
408   }
409
410   VLOG(2) << "All connections drained from Acceptor=" << this << " in thread "
411           << base_;
412
413   downstreamConnectionManager_.reset();
414
415   state_ = State::kDone;
416
417   onConnectionsDrained();
418 }
419
420 milliseconds
421 Acceptor::getConnTimeout() const {
422   return accConfig_.connectionIdleTimeout;
423 }
424
425 void Acceptor::addConnection(ManagedConnection* conn) {
426   // Add the socket to the timeout manager so that it can be cleaned
427   // up after being left idle for a long time.
428   downstreamConnectionManager_->addConnection(conn, true);
429 }
430
431 void
432 Acceptor::forceStop() {
433   base_->runInEventBaseThread([&] { dropAllConnections(); });
434 }
435
436 void
437 Acceptor::dropAllConnections() {
438   if (downstreamConnectionManager_) {
439     VLOG(3) << "Dropping all connections from Acceptor=" << this <<
440       " in thread " << base_;
441     assert(base_->isInEventBaseThread());
442     forceShutdownInProgress_ = true;
443     downstreamConnectionManager_->dropAllConnections();
444     CHECK(downstreamConnectionManager_->getNumConnections() == 0);
445     downstreamConnectionManager_.reset();
446   }
447   CHECK(numPendingSSLConns_ == 0);
448
449   state_ = State::kDone;
450   onConnectionsDrained();
451 }
452
453 } // namespace