a7167989505714fa2a0087e65a864154a4f504c7
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSSLSocket.h>
18
19 #include <folly/io/async/EventBase.h>
20 #include <folly/portability/Sockets.h>
21
22 #include <boost/noncopyable.hpp>
23 #include <errno.h>
24 #include <fcntl.h>
25 #include <openssl/err.h>
26 #include <openssl/asn1.h>
27 #include <openssl/ssl.h>
28 #include <sys/types.h>
29 #include <chrono>
30
31 #include <folly/Bits.h>
32 #include <folly/SocketAddress.h>
33 #include <folly/SpinLock.h>
34 #include <folly/io/IOBuf.h>
35 #include <folly/io/Cursor.h>
36 #include <folly/portability/Unistd.h>
37
38 using folly::SocketAddress;
39 using folly::SSLContext;
40 using std::string;
41 using std::shared_ptr;
42
43 using folly::Endian;
44 using folly::IOBuf;
45 using folly::SpinLock;
46 using folly::SpinLockGuard;
47 using folly::io::Cursor;
48 using std::unique_ptr;
49 using std::bind;
50
51 namespace {
52 using folly::AsyncSocket;
53 using folly::AsyncSocketException;
54 using folly::AsyncSSLSocket;
55 using folly::Optional;
56 using folly::SSLContext;
57 using folly::ssl::OpenSSLUtils;
58
59 // We have one single dummy SSL context so that we can implement attach
60 // and detach methods in a thread safe fashion without modifying opnessl.
61 static SSLContext *dummyCtx = nullptr;
62 static SpinLock dummyCtxLock;
63
64 // If given min write size is less than this, buffer will be allocated on
65 // stack, otherwise it is allocated on heap
66 const size_t MAX_STACK_BUF_SIZE = 2048;
67
68 // This converts "illegal" shutdowns into ZERO_RETURN
69 inline bool zero_return(int error, int rc) {
70   return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0));
71 }
72
73 class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
74                                 public AsyncSSLSocket::HandshakeCB {
75
76  private:
77   AsyncSSLSocket *sslSocket_;
78   AsyncSSLSocket::ConnectCallback *callback_;
79   int timeout_;
80   int64_t startTime_;
81
82  protected:
83   ~AsyncSSLSocketConnector() override {}
84
85  public:
86   AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
87                            AsyncSocket::ConnectCallback *callback,
88                            int timeout) :
89       sslSocket_(sslSocket),
90       callback_(callback),
91       timeout_(timeout),
92       startTime_(std::chrono::duration_cast<std::chrono::milliseconds>(
93                    std::chrono::steady_clock::now().time_since_epoch()).count()) {
94   }
95
96   void connectSuccess() noexcept override {
97     VLOG(7) << "client socket connected";
98
99     int64_t timeoutLeft = 0;
100     if (timeout_ > 0) {
101       auto curTime = std::chrono::duration_cast<std::chrono::milliseconds>(
102         std::chrono::steady_clock::now().time_since_epoch()).count();
103
104       timeoutLeft = timeout_ - (curTime - startTime_);
105       if (timeoutLeft <= 0) {
106         AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
107                                 "SSL connect timed out");
108         fail(ex);
109         delete this;
110         return;
111       }
112     }
113     sslSocket_->sslConn(this, timeoutLeft);
114   }
115
116   void connectErr(const AsyncSocketException& ex) noexcept override {
117     VLOG(1) << "TCP connect failed: " << ex.what();
118     fail(ex);
119     delete this;
120   }
121
122   void handshakeSuc(AsyncSSLSocket* /* sock */) noexcept override {
123     VLOG(7) << "client handshake success";
124     if (callback_) {
125       callback_->connectSuccess();
126     }
127     delete this;
128   }
129
130   void handshakeErr(AsyncSSLSocket* /* socket */,
131                     const AsyncSocketException& ex) noexcept override {
132     VLOG(1) << "client handshakeErr: " << ex.what();
133     fail(ex);
134     delete this;
135   }
136
137   void fail(const AsyncSocketException &ex) {
138     // fail is a noop if called twice
139     if (callback_) {
140       AsyncSSLSocket::ConnectCallback *cb = callback_;
141       callback_ = nullptr;
142
143       cb->connectErr(ex);
144       sslSocket_->closeNow();
145       // closeNow can call handshakeErr if it hasn't been called already.
146       // So this may have been deleted, no member variable access beyond this
147       // point
148       // Note that closeNow may invoke writeError callbacks if the socket had
149       // write data pending connection completion.
150     }
151   }
152 };
153
154 // XXX: implement an equivalent to corking for platforms with TCP_NOPUSH?
155 #ifdef TCP_CORK // Linux-only
156 /**
157  * Utility class that corks a TCP socket upon construction or uncorks
158  * the socket upon destruction
159  */
160 class CorkGuard : private boost::noncopyable {
161  public:
162   CorkGuard(int fd, bool multipleWrites, bool haveMore, bool* corked):
163     fd_(fd), haveMore_(haveMore), corked_(corked) {
164     if (*corked_) {
165       // socket is already corked; nothing to do
166       return;
167     }
168     if (multipleWrites || haveMore) {
169       // We are performing multiple writes in this performWrite() call,
170       // and/or there are more calls to performWrite() that will be invoked
171       // later, so enable corking
172       int flag = 1;
173       setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
174       *corked_ = true;
175     }
176   }
177
178   ~CorkGuard() {
179     if (haveMore_) {
180       // more data to come; don't uncork yet
181       return;
182     }
183     if (!*corked_) {
184       // socket isn't corked; nothing to do
185       return;
186     }
187
188     int flag = 0;
189     setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
190     *corked_ = false;
191   }
192
193  private:
194   int fd_;
195   bool haveMore_;
196   bool* corked_;
197 };
198 #else
199 class CorkGuard : private boost::noncopyable {
200  public:
201   CorkGuard(int, bool, bool, bool*) {}
202 };
203 #endif
204
205 void setup_SSL_CTX(SSL_CTX *ctx) {
206 #ifdef SSL_MODE_RELEASE_BUFFERS
207   SSL_CTX_set_mode(ctx,
208                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
209                    SSL_MODE_ENABLE_PARTIAL_WRITE
210                    | SSL_MODE_RELEASE_BUFFERS
211                    );
212 #else
213   SSL_CTX_set_mode(ctx,
214                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
215                    SSL_MODE_ENABLE_PARTIAL_WRITE
216                    );
217 #endif
218 // SSL_CTX_set_mode is a Macro
219 #ifdef SSL_MODE_WRITE_IOVEC
220   SSL_CTX_set_mode(ctx,
221                    SSL_CTX_get_mode(ctx)
222                    | SSL_MODE_WRITE_IOVEC);
223 #endif
224
225 }
226
227 BIO_METHOD sslWriteBioMethod;
228
229 void* initsslWriteBioMethod(void) {
230   memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod));
231   // override the bwrite method for MSG_EOR support
232   OpenSSLUtils::setCustomBioWriteMethod(
233       &sslWriteBioMethod, AsyncSSLSocket::bioWrite);
234
235   // Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not
236   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
237   // then have specific handlings. The sslWriteBioWrite should be compatible
238   // with the one in openssl.
239
240   // Return something here to enable AsyncSSLSocket to call this method using
241   // a function-scoped static.
242   return nullptr;
243 }
244
245 } // anonymous namespace
246
247 namespace folly {
248
249 /**
250  * Create a client AsyncSSLSocket
251  */
252 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
253                                EventBase* evb, bool deferSecurityNegotiation) :
254     AsyncSocket(evb),
255     ctx_(ctx),
256     handshakeTimeout_(this, evb),
257     connectionTimeout_(this, evb) {
258   init();
259   if (deferSecurityNegotiation) {
260     sslState_ = STATE_UNENCRYPTED;
261   }
262 }
263
264 /**
265  * Create a server/client AsyncSSLSocket
266  */
267 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
268                                EventBase* evb, int fd, bool server,
269                                bool deferSecurityNegotiation) :
270     AsyncSocket(evb, fd),
271     server_(server),
272     ctx_(ctx),
273     handshakeTimeout_(this, evb),
274     connectionTimeout_(this, evb) {
275   init();
276   if (server) {
277     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
278                               AsyncSSLSocket::sslInfoCallback);
279   }
280   if (deferSecurityNegotiation) {
281     sslState_ = STATE_UNENCRYPTED;
282   }
283 }
284
285 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
286 /**
287  * Create a client AsyncSSLSocket and allow tlsext_hostname
288  * to be sent in Client Hello.
289  */
290 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
291                                  EventBase* evb,
292                                const std::string& serverName,
293                                bool deferSecurityNegotiation) :
294     AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
295   tlsextHostname_ = serverName;
296 }
297
298 /**
299  * Create a client AsyncSSLSocket from an already connected fd
300  * and allow tlsext_hostname to be sent in Client Hello.
301  */
302 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
303                                  EventBase* evb, int fd,
304                                const std::string& serverName,
305                                bool deferSecurityNegotiation) :
306     AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
307   tlsextHostname_ = serverName;
308 }
309 #endif
310
311 AsyncSSLSocket::~AsyncSSLSocket() {
312   VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
313           << ", evb=" << eventBase_ << ", fd=" << fd_
314           << ", state=" << int(state_) << ", sslState="
315           << sslState_ << ", events=" << eventFlags_ << ")";
316 }
317
318 void AsyncSSLSocket::init() {
319   // Do this here to ensure we initialize this once before any use of
320   // AsyncSSLSocket instances and not as part of library load.
321   static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod();
322   (void)sslWriteBioMethodInitializer;
323
324   setup_SSL_CTX(ctx_->getSSLCtx());
325 }
326
327 void AsyncSSLSocket::closeNow() {
328   // Close the SSL connection.
329   if (ssl_ != nullptr && fd_ != -1) {
330     int rc = SSL_shutdown(ssl_);
331     if (rc == 0) {
332       rc = SSL_shutdown(ssl_);
333     }
334     if (rc < 0) {
335       ERR_clear_error();
336     }
337   }
338
339   if (sslSession_ != nullptr) {
340     SSL_SESSION_free(sslSession_);
341     sslSession_ = nullptr;
342   }
343
344   sslState_ = STATE_CLOSED;
345
346   if (handshakeTimeout_.isScheduled()) {
347     handshakeTimeout_.cancelTimeout();
348   }
349
350   DestructorGuard dg(this);
351
352   invokeHandshakeErr(
353       AsyncSocketException(
354         AsyncSocketException::END_OF_FILE,
355         "SSL connection closed locally"));
356
357   if (ssl_ != nullptr) {
358     SSL_free(ssl_);
359     ssl_ = nullptr;
360   }
361
362   // Close the socket.
363   AsyncSocket::closeNow();
364 }
365
366 void AsyncSSLSocket::shutdownWrite() {
367   // SSL sockets do not support half-shutdown, so just perform a full shutdown.
368   //
369   // (Performing a full shutdown here is more desirable than doing nothing at
370   // all.  The purpose of shutdownWrite() is normally to notify the other end
371   // of the connection that no more data will be sent.  If we do nothing, the
372   // other end will never know that no more data is coming, and this may result
373   // in protocol deadlock.)
374   close();
375 }
376
377 void AsyncSSLSocket::shutdownWriteNow() {
378   closeNow();
379 }
380
381 bool AsyncSSLSocket::good() const {
382   return (AsyncSocket::good() &&
383           (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
384            sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
385 }
386
387 // The TAsyncTransport definition of 'good' states that the transport is
388 // ready to perform reads and writes, so sslState_ == UNINIT must report !good.
389 // connecting can be true when the sslState_ == UNINIT because the AsyncSocket
390 // is connected but we haven't initiated the call to SSL_connect.
391 bool AsyncSSLSocket::connecting() const {
392   return (!server_ &&
393           (AsyncSocket::connecting() ||
394            (AsyncSocket::good() && (sslState_ == STATE_UNINIT ||
395                                      sslState_ == STATE_CONNECTING))));
396 }
397
398 std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
399   const unsigned char* protoName = nullptr;
400   unsigned protoLength;
401   if (getSelectedNextProtocolNoThrow(&protoName, &protoLength)) {
402     return std::string(reinterpret_cast<const char*>(protoName), protoLength);
403   }
404   return "";
405 }
406
407 bool AsyncSSLSocket::isEorTrackingEnabled() const {
408   return trackEor_;
409 }
410
411 void AsyncSSLSocket::setEorTracking(bool track) {
412   if (trackEor_ != track) {
413     trackEor_ = track;
414     appEorByteNo_ = 0;
415     minEorRawByteNo_ = 0;
416   }
417 }
418
419 size_t AsyncSSLSocket::getRawBytesWritten() const {
420   // The bio(s) in the write path are in a chain
421   // each bio flushes to the next and finally written into the socket
422   // to get the rawBytesWritten on the socket,
423   // get the write bytes of the last bio
424   BIO *b;
425   if (!ssl_ || !(b = SSL_get_wbio(ssl_))) {
426     return 0;
427   }
428   BIO* next = BIO_next(b);
429   while (next != NULL) {
430     b = next;
431     next = BIO_next(b);
432   }
433
434   return BIO_number_written(b);
435 }
436
437 size_t AsyncSSLSocket::getRawBytesReceived() const {
438   BIO *b;
439   if (!ssl_ || !(b = SSL_get_rbio(ssl_))) {
440     return 0;
441   }
442
443   return BIO_number_read(b);
444 }
445
446
447 void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
448   LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_
449              << ", state=" << int(state_) << ", sslState=" << sslState_ << ", "
450              << "events=" << eventFlags_ << ", server=" << short(server_)
451              << "): " << "sslAccept/Connect() called in invalid "
452              << "state, handshake callback " << handshakeCallback_
453              << ", new callback " << callback;
454   assert(!handshakeTimeout_.isScheduled());
455   sslState_ = STATE_ERROR;
456
457   AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
458                          "sslAccept() called with socket in invalid state");
459
460   handshakeEndTime_ = std::chrono::steady_clock::now();
461   if (callback) {
462     callback->handshakeErr(this, ex);
463   }
464
465   // Check the socket state not the ssl state here.
466   if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) {
467     failHandshake(__func__, ex);
468   }
469 }
470
471 void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
472       const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
473   DestructorGuard dg(this);
474   assert(eventBase_->isInEventBaseThread());
475   verifyPeer_ = verifyPeer;
476
477   // Make sure we're in the uninitialized state
478   if (!server_ || (sslState_ != STATE_UNINIT &&
479                    sslState_ != STATE_UNENCRYPTED) ||
480       handshakeCallback_ != nullptr) {
481     return invalidState(callback);
482   }
483
484   // Cache local and remote socket addresses to keep them available
485   // after socket file descriptor is closed.
486   if (cacheAddrOnFailure_ && -1 != getFd()) {
487     cacheLocalPeerAddr();
488   }
489
490   handshakeStartTime_ = std::chrono::steady_clock::now();
491   // Make end time at least >= start time.
492   handshakeEndTime_ = handshakeStartTime_;
493
494   sslState_ = STATE_ACCEPTING;
495   handshakeCallback_ = callback;
496
497   if (timeout > 0) {
498     handshakeTimeout_.scheduleTimeout(timeout);
499   }
500
501   /* register for a read operation (waiting for CLIENT HELLO) */
502   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
503 }
504
505 #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
506 void AsyncSSLSocket::attachSSLContext(
507   const std::shared_ptr<SSLContext>& ctx) {
508
509   // Check to ensure we are in client mode. Changing a server's ssl
510   // context doesn't make sense since clients of that server would likely
511   // become confused when the server's context changes.
512   DCHECK(!server_);
513   DCHECK(!ctx_);
514   DCHECK(ctx);
515   DCHECK(ctx->getSSLCtx());
516   ctx_ = ctx;
517
518   // In order to call attachSSLContext, detachSSLContext must have been
519   // previously called which sets the socket's context to the dummy
520   // context. Thus we must acquire this lock.
521   SpinLockGuard guard(dummyCtxLock);
522   SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx());
523 }
524
525 void AsyncSSLSocket::detachSSLContext() {
526   DCHECK(ctx_);
527   ctx_.reset();
528   // We aren't using the initial_ctx for now, and it can introduce race
529   // conditions in the destructor of the SSL object.
530 #ifndef OPENSSL_NO_TLSEXT
531   if (ssl_->initial_ctx) {
532     SSL_CTX_free(ssl_->initial_ctx);
533     ssl_->initial_ctx = nullptr;
534   }
535 #endif
536   SpinLockGuard guard(dummyCtxLock);
537   if (nullptr == dummyCtx) {
538     // We need to lazily initialize the dummy context so we don't
539     // accidentally override any programmatic settings to openssl
540     dummyCtx = new SSLContext;
541   }
542   // We must remove this socket's references to its context right now
543   // since this socket could get passed to any thread. If the context has
544   // had its locking disabled, just doing a set in attachSSLContext()
545   // would not be thread safe.
546   SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
547 }
548 #endif
549
550 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
551 void AsyncSSLSocket::switchServerSSLContext(
552   const std::shared_ptr<SSLContext>& handshakeCtx) {
553   CHECK(server_);
554   if (sslState_ != STATE_ACCEPTING) {
555     // We log it here and allow the switch.
556     // It should not affect our re-negotiation support (which
557     // is not supported now).
558     VLOG(6) << "fd=" << getFd()
559             << " renegotation detected when switching SSL_CTX";
560   }
561
562   setup_SSL_CTX(handshakeCtx->getSSLCtx());
563   SSL_CTX_set_info_callback(handshakeCtx->getSSLCtx(),
564                             AsyncSSLSocket::sslInfoCallback);
565   handshakeCtx_ = handshakeCtx;
566   SSL_set_SSL_CTX(ssl_, handshakeCtx->getSSLCtx());
567 }
568
569 bool AsyncSSLSocket::isServerNameMatch() const {
570   CHECK(!server_);
571
572   if (!ssl_) {
573     return false;
574   }
575
576   SSL_SESSION *ss = SSL_get_session(ssl_);
577   if (!ss) {
578     return false;
579   }
580
581   if(!ss->tlsext_hostname) {
582     return false;
583   }
584   return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
585 }
586
587 void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
588   tlsextHostname_ = std::move(serverName);
589 }
590
591 #endif
592
593 void AsyncSSLSocket::timeoutExpired() noexcept {
594   if (state_ == StateEnum::ESTABLISHED &&
595       (sslState_ == STATE_CACHE_LOOKUP ||
596        sslState_ == STATE_ASYNC_PENDING)) {
597     sslState_ = STATE_ERROR;
598     // We are expecting a callback in restartSSLAccept.  The cache lookup
599     // and rsa-call necessarily have pointers to this ssl socket, so delay
600     // the cleanup until he calls us back.
601   } else if (state_ == StateEnum::CONNECTING) {
602     assert(sslState_ == STATE_CONNECTING);
603     DestructorGuard dg(this);
604     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
605                            "Fallback connect timed out during TFO");
606     failHandshake(__func__, ex);
607   } else {
608     assert(state_ == StateEnum::ESTABLISHED &&
609            (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
610     DestructorGuard dg(this);
611     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
612                            (sslState_ == STATE_CONNECTING) ?
613                            "SSL connect timed out" : "SSL accept timed out");
614     failHandshake(__func__, ex);
615   }
616 }
617
618 int AsyncSSLSocket::getSSLExDataIndex() {
619   static auto index = SSL_get_ex_new_index(
620       0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
621   return index;
622 }
623
624 AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
625   return static_cast<AsyncSSLSocket *>(SSL_get_ex_data(ssl,
626       getSSLExDataIndex()));
627 }
628
629 void AsyncSSLSocket::failHandshake(const char* /* fn */,
630                                    const AsyncSocketException& ex) {
631   startFail();
632   if (handshakeTimeout_.isScheduled()) {
633     handshakeTimeout_.cancelTimeout();
634   }
635   invokeHandshakeErr(ex);
636   finishFail();
637 }
638
639 void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
640   handshakeEndTime_ = std::chrono::steady_clock::now();
641   if (handshakeCallback_ != nullptr) {
642     HandshakeCB* callback = handshakeCallback_;
643     handshakeCallback_ = nullptr;
644     callback->handshakeErr(this, ex);
645   }
646 }
647
648 void AsyncSSLSocket::invokeHandshakeCB() {
649   handshakeEndTime_ = std::chrono::steady_clock::now();
650   if (handshakeTimeout_.isScheduled()) {
651     handshakeTimeout_.cancelTimeout();
652   }
653   if (handshakeCallback_) {
654     HandshakeCB* callback = handshakeCallback_;
655     handshakeCallback_ = nullptr;
656     callback->handshakeSuc(this);
657   }
658 }
659
660 void AsyncSSLSocket::cacheLocalPeerAddr() {
661   SocketAddress address;
662   try {
663     getLocalAddress(&address);
664     getPeerAddress(&address);
665   } catch (const std::system_error& e) {
666     // The handle can be still valid while the connection is already closed.
667     if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
668       throw;
669     }
670   }
671 }
672
673 void AsyncSSLSocket::connect(ConnectCallback* callback,
674                               const folly::SocketAddress& address,
675                               int timeout,
676                               const OptionMap &options,
677                               const folly::SocketAddress& bindAddr)
678                               noexcept {
679   assert(!server_);
680   assert(state_ == StateEnum::UNINIT);
681   assert(sslState_ == STATE_UNINIT);
682   AsyncSSLSocketConnector *connector =
683     new AsyncSSLSocketConnector(this, callback, timeout);
684   AsyncSocket::connect(connector, address, timeout, options, bindAddr);
685 }
686
687 void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
688   // apply the settings specified in verifyPeer_
689   if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
690     if(ctx_->needsPeerVerification()) {
691       SSL_set_verify(ssl, ctx_->getVerificationMode(),
692         AsyncSSLSocket::sslVerifyCallback);
693     }
694   } else {
695     if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
696         verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT) {
697       SSL_set_verify(ssl, SSLContext::getVerificationMode(verifyPeer_),
698         AsyncSSLSocket::sslVerifyCallback);
699     }
700   }
701 }
702
703 bool AsyncSSLSocket::setupSSLBio() {
704   auto wb = BIO_new(&sslWriteBioMethod);
705
706   if (!wb) {
707     return false;
708   }
709
710   OpenSSLUtils::setBioAppData(wb, this);
711   OpenSSLUtils::setBioFd(wb, fd_, BIO_NOCLOSE);
712   SSL_set_bio(ssl_, wb, wb);
713   return true;
714 }
715
716 void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
717         const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
718   DestructorGuard dg(this);
719   assert(eventBase_->isInEventBaseThread());
720
721   // Cache local and remote socket addresses to keep them available
722   // after socket file descriptor is closed.
723   if (cacheAddrOnFailure_ && -1 != getFd()) {
724     cacheLocalPeerAddr();
725   }
726
727   verifyPeer_ = verifyPeer;
728
729   // Make sure we're in the uninitialized state
730   if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
731                   STATE_UNENCRYPTED) ||
732       handshakeCallback_ != nullptr) {
733     return invalidState(callback);
734   }
735
736   sslState_ = STATE_CONNECTING;
737   handshakeCallback_ = callback;
738
739   try {
740     ssl_ = ctx_->createSSL();
741   } catch (std::exception &e) {
742     sslState_ = STATE_ERROR;
743     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
744                            "error calling SSLContext::createSSL()");
745     LOG(ERROR) << "AsyncSSLSocket::sslConn(this=" << this << ", fd="
746             << fd_ << "): " << e.what();
747     return failHandshake(__func__, ex);
748   }
749
750   if (!setupSSLBio()) {
751     sslState_ = STATE_ERROR;
752     AsyncSocketException ex(
753         AsyncSocketException::INTERNAL_ERROR, "error creating SSL bio");
754     return failHandshake(__func__, ex);
755   }
756
757   applyVerificationOptions(ssl_);
758
759   if (sslSession_ != nullptr) {
760     SSL_set_session(ssl_, sslSession_);
761     SSL_SESSION_free(sslSession_);
762     sslSession_ = nullptr;
763   }
764 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
765   if (tlsextHostname_.size()) {
766     SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
767   }
768 #endif
769
770   SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
771
772   handshakeConnectTimeout_ = timeout;
773   startSSLConnect();
774 }
775
776 // This could be called multiple times, during normal ssl connections
777 // and after TFO fallback.
778 void AsyncSSLSocket::startSSLConnect() {
779   handshakeStartTime_ = std::chrono::steady_clock::now();
780   // Make end time at least >= start time.
781   handshakeEndTime_ = handshakeStartTime_;
782   if (handshakeConnectTimeout_ > 0) {
783     handshakeTimeout_.scheduleTimeout(handshakeConnectTimeout_);
784   }
785   handleConnect();
786 }
787
788 SSL_SESSION *AsyncSSLSocket::getSSLSession() {
789   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
790     return SSL_get1_session(ssl_);
791   }
792
793   return sslSession_;
794 }
795
796 const SSL* AsyncSSLSocket::getSSL() const {
797   return ssl_;
798 }
799
800 void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
801   sslSession_ = session;
802   if (!takeOwnership && session != nullptr) {
803     // Increment the reference count
804     CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
805   }
806 }
807
808 void AsyncSSLSocket::getSelectedNextProtocol(
809     const unsigned char** protoName,
810     unsigned* protoLen,
811     SSLContext::NextProtocolType* protoType) const {
812   if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) {
813     throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
814                               "NPN not supported");
815   }
816 }
817
818 bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
819     const unsigned char** protoName,
820     unsigned* protoLen,
821     SSLContext::NextProtocolType* protoType) const {
822   *protoName = nullptr;
823   *protoLen = 0;
824 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
825   SSL_get0_alpn_selected(ssl_, protoName, protoLen);
826   if (*protoLen > 0) {
827     if (protoType) {
828       *protoType = SSLContext::NextProtocolType::ALPN;
829     }
830     return true;
831   }
832 #endif
833 #ifdef OPENSSL_NPN_NEGOTIATED
834   SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
835   if (protoType) {
836     *protoType = SSLContext::NextProtocolType::NPN;
837   }
838   return true;
839 #else
840   (void)protoType;
841   return false;
842 #endif
843 }
844
845 bool AsyncSSLSocket::getSSLSessionReused() const {
846   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
847     return SSL_session_reused(ssl_);
848   }
849   return false;
850 }
851
852 const char *AsyncSSLSocket::getNegotiatedCipherName() const {
853   return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr;
854 }
855
856 /* static */
857 const char* AsyncSSLSocket::getSSLServerNameFromSSL(SSL* ssl) {
858   if (ssl == nullptr) {
859     return nullptr;
860   }
861 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
862   return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
863 #else
864   return nullptr;
865 #endif
866 }
867
868 const char *AsyncSSLSocket::getSSLServerName() const {
869 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
870   return getSSLServerNameFromSSL(ssl_);
871 #else
872   throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
873                              "SNI not supported");
874 #endif
875 }
876
877 const char *AsyncSSLSocket::getSSLServerNameNoThrow() const {
878   return getSSLServerNameFromSSL(ssl_);
879 }
880
881 int AsyncSSLSocket::getSSLVersion() const {
882   return (ssl_ != nullptr) ? SSL_version(ssl_) : 0;
883 }
884
885 const char *AsyncSSLSocket::getSSLCertSigAlgName() const {
886   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
887   if (cert) {
888     int nid = OBJ_obj2nid(cert->sig_alg->algorithm);
889     return OBJ_nid2ln(nid);
890   }
891   return nullptr;
892 }
893
894 int AsyncSSLSocket::getSSLCertSize() const {
895   int certSize = 0;
896   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
897   if (cert) {
898     EVP_PKEY *key = X509_get_pubkey(cert);
899     certSize = EVP_PKEY_bits(key);
900     EVP_PKEY_free(key);
901   }
902   return certSize;
903 }
904
905 const X509* AsyncSSLSocket::getSelfCert() const {
906   return (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
907 }
908
909 bool AsyncSSLSocket::willBlock(int ret,
910                                int* sslErrorOut,
911                                unsigned long* errErrorOut) noexcept {
912   *errErrorOut = 0;
913   int error = *sslErrorOut = SSL_get_error(ssl_, ret);
914   if (error == SSL_ERROR_WANT_READ) {
915     // Register for read event if not already.
916     updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
917     return true;
918   } else if (error == SSL_ERROR_WANT_WRITE) {
919     VLOG(3) << "AsyncSSLSocket(fd=" << fd_
920             << ", state=" << int(state_) << ", sslState="
921             << sslState_ << ", events=" << eventFlags_ << "): "
922             << "SSL_ERROR_WANT_WRITE";
923     // Register for write event if not already.
924     updateEventRegistration(EventHandler::WRITE, EventHandler::READ);
925     return true;
926 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
927   } else if (error == SSL_ERROR_WANT_SESS_CACHE_LOOKUP) {
928     // We will block but we can't register our own socket.  The callback that
929     // triggered this code will re-call handleAccept at the appropriate time.
930
931     // We can only get here if the linked libssl.so has support for this feature
932     // as well, otherwise SSL_get_error cannot return our error code.
933     sslState_ = STATE_CACHE_LOOKUP;
934
935     // Unregister for all events while blocked here
936     updateEventRegistration(EventHandler::NONE,
937                             EventHandler::READ | EventHandler::WRITE);
938
939     // The timeout (if set) keeps running here
940     return true;
941 #endif
942   } else if (0
943 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
944       || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING
945 #endif
946 #ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
947       || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
948 #endif
949       ) {
950     // Our custom openssl function has kicked off an async request to do
951     // rsa/ecdsa private key operation.  When that call returns, a callback will
952     // be invoked that will re-call handleAccept.
953     sslState_ = STATE_ASYNC_PENDING;
954
955     // Unregister for all events while blocked here
956     updateEventRegistration(
957       EventHandler::NONE,
958       EventHandler::READ | EventHandler::WRITE
959     );
960
961     // The timeout (if set) keeps running here
962     return true;
963   } else {
964     unsigned long lastError = *errErrorOut = ERR_get_error();
965     VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
966             << "state=" << state_ << ", "
967             << "sslState=" << sslState_ << ", "
968             << "events=" << std::hex << eventFlags_ << "): "
969             << "SSL error: " << error << ", "
970             << "errno: " << errno << ", "
971             << "ret: " << ret << ", "
972             << "read: " << BIO_number_read(SSL_get_rbio(ssl_)) << ", "
973             << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
974             << "func: " << ERR_func_error_string(lastError) << ", "
975             << "reason: " << ERR_reason_error_string(lastError);
976     return false;
977   }
978 }
979
980 void AsyncSSLSocket::checkForImmediateRead() noexcept {
981   // openssl may have buffered data that it read from the socket already.
982   // In this case we have to process it immediately, rather than waiting for
983   // the socket to become readable again.
984   if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
985     AsyncSocket::handleRead();
986   }
987 }
988
989 void
990 AsyncSSLSocket::restartSSLAccept()
991 {
992   VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this
993           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
994           << "sslState=" << sslState_ << ", events=" << eventFlags_;
995   DestructorGuard dg(this);
996   assert(
997     sslState_ == STATE_CACHE_LOOKUP ||
998     sslState_ == STATE_ASYNC_PENDING ||
999     sslState_ == STATE_ERROR ||
1000     sslState_ == STATE_CLOSED);
1001   if (sslState_ == STATE_CLOSED) {
1002     // I sure hope whoever closed this socket didn't delete it already,
1003     // but this is not strictly speaking an error
1004     return;
1005   }
1006   if (sslState_ == STATE_ERROR) {
1007     // go straight to fail if timeout expired during lookup
1008     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
1009                            "SSL accept timed out");
1010     failHandshake(__func__, ex);
1011     return;
1012   }
1013   sslState_ = STATE_ACCEPTING;
1014   this->handleAccept();
1015 }
1016
1017 void
1018 AsyncSSLSocket::handleAccept() noexcept {
1019   VLOG(3) << "AsyncSSLSocket::handleAccept() this=" << this
1020           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
1021           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1022   assert(server_);
1023   assert(state_ == StateEnum::ESTABLISHED &&
1024          sslState_ == STATE_ACCEPTING);
1025   if (!ssl_) {
1026     /* lazily create the SSL structure */
1027     try {
1028       ssl_ = ctx_->createSSL();
1029     } catch (std::exception &e) {
1030       sslState_ = STATE_ERROR;
1031       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1032                              "error calling SSLContext::createSSL()");
1033       LOG(ERROR) << "AsyncSSLSocket::handleAccept(this=" << this
1034                  << ", fd=" << fd_ << "): " << e.what();
1035       return failHandshake(__func__, ex);
1036     }
1037
1038     if (!setupSSLBio()) {
1039       sslState_ = STATE_ERROR;
1040       AsyncSocketException ex(
1041           AsyncSocketException::INTERNAL_ERROR, "error creating write bio");
1042       return failHandshake(__func__, ex);
1043     }
1044
1045     SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
1046
1047     applyVerificationOptions(ssl_);
1048   }
1049
1050   if (server_ && parseClientHello_) {
1051     SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback);
1052     SSL_set_msg_callback_arg(ssl_, this);
1053   }
1054
1055   int ret = SSL_accept(ssl_);
1056   if (ret <= 0) {
1057     int sslError;
1058     unsigned long errError;
1059     int errnoCopy = errno;
1060     if (willBlock(ret, &sslError, &errError)) {
1061       return;
1062     } else {
1063       sslState_ = STATE_ERROR;
1064       SSLException ex(sslError, errError, ret, errnoCopy);
1065       return failHandshake(__func__, ex);
1066     }
1067   }
1068
1069   handshakeComplete_ = true;
1070   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1071
1072   // Move into STATE_ESTABLISHED in the normal case that we are in
1073   // STATE_ACCEPTING.
1074   sslState_ = STATE_ESTABLISHED;
1075
1076   VLOG(3) << "AsyncSSLSocket " << this << ": fd " << fd_
1077           << " successfully accepted; state=" << int(state_)
1078           << ", sslState=" << sslState_ << ", events=" << eventFlags_;
1079
1080   // Remember the EventBase we are attached to, before we start invoking any
1081   // callbacks (since the callbacks may call detachEventBase()).
1082   EventBase* originalEventBase = eventBase_;
1083
1084   // Call the accept callback.
1085   invokeHandshakeCB();
1086
1087   // Note that the accept callback may have changed our state.
1088   // (set or unset the read callback, called write(), closed the socket, etc.)
1089   // The following code needs to handle these situations correctly.
1090   //
1091   // If the socket has been closed, readCallback_ and writeReqHead_ will
1092   // always be nullptr, so that will prevent us from trying to read or write.
1093   //
1094   // The main thing to check for is if eventBase_ is still originalEventBase.
1095   // If not, we have been detached from this event base, so we shouldn't
1096   // perform any more operations.
1097   if (eventBase_ != originalEventBase) {
1098     return;
1099   }
1100
1101   AsyncSocket::handleInitialReadWrite();
1102 }
1103
1104 void
1105 AsyncSSLSocket::handleConnect() noexcept {
1106   VLOG(3) <<  "AsyncSSLSocket::handleConnect() this=" << this
1107           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
1108           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1109   assert(!server_);
1110   if (state_ < StateEnum::ESTABLISHED) {
1111     return AsyncSocket::handleConnect();
1112   }
1113
1114   assert(
1115       (state_ == StateEnum::FAST_OPEN || state_ == StateEnum::ESTABLISHED) &&
1116       sslState_ == STATE_CONNECTING);
1117   assert(ssl_);
1118
1119   auto originalState = state_;
1120   int ret = SSL_connect(ssl_);
1121   if (ret <= 0) {
1122     int sslError;
1123     unsigned long errError;
1124     int errnoCopy = errno;
1125     if (willBlock(ret, &sslError, &errError)) {
1126       // We fell back to connecting state due to TFO
1127       if (state_ == StateEnum::CONNECTING) {
1128         DCHECK_EQ(StateEnum::FAST_OPEN, originalState);
1129         if (handshakeTimeout_.isScheduled()) {
1130           handshakeTimeout_.cancelTimeout();
1131         }
1132       }
1133       return;
1134     } else {
1135       sslState_ = STATE_ERROR;
1136       SSLException ex(sslError, errError, ret, errnoCopy);
1137       return failHandshake(__func__, ex);
1138     }
1139   }
1140
1141   handshakeComplete_ = true;
1142   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1143
1144   // Move into STATE_ESTABLISHED in the normal case that we are in
1145   // STATE_CONNECTING.
1146   sslState_ = STATE_ESTABLISHED;
1147
1148   VLOG(3) << "AsyncSSLSocket " << this << ": "
1149           << "fd " << fd_ << " successfully connected; "
1150           << "state=" << int(state_) << ", sslState=" << sslState_
1151           << ", events=" << eventFlags_;
1152
1153   // Remember the EventBase we are attached to, before we start invoking any
1154   // callbacks (since the callbacks may call detachEventBase()).
1155   EventBase* originalEventBase = eventBase_;
1156
1157   // Call the handshake callback.
1158   invokeHandshakeCB();
1159
1160   // Note that the connect callback may have changed our state.
1161   // (set or unset the read callback, called write(), closed the socket, etc.)
1162   // The following code needs to handle these situations correctly.
1163   //
1164   // If the socket has been closed, readCallback_ and writeReqHead_ will
1165   // always be nullptr, so that will prevent us from trying to read or write.
1166   //
1167   // The main thing to check for is if eventBase_ is still originalEventBase.
1168   // If not, we have been detached from this event base, so we shouldn't
1169   // perform any more operations.
1170   if (eventBase_ != originalEventBase) {
1171     return;
1172   }
1173
1174   AsyncSocket::handleInitialReadWrite();
1175 }
1176
1177 void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) {
1178   connectionTimeout_.cancelTimeout();
1179   AsyncSocket::invokeConnectErr(ex);
1180 }
1181
1182 void AsyncSSLSocket::invokeConnectSuccess() {
1183   connectionTimeout_.cancelTimeout();
1184   if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
1185     // If we failed TFO, we'd fall back to trying to connect the socket,
1186     // to setup things like timeouts.
1187     startSSLConnect();
1188   }
1189   // still invoke the base class since it re-sets the connect time.
1190   AsyncSocket::invokeConnectSuccess();
1191 }
1192
1193 void AsyncSSLSocket::scheduleConnectTimeout() {
1194   if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
1195     // We fell back from TFO, and need to set the timeouts.
1196     // We will not have a connect callback in this case, thus if the timer
1197     // expires we would have no-one to notify.
1198     // Thus we should reset even the connect timers to point to the handshake
1199     // timeouts.
1200     assert(connectCallback_ == nullptr);
1201     // We use a different connect timeout here than the handshake timeout, so
1202     // that we can disambiguate the 2 timers.
1203     int timeout = connectTimeout_.count();
1204     if (timeout > 0) {
1205       if (!connectionTimeout_.scheduleTimeout(timeout)) {
1206         throw AsyncSocketException(
1207             AsyncSocketException::INTERNAL_ERROR,
1208             withAddr("failed to schedule AsyncSSLSocket connect timeout"));
1209       }
1210     }
1211     return;
1212   }
1213   AsyncSocket::scheduleConnectTimeout();
1214 }
1215
1216 void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
1217 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1218   // turn on the buffer movable in openssl
1219   if (bufferMovableEnabled_ && ssl_ != nullptr && !isBufferMovable_ &&
1220       callback != nullptr && callback->isBufferMovable()) {
1221     SSL_set_mode(ssl_, SSL_get_mode(ssl_) | SSL_MODE_MOVE_BUFFER_OWNERSHIP);
1222     isBufferMovable_ = true;
1223   }
1224 #endif
1225
1226   AsyncSocket::setReadCB(callback);
1227 }
1228
1229 void AsyncSSLSocket::setBufferMovableEnabled(bool enabled) {
1230   bufferMovableEnabled_ = enabled;
1231 }
1232
1233 void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
1234   CHECK(readCallback_);
1235   if (isBufferMovable_) {
1236     *buf = nullptr;
1237     *buflen = 0;
1238   } else {
1239     // buf is necessary for SSLSocket without SSL_MODE_MOVE_BUFFER_OWNERSHIP
1240     readCallback_->getReadBuffer(buf, buflen);
1241   }
1242 }
1243
1244 void
1245 AsyncSSLSocket::handleRead() noexcept {
1246   VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
1247           << ", state=" << int(state_) << ", "
1248           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1249   if (state_ < StateEnum::ESTABLISHED) {
1250     return AsyncSocket::handleRead();
1251   }
1252
1253
1254   if (sslState_ == STATE_ACCEPTING) {
1255     assert(server_);
1256     handleAccept();
1257     return;
1258   }
1259   else if (sslState_ == STATE_CONNECTING) {
1260     assert(!server_);
1261     handleConnect();
1262     return;
1263   }
1264
1265   // Normal read
1266   AsyncSocket::handleRead();
1267 }
1268
1269 AsyncSocket::ReadResult
1270 AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
1271   VLOG(4) << "AsyncSSLSocket::performRead() this=" << this << ", buf=" << *buf
1272           << ", buflen=" << *buflen;
1273
1274   if (sslState_ == STATE_UNENCRYPTED) {
1275     return AsyncSocket::performRead(buf, buflen, offset);
1276   }
1277
1278   ssize_t bytes = 0;
1279   if (!isBufferMovable_) {
1280     bytes = SSL_read(ssl_, *buf, *buflen);
1281   }
1282 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1283   else {
1284     bytes = SSL_read_buf(ssl_, buf, (int *) offset, (int *) buflen);
1285   }
1286 #endif
1287
1288   if (server_ && renegotiateAttempted_) {
1289     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1290                << ", sslstate=" << sslState_ << ", events=" << eventFlags_
1291                << "): client intitiated SSL renegotiation not permitted";
1292     return ReadResult(
1293         READ_ERROR,
1294         folly::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
1295   }
1296   if (bytes <= 0) {
1297     int error = SSL_get_error(ssl_, bytes);
1298     if (error == SSL_ERROR_WANT_READ) {
1299       // The caller will register for read event if not already.
1300       if (errno == EWOULDBLOCK || errno == EAGAIN) {
1301         return ReadResult(READ_BLOCKING);
1302       } else {
1303         return ReadResult(READ_ERROR);
1304       }
1305     } else if (error == SSL_ERROR_WANT_WRITE) {
1306       // TODO: Even though we are attempting to read data, SSL_read() may
1307       // need to write data if renegotiation is being performed.  We currently
1308       // don't support this and just fail the read.
1309       LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1310                  << ", sslState=" << sslState_ << ", events=" << eventFlags_
1311                  << "): unsupported SSL renegotiation during read";
1312       return ReadResult(
1313           READ_ERROR,
1314           folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
1315     } else {
1316       if (zero_return(error, bytes)) {
1317         return ReadResult(bytes);
1318       }
1319       long errError = ERR_get_error();
1320       VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
1321               << "state=" << state_ << ", "
1322               << "sslState=" << sslState_ << ", "
1323               << "events=" << std::hex << eventFlags_ << "): "
1324               << "bytes: " << bytes << ", "
1325               << "error: " << error << ", "
1326               << "errno: " << errno << ", "
1327               << "func: " << ERR_func_error_string(errError) << ", "
1328               << "reason: " << ERR_reason_error_string(errError);
1329       return ReadResult(
1330           READ_ERROR,
1331           folly::make_unique<SSLException>(error, errError, bytes, errno));
1332     }
1333   } else {
1334     appBytesReceived_ += bytes;
1335     return ReadResult(bytes);
1336   }
1337 }
1338
1339 void AsyncSSLSocket::handleWrite() noexcept {
1340   VLOG(5) << "AsyncSSLSocket::handleWrite() this=" << this << ", fd=" << fd_
1341           << ", state=" << int(state_) << ", "
1342           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1343   if (state_ < StateEnum::ESTABLISHED) {
1344     return AsyncSocket::handleWrite();
1345   }
1346
1347   if (sslState_ == STATE_ACCEPTING) {
1348     assert(server_);
1349     handleAccept();
1350     return;
1351   }
1352
1353   if (sslState_ == STATE_CONNECTING) {
1354     assert(!server_);
1355     handleConnect();
1356     return;
1357   }
1358
1359   // Normal write
1360   AsyncSocket::handleWrite();
1361 }
1362
1363 AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
1364   if (error == SSL_ERROR_WANT_READ) {
1365     // Even though we are attempting to write data, SSL_write() may
1366     // need to read data if renegotiation is being performed.  We currently
1367     // don't support this and just fail the write.
1368     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1369                << ", sslState=" << sslState_ << ", events=" << eventFlags_
1370                << "): "
1371                << "unsupported SSL renegotiation during write";
1372     return WriteResult(
1373         WRITE_ERROR,
1374         folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
1375   } else {
1376     if (zero_return(error, rc)) {
1377       return WriteResult(0);
1378     }
1379     auto errError = ERR_get_error();
1380     VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1381             << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
1382             << "SSL error: " << error << ", errno: " << errno
1383             << ", func: " << ERR_func_error_string(errError)
1384             << ", reason: " << ERR_reason_error_string(errError);
1385     return WriteResult(
1386         WRITE_ERROR,
1387         folly::make_unique<SSLException>(error, errError, rc, errno));
1388   }
1389 }
1390
1391 AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
1392     const iovec* vec,
1393     uint32_t count,
1394     WriteFlags flags,
1395     uint32_t* countWritten,
1396     uint32_t* partialWritten) {
1397   if (sslState_ == STATE_UNENCRYPTED) {
1398     return AsyncSocket::performWrite(
1399       vec, count, flags, countWritten, partialWritten);
1400   }
1401   if (sslState_ != STATE_ESTABLISHED) {
1402     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1403                << ", sslState=" << sslState_
1404                << ", events=" << eventFlags_ << "): "
1405                << "TODO: AsyncSSLSocket currently does not support calling "
1406                << "write() before the handshake has fully completed";
1407     return WriteResult(
1408         WRITE_ERROR, folly::make_unique<SSLException>(SSLError::EARLY_WRITE));
1409   }
1410
1411   bool cork = isSet(flags, WriteFlags::CORK);
1412   CorkGuard guard(fd_, count > 1, cork, &corked_);
1413
1414   // Declare a buffer used to hold small write requests.  It could point to a
1415   // memory block either on stack or on heap. If it is on heap, we release it
1416   // manually when scope exits
1417   char* combinedBuf{nullptr};
1418   SCOPE_EXIT {
1419     // Note, always keep this check consistent with what we do below
1420     if (combinedBuf != nullptr && minWriteSize_ > MAX_STACK_BUF_SIZE) {
1421       delete[] combinedBuf;
1422     }
1423   };
1424
1425   *countWritten = 0;
1426   *partialWritten = 0;
1427   ssize_t totalWritten = 0;
1428   size_t bytesStolenFromNextBuffer = 0;
1429   for (uint32_t i = 0; i < count; i++) {
1430     const iovec* v = vec + i;
1431     size_t offset = bytesStolenFromNextBuffer;
1432     bytesStolenFromNextBuffer = 0;
1433     size_t len = v->iov_len - offset;
1434     const void* buf;
1435     if (len == 0) {
1436       (*countWritten)++;
1437       continue;
1438     }
1439     buf = ((const char*)v->iov_base) + offset;
1440
1441     ssize_t bytes;
1442     uint32_t buffersStolen = 0;
1443     if ((len < minWriteSize_) && ((i + 1) < count)) {
1444       // Combine this buffer with part or all of the next buffers in
1445       // order to avoid really small-grained calls to SSL_write().
1446       // Each call to SSL_write() produces a separate record in
1447       // the egress SSL stream, and we've found that some low-end
1448       // mobile clients can't handle receiving an HTTP response
1449       // header and the first part of the response body in two
1450       // separate SSL records (even if those two records are in
1451       // the same TCP packet).
1452
1453       if (combinedBuf == nullptr) {
1454         if (minWriteSize_ > MAX_STACK_BUF_SIZE) {
1455           // Allocate the buffer on heap
1456           combinedBuf = new char[minWriteSize_];
1457         } else {
1458           // Allocate the buffer on stack
1459           combinedBuf = (char*)alloca(minWriteSize_);
1460         }
1461       }
1462       assert(combinedBuf != nullptr);
1463
1464       memcpy(combinedBuf, buf, len);
1465       do {
1466         // INVARIANT: i + buffersStolen == complete chunks serialized
1467         uint32_t nextIndex = i + buffersStolen + 1;
1468         bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len,
1469                                              minWriteSize_ - len);
1470         memcpy(combinedBuf + len, vec[nextIndex].iov_base,
1471                bytesStolenFromNextBuffer);
1472         len += bytesStolenFromNextBuffer;
1473         if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
1474           // couldn't steal the whole buffer
1475           break;
1476         } else {
1477           bytesStolenFromNextBuffer = 0;
1478           buffersStolen++;
1479         }
1480       } while ((i + buffersStolen + 1) < count && (len < minWriteSize_));
1481       bytes = eorAwareSSLWrite(
1482         ssl_, combinedBuf, len,
1483         (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
1484
1485     } else {
1486       bytes = eorAwareSSLWrite(ssl_, buf, len,
1487                            (isSet(flags, WriteFlags::EOR) && i + 1 == count));
1488     }
1489
1490     if (bytes <= 0) {
1491       int error = SSL_get_error(ssl_, bytes);
1492       if (error == SSL_ERROR_WANT_WRITE) {
1493         // The caller will register for write event if not already.
1494         *partialWritten = offset;
1495         return WriteResult(totalWritten);
1496       }
1497       auto writeResult = interpretSSLError(bytes, error);
1498       if (writeResult.writeReturn < 0) {
1499         return writeResult;
1500       } // else fall through to below to correctly record totalWritten
1501     }
1502
1503     totalWritten += bytes;
1504
1505     if (bytes == (ssize_t)len) {
1506       // The full iovec is written.
1507       (*countWritten) += 1 + buffersStolen;
1508       i += buffersStolen;
1509       // continue
1510     } else {
1511       bytes += offset; // adjust bytes to account for all of v
1512       while (bytes >= (ssize_t)v->iov_len) {
1513         // We combined this buf with part or all of the next one, and
1514         // we managed to write all of this buf but not all of the bytes
1515         // from the next one that we'd hoped to write.
1516         bytes -= v->iov_len;
1517         (*countWritten)++;
1518         v = &(vec[++i]);
1519       }
1520       *partialWritten = bytes;
1521       return WriteResult(totalWritten);
1522     }
1523   }
1524
1525   return WriteResult(totalWritten);
1526 }
1527
1528 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
1529                                       bool eor) {
1530   if (eor && trackEor_) {
1531     if (appEorByteNo_) {
1532       // cannot track for more than one app byte EOR
1533       CHECK(appEorByteNo_ == appBytesWritten_ + n);
1534     } else {
1535       appEorByteNo_ = appBytesWritten_ + n;
1536     }
1537
1538     // 1. It is fine to keep updating minEorRawByteNo_.
1539     // 2. It is _min_ in the sense that SSL record will add some overhead.
1540     minEorRawByteNo_ = getRawBytesWritten() + n;
1541   }
1542
1543   n = sslWriteImpl(ssl, buf, n);
1544   if (n > 0) {
1545     appBytesWritten_ += n;
1546     if (appEorByteNo_) {
1547       if (getRawBytesWritten() >= minEorRawByteNo_) {
1548         minEorRawByteNo_ = 0;
1549       }
1550       if(appBytesWritten_ == appEorByteNo_) {
1551         appEorByteNo_ = 0;
1552       } else {
1553         CHECK(appBytesWritten_ < appEorByteNo_);
1554       }
1555     }
1556   }
1557   return n;
1558 }
1559
1560 void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
1561   AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl);
1562   if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
1563     sslSocket->renegotiateAttempted_ = true;
1564   }
1565   if (where & SSL_CB_READ_ALERT) {
1566     const char* type = SSL_alert_type_string(ret);
1567     if (type) {
1568       const char* desc = SSL_alert_desc_string(ret);
1569       sslSocket->alertsReceived_.emplace_back(
1570           *type, StringPiece(desc, std::strlen(desc)));
1571     }
1572   }
1573 }
1574
1575 int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
1576   struct msghdr msg;
1577   struct iovec iov;
1578   int flags = 0;
1579   AsyncSSLSocket* tsslSock;
1580
1581   iov.iov_base = const_cast<char*>(in);
1582   iov.iov_len = inl;
1583   memset(&msg, 0, sizeof(msg));
1584   msg.msg_iov = &iov;
1585   msg.msg_iovlen = 1;
1586
1587   auto appData = OpenSSLUtils::getBioAppData(b);
1588   CHECK(appData);
1589
1590   tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
1591   CHECK(tsslSock);
1592
1593   if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ &&
1594       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
1595     flags = MSG_EOR;
1596   }
1597
1598 #ifdef MSG_NOSIGNAL
1599   flags |= MSG_NOSIGNAL;
1600 #endif
1601
1602   auto result = tsslSock->sendSocketMessage(
1603       OpenSSLUtils::getBioFd(b, nullptr), &msg, flags);
1604   BIO_clear_retry_flags(b);
1605   if (!result.exception && result.writeReturn <= 0) {
1606     if (OpenSSLUtils::getBioShouldRetryWrite(result.writeReturn)) {
1607       BIO_set_retry_write(b);
1608     }
1609   }
1610   return result.writeReturn;
1611 }
1612
1613 int AsyncSSLSocket::sslVerifyCallback(
1614     int preverifyOk,
1615     X509_STORE_CTX* x509Ctx) {
1616   SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
1617     x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
1618   AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
1619
1620   VLOG(3) <<  "AsyncSSLSocket::sslVerifyCallback() this=" << self << ", "
1621           << "fd=" << self->fd_ << ", preverifyOk=" << preverifyOk;
1622   return (self->handshakeCallback_) ?
1623     self->handshakeCallback_->handshakeVer(self, preverifyOk, x509Ctx) :
1624     preverifyOk;
1625 }
1626
1627 void AsyncSSLSocket::enableClientHelloParsing()  {
1628     parseClientHello_ = true;
1629     clientHelloInfo_.reset(new ssl::ClientHelloInfo());
1630 }
1631
1632 void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl)  {
1633   SSL_set_msg_callback(ssl, nullptr);
1634   SSL_set_msg_callback_arg(ssl, nullptr);
1635   clientHelloInfo_->clientHelloBuf_.clear();
1636 }
1637
1638 void AsyncSSLSocket::clientHelloParsingCallback(int written,
1639                                                 int /* version */,
1640                                                 int contentType,
1641                                                 const void* buf,
1642                                                 size_t len,
1643                                                 SSL* ssl,
1644                                                 void* arg) {
1645   AsyncSSLSocket *sock = static_cast<AsyncSSLSocket*>(arg);
1646   if (written != 0) {
1647     sock->resetClientHelloParsing(ssl);
1648     return;
1649   }
1650   if (contentType != SSL3_RT_HANDSHAKE) {
1651     return;
1652   }
1653   if (len == 0) {
1654     return;
1655   }
1656
1657   auto& clientHelloBuf = sock->clientHelloInfo_->clientHelloBuf_;
1658   clientHelloBuf.append(IOBuf::wrapBuffer(buf, len));
1659   try {
1660     Cursor cursor(clientHelloBuf.front());
1661     if (cursor.read<uint8_t>() != SSL3_MT_CLIENT_HELLO) {
1662       sock->resetClientHelloParsing(ssl);
1663       return;
1664     }
1665
1666     if (cursor.totalLength() < 3) {
1667       clientHelloBuf.trimEnd(len);
1668       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1669       return;
1670     }
1671
1672     uint32_t messageLength = cursor.read<uint8_t>();
1673     messageLength <<= 8;
1674     messageLength |= cursor.read<uint8_t>();
1675     messageLength <<= 8;
1676     messageLength |= cursor.read<uint8_t>();
1677     if (cursor.totalLength() < messageLength) {
1678       clientHelloBuf.trimEnd(len);
1679       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1680       return;
1681     }
1682
1683     sock->clientHelloInfo_->clientHelloMajorVersion_ = cursor.read<uint8_t>();
1684     sock->clientHelloInfo_->clientHelloMinorVersion_ = cursor.read<uint8_t>();
1685
1686     cursor.skip(4); // gmt_unix_time
1687     cursor.skip(28); // random_bytes
1688
1689     cursor.skip(cursor.read<uint8_t>()); // session_id
1690
1691     uint16_t cipherSuitesLength = cursor.readBE<uint16_t>();
1692     for (int i = 0; i < cipherSuitesLength; i += 2) {
1693       sock->clientHelloInfo_->
1694         clientHelloCipherSuites_.push_back(cursor.readBE<uint16_t>());
1695     }
1696
1697     uint8_t compressionMethodsLength = cursor.read<uint8_t>();
1698     for (int i = 0; i < compressionMethodsLength; ++i) {
1699       sock->clientHelloInfo_->
1700         clientHelloCompressionMethods_.push_back(cursor.readBE<uint8_t>());
1701     }
1702
1703     if (cursor.totalLength() > 0) {
1704       uint16_t extensionsLength = cursor.readBE<uint16_t>();
1705       while (extensionsLength) {
1706         ssl::TLSExtension extensionType =
1707             static_cast<ssl::TLSExtension>(cursor.readBE<uint16_t>());
1708         sock->clientHelloInfo_->
1709           clientHelloExtensions_.push_back(extensionType);
1710         extensionsLength -= 2;
1711         uint16_t extensionDataLength = cursor.readBE<uint16_t>();
1712         extensionsLength -= 2;
1713         extensionsLength -= extensionDataLength;
1714
1715         if (extensionType == ssl::TLSExtension::SIGNATURE_ALGORITHMS) {
1716           cursor.skip(2);
1717           extensionDataLength -= 2;
1718           while (extensionDataLength) {
1719             ssl::HashAlgorithm hashAlg =
1720                 static_cast<ssl::HashAlgorithm>(cursor.readBE<uint8_t>());
1721             ssl::SignatureAlgorithm sigAlg =
1722                 static_cast<ssl::SignatureAlgorithm>(cursor.readBE<uint8_t>());
1723             extensionDataLength -= 2;
1724             sock->clientHelloInfo_->
1725               clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg);
1726           }
1727         } else {
1728           cursor.skip(extensionDataLength);
1729         }
1730       }
1731     }
1732   } catch (std::out_of_range& e) {
1733     // we'll use what we found and cleanup below.
1734     VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
1735       << "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock;
1736   }
1737
1738   sock->resetClientHelloParsing(ssl);
1739 }
1740
1741 } // namespace