Fix ssl timeouts during TFO
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSSLSocket.h>
18
19 #include <folly/io/async/EventBase.h>
20 #include <folly/portability/Sockets.h>
21
22 #include <boost/noncopyable.hpp>
23 #include <errno.h>
24 #include <fcntl.h>
25 #include <openssl/err.h>
26 #include <openssl/asn1.h>
27 #include <openssl/ssl.h>
28 #include <sys/types.h>
29 #include <chrono>
30
31 #include <folly/Bits.h>
32 #include <folly/SocketAddress.h>
33 #include <folly/SpinLock.h>
34 #include <folly/io/IOBuf.h>
35 #include <folly/io/Cursor.h>
36 #include <folly/portability/Unistd.h>
37
38 using folly::SocketAddress;
39 using folly::SSLContext;
40 using std::string;
41 using std::shared_ptr;
42
43 using folly::Endian;
44 using folly::IOBuf;
45 using folly::SpinLock;
46 using folly::SpinLockGuard;
47 using folly::io::Cursor;
48 using std::unique_ptr;
49 using std::bind;
50
51 namespace {
52 using folly::AsyncSocket;
53 using folly::AsyncSocketException;
54 using folly::AsyncSSLSocket;
55 using folly::Optional;
56 using folly::SSLContext;
57 using folly::ssl::OpenSSLUtils;
58
59 // We have one single dummy SSL context so that we can implement attach
60 // and detach methods in a thread safe fashion without modifying opnessl.
61 static SSLContext *dummyCtx = nullptr;
62 static SpinLock dummyCtxLock;
63
64 // If given min write size is less than this, buffer will be allocated on
65 // stack, otherwise it is allocated on heap
66 const size_t MAX_STACK_BUF_SIZE = 2048;
67
68 // This converts "illegal" shutdowns into ZERO_RETURN
69 inline bool zero_return(int error, int rc) {
70   return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0));
71 }
72
73 class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
74                                 public AsyncSSLSocket::HandshakeCB {
75
76  private:
77   AsyncSSLSocket *sslSocket_;
78   AsyncSSLSocket::ConnectCallback *callback_;
79   int timeout_;
80   int64_t startTime_;
81
82  protected:
83   ~AsyncSSLSocketConnector() override {}
84
85  public:
86   AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
87                            AsyncSocket::ConnectCallback *callback,
88                            int timeout) :
89       sslSocket_(sslSocket),
90       callback_(callback),
91       timeout_(timeout),
92       startTime_(std::chrono::duration_cast<std::chrono::milliseconds>(
93                    std::chrono::steady_clock::now().time_since_epoch()).count()) {
94   }
95
96   void connectSuccess() noexcept override {
97     VLOG(7) << "client socket connected";
98
99     int64_t timeoutLeft = 0;
100     if (timeout_ > 0) {
101       auto curTime = std::chrono::duration_cast<std::chrono::milliseconds>(
102         std::chrono::steady_clock::now().time_since_epoch()).count();
103
104       timeoutLeft = timeout_ - (curTime - startTime_);
105       if (timeoutLeft <= 0) {
106         AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
107                                 "SSL connect timed out");
108         fail(ex);
109         delete this;
110         return;
111       }
112     }
113     sslSocket_->sslConn(this, timeoutLeft);
114   }
115
116   void connectErr(const AsyncSocketException& ex) noexcept override {
117     LOG(ERROR) << "TCP connect failed: " <<  ex.what();
118     fail(ex);
119     delete this;
120   }
121
122   void handshakeSuc(AsyncSSLSocket* /* sock */) noexcept override {
123     VLOG(7) << "client handshake success";
124     if (callback_) {
125       callback_->connectSuccess();
126     }
127     delete this;
128   }
129
130   void handshakeErr(AsyncSSLSocket* /* socket */,
131                     const AsyncSocketException& ex) noexcept override {
132     LOG(ERROR) << "client handshakeErr: " << ex.what();
133     fail(ex);
134     delete this;
135   }
136
137   void fail(const AsyncSocketException &ex) {
138     // fail is a noop if called twice
139     if (callback_) {
140       AsyncSSLSocket::ConnectCallback *cb = callback_;
141       callback_ = nullptr;
142
143       cb->connectErr(ex);
144       sslSocket_->closeNow();
145       // closeNow can call handshakeErr if it hasn't been called already.
146       // So this may have been deleted, no member variable access beyond this
147       // point
148       // Note that closeNow may invoke writeError callbacks if the socket had
149       // write data pending connection completion.
150     }
151   }
152 };
153
154 // XXX: implement an equivalent to corking for platforms with TCP_NOPUSH?
155 #ifdef TCP_CORK // Linux-only
156 /**
157  * Utility class that corks a TCP socket upon construction or uncorks
158  * the socket upon destruction
159  */
160 class CorkGuard : private boost::noncopyable {
161  public:
162   CorkGuard(int fd, bool multipleWrites, bool haveMore, bool* corked):
163     fd_(fd), haveMore_(haveMore), corked_(corked) {
164     if (*corked_) {
165       // socket is already corked; nothing to do
166       return;
167     }
168     if (multipleWrites || haveMore) {
169       // We are performing multiple writes in this performWrite() call,
170       // and/or there are more calls to performWrite() that will be invoked
171       // later, so enable corking
172       int flag = 1;
173       setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
174       *corked_ = true;
175     }
176   }
177
178   ~CorkGuard() {
179     if (haveMore_) {
180       // more data to come; don't uncork yet
181       return;
182     }
183     if (!*corked_) {
184       // socket isn't corked; nothing to do
185       return;
186     }
187
188     int flag = 0;
189     setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
190     *corked_ = false;
191   }
192
193  private:
194   int fd_;
195   bool haveMore_;
196   bool* corked_;
197 };
198 #else
199 class CorkGuard : private boost::noncopyable {
200  public:
201   CorkGuard(int, bool, bool, bool*) {}
202 };
203 #endif
204
205 void setup_SSL_CTX(SSL_CTX *ctx) {
206 #ifdef SSL_MODE_RELEASE_BUFFERS
207   SSL_CTX_set_mode(ctx,
208                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
209                    SSL_MODE_ENABLE_PARTIAL_WRITE
210                    | SSL_MODE_RELEASE_BUFFERS
211                    );
212 #else
213   SSL_CTX_set_mode(ctx,
214                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
215                    SSL_MODE_ENABLE_PARTIAL_WRITE
216                    );
217 #endif
218 // SSL_CTX_set_mode is a Macro
219 #ifdef SSL_MODE_WRITE_IOVEC
220   SSL_CTX_set_mode(ctx,
221                    SSL_CTX_get_mode(ctx)
222                    | SSL_MODE_WRITE_IOVEC);
223 #endif
224
225 }
226
227 BIO_METHOD sslWriteBioMethod;
228
229 void* initsslWriteBioMethod(void) {
230   memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod));
231   // override the bwrite method for MSG_EOR support
232   OpenSSLUtils::setCustomBioWriteMethod(
233       &sslWriteBioMethod, AsyncSSLSocket::bioWrite);
234
235   // Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not
236   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
237   // then have specific handlings. The sslWriteBioWrite should be compatible
238   // with the one in openssl.
239
240   // Return something here to enable AsyncSSLSocket to call this method using
241   // a function-scoped static.
242   return nullptr;
243 }
244
245 } // anonymous namespace
246
247 namespace folly {
248
249 /**
250  * Create a client AsyncSSLSocket
251  */
252 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
253                                EventBase* evb, bool deferSecurityNegotiation) :
254     AsyncSocket(evb),
255     ctx_(ctx),
256     handshakeTimeout_(this, evb) {
257   init();
258   if (deferSecurityNegotiation) {
259     sslState_ = STATE_UNENCRYPTED;
260   }
261 }
262
263 /**
264  * Create a server/client AsyncSSLSocket
265  */
266 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
267                                EventBase* evb, int fd, bool server,
268                                bool deferSecurityNegotiation) :
269     AsyncSocket(evb, fd),
270     server_(server),
271     ctx_(ctx),
272     handshakeTimeout_(this, evb) {
273   init();
274   if (server) {
275     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
276                               AsyncSSLSocket::sslInfoCallback);
277   }
278   if (deferSecurityNegotiation) {
279     sslState_ = STATE_UNENCRYPTED;
280   }
281 }
282
283 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
284 /**
285  * Create a client AsyncSSLSocket and allow tlsext_hostname
286  * to be sent in Client Hello.
287  */
288 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
289                                  EventBase* evb,
290                                const std::string& serverName,
291                                bool deferSecurityNegotiation) :
292     AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
293   tlsextHostname_ = serverName;
294 }
295
296 /**
297  * Create a client AsyncSSLSocket from an already connected fd
298  * and allow tlsext_hostname to be sent in Client Hello.
299  */
300 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
301                                  EventBase* evb, int fd,
302                                const std::string& serverName,
303                                bool deferSecurityNegotiation) :
304     AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
305   tlsextHostname_ = serverName;
306 }
307 #endif
308
309 AsyncSSLSocket::~AsyncSSLSocket() {
310   VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
311           << ", evb=" << eventBase_ << ", fd=" << fd_
312           << ", state=" << int(state_) << ", sslState="
313           << sslState_ << ", events=" << eventFlags_ << ")";
314 }
315
316 void AsyncSSLSocket::init() {
317   // Do this here to ensure we initialize this once before any use of
318   // AsyncSSLSocket instances and not as part of library load.
319   static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod();
320   (void)sslWriteBioMethodInitializer;
321
322   setup_SSL_CTX(ctx_->getSSLCtx());
323 }
324
325 void AsyncSSLSocket::closeNow() {
326   // Close the SSL connection.
327   if (ssl_ != nullptr && fd_ != -1) {
328     int rc = SSL_shutdown(ssl_);
329     if (rc == 0) {
330       rc = SSL_shutdown(ssl_);
331     }
332     if (rc < 0) {
333       ERR_clear_error();
334     }
335   }
336
337   if (sslSession_ != nullptr) {
338     SSL_SESSION_free(sslSession_);
339     sslSession_ = nullptr;
340   }
341
342   sslState_ = STATE_CLOSED;
343
344   if (handshakeTimeout_.isScheduled()) {
345     handshakeTimeout_.cancelTimeout();
346   }
347
348   DestructorGuard dg(this);
349
350   invokeHandshakeErr(
351       AsyncSocketException(
352         AsyncSocketException::END_OF_FILE,
353         "SSL connection closed locally"));
354
355   if (ssl_ != nullptr) {
356     SSL_free(ssl_);
357     ssl_ = nullptr;
358   }
359
360   // Close the socket.
361   AsyncSocket::closeNow();
362 }
363
364 void AsyncSSLSocket::shutdownWrite() {
365   // SSL sockets do not support half-shutdown, so just perform a full shutdown.
366   //
367   // (Performing a full shutdown here is more desirable than doing nothing at
368   // all.  The purpose of shutdownWrite() is normally to notify the other end
369   // of the connection that no more data will be sent.  If we do nothing, the
370   // other end will never know that no more data is coming, and this may result
371   // in protocol deadlock.)
372   close();
373 }
374
375 void AsyncSSLSocket::shutdownWriteNow() {
376   closeNow();
377 }
378
379 bool AsyncSSLSocket::good() const {
380   return (AsyncSocket::good() &&
381           (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
382            sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
383 }
384
385 // The TAsyncTransport definition of 'good' states that the transport is
386 // ready to perform reads and writes, so sslState_ == UNINIT must report !good.
387 // connecting can be true when the sslState_ == UNINIT because the AsyncSocket
388 // is connected but we haven't initiated the call to SSL_connect.
389 bool AsyncSSLSocket::connecting() const {
390   return (!server_ &&
391           (AsyncSocket::connecting() ||
392            (AsyncSocket::good() && (sslState_ == STATE_UNINIT ||
393                                      sslState_ == STATE_CONNECTING))));
394 }
395
396 std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
397   const unsigned char* protoName = nullptr;
398   unsigned protoLength;
399   if (getSelectedNextProtocolNoThrow(&protoName, &protoLength)) {
400     return std::string(reinterpret_cast<const char*>(protoName), protoLength);
401   }
402   return "";
403 }
404
405 bool AsyncSSLSocket::isEorTrackingEnabled() const {
406   return trackEor_;
407 }
408
409 void AsyncSSLSocket::setEorTracking(bool track) {
410   if (trackEor_ != track) {
411     trackEor_ = track;
412     appEorByteNo_ = 0;
413     minEorRawByteNo_ = 0;
414   }
415 }
416
417 size_t AsyncSSLSocket::getRawBytesWritten() const {
418   BIO *b;
419   if (!ssl_ || !(b = SSL_get_wbio(ssl_))) {
420     return 0;
421   }
422
423   return BIO_number_written(b);
424 }
425
426 size_t AsyncSSLSocket::getRawBytesReceived() const {
427   BIO *b;
428   if (!ssl_ || !(b = SSL_get_rbio(ssl_))) {
429     return 0;
430   }
431
432   return BIO_number_read(b);
433 }
434
435
436 void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
437   LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_
438              << ", state=" << int(state_) << ", sslState=" << sslState_ << ", "
439              << "events=" << eventFlags_ << ", server=" << short(server_)
440              << "): " << "sslAccept/Connect() called in invalid "
441              << "state, handshake callback " << handshakeCallback_
442              << ", new callback " << callback;
443   assert(!handshakeTimeout_.isScheduled());
444   sslState_ = STATE_ERROR;
445
446   AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
447                          "sslAccept() called with socket in invalid state");
448
449   handshakeEndTime_ = std::chrono::steady_clock::now();
450   if (callback) {
451     callback->handshakeErr(this, ex);
452   }
453
454   // Check the socket state not the ssl state here.
455   if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) {
456     failHandshake(__func__, ex);
457   }
458 }
459
460 void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
461       const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
462   DestructorGuard dg(this);
463   assert(eventBase_->isInEventBaseThread());
464   verifyPeer_ = verifyPeer;
465
466   // Make sure we're in the uninitialized state
467   if (!server_ || (sslState_ != STATE_UNINIT &&
468                    sslState_ != STATE_UNENCRYPTED) ||
469       handshakeCallback_ != nullptr) {
470     return invalidState(callback);
471   }
472
473   // Cache local and remote socket addresses to keep them available
474   // after socket file descriptor is closed.
475   if (cacheAddrOnFailure_ && -1 != getFd()) {
476     cacheLocalPeerAddr();
477   }
478
479   handshakeStartTime_ = std::chrono::steady_clock::now();
480   // Make end time at least >= start time.
481   handshakeEndTime_ = handshakeStartTime_;
482
483   sslState_ = STATE_ACCEPTING;
484   handshakeCallback_ = callback;
485
486   if (timeout > 0) {
487     handshakeTimeout_.scheduleTimeout(timeout);
488   }
489
490   /* register for a read operation (waiting for CLIENT HELLO) */
491   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
492 }
493
494 #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
495 void AsyncSSLSocket::attachSSLContext(
496   const std::shared_ptr<SSLContext>& ctx) {
497
498   // Check to ensure we are in client mode. Changing a server's ssl
499   // context doesn't make sense since clients of that server would likely
500   // become confused when the server's context changes.
501   DCHECK(!server_);
502   DCHECK(!ctx_);
503   DCHECK(ctx);
504   DCHECK(ctx->getSSLCtx());
505   ctx_ = ctx;
506
507   // In order to call attachSSLContext, detachSSLContext must have been
508   // previously called which sets the socket's context to the dummy
509   // context. Thus we must acquire this lock.
510   SpinLockGuard guard(dummyCtxLock);
511   SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx());
512 }
513
514 void AsyncSSLSocket::detachSSLContext() {
515   DCHECK(ctx_);
516   ctx_.reset();
517   // We aren't using the initial_ctx for now, and it can introduce race
518   // conditions in the destructor of the SSL object.
519 #ifndef OPENSSL_NO_TLSEXT
520   if (ssl_->initial_ctx) {
521     SSL_CTX_free(ssl_->initial_ctx);
522     ssl_->initial_ctx = nullptr;
523   }
524 #endif
525   SpinLockGuard guard(dummyCtxLock);
526   if (nullptr == dummyCtx) {
527     // We need to lazily initialize the dummy context so we don't
528     // accidentally override any programmatic settings to openssl
529     dummyCtx = new SSLContext;
530   }
531   // We must remove this socket's references to its context right now
532   // since this socket could get passed to any thread. If the context has
533   // had its locking disabled, just doing a set in attachSSLContext()
534   // would not be thread safe.
535   SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
536 }
537 #endif
538
539 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
540 void AsyncSSLSocket::switchServerSSLContext(
541   const std::shared_ptr<SSLContext>& handshakeCtx) {
542   CHECK(server_);
543   if (sslState_ != STATE_ACCEPTING) {
544     // We log it here and allow the switch.
545     // It should not affect our re-negotiation support (which
546     // is not supported now).
547     VLOG(6) << "fd=" << getFd()
548             << " renegotation detected when switching SSL_CTX";
549   }
550
551   setup_SSL_CTX(handshakeCtx->getSSLCtx());
552   SSL_CTX_set_info_callback(handshakeCtx->getSSLCtx(),
553                             AsyncSSLSocket::sslInfoCallback);
554   handshakeCtx_ = handshakeCtx;
555   SSL_set_SSL_CTX(ssl_, handshakeCtx->getSSLCtx());
556 }
557
558 bool AsyncSSLSocket::isServerNameMatch() const {
559   CHECK(!server_);
560
561   if (!ssl_) {
562     return false;
563   }
564
565   SSL_SESSION *ss = SSL_get_session(ssl_);
566   if (!ss) {
567     return false;
568   }
569
570   if(!ss->tlsext_hostname) {
571     return false;
572   }
573   return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
574 }
575
576 void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
577   tlsextHostname_ = std::move(serverName);
578 }
579
580 #endif
581
582 void AsyncSSLSocket::timeoutExpired() noexcept {
583   if (state_ == StateEnum::ESTABLISHED &&
584       (sslState_ == STATE_CACHE_LOOKUP ||
585        sslState_ == STATE_ASYNC_PENDING)) {
586     sslState_ = STATE_ERROR;
587     // We are expecting a callback in restartSSLAccept.  The cache lookup
588     // and rsa-call necessarily have pointers to this ssl socket, so delay
589     // the cleanup until he calls us back.
590   } else {
591     assert(state_ == StateEnum::ESTABLISHED &&
592            (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
593     DestructorGuard dg(this);
594     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
595                            (sslState_ == STATE_CONNECTING) ?
596                            "SSL connect timed out" : "SSL accept timed out");
597     failHandshake(__func__, ex);
598   }
599 }
600
601 int AsyncSSLSocket::getSSLExDataIndex() {
602   static auto index = SSL_get_ex_new_index(
603       0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
604   return index;
605 }
606
607 AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
608   return static_cast<AsyncSSLSocket *>(SSL_get_ex_data(ssl,
609       getSSLExDataIndex()));
610 }
611
612 void AsyncSSLSocket::failHandshake(const char* /* fn */,
613                                    const AsyncSocketException& ex) {
614   startFail();
615   if (handshakeTimeout_.isScheduled()) {
616     handshakeTimeout_.cancelTimeout();
617   }
618   invokeHandshakeErr(ex);
619   finishFail();
620 }
621
622 void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
623   handshakeEndTime_ = std::chrono::steady_clock::now();
624   if (handshakeCallback_ != nullptr) {
625     HandshakeCB* callback = handshakeCallback_;
626     handshakeCallback_ = nullptr;
627     callback->handshakeErr(this, ex);
628   }
629 }
630
631 void AsyncSSLSocket::invokeHandshakeCB() {
632   handshakeEndTime_ = std::chrono::steady_clock::now();
633   if (handshakeTimeout_.isScheduled()) {
634     handshakeTimeout_.cancelTimeout();
635   }
636   if (handshakeCallback_) {
637     HandshakeCB* callback = handshakeCallback_;
638     handshakeCallback_ = nullptr;
639     callback->handshakeSuc(this);
640   }
641 }
642
643 void AsyncSSLSocket::cacheLocalPeerAddr() {
644   SocketAddress address;
645   try {
646     getLocalAddress(&address);
647     getPeerAddress(&address);
648   } catch (const std::system_error& e) {
649     // The handle can be still valid while the connection is already closed.
650     if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
651       throw;
652     }
653   }
654 }
655
656 void AsyncSSLSocket::connect(ConnectCallback* callback,
657                               const folly::SocketAddress& address,
658                               int timeout,
659                               const OptionMap &options,
660                               const folly::SocketAddress& bindAddr)
661                               noexcept {
662   assert(!server_);
663   assert(state_ == StateEnum::UNINIT);
664   assert(sslState_ == STATE_UNINIT);
665   AsyncSSLSocketConnector *connector =
666     new AsyncSSLSocketConnector(this, callback, timeout);
667   AsyncSocket::connect(connector, address, timeout, options, bindAddr);
668 }
669
670 void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
671   // apply the settings specified in verifyPeer_
672   if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
673     if(ctx_->needsPeerVerification()) {
674       SSL_set_verify(ssl, ctx_->getVerificationMode(),
675         AsyncSSLSocket::sslVerifyCallback);
676     }
677   } else {
678     if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
679         verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT) {
680       SSL_set_verify(ssl, SSLContext::getVerificationMode(verifyPeer_),
681         AsyncSSLSocket::sslVerifyCallback);
682     }
683   }
684 }
685
686 bool AsyncSSLSocket::setupSSLBio() {
687   auto wb = BIO_new(&sslWriteBioMethod);
688
689   if (!wb) {
690     return false;
691   }
692
693   OpenSSLUtils::setBioAppData(wb, this);
694   BIO_set_fd(wb, fd_, BIO_NOCLOSE);
695   SSL_set_bio(ssl_, wb, wb);
696   return true;
697 }
698
699 void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
700         const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
701   DestructorGuard dg(this);
702   assert(eventBase_->isInEventBaseThread());
703
704   // Cache local and remote socket addresses to keep them available
705   // after socket file descriptor is closed.
706   if (cacheAddrOnFailure_ && -1 != getFd()) {
707     cacheLocalPeerAddr();
708   }
709
710   verifyPeer_ = verifyPeer;
711
712   // Make sure we're in the uninitialized state
713   if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
714                   STATE_UNENCRYPTED) ||
715       handshakeCallback_ != nullptr) {
716     return invalidState(callback);
717   }
718
719   sslState_ = STATE_CONNECTING;
720   handshakeCallback_ = callback;
721
722   try {
723     ssl_ = ctx_->createSSL();
724   } catch (std::exception &e) {
725     sslState_ = STATE_ERROR;
726     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
727                            "error calling SSLContext::createSSL()");
728     LOG(ERROR) << "AsyncSSLSocket::sslConn(this=" << this << ", fd="
729             << fd_ << "): " << e.what();
730     return failHandshake(__func__, ex);
731   }
732
733   if (!setupSSLBio()) {
734     sslState_ = STATE_ERROR;
735     AsyncSocketException ex(
736         AsyncSocketException::INTERNAL_ERROR, "error creating SSL bio");
737     return failHandshake(__func__, ex);
738   }
739
740   applyVerificationOptions(ssl_);
741
742   if (sslSession_ != nullptr) {
743     SSL_set_session(ssl_, sslSession_);
744     SSL_SESSION_free(sslSession_);
745     sslSession_ = nullptr;
746   }
747 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
748   if (tlsextHostname_.size()) {
749     SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
750   }
751 #endif
752
753   SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
754
755   handshakeConnectTimeout_ = timeout;
756   startSSLConnect();
757 }
758
759 // This could be called multiple times, during normal ssl connections
760 // and after TFO fallback.
761 void AsyncSSLSocket::startSSLConnect() {
762   handshakeStartTime_ = std::chrono::steady_clock::now();
763   // Make end time at least >= start time.
764   handshakeEndTime_ = handshakeStartTime_;
765   if (handshakeConnectTimeout_ > 0) {
766     handshakeTimeout_.scheduleTimeout(handshakeConnectTimeout_);
767   }
768   handleConnect();
769 }
770
771 SSL_SESSION *AsyncSSLSocket::getSSLSession() {
772   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
773     return SSL_get1_session(ssl_);
774   }
775
776   return sslSession_;
777 }
778
779 const SSL* AsyncSSLSocket::getSSL() const {
780   return ssl_;
781 }
782
783 void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
784   sslSession_ = session;
785   if (!takeOwnership && session != nullptr) {
786     // Increment the reference count
787     CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
788   }
789 }
790
791 void AsyncSSLSocket::getSelectedNextProtocol(
792     const unsigned char** protoName,
793     unsigned* protoLen,
794     SSLContext::NextProtocolType* protoType) const {
795   if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) {
796     throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
797                               "NPN not supported");
798   }
799 }
800
801 bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
802     const unsigned char** protoName,
803     unsigned* protoLen,
804     SSLContext::NextProtocolType* protoType) const {
805   *protoName = nullptr;
806   *protoLen = 0;
807 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
808   SSL_get0_alpn_selected(ssl_, protoName, protoLen);
809   if (*protoLen > 0) {
810     if (protoType) {
811       *protoType = SSLContext::NextProtocolType::ALPN;
812     }
813     return true;
814   }
815 #endif
816 #ifdef OPENSSL_NPN_NEGOTIATED
817   SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
818   if (protoType) {
819     *protoType = SSLContext::NextProtocolType::NPN;
820   }
821   return true;
822 #else
823   (void)protoType;
824   return false;
825 #endif
826 }
827
828 bool AsyncSSLSocket::getSSLSessionReused() const {
829   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
830     return SSL_session_reused(ssl_);
831   }
832   return false;
833 }
834
835 const char *AsyncSSLSocket::getNegotiatedCipherName() const {
836   return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr;
837 }
838
839 /* static */
840 const char* AsyncSSLSocket::getSSLServerNameFromSSL(SSL* ssl) {
841   if (ssl == nullptr) {
842     return nullptr;
843   }
844 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
845   return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
846 #else
847   return nullptr;
848 #endif
849 }
850
851 const char *AsyncSSLSocket::getSSLServerName() const {
852 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
853   return getSSLServerNameFromSSL(ssl_);
854 #else
855   throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
856                              "SNI not supported");
857 #endif
858 }
859
860 const char *AsyncSSLSocket::getSSLServerNameNoThrow() const {
861   return getSSLServerNameFromSSL(ssl_);
862 }
863
864 int AsyncSSLSocket::getSSLVersion() const {
865   return (ssl_ != nullptr) ? SSL_version(ssl_) : 0;
866 }
867
868 const char *AsyncSSLSocket::getSSLCertSigAlgName() const {
869   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
870   if (cert) {
871     int nid = OBJ_obj2nid(cert->sig_alg->algorithm);
872     return OBJ_nid2ln(nid);
873   }
874   return nullptr;
875 }
876
877 int AsyncSSLSocket::getSSLCertSize() const {
878   int certSize = 0;
879   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
880   if (cert) {
881     EVP_PKEY *key = X509_get_pubkey(cert);
882     certSize = EVP_PKEY_bits(key);
883     EVP_PKEY_free(key);
884   }
885   return certSize;
886 }
887
888 const X509* AsyncSSLSocket::getSelfCert() const {
889   return (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
890 }
891
892 bool AsyncSSLSocket::willBlock(int ret,
893                                int* sslErrorOut,
894                                unsigned long* errErrorOut) noexcept {
895   *errErrorOut = 0;
896   int error = *sslErrorOut = SSL_get_error(ssl_, ret);
897   if (error == SSL_ERROR_WANT_READ) {
898     // Register for read event if not already.
899     updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
900     return true;
901   } else if (error == SSL_ERROR_WANT_WRITE) {
902     VLOG(3) << "AsyncSSLSocket(fd=" << fd_
903             << ", state=" << int(state_) << ", sslState="
904             << sslState_ << ", events=" << eventFlags_ << "): "
905             << "SSL_ERROR_WANT_WRITE";
906     // Register for write event if not already.
907     updateEventRegistration(EventHandler::WRITE, EventHandler::READ);
908     return true;
909 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
910   } else if (error == SSL_ERROR_WANT_SESS_CACHE_LOOKUP) {
911     // We will block but we can't register our own socket.  The callback that
912     // triggered this code will re-call handleAccept at the appropriate time.
913
914     // We can only get here if the linked libssl.so has support for this feature
915     // as well, otherwise SSL_get_error cannot return our error code.
916     sslState_ = STATE_CACHE_LOOKUP;
917
918     // Unregister for all events while blocked here
919     updateEventRegistration(EventHandler::NONE,
920                             EventHandler::READ | EventHandler::WRITE);
921
922     // The timeout (if set) keeps running here
923     return true;
924 #endif
925   } else if (0
926 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
927       || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING
928 #endif
929 #ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
930       || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
931 #endif
932       ) {
933     // Our custom openssl function has kicked off an async request to do
934     // rsa/ecdsa private key operation.  When that call returns, a callback will
935     // be invoked that will re-call handleAccept.
936     sslState_ = STATE_ASYNC_PENDING;
937
938     // Unregister for all events while blocked here
939     updateEventRegistration(
940       EventHandler::NONE,
941       EventHandler::READ | EventHandler::WRITE
942     );
943
944     // The timeout (if set) keeps running here
945     return true;
946   } else {
947     unsigned long lastError = *errErrorOut = ERR_get_error();
948     VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
949             << "state=" << state_ << ", "
950             << "sslState=" << sslState_ << ", "
951             << "events=" << std::hex << eventFlags_ << "): "
952             << "SSL error: " << error << ", "
953             << "errno: " << errno << ", "
954             << "ret: " << ret << ", "
955             << "read: " << BIO_number_read(SSL_get_rbio(ssl_)) << ", "
956             << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
957             << "func: " << ERR_func_error_string(lastError) << ", "
958             << "reason: " << ERR_reason_error_string(lastError);
959     return false;
960   }
961 }
962
963 void AsyncSSLSocket::checkForImmediateRead() noexcept {
964   // openssl may have buffered data that it read from the socket already.
965   // In this case we have to process it immediately, rather than waiting for
966   // the socket to become readable again.
967   if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
968     AsyncSocket::handleRead();
969   }
970 }
971
972 void
973 AsyncSSLSocket::restartSSLAccept()
974 {
975   VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this
976           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
977           << "sslState=" << sslState_ << ", events=" << eventFlags_;
978   DestructorGuard dg(this);
979   assert(
980     sslState_ == STATE_CACHE_LOOKUP ||
981     sslState_ == STATE_ASYNC_PENDING ||
982     sslState_ == STATE_ERROR ||
983     sslState_ == STATE_CLOSED);
984   if (sslState_ == STATE_CLOSED) {
985     // I sure hope whoever closed this socket didn't delete it already,
986     // but this is not strictly speaking an error
987     return;
988   }
989   if (sslState_ == STATE_ERROR) {
990     // go straight to fail if timeout expired during lookup
991     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
992                            "SSL accept timed out");
993     failHandshake(__func__, ex);
994     return;
995   }
996   sslState_ = STATE_ACCEPTING;
997   this->handleAccept();
998 }
999
1000 void
1001 AsyncSSLSocket::handleAccept() noexcept {
1002   VLOG(3) << "AsyncSSLSocket::handleAccept() this=" << this
1003           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
1004           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1005   assert(server_);
1006   assert(state_ == StateEnum::ESTABLISHED &&
1007          sslState_ == STATE_ACCEPTING);
1008   if (!ssl_) {
1009     /* lazily create the SSL structure */
1010     try {
1011       ssl_ = ctx_->createSSL();
1012     } catch (std::exception &e) {
1013       sslState_ = STATE_ERROR;
1014       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1015                              "error calling SSLContext::createSSL()");
1016       LOG(ERROR) << "AsyncSSLSocket::handleAccept(this=" << this
1017                  << ", fd=" << fd_ << "): " << e.what();
1018       return failHandshake(__func__, ex);
1019     }
1020
1021     if (!setupSSLBio()) {
1022       sslState_ = STATE_ERROR;
1023       AsyncSocketException ex(
1024           AsyncSocketException::INTERNAL_ERROR, "error creating write bio");
1025       return failHandshake(__func__, ex);
1026     }
1027
1028     SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
1029
1030     applyVerificationOptions(ssl_);
1031   }
1032
1033   if (server_ && parseClientHello_) {
1034     SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback);
1035     SSL_set_msg_callback_arg(ssl_, this);
1036   }
1037
1038   int ret = SSL_accept(ssl_);
1039   if (ret <= 0) {
1040     int sslError;
1041     unsigned long errError;
1042     int errnoCopy = errno;
1043     if (willBlock(ret, &sslError, &errError)) {
1044       return;
1045     } else {
1046       sslState_ = STATE_ERROR;
1047       SSLException ex(sslError, errError, ret, errnoCopy);
1048       return failHandshake(__func__, ex);
1049     }
1050   }
1051
1052   handshakeComplete_ = true;
1053   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1054
1055   // Move into STATE_ESTABLISHED in the normal case that we are in
1056   // STATE_ACCEPTING.
1057   sslState_ = STATE_ESTABLISHED;
1058
1059   VLOG(3) << "AsyncSSLSocket " << this << ": fd " << fd_
1060           << " successfully accepted; state=" << int(state_)
1061           << ", sslState=" << sslState_ << ", events=" << eventFlags_;
1062
1063   // Remember the EventBase we are attached to, before we start invoking any
1064   // callbacks (since the callbacks may call detachEventBase()).
1065   EventBase* originalEventBase = eventBase_;
1066
1067   // Call the accept callback.
1068   invokeHandshakeCB();
1069
1070   // Note that the accept callback may have changed our state.
1071   // (set or unset the read callback, called write(), closed the socket, etc.)
1072   // The following code needs to handle these situations correctly.
1073   //
1074   // If the socket has been closed, readCallback_ and writeReqHead_ will
1075   // always be nullptr, so that will prevent us from trying to read or write.
1076   //
1077   // The main thing to check for is if eventBase_ is still originalEventBase.
1078   // If not, we have been detached from this event base, so we shouldn't
1079   // perform any more operations.
1080   if (eventBase_ != originalEventBase) {
1081     return;
1082   }
1083
1084   AsyncSocket::handleInitialReadWrite();
1085 }
1086
1087 void
1088 AsyncSSLSocket::handleConnect() noexcept {
1089   VLOG(3) <<  "AsyncSSLSocket::handleConnect() this=" << this
1090           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
1091           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1092   assert(!server_);
1093   if (state_ < StateEnum::ESTABLISHED) {
1094     return AsyncSocket::handleConnect();
1095   }
1096
1097   assert(
1098       (state_ == StateEnum::FAST_OPEN || state_ == StateEnum::ESTABLISHED) &&
1099       sslState_ == STATE_CONNECTING);
1100   assert(ssl_);
1101
1102   auto originalState = state_;
1103   int ret = SSL_connect(ssl_);
1104   if (ret <= 0) {
1105     int sslError;
1106     unsigned long errError;
1107     int errnoCopy = errno;
1108     if (willBlock(ret, &sslError, &errError)) {
1109       // We fell back to connecting state due to TFO
1110       if (state_ == StateEnum::CONNECTING) {
1111         DCHECK_EQ(StateEnum::FAST_OPEN, originalState);
1112         if (handshakeTimeout_.isScheduled()) {
1113           handshakeTimeout_.cancelTimeout();
1114         }
1115       }
1116       return;
1117     } else {
1118       sslState_ = STATE_ERROR;
1119       SSLException ex(sslError, errError, ret, errnoCopy);
1120       return failHandshake(__func__, ex);
1121     }
1122   }
1123
1124   handshakeComplete_ = true;
1125   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1126
1127   // Move into STATE_ESTABLISHED in the normal case that we are in
1128   // STATE_CONNECTING.
1129   sslState_ = STATE_ESTABLISHED;
1130
1131   VLOG(3) << "AsyncSSLSocket " << this << ": "
1132           << "fd " << fd_ << " successfully connected; "
1133           << "state=" << int(state_) << ", sslState=" << sslState_
1134           << ", events=" << eventFlags_;
1135
1136   // Remember the EventBase we are attached to, before we start invoking any
1137   // callbacks (since the callbacks may call detachEventBase()).
1138   EventBase* originalEventBase = eventBase_;
1139
1140   // Call the handshake callback.
1141   invokeHandshakeCB();
1142
1143   // Note that the connect callback may have changed our state.
1144   // (set or unset the read callback, called write(), closed the socket, etc.)
1145   // The following code needs to handle these situations correctly.
1146   //
1147   // If the socket has been closed, readCallback_ and writeReqHead_ will
1148   // always be nullptr, so that will prevent us from trying to read or write.
1149   //
1150   // The main thing to check for is if eventBase_ is still originalEventBase.
1151   // If not, we have been detached from this event base, so we shouldn't
1152   // perform any more operations.
1153   if (eventBase_ != originalEventBase) {
1154     return;
1155   }
1156
1157   AsyncSocket::handleInitialReadWrite();
1158 }
1159
1160 void AsyncSSLSocket::invokeConnectSuccess() {
1161   if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
1162     // If we failed TFO, we'd fall back to trying to connect the socket,
1163     // to setup things like timeouts.
1164     startSSLConnect();
1165   }
1166   AsyncSocket::invokeConnectSuccess();
1167 }
1168
1169 void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
1170 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1171   // turn on the buffer movable in openssl
1172   if (bufferMovableEnabled_ && ssl_ != nullptr && !isBufferMovable_ &&
1173       callback != nullptr && callback->isBufferMovable()) {
1174     SSL_set_mode(ssl_, SSL_get_mode(ssl_) | SSL_MODE_MOVE_BUFFER_OWNERSHIP);
1175     isBufferMovable_ = true;
1176   }
1177 #endif
1178
1179   AsyncSocket::setReadCB(callback);
1180 }
1181
1182 void AsyncSSLSocket::setBufferMovableEnabled(bool enabled) {
1183   bufferMovableEnabled_ = enabled;
1184 }
1185
1186 void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
1187   CHECK(readCallback_);
1188   if (isBufferMovable_) {
1189     *buf = nullptr;
1190     *buflen = 0;
1191   } else {
1192     // buf is necessary for SSLSocket without SSL_MODE_MOVE_BUFFER_OWNERSHIP
1193     readCallback_->getReadBuffer(buf, buflen);
1194   }
1195 }
1196
1197 void
1198 AsyncSSLSocket::handleRead() noexcept {
1199   VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
1200           << ", state=" << int(state_) << ", "
1201           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1202   if (state_ < StateEnum::ESTABLISHED) {
1203     return AsyncSocket::handleRead();
1204   }
1205
1206
1207   if (sslState_ == STATE_ACCEPTING) {
1208     assert(server_);
1209     handleAccept();
1210     return;
1211   }
1212   else if (sslState_ == STATE_CONNECTING) {
1213     assert(!server_);
1214     handleConnect();
1215     return;
1216   }
1217
1218   // Normal read
1219   AsyncSocket::handleRead();
1220 }
1221
1222 AsyncSocket::ReadResult
1223 AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
1224   VLOG(4) << "AsyncSSLSocket::performRead() this=" << this << ", buf=" << *buf
1225           << ", buflen=" << *buflen;
1226
1227   if (sslState_ == STATE_UNENCRYPTED) {
1228     return AsyncSocket::performRead(buf, buflen, offset);
1229   }
1230
1231   ssize_t bytes = 0;
1232   if (!isBufferMovable_) {
1233     bytes = SSL_read(ssl_, *buf, *buflen);
1234   }
1235 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1236   else {
1237     bytes = SSL_read_buf(ssl_, buf, (int *) offset, (int *) buflen);
1238   }
1239 #endif
1240
1241   if (server_ && renegotiateAttempted_) {
1242     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1243                << ", sslstate=" << sslState_ << ", events=" << eventFlags_
1244                << "): client intitiated SSL renegotiation not permitted";
1245     return ReadResult(
1246         READ_ERROR,
1247         folly::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
1248   }
1249   if (bytes <= 0) {
1250     int error = SSL_get_error(ssl_, bytes);
1251     if (error == SSL_ERROR_WANT_READ) {
1252       // The caller will register for read event if not already.
1253       if (errno == EWOULDBLOCK || errno == EAGAIN) {
1254         return ReadResult(READ_BLOCKING);
1255       } else {
1256         return ReadResult(READ_ERROR);
1257       }
1258     } else if (error == SSL_ERROR_WANT_WRITE) {
1259       // TODO: Even though we are attempting to read data, SSL_read() may
1260       // need to write data if renegotiation is being performed.  We currently
1261       // don't support this and just fail the read.
1262       LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1263                  << ", sslState=" << sslState_ << ", events=" << eventFlags_
1264                  << "): unsupported SSL renegotiation during read";
1265       return ReadResult(
1266           READ_ERROR,
1267           folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
1268     } else {
1269       if (zero_return(error, bytes)) {
1270         return ReadResult(bytes);
1271       }
1272       long errError = ERR_get_error();
1273       VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
1274               << "state=" << state_ << ", "
1275               << "sslState=" << sslState_ << ", "
1276               << "events=" << std::hex << eventFlags_ << "): "
1277               << "bytes: " << bytes << ", "
1278               << "error: " << error << ", "
1279               << "errno: " << errno << ", "
1280               << "func: " << ERR_func_error_string(errError) << ", "
1281               << "reason: " << ERR_reason_error_string(errError);
1282       return ReadResult(
1283           READ_ERROR,
1284           folly::make_unique<SSLException>(error, errError, bytes, errno));
1285     }
1286   } else {
1287     appBytesReceived_ += bytes;
1288     return ReadResult(bytes);
1289   }
1290 }
1291
1292 void AsyncSSLSocket::handleWrite() noexcept {
1293   VLOG(5) << "AsyncSSLSocket::handleWrite() this=" << this << ", fd=" << fd_
1294           << ", state=" << int(state_) << ", "
1295           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1296   if (state_ < StateEnum::ESTABLISHED) {
1297     return AsyncSocket::handleWrite();
1298   }
1299
1300   if (sslState_ == STATE_ACCEPTING) {
1301     assert(server_);
1302     handleAccept();
1303     return;
1304   }
1305
1306   if (sslState_ == STATE_CONNECTING) {
1307     assert(!server_);
1308     handleConnect();
1309     return;
1310   }
1311
1312   // Normal write
1313   AsyncSocket::handleWrite();
1314 }
1315
1316 AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
1317   if (error == SSL_ERROR_WANT_READ) {
1318     // Even though we are attempting to write data, SSL_write() may
1319     // need to read data if renegotiation is being performed.  We currently
1320     // don't support this and just fail the write.
1321     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1322                << ", sslState=" << sslState_ << ", events=" << eventFlags_
1323                << "): "
1324                << "unsupported SSL renegotiation during write";
1325     return WriteResult(
1326         WRITE_ERROR,
1327         folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
1328   } else {
1329     if (zero_return(error, rc)) {
1330       return WriteResult(0);
1331     }
1332     auto errError = ERR_get_error();
1333     VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1334             << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
1335             << "SSL error: " << error << ", errno: " << errno
1336             << ", func: " << ERR_func_error_string(errError)
1337             << ", reason: " << ERR_reason_error_string(errError);
1338     return WriteResult(
1339         WRITE_ERROR,
1340         folly::make_unique<SSLException>(error, errError, rc, errno));
1341   }
1342 }
1343
1344 AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
1345     const iovec* vec,
1346     uint32_t count,
1347     WriteFlags flags,
1348     uint32_t* countWritten,
1349     uint32_t* partialWritten) {
1350   if (sslState_ == STATE_UNENCRYPTED) {
1351     return AsyncSocket::performWrite(
1352       vec, count, flags, countWritten, partialWritten);
1353   }
1354   if (sslState_ != STATE_ESTABLISHED) {
1355     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1356                << ", sslState=" << sslState_
1357                << ", events=" << eventFlags_ << "): "
1358                << "TODO: AsyncSSLSocket currently does not support calling "
1359                << "write() before the handshake has fully completed";
1360     return WriteResult(
1361         WRITE_ERROR, folly::make_unique<SSLException>(SSLError::EARLY_WRITE));
1362   }
1363
1364   bool cork = isSet(flags, WriteFlags::CORK);
1365   CorkGuard guard(fd_, count > 1, cork, &corked_);
1366
1367   // Declare a buffer used to hold small write requests.  It could point to a
1368   // memory block either on stack or on heap. If it is on heap, we release it
1369   // manually when scope exits
1370   char* combinedBuf{nullptr};
1371   SCOPE_EXIT {
1372     // Note, always keep this check consistent with what we do below
1373     if (combinedBuf != nullptr && minWriteSize_ > MAX_STACK_BUF_SIZE) {
1374       delete[] combinedBuf;
1375     }
1376   };
1377
1378   *countWritten = 0;
1379   *partialWritten = 0;
1380   ssize_t totalWritten = 0;
1381   size_t bytesStolenFromNextBuffer = 0;
1382   for (uint32_t i = 0; i < count; i++) {
1383     const iovec* v = vec + i;
1384     size_t offset = bytesStolenFromNextBuffer;
1385     bytesStolenFromNextBuffer = 0;
1386     size_t len = v->iov_len - offset;
1387     const void* buf;
1388     if (len == 0) {
1389       (*countWritten)++;
1390       continue;
1391     }
1392     buf = ((const char*)v->iov_base) + offset;
1393
1394     ssize_t bytes;
1395     uint32_t buffersStolen = 0;
1396     if ((len < minWriteSize_) && ((i + 1) < count)) {
1397       // Combine this buffer with part or all of the next buffers in
1398       // order to avoid really small-grained calls to SSL_write().
1399       // Each call to SSL_write() produces a separate record in
1400       // the egress SSL stream, and we've found that some low-end
1401       // mobile clients can't handle receiving an HTTP response
1402       // header and the first part of the response body in two
1403       // separate SSL records (even if those two records are in
1404       // the same TCP packet).
1405
1406       if (combinedBuf == nullptr) {
1407         if (minWriteSize_ > MAX_STACK_BUF_SIZE) {
1408           // Allocate the buffer on heap
1409           combinedBuf = new char[minWriteSize_];
1410         } else {
1411           // Allocate the buffer on stack
1412           combinedBuf = (char*)alloca(minWriteSize_);
1413         }
1414       }
1415       assert(combinedBuf != nullptr);
1416
1417       memcpy(combinedBuf, buf, len);
1418       do {
1419         // INVARIANT: i + buffersStolen == complete chunks serialized
1420         uint32_t nextIndex = i + buffersStolen + 1;
1421         bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len,
1422                                              minWriteSize_ - len);
1423         memcpy(combinedBuf + len, vec[nextIndex].iov_base,
1424                bytesStolenFromNextBuffer);
1425         len += bytesStolenFromNextBuffer;
1426         if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
1427           // couldn't steal the whole buffer
1428           break;
1429         } else {
1430           bytesStolenFromNextBuffer = 0;
1431           buffersStolen++;
1432         }
1433       } while ((i + buffersStolen + 1) < count && (len < minWriteSize_));
1434       bytes = eorAwareSSLWrite(
1435         ssl_, combinedBuf, len,
1436         (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
1437
1438     } else {
1439       bytes = eorAwareSSLWrite(ssl_, buf, len,
1440                            (isSet(flags, WriteFlags::EOR) && i + 1 == count));
1441     }
1442
1443     if (bytes <= 0) {
1444       int error = SSL_get_error(ssl_, bytes);
1445       if (error == SSL_ERROR_WANT_WRITE) {
1446         // The caller will register for write event if not already.
1447         *partialWritten = offset;
1448         return WriteResult(totalWritten);
1449       }
1450       auto writeResult = interpretSSLError(bytes, error);
1451       if (writeResult.writeReturn < 0) {
1452         return writeResult;
1453       } // else fall through to below to correctly record totalWritten
1454     }
1455
1456     totalWritten += bytes;
1457
1458     if (bytes == (ssize_t)len) {
1459       // The full iovec is written.
1460       (*countWritten) += 1 + buffersStolen;
1461       i += buffersStolen;
1462       // continue
1463     } else {
1464       bytes += offset; // adjust bytes to account for all of v
1465       while (bytes >= (ssize_t)v->iov_len) {
1466         // We combined this buf with part or all of the next one, and
1467         // we managed to write all of this buf but not all of the bytes
1468         // from the next one that we'd hoped to write.
1469         bytes -= v->iov_len;
1470         (*countWritten)++;
1471         v = &(vec[++i]);
1472       }
1473       *partialWritten = bytes;
1474       return WriteResult(totalWritten);
1475     }
1476   }
1477
1478   return WriteResult(totalWritten);
1479 }
1480
1481 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
1482                                       bool eor) {
1483   if (eor && trackEor_) {
1484     if (appEorByteNo_) {
1485       // cannot track for more than one app byte EOR
1486       CHECK(appEorByteNo_ == appBytesWritten_ + n);
1487     } else {
1488       appEorByteNo_ = appBytesWritten_ + n;
1489     }
1490
1491     // 1. It is fine to keep updating minEorRawByteNo_.
1492     // 2. It is _min_ in the sense that SSL record will add some overhead.
1493     minEorRawByteNo_ = getRawBytesWritten() + n;
1494   }
1495
1496   n = sslWriteImpl(ssl, buf, n);
1497   if (n > 0) {
1498     appBytesWritten_ += n;
1499     if (appEorByteNo_) {
1500       if (getRawBytesWritten() >= minEorRawByteNo_) {
1501         minEorRawByteNo_ = 0;
1502       }
1503       if(appBytesWritten_ == appEorByteNo_) {
1504         appEorByteNo_ = 0;
1505       } else {
1506         CHECK(appBytesWritten_ < appEorByteNo_);
1507       }
1508     }
1509   }
1510   return n;
1511 }
1512
1513 void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
1514   AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl);
1515   if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
1516     sslSocket->renegotiateAttempted_ = true;
1517   }
1518   if (where & SSL_CB_READ_ALERT) {
1519     const char* type = SSL_alert_type_string(ret);
1520     if (type) {
1521       const char* desc = SSL_alert_desc_string(ret);
1522       sslSocket->alertsReceived_.emplace_back(
1523           *type, StringPiece(desc, std::strlen(desc)));
1524     }
1525   }
1526 }
1527
1528 int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
1529   struct msghdr msg;
1530   struct iovec iov;
1531   int flags = 0;
1532   AsyncSSLSocket* tsslSock;
1533
1534   iov.iov_base = const_cast<char*>(in);
1535   iov.iov_len = inl;
1536   memset(&msg, 0, sizeof(msg));
1537   msg.msg_iov = &iov;
1538   msg.msg_iovlen = 1;
1539
1540   auto appData = OpenSSLUtils::getBioAppData(b);
1541   CHECK(appData);
1542
1543   tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
1544   CHECK(tsslSock);
1545
1546   if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ &&
1547       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
1548     flags = MSG_EOR;
1549   }
1550
1551 #ifdef MSG_NOSIGNAL
1552   flags |= MSG_NOSIGNAL;
1553 #endif
1554
1555   auto result =
1556       tsslSock->sendSocketMessage(BIO_get_fd(b, nullptr), &msg, flags);
1557   BIO_clear_retry_flags(b);
1558   if (!result.exception && result.writeReturn <= 0) {
1559     if (OpenSSLUtils::getBioShouldRetryWrite(result.writeReturn)) {
1560       BIO_set_retry_write(b);
1561     }
1562   }
1563   return result.writeReturn;
1564 }
1565
1566 int AsyncSSLSocket::sslVerifyCallback(
1567     int preverifyOk,
1568     X509_STORE_CTX* x509Ctx) {
1569   SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
1570     x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
1571   AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
1572
1573   VLOG(3) <<  "AsyncSSLSocket::sslVerifyCallback() this=" << self << ", "
1574           << "fd=" << self->fd_ << ", preverifyOk=" << preverifyOk;
1575   return (self->handshakeCallback_) ?
1576     self->handshakeCallback_->handshakeVer(self, preverifyOk, x509Ctx) :
1577     preverifyOk;
1578 }
1579
1580 void AsyncSSLSocket::enableClientHelloParsing()  {
1581     parseClientHello_ = true;
1582     clientHelloInfo_.reset(new ssl::ClientHelloInfo());
1583 }
1584
1585 void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl)  {
1586   SSL_set_msg_callback(ssl, nullptr);
1587   SSL_set_msg_callback_arg(ssl, nullptr);
1588   clientHelloInfo_->clientHelloBuf_.clear();
1589 }
1590
1591 void AsyncSSLSocket::clientHelloParsingCallback(int written,
1592                                                 int /* version */,
1593                                                 int contentType,
1594                                                 const void* buf,
1595                                                 size_t len,
1596                                                 SSL* ssl,
1597                                                 void* arg) {
1598   AsyncSSLSocket *sock = static_cast<AsyncSSLSocket*>(arg);
1599   if (written != 0) {
1600     sock->resetClientHelloParsing(ssl);
1601     return;
1602   }
1603   if (contentType != SSL3_RT_HANDSHAKE) {
1604     return;
1605   }
1606   if (len == 0) {
1607     return;
1608   }
1609
1610   auto& clientHelloBuf = sock->clientHelloInfo_->clientHelloBuf_;
1611   clientHelloBuf.append(IOBuf::wrapBuffer(buf, len));
1612   try {
1613     Cursor cursor(clientHelloBuf.front());
1614     if (cursor.read<uint8_t>() != SSL3_MT_CLIENT_HELLO) {
1615       sock->resetClientHelloParsing(ssl);
1616       return;
1617     }
1618
1619     if (cursor.totalLength() < 3) {
1620       clientHelloBuf.trimEnd(len);
1621       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1622       return;
1623     }
1624
1625     uint32_t messageLength = cursor.read<uint8_t>();
1626     messageLength <<= 8;
1627     messageLength |= cursor.read<uint8_t>();
1628     messageLength <<= 8;
1629     messageLength |= cursor.read<uint8_t>();
1630     if (cursor.totalLength() < messageLength) {
1631       clientHelloBuf.trimEnd(len);
1632       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1633       return;
1634     }
1635
1636     sock->clientHelloInfo_->clientHelloMajorVersion_ = cursor.read<uint8_t>();
1637     sock->clientHelloInfo_->clientHelloMinorVersion_ = cursor.read<uint8_t>();
1638
1639     cursor.skip(4); // gmt_unix_time
1640     cursor.skip(28); // random_bytes
1641
1642     cursor.skip(cursor.read<uint8_t>()); // session_id
1643
1644     uint16_t cipherSuitesLength = cursor.readBE<uint16_t>();
1645     for (int i = 0; i < cipherSuitesLength; i += 2) {
1646       sock->clientHelloInfo_->
1647         clientHelloCipherSuites_.push_back(cursor.readBE<uint16_t>());
1648     }
1649
1650     uint8_t compressionMethodsLength = cursor.read<uint8_t>();
1651     for (int i = 0; i < compressionMethodsLength; ++i) {
1652       sock->clientHelloInfo_->
1653         clientHelloCompressionMethods_.push_back(cursor.readBE<uint8_t>());
1654     }
1655
1656     if (cursor.totalLength() > 0) {
1657       uint16_t extensionsLength = cursor.readBE<uint16_t>();
1658       while (extensionsLength) {
1659         ssl::TLSExtension extensionType =
1660             static_cast<ssl::TLSExtension>(cursor.readBE<uint16_t>());
1661         sock->clientHelloInfo_->
1662           clientHelloExtensions_.push_back(extensionType);
1663         extensionsLength -= 2;
1664         uint16_t extensionDataLength = cursor.readBE<uint16_t>();
1665         extensionsLength -= 2;
1666         extensionsLength -= extensionDataLength;
1667
1668         if (extensionType == ssl::TLSExtension::SIGNATURE_ALGORITHMS) {
1669           cursor.skip(2);
1670           extensionDataLength -= 2;
1671           while (extensionDataLength) {
1672             ssl::HashAlgorithm hashAlg =
1673                 static_cast<ssl::HashAlgorithm>(cursor.readBE<uint8_t>());
1674             ssl::SignatureAlgorithm sigAlg =
1675                 static_cast<ssl::SignatureAlgorithm>(cursor.readBE<uint8_t>());
1676             extensionDataLength -= 2;
1677             sock->clientHelloInfo_->
1678               clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg);
1679           }
1680         } else {
1681           cursor.skip(extensionDataLength);
1682         }
1683       }
1684     }
1685   } catch (std::out_of_range& e) {
1686     // we'll use what we found and cleanup below.
1687     VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
1688       << "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock;
1689   }
1690
1691   sock->resetClientHelloParsing(ssl);
1692 }
1693
1694 } // namespace