fix open source build
[folly.git] / folly / wangle / acceptor / Acceptor.cpp
1 /*
2  *  Copyright (c) 2014, Facebook, Inc.
3  *  All rights reserved.
4  *
5  *  This source code is licensed under the BSD-style license found in the
6  *  LICENSE file in the root directory of this source tree. An additional grant
7  *  of patent rights can be found in the PATENTS file in the same directory.
8  *
9  */
10 #include <folly/wangle/acceptor/Acceptor.h>
11
12 #include <folly/wangle/acceptor/ManagedConnection.h>
13 #include <folly/wangle/ssl/SSLContextManager.h>
14
15 #include <boost/cast.hpp>
16 #include <fcntl.h>
17 #include <folly/ScopeGuard.h>
18 #include <folly/io/async/EventBase.h>
19 #include <fstream>
20 #include <sys/socket.h>
21 #include <sys/types.h>
22 #include <folly/io/async/AsyncSSLSocket.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <gflags/gflags.h>
25 #include <unistd.h>
26
27 using folly::wangle::ConnectionManager;
28 using folly::wangle::ManagedConnection;
29 using std::chrono::microseconds;
30 using std::chrono::milliseconds;
31 using std::filebuf;
32 using std::ifstream;
33 using std::ios;
34 using std::shared_ptr;
35 using std::string;
36
37 namespace folly {
38
39 #ifndef NO_LIB_GFLAGS
40 DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before "
41              "closing idle conns");
42 #else
43 const int32_t FLAGS_shutdown_idle_grace_ms = 5000;
44 #endif
45
46 static const std::string empty_string;
47 std::atomic<uint64_t> Acceptor::totalNumPendingSSLConns_{0};
48
49 /**
50  * Lightweight wrapper class to keep track of a newly
51  * accepted connection during SSL handshaking.
52  */
53 class AcceptorHandshakeHelper :
54       public AsyncSSLSocket::HandshakeCB,
55       public ManagedConnection {
56  public:
57   AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket,
58                           Acceptor* acceptor,
59                           const SocketAddress& clientAddr,
60                           std::chrono::steady_clock::time_point acceptTime)
61     : socket_(std::move(socket)), acceptor_(acceptor),
62       acceptTime_(acceptTime), clientAddr_(clientAddr) {
63     acceptor_->downstreamConnectionManager_->addConnection(this, true);
64     if(acceptor_->parseClientHello_)  {
65       socket_->enableClientHelloParsing();
66     }
67     socket_->sslAccept(this);
68   }
69
70   virtual void timeoutExpired() noexcept {
71     VLOG(4) << "SSL handshake timeout expired";
72     sslError_ = SSLErrorEnum::TIMEOUT;
73     dropConnection();
74   }
75   virtual void describe(std::ostream& os) const {
76     os << "pending handshake on " << clientAddr_;
77   }
78   virtual bool isBusy() const {
79     return true;
80   }
81   virtual void notifyPendingShutdown() {}
82   virtual void closeWhenIdle() {}
83
84   virtual void dropConnection() {
85     VLOG(10) << "Dropping in progress handshake for " << clientAddr_;
86     socket_->closeNow();
87   }
88   virtual void dumpConnectionState(uint8_t loglevel) {
89   }
90
91  private:
92   // AsyncSSLSocket::HandshakeCallback API
93   virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept {
94
95     const unsigned char* nextProto = nullptr;
96     unsigned nextProtoLength = 0;
97     sock->getSelectedNextProtocol(&nextProto, &nextProtoLength);
98     if (VLOG_IS_ON(3)) {
99       if (nextProto) {
100         VLOG(3) << "Client selected next protocol " <<
101             string((const char*)nextProto, nextProtoLength);
102       } else {
103         VLOG(3) << "Client did not select a next protocol";
104       }
105     }
106
107     // fill in SSL-related fields from TransportInfo
108     // the other fields like RTT are filled in the Acceptor
109     TransportInfo tinfo;
110     tinfo.ssl = true;
111     tinfo.acceptTime = acceptTime_;
112     tinfo.sslSetupTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
113     tinfo.sslSetupBytesRead = sock->getRawBytesReceived();
114     tinfo.sslSetupBytesWritten = sock->getRawBytesWritten();
115     tinfo.sslServerName = sock->getSSLServerName() ?
116       std::make_shared<std::string>(sock->getSSLServerName()) : nullptr;
117     tinfo.sslCipher = sock->getNegotiatedCipherName() ?
118       std::make_shared<std::string>(sock->getNegotiatedCipherName()) : nullptr;
119     tinfo.sslVersion = sock->getSSLVersion();
120     tinfo.sslCertSize = sock->getSSLCertSize();
121     tinfo.sslResume = SSLUtil::getResumeState(sock);
122     tinfo.sslClientCiphers = std::make_shared<std::string>();
123     sock->getSSLClientCiphers(*tinfo.sslClientCiphers);
124     tinfo.sslServerCiphers = std::make_shared<std::string>();
125     sock->getSSLServerCiphers(*tinfo.sslServerCiphers);
126     tinfo.sslClientComprMethods =
127         std::make_shared<std::string>(sock->getSSLClientComprMethods());
128     tinfo.sslClientExts =
129         std::make_shared<std::string>(sock->getSSLClientExts());
130     tinfo.sslNextProtocol = std::make_shared<std::string>();
131     tinfo.sslNextProtocol->assign(reinterpret_cast<const char*>(nextProto),
132                                   nextProtoLength);
133
134     acceptor_->updateSSLStats(sock, tinfo.sslSetupTime, SSLErrorEnum::NO_ERROR);
135     acceptor_->downstreamConnectionManager_->removeConnection(this);
136     acceptor_->sslConnectionReady(std::move(socket_), clientAddr_,
137         nextProto ? string((const char*)nextProto, nextProtoLength) :
138                                   empty_string, tinfo);
139     delete this;
140   }
141
142   virtual void handshakeErr(AsyncSSLSocket* sock,
143                             const AsyncSocketException& ex) noexcept {
144     auto elapsedTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - acceptTime_);
145     VLOG(3) << "SSL handshake error after " << elapsedTime.count() <<
146         " ms; " << sock->getRawBytesReceived() << " bytes received & " <<
147         sock->getRawBytesWritten() << " bytes sent: " <<
148         ex.what();
149     acceptor_->updateSSLStats(sock, elapsedTime, sslError_);
150     acceptor_->sslConnectionError();
151     delete this;
152   }
153
154   AsyncSSLSocket::UniquePtr socket_;
155   Acceptor* acceptor_;
156   std::chrono::steady_clock::time_point acceptTime_;
157   SocketAddress clientAddr_;
158   SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR};
159 };
160
161 Acceptor::Acceptor(const ServerSocketConfig& accConfig) :
162   accConfig_(accConfig),
163   socketOptions_(accConfig.getSocketOptions()) {
164 }
165
166 void
167 Acceptor::init(AsyncServerSocket* serverSocket,
168                EventBase* eventBase) {
169   CHECK(nullptr == this->base_);
170
171   if (accConfig_.isSSL()) {
172     if (!sslCtxManager_) {
173       sslCtxManager_ = folly::make_unique<SSLContextManager>(
174         eventBase,
175         "vip_" + getName(),
176         accConfig_.strictSSL, nullptr);
177     }
178     for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) {
179       sslCtxManager_->addSSLContextConfig(
180         sslCtxConfig,
181         accConfig_.sslCacheOptions,
182         &accConfig_.initialTicketSeeds,
183         accConfig_.bindAddress,
184         cacheProvider_);
185       parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled;
186     }
187
188     CHECK(sslCtxManager_->getDefaultSSLCtx());
189   }
190
191   base_ = eventBase;
192   state_ = State::kRunning;
193   downstreamConnectionManager_ = ConnectionManager::makeUnique(
194     eventBase, accConfig_.connectionIdleTimeout, this);
195
196   if (serverSocket) {
197     serverSocket->addAcceptCallback(this, eventBase);
198
199     for (auto& fd : serverSocket->getSockets()) {
200       if (fd < 0) {
201         continue;
202       }
203       for (const auto& opt: socketOptions_) {
204         opt.first.apply(fd, opt.second);
205       }
206     }
207   }
208 }
209
210 Acceptor::~Acceptor(void) {
211 }
212
213 void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) {
214   sslCtxManager_->addSSLContextConfig(sslCtxConfig,
215                                       accConfig_.sslCacheOptions,
216                                       &accConfig_.initialTicketSeeds,
217                                       accConfig_.bindAddress,
218                                       cacheProvider_);
219 }
220
221 void
222 Acceptor::drainAllConnections() {
223   if (downstreamConnectionManager_) {
224     downstreamConnectionManager_->initiateGracefulShutdown(
225       std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms));
226   }
227 }
228
229 void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from,
230                        IConnectionCounter* counter) {
231   loadShedConfig_ = from;
232   connectionCounter_ = counter;
233 }
234
235 bool Acceptor::canAccept(const SocketAddress& address) {
236   if (!connectionCounter_) {
237     return true;
238   }
239
240   uint64_t maxConnections = connectionCounter_->getMaxConnections();
241   if (maxConnections == 0) {
242     return true;
243   }
244
245   uint64_t currentConnections = connectionCounter_->getNumConnections();
246   if (currentConnections < maxConnections) {
247     return true;
248   }
249
250   if (loadShedConfig_.isWhitelisted(address)) {
251     return true;
252   }
253
254   // Take care of comparing connection count against max connections across
255   // all acceptors. Expensive since a lock must be taken to get the counter.
256   auto connectionCountForLoadShedding = getConnectionCountForLoadShedding();
257   if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) {
258     return true;
259   }
260
261   VLOG(4) << address.describe() << " not whitelisted";
262   return false;
263 }
264
265 void
266 Acceptor::connectionAccepted(
267     int fd, const SocketAddress& clientAddr) noexcept {
268   if (!canAccept(clientAddr)) {
269     close(fd);
270     return;
271   }
272   auto acceptTime = std::chrono::steady_clock::now();
273   for (const auto& opt: socketOptions_) {
274     opt.first.apply(fd, opt.second);
275   }
276
277   onDoneAcceptingConnection(fd, clientAddr, acceptTime);
278 }
279
280 void Acceptor::onDoneAcceptingConnection(
281     int fd,
282     const SocketAddress& clientAddr,
283     std::chrono::steady_clock::time_point acceptTime) noexcept {
284   processEstablishedConnection(fd, clientAddr, acceptTime);
285 }
286
287 void
288 Acceptor::processEstablishedConnection(
289     int fd,
290     const SocketAddress& clientAddr,
291     std::chrono::steady_clock::time_point acceptTime) noexcept {
292   if (accConfig_.isSSL()) {
293     CHECK(sslCtxManager_);
294     AsyncSSLSocket::UniquePtr sslSock(
295       makeNewAsyncSSLSocket(
296         sslCtxManager_->getDefaultSSLCtx(), base_, fd));
297     ++numPendingSSLConns_;
298     ++totalNumPendingSSLConns_;
299     if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) {
300       VLOG(2) << "dropped SSL handshake on " << accConfig_.name <<
301         " too many handshakes in progress";
302       updateSSLStats(sslSock.get(), std::chrono::milliseconds(0),
303                      SSLErrorEnum::DROPPED);
304       sslConnectionError();
305       return;
306     }
307     new AcceptorHandshakeHelper(
308       std::move(sslSock), this, clientAddr, acceptTime);
309   } else {
310     TransportInfo tinfo;
311     tinfo.ssl = false;
312     tinfo.acceptTime = acceptTime;
313     AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd));
314     connectionReady(std::move(sock), clientAddr, empty_string, tinfo);
315   }
316 }
317
318 void
319 Acceptor::connectionReady(
320     AsyncSocket::UniquePtr sock,
321     const SocketAddress& clientAddr,
322     const string& nextProtocolName,
323     TransportInfo& tinfo) {
324   // Limit the number of reads from the socket per poll loop iteration,
325   // both to keep memory usage under control and to prevent one fast-
326   // writing client from starving other connections.
327   sock->setMaxReadsPerEvent(16);
328   tinfo.initWithSocket(sock.get());
329   onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo);
330 }
331
332 void
333 Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock,
334                              const SocketAddress& clientAddr,
335                              const string& nextProtocol,
336                              TransportInfo& tinfo) {
337   CHECK(numPendingSSLConns_ > 0);
338   connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo);
339   --numPendingSSLConns_;
340   --totalNumPendingSSLConns_;
341   if (state_ == State::kDraining) {
342     checkDrained();
343   }
344 }
345
346 void
347 Acceptor::sslConnectionError() {
348   CHECK(numPendingSSLConns_ > 0);
349   --numPendingSSLConns_;
350   --totalNumPendingSSLConns_;
351   if (state_ == State::kDraining) {
352     checkDrained();
353   }
354 }
355
356 void
357 Acceptor::acceptError(const std::exception& ex) noexcept {
358   // An error occurred.
359   // The most likely error is out of FDs.  AsyncServerSocket will back off
360   // briefly if we are out of FDs, then continue accepting later.
361   // Just log a message here.
362   LOG(ERROR) << "error accepting on acceptor socket: " << ex.what();
363 }
364
365 void
366 Acceptor::acceptStopped() noexcept {
367   VLOG(3) << "Acceptor " << this << " acceptStopped()";
368   // Drain the open client connections
369   drainAllConnections();
370
371   // If we haven't yet finished draining, begin doing so by marking ourselves
372   // as in the draining state. We must be sure to hit checkDrained() here, as
373   // if we're completely idle, we can should consider ourself drained
374   // immediately (as there is no outstanding work to complete to cause us to
375   // re-evaluate this).
376   if (state_ != State::kDone) {
377     state_ = State::kDraining;
378     checkDrained();
379   }
380 }
381
382 void
383 Acceptor::onEmpty(const ConnectionManager& cm) {
384   VLOG(3) << "Acceptor=" << this << " onEmpty()";
385   if (state_ == State::kDraining) {
386     checkDrained();
387   }
388 }
389
390 void
391 Acceptor::checkDrained() {
392   CHECK(state_ == State::kDraining);
393   if (forceShutdownInProgress_ ||
394       (downstreamConnectionManager_->getNumConnections() != 0) ||
395       (numPendingSSLConns_ != 0)) {
396     return;
397   }
398
399   VLOG(2) << "All connections drained from Acceptor=" << this << " in thread "
400           << base_;
401
402   downstreamConnectionManager_.reset();
403
404   state_ = State::kDone;
405
406   onConnectionsDrained();
407 }
408
409 milliseconds
410 Acceptor::getConnTimeout() const {
411   return accConfig_.connectionIdleTimeout;
412 }
413
414 void Acceptor::addConnection(ManagedConnection* conn) {
415   // Add the socket to the timeout manager so that it can be cleaned
416   // up after being left idle for a long time.
417   downstreamConnectionManager_->addConnection(conn, true);
418 }
419
420 void
421 Acceptor::forceStop() {
422   base_->runInEventBaseThread([&] { dropAllConnections(); });
423 }
424
425 void
426 Acceptor::dropAllConnections() {
427   if (downstreamConnectionManager_) {
428     VLOG(3) << "Dropping all connections from Acceptor=" << this <<
429       " in thread " << base_;
430     assert(base_->isInEventBaseThread());
431     forceShutdownInProgress_ = true;
432     downstreamConnectionManager_->dropAllConnections();
433     CHECK(downstreamConnectionManager_->getNumConnections() == 0);
434     downstreamConnectionManager_.reset();
435   }
436   CHECK(numPendingSSLConns_ == 0);
437
438   state_ = State::kDone;
439   onConnectionsDrained();
440 }
441
442 } // namespace