Stop abusing errno
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSSLSocket.h>
18
19 #include <folly/io/async/EventBase.h>
20
21 #include <boost/noncopyable.hpp>
22 #include <errno.h>
23 #include <fcntl.h>
24 #include <netinet/in.h>
25 #include <netinet/tcp.h>
26 #include <openssl/err.h>
27 #include <openssl/asn1.h>
28 #include <openssl/ssl.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31 #include <unistd.h>
32 #include <chrono>
33
34 #include <folly/Bits.h>
35 #include <folly/SocketAddress.h>
36 #include <folly/SpinLock.h>
37 #include <folly/io/IOBuf.h>
38 #include <folly/io/Cursor.h>
39
40 using folly::SocketAddress;
41 using folly::SSLContext;
42 using std::string;
43 using std::shared_ptr;
44
45 using folly::Endian;
46 using folly::IOBuf;
47 using folly::SpinLock;
48 using folly::SpinLockGuard;
49 using folly::io::Cursor;
50 using std::unique_ptr;
51 using std::bind;
52
53 namespace {
54 using folly::AsyncSocket;
55 using folly::AsyncSocketException;
56 using folly::AsyncSSLSocket;
57 using folly::Optional;
58 using folly::SSLContext;
59
60 // We have one single dummy SSL context so that we can implement attach
61 // and detach methods in a thread safe fashion without modifying opnessl.
62 static SSLContext *dummyCtx = nullptr;
63 static SpinLock dummyCtxLock;
64
65 // If given min write size is less than this, buffer will be allocated on
66 // stack, otherwise it is allocated on heap
67 const size_t MAX_STACK_BUF_SIZE = 2048;
68
69 // This converts "illegal" shutdowns into ZERO_RETURN
70 inline bool zero_return(int error, int rc) {
71   return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0));
72 }
73
74 class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
75                                 public AsyncSSLSocket::HandshakeCB {
76
77  private:
78   AsyncSSLSocket *sslSocket_;
79   AsyncSSLSocket::ConnectCallback *callback_;
80   int timeout_;
81   int64_t startTime_;
82
83  protected:
84   ~AsyncSSLSocketConnector() override {}
85
86  public:
87   AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
88                            AsyncSocket::ConnectCallback *callback,
89                            int timeout) :
90       sslSocket_(sslSocket),
91       callback_(callback),
92       timeout_(timeout),
93       startTime_(std::chrono::duration_cast<std::chrono::milliseconds>(
94                    std::chrono::steady_clock::now().time_since_epoch()).count()) {
95   }
96
97   void connectSuccess() noexcept override {
98     VLOG(7) << "client socket connected";
99
100     int64_t timeoutLeft = 0;
101     if (timeout_ > 0) {
102       auto curTime = std::chrono::duration_cast<std::chrono::milliseconds>(
103         std::chrono::steady_clock::now().time_since_epoch()).count();
104
105       timeoutLeft = timeout_ - (curTime - startTime_);
106       if (timeoutLeft <= 0) {
107         AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
108                                 "SSL connect timed out");
109         fail(ex);
110         delete this;
111         return;
112       }
113     }
114     sslSocket_->sslConn(this, timeoutLeft);
115   }
116
117   void connectErr(const AsyncSocketException& ex) noexcept override {
118     LOG(ERROR) << "TCP connect failed: " <<  ex.what();
119     fail(ex);
120     delete this;
121   }
122
123   void handshakeSuc(AsyncSSLSocket* /* sock */) noexcept override {
124     VLOG(7) << "client handshake success";
125     if (callback_) {
126       callback_->connectSuccess();
127     }
128     delete this;
129   }
130
131   void handshakeErr(AsyncSSLSocket* /* socket */,
132                     const AsyncSocketException& ex) noexcept override {
133     LOG(ERROR) << "client handshakeErr: " << ex.what();
134     fail(ex);
135     delete this;
136   }
137
138   void fail(const AsyncSocketException &ex) {
139     // fail is a noop if called twice
140     if (callback_) {
141       AsyncSSLSocket::ConnectCallback *cb = callback_;
142       callback_ = nullptr;
143
144       cb->connectErr(ex);
145       sslSocket_->closeNow();
146       // closeNow can call handshakeErr if it hasn't been called already.
147       // So this may have been deleted, no member variable access beyond this
148       // point
149       // Note that closeNow may invoke writeError callbacks if the socket had
150       // write data pending connection completion.
151     }
152   }
153 };
154
155 // XXX: implement an equivalent to corking for platforms with TCP_NOPUSH?
156 #ifdef TCP_CORK // Linux-only
157 /**
158  * Utility class that corks a TCP socket upon construction or uncorks
159  * the socket upon destruction
160  */
161 class CorkGuard : private boost::noncopyable {
162  public:
163   CorkGuard(int fd, bool multipleWrites, bool haveMore, bool* corked):
164     fd_(fd), haveMore_(haveMore), corked_(corked) {
165     if (*corked_) {
166       // socket is already corked; nothing to do
167       return;
168     }
169     if (multipleWrites || haveMore) {
170       // We are performing multiple writes in this performWrite() call,
171       // and/or there are more calls to performWrite() that will be invoked
172       // later, so enable corking
173       int flag = 1;
174       setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
175       *corked_ = true;
176     }
177   }
178
179   ~CorkGuard() {
180     if (haveMore_) {
181       // more data to come; don't uncork yet
182       return;
183     }
184     if (!*corked_) {
185       // socket isn't corked; nothing to do
186       return;
187     }
188
189     int flag = 0;
190     setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
191     *corked_ = false;
192   }
193
194  private:
195   int fd_;
196   bool haveMore_;
197   bool* corked_;
198 };
199 #else
200 class CorkGuard : private boost::noncopyable {
201  public:
202   CorkGuard(int, bool, bool, bool*) {}
203 };
204 #endif
205
206 void setup_SSL_CTX(SSL_CTX *ctx) {
207 #ifdef SSL_MODE_RELEASE_BUFFERS
208   SSL_CTX_set_mode(ctx,
209                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
210                    SSL_MODE_ENABLE_PARTIAL_WRITE
211                    | SSL_MODE_RELEASE_BUFFERS
212                    );
213 #else
214   SSL_CTX_set_mode(ctx,
215                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
216                    SSL_MODE_ENABLE_PARTIAL_WRITE
217                    );
218 #endif
219 // SSL_CTX_set_mode is a Macro
220 #ifdef SSL_MODE_WRITE_IOVEC
221   SSL_CTX_set_mode(ctx,
222                    SSL_CTX_get_mode(ctx)
223                    | SSL_MODE_WRITE_IOVEC);
224 #endif
225
226 }
227
228 BIO_METHOD eorAwareBioMethod;
229
230 void* initEorBioMethod(void) {
231   memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod));
232   // override the bwrite method for MSG_EOR support
233   eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite;
234
235   // Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not
236   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
237   // then have specific handlings. The eorAwareBioWrite should be compatible
238   // with the one in openssl.
239
240   // Return something here to enable AsyncSSLSocket to call this method using
241   // a function-scoped static.
242   return nullptr;
243 }
244
245 } // anonymous namespace
246
247 namespace folly {
248
249 /**
250  * Create a client AsyncSSLSocket
251  */
252 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
253                                EventBase* evb, bool deferSecurityNegotiation) :
254     AsyncSocket(evb),
255     ctx_(ctx),
256     handshakeTimeout_(this, evb) {
257   init();
258   if (deferSecurityNegotiation) {
259     sslState_ = STATE_UNENCRYPTED;
260   }
261 }
262
263 /**
264  * Create a server/client AsyncSSLSocket
265  */
266 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
267                                EventBase* evb, int fd, bool server,
268                                bool deferSecurityNegotiation) :
269     AsyncSocket(evb, fd),
270     server_(server),
271     ctx_(ctx),
272     handshakeTimeout_(this, evb) {
273   init();
274   if (server) {
275     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
276                               AsyncSSLSocket::sslInfoCallback);
277   }
278   if (deferSecurityNegotiation) {
279     sslState_ = STATE_UNENCRYPTED;
280   }
281 }
282
283 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
284 /**
285  * Create a client AsyncSSLSocket and allow tlsext_hostname
286  * to be sent in Client Hello.
287  */
288 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
289                                  EventBase* evb,
290                                const std::string& serverName,
291                                bool deferSecurityNegotiation) :
292     AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
293   tlsextHostname_ = serverName;
294 }
295
296 /**
297  * Create a client AsyncSSLSocket from an already connected fd
298  * and allow tlsext_hostname to be sent in Client Hello.
299  */
300 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
301                                  EventBase* evb, int fd,
302                                const std::string& serverName,
303                                bool deferSecurityNegotiation) :
304     AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
305   tlsextHostname_ = serverName;
306 }
307 #endif
308
309 AsyncSSLSocket::~AsyncSSLSocket() {
310   VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
311           << ", evb=" << eventBase_ << ", fd=" << fd_
312           << ", state=" << int(state_) << ", sslState="
313           << sslState_ << ", events=" << eventFlags_ << ")";
314 }
315
316 void AsyncSSLSocket::init() {
317   // Do this here to ensure we initialize this once before any use of
318   // AsyncSSLSocket instances and not as part of library load.
319   static const auto eorAwareBioMethodInitializer = initEorBioMethod();
320   (void)eorAwareBioMethodInitializer;
321
322   setup_SSL_CTX(ctx_->getSSLCtx());
323 }
324
325 void AsyncSSLSocket::closeNow() {
326   // Close the SSL connection.
327   if (ssl_ != nullptr && fd_ != -1) {
328     int rc = SSL_shutdown(ssl_);
329     if (rc == 0) {
330       rc = SSL_shutdown(ssl_);
331     }
332     if (rc < 0) {
333       ERR_clear_error();
334     }
335   }
336
337   if (sslSession_ != nullptr) {
338     SSL_SESSION_free(sslSession_);
339     sslSession_ = nullptr;
340   }
341
342   sslState_ = STATE_CLOSED;
343
344   if (handshakeTimeout_.isScheduled()) {
345     handshakeTimeout_.cancelTimeout();
346   }
347
348   DestructorGuard dg(this);
349
350   invokeHandshakeErr(
351       AsyncSocketException(
352         AsyncSocketException::END_OF_FILE,
353         "SSL connection closed locally"));
354
355   if (ssl_ != nullptr) {
356     SSL_free(ssl_);
357     ssl_ = nullptr;
358   }
359
360   // Close the socket.
361   AsyncSocket::closeNow();
362 }
363
364 void AsyncSSLSocket::shutdownWrite() {
365   // SSL sockets do not support half-shutdown, so just perform a full shutdown.
366   //
367   // (Performing a full shutdown here is more desirable than doing nothing at
368   // all.  The purpose of shutdownWrite() is normally to notify the other end
369   // of the connection that no more data will be sent.  If we do nothing, the
370   // other end will never know that no more data is coming, and this may result
371   // in protocol deadlock.)
372   close();
373 }
374
375 void AsyncSSLSocket::shutdownWriteNow() {
376   closeNow();
377 }
378
379 bool AsyncSSLSocket::good() const {
380   return (AsyncSocket::good() &&
381           (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
382            sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
383 }
384
385 // The TAsyncTransport definition of 'good' states that the transport is
386 // ready to perform reads and writes, so sslState_ == UNINIT must report !good.
387 // connecting can be true when the sslState_ == UNINIT because the AsyncSocket
388 // is connected but we haven't initiated the call to SSL_connect.
389 bool AsyncSSLSocket::connecting() const {
390   return (!server_ &&
391           (AsyncSocket::connecting() ||
392            (AsyncSocket::good() && (sslState_ == STATE_UNINIT ||
393                                      sslState_ == STATE_CONNECTING))));
394 }
395
396 std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
397   const unsigned char* protoName = nullptr;
398   unsigned protoLength;
399   if (getSelectedNextProtocolNoThrow(&protoName, &protoLength)) {
400     return std::string(reinterpret_cast<const char*>(protoName), protoLength);
401   }
402   return "";
403 }
404
405 bool AsyncSSLSocket::isEorTrackingEnabled() const {
406   if (ssl_ == nullptr) {
407     return false;
408   }
409   const BIO *wb = SSL_get_wbio(ssl_);
410   return wb && wb->method == &eorAwareBioMethod;
411 }
412
413 void AsyncSSLSocket::setEorTracking(bool track) {
414   BIO *wb = SSL_get_wbio(ssl_);
415   if (!wb) {
416     throw AsyncSocketException(AsyncSocketException::INVALID_STATE,
417                               "setting EOR tracking without an initialized "
418                               "BIO");
419   }
420
421   if (track) {
422     if (wb->method != &eorAwareBioMethod) {
423       // only do this if we didn't
424       wb->method = &eorAwareBioMethod;
425       BIO_set_app_data(wb, this);
426       appEorByteNo_ = 0;
427       minEorRawByteNo_ = 0;
428     }
429   } else if (wb->method == &eorAwareBioMethod) {
430     wb->method = BIO_s_socket();
431     BIO_set_app_data(wb, nullptr);
432     appEorByteNo_ = 0;
433     minEorRawByteNo_ = 0;
434   } else {
435     CHECK(wb->method == BIO_s_socket());
436   }
437 }
438
439 size_t AsyncSSLSocket::getRawBytesWritten() const {
440   BIO *b;
441   if (!ssl_ || !(b = SSL_get_wbio(ssl_))) {
442     return 0;
443   }
444
445   return BIO_number_written(b);
446 }
447
448 size_t AsyncSSLSocket::getRawBytesReceived() const {
449   BIO *b;
450   if (!ssl_ || !(b = SSL_get_rbio(ssl_))) {
451     return 0;
452   }
453
454   return BIO_number_read(b);
455 }
456
457
458 void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
459   LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_
460              << ", state=" << int(state_) << ", sslState=" << sslState_ << ", "
461              << "events=" << eventFlags_ << ", server=" << short(server_) << "): "
462              << "sslAccept/Connect() called in invalid "
463              << "state, handshake callback " << handshakeCallback_ << ", new callback "
464              << callback;
465   assert(!handshakeTimeout_.isScheduled());
466   sslState_ = STATE_ERROR;
467
468   AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
469                          "sslAccept() called with socket in invalid state");
470
471   handshakeEndTime_ = std::chrono::steady_clock::now();
472   if (callback) {
473     callback->handshakeErr(this, ex);
474   }
475
476   // Check the socket state not the ssl state here.
477   if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) {
478     failHandshake(__func__, ex);
479   }
480 }
481
482 void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
483       const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
484   DestructorGuard dg(this);
485   assert(eventBase_->isInEventBaseThread());
486   verifyPeer_ = verifyPeer;
487
488   // Make sure we're in the uninitialized state
489   if (!server_ || (sslState_ != STATE_UNINIT &&
490                    sslState_ != STATE_UNENCRYPTED) ||
491       handshakeCallback_ != nullptr) {
492     return invalidState(callback);
493   }
494
495   // Cache local and remote socket addresses to keep them available
496   // after socket file descriptor is closed.
497   if (cacheAddrOnFailure_ && -1 != getFd()) {
498     cacheLocalPeerAddr();
499   }
500
501   handshakeStartTime_ = std::chrono::steady_clock::now();
502   // Make end time at least >= start time.
503   handshakeEndTime_ = handshakeStartTime_;
504
505   sslState_ = STATE_ACCEPTING;
506   handshakeCallback_ = callback;
507
508   if (timeout > 0) {
509     handshakeTimeout_.scheduleTimeout(timeout);
510   }
511
512   /* register for a read operation (waiting for CLIENT HELLO) */
513   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
514 }
515
516 #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
517 void AsyncSSLSocket::attachSSLContext(
518   const std::shared_ptr<SSLContext>& ctx) {
519
520   // Check to ensure we are in client mode. Changing a server's ssl
521   // context doesn't make sense since clients of that server would likely
522   // become confused when the server's context changes.
523   DCHECK(!server_);
524   DCHECK(!ctx_);
525   DCHECK(ctx);
526   DCHECK(ctx->getSSLCtx());
527   ctx_ = ctx;
528
529   // In order to call attachSSLContext, detachSSLContext must have been
530   // previously called which sets the socket's context to the dummy
531   // context. Thus we must acquire this lock.
532   SpinLockGuard guard(dummyCtxLock);
533   SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx());
534 }
535
536 void AsyncSSLSocket::detachSSLContext() {
537   DCHECK(ctx_);
538   ctx_.reset();
539   // We aren't using the initial_ctx for now, and it can introduce race
540   // conditions in the destructor of the SSL object.
541 #ifndef OPENSSL_NO_TLSEXT
542   if (ssl_->initial_ctx) {
543     SSL_CTX_free(ssl_->initial_ctx);
544     ssl_->initial_ctx = nullptr;
545   }
546 #endif
547   SpinLockGuard guard(dummyCtxLock);
548   if (nullptr == dummyCtx) {
549     // We need to lazily initialize the dummy context so we don't
550     // accidentally override any programmatic settings to openssl
551     dummyCtx = new SSLContext;
552   }
553   // We must remove this socket's references to its context right now
554   // since this socket could get passed to any thread. If the context has
555   // had its locking disabled, just doing a set in attachSSLContext()
556   // would not be thread safe.
557   SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
558 }
559 #endif
560
561 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
562 void AsyncSSLSocket::switchServerSSLContext(
563   const std::shared_ptr<SSLContext>& handshakeCtx) {
564   CHECK(server_);
565   if (sslState_ != STATE_ACCEPTING) {
566     // We log it here and allow the switch.
567     // It should not affect our re-negotiation support (which
568     // is not supported now).
569     VLOG(6) << "fd=" << getFd()
570             << " renegotation detected when switching SSL_CTX";
571   }
572
573   setup_SSL_CTX(handshakeCtx->getSSLCtx());
574   SSL_CTX_set_info_callback(handshakeCtx->getSSLCtx(),
575                             AsyncSSLSocket::sslInfoCallback);
576   handshakeCtx_ = handshakeCtx;
577   SSL_set_SSL_CTX(ssl_, handshakeCtx->getSSLCtx());
578 }
579
580 bool AsyncSSLSocket::isServerNameMatch() const {
581   CHECK(!server_);
582
583   if (!ssl_) {
584     return false;
585   }
586
587   SSL_SESSION *ss = SSL_get_session(ssl_);
588   if (!ss) {
589     return false;
590   }
591
592   if(!ss->tlsext_hostname) {
593     return false;
594   }
595   return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
596 }
597
598 void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
599   tlsextHostname_ = std::move(serverName);
600 }
601
602 #endif
603
604 void AsyncSSLSocket::timeoutExpired() noexcept {
605   if (state_ == StateEnum::ESTABLISHED &&
606       (sslState_ == STATE_CACHE_LOOKUP ||
607        sslState_ == STATE_ASYNC_PENDING)) {
608     sslState_ = STATE_ERROR;
609     // We are expecting a callback in restartSSLAccept.  The cache lookup
610     // and rsa-call necessarily have pointers to this ssl socket, so delay
611     // the cleanup until he calls us back.
612   } else {
613     assert(state_ == StateEnum::ESTABLISHED &&
614            (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
615     DestructorGuard dg(this);
616     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
617                            (sslState_ == STATE_CONNECTING) ?
618                            "SSL connect timed out" : "SSL accept timed out");
619     failHandshake(__func__, ex);
620   }
621 }
622
623 int AsyncSSLSocket::getSSLExDataIndex() {
624   static auto index = SSL_get_ex_new_index(
625       0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
626   return index;
627 }
628
629 AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
630   return static_cast<AsyncSSLSocket *>(SSL_get_ex_data(ssl,
631       getSSLExDataIndex()));
632 }
633
634 void AsyncSSLSocket::failHandshake(const char* /* fn */,
635                                    const AsyncSocketException& ex) {
636   startFail();
637   if (handshakeTimeout_.isScheduled()) {
638     handshakeTimeout_.cancelTimeout();
639   }
640   invokeHandshakeErr(ex);
641   finishFail();
642 }
643
644 void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
645   handshakeEndTime_ = std::chrono::steady_clock::now();
646   if (handshakeCallback_ != nullptr) {
647     HandshakeCB* callback = handshakeCallback_;
648     handshakeCallback_ = nullptr;
649     callback->handshakeErr(this, ex);
650   }
651 }
652
653 void AsyncSSLSocket::invokeHandshakeCB() {
654   handshakeEndTime_ = std::chrono::steady_clock::now();
655   if (handshakeTimeout_.isScheduled()) {
656     handshakeTimeout_.cancelTimeout();
657   }
658   if (handshakeCallback_) {
659     HandshakeCB* callback = handshakeCallback_;
660     handshakeCallback_ = nullptr;
661     callback->handshakeSuc(this);
662   }
663 }
664
665 void AsyncSSLSocket::cacheLocalPeerAddr() {
666   SocketAddress address;
667   try {
668     getLocalAddress(&address);
669     getPeerAddress(&address);
670   } catch (const std::system_error& e) {
671     // The handle can be still valid while the connection is already closed.
672     if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
673       throw;
674     }
675   }
676 }
677
678 void AsyncSSLSocket::connect(ConnectCallback* callback,
679                               const folly::SocketAddress& address,
680                               int timeout,
681                               const OptionMap &options,
682                               const folly::SocketAddress& bindAddr)
683                               noexcept {
684   assert(!server_);
685   assert(state_ == StateEnum::UNINIT);
686   assert(sslState_ == STATE_UNINIT);
687   AsyncSSLSocketConnector *connector =
688     new AsyncSSLSocketConnector(this, callback, timeout);
689   AsyncSocket::connect(connector, address, timeout, options, bindAddr);
690 }
691
692 void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
693   // apply the settings specified in verifyPeer_
694   if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
695     if(ctx_->needsPeerVerification()) {
696       SSL_set_verify(ssl, ctx_->getVerificationMode(),
697         AsyncSSLSocket::sslVerifyCallback);
698     }
699   } else {
700     if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
701         verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT) {
702       SSL_set_verify(ssl, SSLContext::getVerificationMode(verifyPeer_),
703         AsyncSSLSocket::sslVerifyCallback);
704     }
705   }
706 }
707
708 void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
709         const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
710   DestructorGuard dg(this);
711   assert(eventBase_->isInEventBaseThread());
712
713   // Cache local and remote socket addresses to keep them available
714   // after socket file descriptor is closed.
715   if (cacheAddrOnFailure_ && -1 != getFd()) {
716     cacheLocalPeerAddr();
717   }
718
719   verifyPeer_ = verifyPeer;
720
721   // Make sure we're in the uninitialized state
722   if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
723                   STATE_UNENCRYPTED) ||
724       handshakeCallback_ != nullptr) {
725     return invalidState(callback);
726   }
727
728   handshakeStartTime_ = std::chrono::steady_clock::now();
729   // Make end time at least >= start time.
730   handshakeEndTime_ = handshakeStartTime_;
731
732   sslState_ = STATE_CONNECTING;
733   handshakeCallback_ = callback;
734
735   try {
736     ssl_ = ctx_->createSSL();
737   } catch (std::exception &e) {
738     sslState_ = STATE_ERROR;
739     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
740                            "error calling SSLContext::createSSL()");
741     LOG(ERROR) << "AsyncSSLSocket::sslConn(this=" << this << ", fd="
742             << fd_ << "): " << e.what();
743     return failHandshake(__func__, ex);
744   }
745
746   applyVerificationOptions(ssl_);
747
748   SSL_set_fd(ssl_, fd_);
749   if (sslSession_ != nullptr) {
750     SSL_set_session(ssl_, sslSession_);
751     SSL_SESSION_free(sslSession_);
752     sslSession_ = nullptr;
753   }
754 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
755   if (tlsextHostname_.size()) {
756     SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
757   }
758 #endif
759
760   SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
761
762   if (timeout > 0) {
763     handshakeTimeout_.scheduleTimeout(timeout);
764   }
765
766   handleConnect();
767 }
768
769 SSL_SESSION *AsyncSSLSocket::getSSLSession() {
770   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
771     return SSL_get1_session(ssl_);
772   }
773
774   return sslSession_;
775 }
776
777 const SSL* AsyncSSLSocket::getSSL() const {
778   return ssl_;
779 }
780
781 void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
782   sslSession_ = session;
783   if (!takeOwnership && session != nullptr) {
784     // Increment the reference count
785     CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
786   }
787 }
788
789 void AsyncSSLSocket::getSelectedNextProtocol(
790     const unsigned char** protoName,
791     unsigned* protoLen,
792     SSLContext::NextProtocolType* protoType) const {
793   if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) {
794     throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
795                               "NPN not supported");
796   }
797 }
798
799 bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
800     const unsigned char** protoName,
801     unsigned* protoLen,
802     SSLContext::NextProtocolType* protoType) const {
803   *protoName = nullptr;
804   *protoLen = 0;
805 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
806   SSL_get0_alpn_selected(ssl_, protoName, protoLen);
807   if (*protoLen > 0) {
808     if (protoType) {
809       *protoType = SSLContext::NextProtocolType::ALPN;
810     }
811     return true;
812   }
813 #endif
814 #ifdef OPENSSL_NPN_NEGOTIATED
815   SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
816   if (protoType) {
817     *protoType = SSLContext::NextProtocolType::NPN;
818   }
819   return true;
820 #else
821   (void)protoType;
822   return false;
823 #endif
824 }
825
826 bool AsyncSSLSocket::getSSLSessionReused() const {
827   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
828     return SSL_session_reused(ssl_);
829   }
830   return false;
831 }
832
833 const char *AsyncSSLSocket::getNegotiatedCipherName() const {
834   return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr;
835 }
836
837 /* static */
838 const char* AsyncSSLSocket::getSSLServerNameFromSSL(SSL* ssl) {
839   if (ssl == nullptr) {
840     return nullptr;
841   }
842 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
843   return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
844 #else
845   return nullptr;
846 #endif
847 }
848
849 const char *AsyncSSLSocket::getSSLServerName() const {
850 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
851   return getSSLServerNameFromSSL(ssl_);
852 #else
853   throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
854                              "SNI not supported");
855 #endif
856 }
857
858 const char *AsyncSSLSocket::getSSLServerNameNoThrow() const {
859   return getSSLServerNameFromSSL(ssl_);
860 }
861
862 int AsyncSSLSocket::getSSLVersion() const {
863   return (ssl_ != nullptr) ? SSL_version(ssl_) : 0;
864 }
865
866 const char *AsyncSSLSocket::getSSLCertSigAlgName() const {
867   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
868   if (cert) {
869     int nid = OBJ_obj2nid(cert->sig_alg->algorithm);
870     return OBJ_nid2ln(nid);
871   }
872   return nullptr;
873 }
874
875 int AsyncSSLSocket::getSSLCertSize() const {
876   int certSize = 0;
877   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
878   if (cert) {
879     EVP_PKEY *key = X509_get_pubkey(cert);
880     certSize = EVP_PKEY_bits(key);
881     EVP_PKEY_free(key);
882   }
883   return certSize;
884 }
885
886 bool AsyncSSLSocket::willBlock(int ret,
887                                int* sslErrorOut,
888                                unsigned long* errErrorOut) noexcept {
889   *errErrorOut = 0;
890   int error = *sslErrorOut = SSL_get_error(ssl_, ret);
891   if (error == SSL_ERROR_WANT_READ) {
892     // Register for read event if not already.
893     updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
894     return true;
895   } else if (error == SSL_ERROR_WANT_WRITE) {
896     VLOG(3) << "AsyncSSLSocket(fd=" << fd_
897             << ", state=" << int(state_) << ", sslState="
898             << sslState_ << ", events=" << eventFlags_ << "): "
899             << "SSL_ERROR_WANT_WRITE";
900     // Register for write event if not already.
901     updateEventRegistration(EventHandler::WRITE, EventHandler::READ);
902     return true;
903 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
904   } else if (error == SSL_ERROR_WANT_SESS_CACHE_LOOKUP) {
905     // We will block but we can't register our own socket.  The callback that
906     // triggered this code will re-call handleAccept at the appropriate time.
907
908     // We can only get here if the linked libssl.so has support for this feature
909     // as well, otherwise SSL_get_error cannot return our error code.
910     sslState_ = STATE_CACHE_LOOKUP;
911
912     // Unregister for all events while blocked here
913     updateEventRegistration(EventHandler::NONE,
914                             EventHandler::READ | EventHandler::WRITE);
915
916     // The timeout (if set) keeps running here
917     return true;
918 #endif
919   } else if (0
920 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
921       || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING
922 #endif
923 #ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
924       || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
925 #endif
926       ) {
927     // Our custom openssl function has kicked off an async request to do
928     // rsa/ecdsa private key operation.  When that call returns, a callback will
929     // be invoked that will re-call handleAccept.
930     sslState_ = STATE_ASYNC_PENDING;
931
932     // Unregister for all events while blocked here
933     updateEventRegistration(
934       EventHandler::NONE,
935       EventHandler::READ | EventHandler::WRITE
936     );
937
938     // The timeout (if set) keeps running here
939     return true;
940   } else {
941     unsigned long lastError = *errErrorOut = ERR_get_error();
942     VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
943             << "state=" << state_ << ", "
944             << "sslState=" << sslState_ << ", "
945             << "events=" << std::hex << eventFlags_ << "): "
946             << "SSL error: " << error << ", "
947             << "errno: " << errno << ", "
948             << "ret: " << ret << ", "
949             << "read: " << BIO_number_read(SSL_get_rbio(ssl_)) << ", "
950             << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
951             << "func: " << ERR_func_error_string(lastError) << ", "
952             << "reason: " << ERR_reason_error_string(lastError);
953     return false;
954   }
955 }
956
957 void AsyncSSLSocket::checkForImmediateRead() noexcept {
958   // openssl may have buffered data that it read from the socket already.
959   // In this case we have to process it immediately, rather than waiting for
960   // the socket to become readable again.
961   if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
962     AsyncSocket::handleRead();
963   }
964 }
965
966 void
967 AsyncSSLSocket::restartSSLAccept()
968 {
969   VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this << ", fd=" << fd_
970           << ", state=" << int(state_) << ", "
971           << "sslState=" << sslState_ << ", events=" << eventFlags_;
972   DestructorGuard dg(this);
973   assert(
974     sslState_ == STATE_CACHE_LOOKUP ||
975     sslState_ == STATE_ASYNC_PENDING ||
976     sslState_ == STATE_ERROR ||
977     sslState_ == STATE_CLOSED
978   );
979   if (sslState_ == STATE_CLOSED) {
980     // I sure hope whoever closed this socket didn't delete it already,
981     // but this is not strictly speaking an error
982     return;
983   }
984   if (sslState_ == STATE_ERROR) {
985     // go straight to fail if timeout expired during lookup
986     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
987                            "SSL accept timed out");
988     failHandshake(__func__, ex);
989     return;
990   }
991   sslState_ = STATE_ACCEPTING;
992   this->handleAccept();
993 }
994
995 void
996 AsyncSSLSocket::handleAccept() noexcept {
997   VLOG(3) << "AsyncSSLSocket::handleAccept() this=" << this
998           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
999           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1000   assert(server_);
1001   assert(state_ == StateEnum::ESTABLISHED &&
1002          sslState_ == STATE_ACCEPTING);
1003   if (!ssl_) {
1004     /* lazily create the SSL structure */
1005     try {
1006       ssl_ = ctx_->createSSL();
1007     } catch (std::exception &e) {
1008       sslState_ = STATE_ERROR;
1009       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1010                              "error calling SSLContext::createSSL()");
1011       LOG(ERROR) << "AsyncSSLSocket::handleAccept(this=" << this
1012                  << ", fd=" << fd_ << "): " << e.what();
1013       return failHandshake(__func__, ex);
1014     }
1015     SSL_set_fd(ssl_, fd_);
1016     SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
1017
1018     applyVerificationOptions(ssl_);
1019   }
1020
1021   if (server_ && parseClientHello_) {
1022     SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback);
1023     SSL_set_msg_callback_arg(ssl_, this);
1024   }
1025
1026   int ret = SSL_accept(ssl_);
1027   if (ret <= 0) {
1028     int sslError;
1029     unsigned long errError;
1030     int errnoCopy = errno;
1031     if (willBlock(ret, &sslError, &errError)) {
1032       return;
1033     } else {
1034       sslState_ = STATE_ERROR;
1035       SSLException ex(sslError, errError, ret, errnoCopy);
1036       return failHandshake(__func__, ex);
1037     }
1038   }
1039
1040   handshakeComplete_ = true;
1041   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1042
1043   // Move into STATE_ESTABLISHED in the normal case that we are in
1044   // STATE_ACCEPTING.
1045   sslState_ = STATE_ESTABLISHED;
1046
1047   VLOG(3) << "AsyncSSLSocket " << this << ": fd " << fd_
1048           << " successfully accepted; state=" << int(state_)
1049           << ", sslState=" << sslState_ << ", events=" << eventFlags_;
1050
1051   // Remember the EventBase we are attached to, before we start invoking any
1052   // callbacks (since the callbacks may call detachEventBase()).
1053   EventBase* originalEventBase = eventBase_;
1054
1055   // Call the accept callback.
1056   invokeHandshakeCB();
1057
1058   // Note that the accept callback may have changed our state.
1059   // (set or unset the read callback, called write(), closed the socket, etc.)
1060   // The following code needs to handle these situations correctly.
1061   //
1062   // If the socket has been closed, readCallback_ and writeReqHead_ will
1063   // always be nullptr, so that will prevent us from trying to read or write.
1064   //
1065   // The main thing to check for is if eventBase_ is still originalEventBase.
1066   // If not, we have been detached from this event base, so we shouldn't
1067   // perform any more operations.
1068   if (eventBase_ != originalEventBase) {
1069     return;
1070   }
1071
1072   AsyncSocket::handleInitialReadWrite();
1073 }
1074
1075 void
1076 AsyncSSLSocket::handleConnect() noexcept {
1077   VLOG(3) <<  "AsyncSSLSocket::handleConnect() this=" << this
1078           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
1079           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1080   assert(!server_);
1081   if (state_ < StateEnum::ESTABLISHED) {
1082     return AsyncSocket::handleConnect();
1083   }
1084
1085   assert(state_ == StateEnum::ESTABLISHED &&
1086          sslState_ == STATE_CONNECTING);
1087   assert(ssl_);
1088
1089   int ret = SSL_connect(ssl_);
1090   if (ret <= 0) {
1091     int sslError;
1092     unsigned long errError;
1093     int errnoCopy = errno;
1094     if (willBlock(ret, &sslError, &errError)) {
1095       return;
1096     } else {
1097       sslState_ = STATE_ERROR;
1098       SSLException ex(sslError, errError, ret, errnoCopy);
1099       return failHandshake(__func__, ex);
1100     }
1101   }
1102
1103   handshakeComplete_ = true;
1104   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1105
1106   // Move into STATE_ESTABLISHED in the normal case that we are in
1107   // STATE_CONNECTING.
1108   sslState_ = STATE_ESTABLISHED;
1109
1110   VLOG(3) << "AsyncSSLSocket " << this << ": "
1111           << "fd " << fd_ << " successfully connected; "
1112           << "state=" << int(state_) << ", sslState=" << sslState_
1113           << ", events=" << eventFlags_;
1114
1115   // Remember the EventBase we are attached to, before we start invoking any
1116   // callbacks (since the callbacks may call detachEventBase()).
1117   EventBase* originalEventBase = eventBase_;
1118
1119   // Call the handshake callback.
1120   invokeHandshakeCB();
1121
1122   // Note that the connect callback may have changed our state.
1123   // (set or unset the read callback, called write(), closed the socket, etc.)
1124   // The following code needs to handle these situations correctly.
1125   //
1126   // If the socket has been closed, readCallback_ and writeReqHead_ will
1127   // always be nullptr, so that will prevent us from trying to read or write.
1128   //
1129   // The main thing to check for is if eventBase_ is still originalEventBase.
1130   // If not, we have been detached from this event base, so we shouldn't
1131   // perform any more operations.
1132   if (eventBase_ != originalEventBase) {
1133     return;
1134   }
1135
1136   AsyncSocket::handleInitialReadWrite();
1137 }
1138
1139 void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
1140 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1141   // turn on the buffer movable in openssl
1142   if (bufferMovableEnabled_ && ssl_ != nullptr && !isBufferMovable_ &&
1143       callback != nullptr && callback->isBufferMovable()) {
1144     SSL_set_mode(ssl_, SSL_get_mode(ssl_) | SSL_MODE_MOVE_BUFFER_OWNERSHIP);
1145     isBufferMovable_ = true;
1146   }
1147 #endif
1148
1149   AsyncSocket::setReadCB(callback);
1150 }
1151
1152 void AsyncSSLSocket::setBufferMovableEnabled(bool enabled) {
1153   bufferMovableEnabled_ = enabled;
1154 }
1155
1156 void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
1157   CHECK(readCallback_);
1158   if (isBufferMovable_) {
1159     *buf = nullptr;
1160     *buflen = 0;
1161   } else {
1162     // buf is necessary for SSLSocket without SSL_MODE_MOVE_BUFFER_OWNERSHIP
1163     readCallback_->getReadBuffer(buf, buflen);
1164   }
1165 }
1166
1167 void
1168 AsyncSSLSocket::handleRead() noexcept {
1169   VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
1170           << ", state=" << int(state_) << ", "
1171           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1172   if (state_ < StateEnum::ESTABLISHED) {
1173     return AsyncSocket::handleRead();
1174   }
1175
1176
1177   if (sslState_ == STATE_ACCEPTING) {
1178     assert(server_);
1179     handleAccept();
1180     return;
1181   }
1182   else if (sslState_ == STATE_CONNECTING) {
1183     assert(!server_);
1184     handleConnect();
1185     return;
1186   }
1187
1188   // Normal read
1189   AsyncSocket::handleRead();
1190 }
1191
1192 AsyncSocket::ReadResult
1193 AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
1194   VLOG(4) << "AsyncSSLSocket::performRead() this=" << this << ", buf=" << *buf
1195           << ", buflen=" << *buflen;
1196
1197   if (sslState_ == STATE_UNENCRYPTED) {
1198     return AsyncSocket::performRead(buf, buflen, offset);
1199   }
1200
1201   ssize_t bytes = 0;
1202   if (!isBufferMovable_) {
1203     bytes = SSL_read(ssl_, *buf, *buflen);
1204   }
1205 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1206   else {
1207     bytes = SSL_read_buf(ssl_, buf, (int *) offset, (int *) buflen);
1208   }
1209 #endif
1210
1211   if (server_ && renegotiateAttempted_) {
1212     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1213                << ", sslstate=" << sslState_ << ", events=" << eventFlags_
1214                << "): client intitiated SSL renegotiation not permitted";
1215     return ReadResult(
1216         READ_ERROR,
1217         folly::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
1218   }
1219   if (bytes <= 0) {
1220     int error = SSL_get_error(ssl_, bytes);
1221     if (error == SSL_ERROR_WANT_READ) {
1222       // The caller will register for read event if not already.
1223       if (errno == EWOULDBLOCK || errno == EAGAIN) {
1224         return ReadResult(READ_BLOCKING);
1225       } else {
1226         return ReadResult(READ_ERROR);
1227       }
1228     } else if (error == SSL_ERROR_WANT_WRITE) {
1229       // TODO: Even though we are attempting to read data, SSL_read() may
1230       // need to write data if renegotiation is being performed.  We currently
1231       // don't support this and just fail the read.
1232       LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1233                  << ", sslState=" << sslState_ << ", events=" << eventFlags_
1234                  << "): unsupported SSL renegotiation during read";
1235       return ReadResult(
1236           READ_ERROR,
1237           folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
1238     } else {
1239       if (zero_return(error, bytes)) {
1240         return ReadResult(bytes);
1241       }
1242       long errError = ERR_get_error();
1243       VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
1244               << "state=" << state_ << ", "
1245               << "sslState=" << sslState_ << ", "
1246               << "events=" << std::hex << eventFlags_ << "): "
1247               << "bytes: " << bytes << ", "
1248               << "error: " << error << ", "
1249               << "errno: " << errno << ", "
1250               << "func: " << ERR_func_error_string(errError) << ", "
1251               << "reason: " << ERR_reason_error_string(errError);
1252       return ReadResult(
1253           READ_ERROR,
1254           folly::make_unique<SSLException>(error, errError, bytes, errno));
1255     }
1256   } else {
1257     appBytesReceived_ += bytes;
1258     return ReadResult(bytes);
1259   }
1260 }
1261
1262 void AsyncSSLSocket::handleWrite() noexcept {
1263   VLOG(5) << "AsyncSSLSocket::handleWrite() this=" << this << ", fd=" << fd_
1264           << ", state=" << int(state_) << ", "
1265           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1266   if (state_ < StateEnum::ESTABLISHED) {
1267     return AsyncSocket::handleWrite();
1268   }
1269
1270   if (sslState_ == STATE_ACCEPTING) {
1271     assert(server_);
1272     handleAccept();
1273     return;
1274   }
1275
1276   if (sslState_ == STATE_CONNECTING) {
1277     assert(!server_);
1278     handleConnect();
1279     return;
1280   }
1281
1282   // Normal write
1283   AsyncSocket::handleWrite();
1284 }
1285
1286 AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
1287   if (error == SSL_ERROR_WANT_READ) {
1288     // Even though we are attempting to write data, SSL_write() may
1289     // need to read data if renegotiation is being performed.  We currently
1290     // don't support this and just fail the write.
1291     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1292                << ", sslState=" << sslState_ << ", events=" << eventFlags_
1293                << "): "
1294                << "unsupported SSL renegotiation during write";
1295     return WriteResult(
1296         WRITE_ERROR,
1297         folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
1298   } else {
1299     if (zero_return(error, rc)) {
1300       return WriteResult(0);
1301     }
1302     auto errError = ERR_get_error();
1303     VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1304             << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
1305             << "SSL error: " << error << ", errno: " << errno
1306             << ", func: " << ERR_func_error_string(errError)
1307             << ", reason: " << ERR_reason_error_string(errError);
1308     return WriteResult(
1309         WRITE_ERROR,
1310         folly::make_unique<SSLException>(error, errError, rc, errno));
1311   }
1312 }
1313
1314 AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
1315     const iovec* vec,
1316     uint32_t count,
1317     WriteFlags flags,
1318     uint32_t* countWritten,
1319     uint32_t* partialWritten) {
1320   if (sslState_ == STATE_UNENCRYPTED) {
1321     return AsyncSocket::performWrite(
1322       vec, count, flags, countWritten, partialWritten);
1323   }
1324   if (sslState_ != STATE_ESTABLISHED) {
1325     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1326                << ", sslState=" << sslState_
1327                << ", events=" << eventFlags_ << "): "
1328                << "TODO: AsyncSSLSocket currently does not support calling "
1329                << "write() before the handshake has fully completed";
1330     return WriteResult(
1331         WRITE_ERROR, folly::make_unique<SSLException>(SSLError::EARLY_WRITE));
1332   }
1333
1334   bool cork = isSet(flags, WriteFlags::CORK);
1335   CorkGuard guard(fd_, count > 1, cork, &corked_);
1336
1337   // Declare a buffer used to hold small write requests.  It could point to a
1338   // memory block either on stack or on heap. If it is on heap, we release it
1339   // manually when scope exits
1340   char* combinedBuf{nullptr};
1341   SCOPE_EXIT {
1342     // Note, always keep this check consistent with what we do below
1343     if (combinedBuf != nullptr && minWriteSize_ > MAX_STACK_BUF_SIZE) {
1344       delete[] combinedBuf;
1345     }
1346   };
1347
1348   *countWritten = 0;
1349   *partialWritten = 0;
1350   ssize_t totalWritten = 0;
1351   size_t bytesStolenFromNextBuffer = 0;
1352   for (uint32_t i = 0; i < count; i++) {
1353     const iovec* v = vec + i;
1354     size_t offset = bytesStolenFromNextBuffer;
1355     bytesStolenFromNextBuffer = 0;
1356     size_t len = v->iov_len - offset;
1357     const void* buf;
1358     if (len == 0) {
1359       (*countWritten)++;
1360       continue;
1361     }
1362     buf = ((const char*)v->iov_base) + offset;
1363
1364     ssize_t bytes;
1365     uint32_t buffersStolen = 0;
1366     if ((len < minWriteSize_) && ((i + 1) < count)) {
1367       // Combine this buffer with part or all of the next buffers in
1368       // order to avoid really small-grained calls to SSL_write().
1369       // Each call to SSL_write() produces a separate record in
1370       // the egress SSL stream, and we've found that some low-end
1371       // mobile clients can't handle receiving an HTTP response
1372       // header and the first part of the response body in two
1373       // separate SSL records (even if those two records are in
1374       // the same TCP packet).
1375
1376       if (combinedBuf == nullptr) {
1377         if (minWriteSize_ > MAX_STACK_BUF_SIZE) {
1378           // Allocate the buffer on heap
1379           combinedBuf = new char[minWriteSize_];
1380         } else {
1381           // Allocate the buffer on stack
1382           combinedBuf = (char*)alloca(minWriteSize_);
1383         }
1384       }
1385       assert(combinedBuf != nullptr);
1386
1387       memcpy(combinedBuf, buf, len);
1388       do {
1389         // INVARIANT: i + buffersStolen == complete chunks serialized
1390         uint32_t nextIndex = i + buffersStolen + 1;
1391         bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len,
1392                                              minWriteSize_ - len);
1393         memcpy(combinedBuf + len, vec[nextIndex].iov_base,
1394                bytesStolenFromNextBuffer);
1395         len += bytesStolenFromNextBuffer;
1396         if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
1397           // couldn't steal the whole buffer
1398           break;
1399         } else {
1400           bytesStolenFromNextBuffer = 0;
1401           buffersStolen++;
1402         }
1403       } while ((i + buffersStolen + 1) < count && (len < minWriteSize_));
1404       bytes = eorAwareSSLWrite(
1405         ssl_, combinedBuf, len,
1406         (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
1407
1408     } else {
1409       bytes = eorAwareSSLWrite(ssl_, buf, len,
1410                            (isSet(flags, WriteFlags::EOR) && i + 1 == count));
1411     }
1412
1413     if (bytes <= 0) {
1414       int error = SSL_get_error(ssl_, bytes);
1415       if (error == SSL_ERROR_WANT_WRITE) {
1416         // The caller will register for write event if not already.
1417         *partialWritten = offset;
1418         return WriteResult(totalWritten);
1419       }
1420       auto writeResult = interpretSSLError(bytes, error);
1421       if (writeResult.writeReturn < 0) {
1422         return writeResult;
1423       } // else fall through to below to correctly record totalWritten
1424     }
1425
1426     totalWritten += bytes;
1427
1428     if (bytes == (ssize_t)len) {
1429       // The full iovec is written.
1430       (*countWritten) += 1 + buffersStolen;
1431       i += buffersStolen;
1432       // continue
1433     } else {
1434       bytes += offset; // adjust bytes to account for all of v
1435       while (bytes >= (ssize_t)v->iov_len) {
1436         // We combined this buf with part or all of the next one, and
1437         // we managed to write all of this buf but not all of the bytes
1438         // from the next one that we'd hoped to write.
1439         bytes -= v->iov_len;
1440         (*countWritten)++;
1441         v = &(vec[++i]);
1442       }
1443       *partialWritten = bytes;
1444       return WriteResult(totalWritten);
1445     }
1446   }
1447
1448   return WriteResult(totalWritten);
1449 }
1450
1451 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
1452                                       bool eor) {
1453   if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) {
1454     if (appEorByteNo_) {
1455       // cannot track for more than one app byte EOR
1456       CHECK(appEorByteNo_ == appBytesWritten_ + n);
1457     } else {
1458       appEorByteNo_ = appBytesWritten_ + n;
1459     }
1460
1461     // 1. It is fine to keep updating minEorRawByteNo_.
1462     // 2. It is _min_ in the sense that SSL record will add some overhead.
1463     minEorRawByteNo_ = getRawBytesWritten() + n;
1464   }
1465
1466   n = sslWriteImpl(ssl, buf, n);
1467   if (n > 0) {
1468     appBytesWritten_ += n;
1469     if (appEorByteNo_) {
1470       if (getRawBytesWritten() >= minEorRawByteNo_) {
1471         minEorRawByteNo_ = 0;
1472       }
1473       if(appBytesWritten_ == appEorByteNo_) {
1474         appEorByteNo_ = 0;
1475       } else {
1476         CHECK(appBytesWritten_ < appEorByteNo_);
1477       }
1478     }
1479   }
1480   return n;
1481 }
1482
1483 void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
1484   AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl);
1485   if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
1486     sslSocket->renegotiateAttempted_ = true;
1487   }
1488   if (where & SSL_CB_READ_ALERT) {
1489     const char* type = SSL_alert_type_string(ret);
1490     if (type) {
1491       const char* desc = SSL_alert_desc_string(ret);
1492       sslSocket->alertsReceived_.emplace_back(
1493           *type, StringPiece(desc, std::strlen(desc)));
1494     }
1495   }
1496 }
1497
1498 int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) {
1499   int ret;
1500   struct msghdr msg;
1501   struct iovec iov;
1502   int flags = 0;
1503   AsyncSSLSocket *tsslSock;
1504
1505   iov.iov_base = const_cast<char *>(in);
1506   iov.iov_len = inl;
1507   memset(&msg, 0, sizeof(msg));
1508   msg.msg_iov = &iov;
1509   msg.msg_iovlen = 1;
1510
1511   tsslSock =
1512     reinterpret_cast<AsyncSSLSocket*>(BIO_get_app_data(b));
1513   if (tsslSock &&
1514       tsslSock->minEorRawByteNo_ &&
1515       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
1516     flags = MSG_EOR;
1517   }
1518
1519   ret = sendmsg(b->num, &msg, flags);
1520   BIO_clear_retry_flags(b);
1521   if (ret <= 0) {
1522     if (BIO_sock_should_retry(ret))
1523       BIO_set_retry_write(b);
1524   }
1525   return(ret);
1526 }
1527
1528 int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
1529                                        X509_STORE_CTX* x509Ctx) {
1530   SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
1531     x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
1532   AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
1533
1534   VLOG(3) <<  "AsyncSSLSocket::sslVerifyCallback() this=" << self << ", "
1535           << "fd=" << self->fd_ << ", preverifyOk=" << preverifyOk;
1536   return (self->handshakeCallback_) ?
1537     self->handshakeCallback_->handshakeVer(self, preverifyOk, x509Ctx) :
1538     preverifyOk;
1539 }
1540
1541 void AsyncSSLSocket::enableClientHelloParsing()  {
1542     parseClientHello_ = true;
1543     clientHelloInfo_.reset(new ssl::ClientHelloInfo());
1544 }
1545
1546 void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl)  {
1547   SSL_set_msg_callback(ssl, nullptr);
1548   SSL_set_msg_callback_arg(ssl, nullptr);
1549   clientHelloInfo_->clientHelloBuf_.clear();
1550 }
1551
1552 void AsyncSSLSocket::clientHelloParsingCallback(int written,
1553                                                 int /* version */,
1554                                                 int contentType,
1555                                                 const void* buf,
1556                                                 size_t len,
1557                                                 SSL* ssl,
1558                                                 void* arg) {
1559   AsyncSSLSocket *sock = static_cast<AsyncSSLSocket*>(arg);
1560   if (written != 0) {
1561     sock->resetClientHelloParsing(ssl);
1562     return;
1563   }
1564   if (contentType != SSL3_RT_HANDSHAKE) {
1565     return;
1566   }
1567   if (len == 0) {
1568     return;
1569   }
1570
1571   auto& clientHelloBuf = sock->clientHelloInfo_->clientHelloBuf_;
1572   clientHelloBuf.append(IOBuf::wrapBuffer(buf, len));
1573   try {
1574     Cursor cursor(clientHelloBuf.front());
1575     if (cursor.read<uint8_t>() != SSL3_MT_CLIENT_HELLO) {
1576       sock->resetClientHelloParsing(ssl);
1577       return;
1578     }
1579
1580     if (cursor.totalLength() < 3) {
1581       clientHelloBuf.trimEnd(len);
1582       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1583       return;
1584     }
1585
1586     uint32_t messageLength = cursor.read<uint8_t>();
1587     messageLength <<= 8;
1588     messageLength |= cursor.read<uint8_t>();
1589     messageLength <<= 8;
1590     messageLength |= cursor.read<uint8_t>();
1591     if (cursor.totalLength() < messageLength) {
1592       clientHelloBuf.trimEnd(len);
1593       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1594       return;
1595     }
1596
1597     sock->clientHelloInfo_->clientHelloMajorVersion_ = cursor.read<uint8_t>();
1598     sock->clientHelloInfo_->clientHelloMinorVersion_ = cursor.read<uint8_t>();
1599
1600     cursor.skip(4); // gmt_unix_time
1601     cursor.skip(28); // random_bytes
1602
1603     cursor.skip(cursor.read<uint8_t>()); // session_id
1604
1605     uint16_t cipherSuitesLength = cursor.readBE<uint16_t>();
1606     for (int i = 0; i < cipherSuitesLength; i += 2) {
1607       sock->clientHelloInfo_->
1608         clientHelloCipherSuites_.push_back(cursor.readBE<uint16_t>());
1609     }
1610
1611     uint8_t compressionMethodsLength = cursor.read<uint8_t>();
1612     for (int i = 0; i < compressionMethodsLength; ++i) {
1613       sock->clientHelloInfo_->
1614         clientHelloCompressionMethods_.push_back(cursor.readBE<uint8_t>());
1615     }
1616
1617     if (cursor.totalLength() > 0) {
1618       uint16_t extensionsLength = cursor.readBE<uint16_t>();
1619       while (extensionsLength) {
1620         ssl::TLSExtension extensionType =
1621             static_cast<ssl::TLSExtension>(cursor.readBE<uint16_t>());
1622         sock->clientHelloInfo_->
1623           clientHelloExtensions_.push_back(extensionType);
1624         extensionsLength -= 2;
1625         uint16_t extensionDataLength = cursor.readBE<uint16_t>();
1626         extensionsLength -= 2;
1627
1628         if (extensionType == ssl::TLSExtension::SIGNATURE_ALGORITHMS) {
1629           cursor.skip(2);
1630           extensionDataLength -= 2;
1631           while (extensionDataLength) {
1632             ssl::HashAlgorithm hashAlg =
1633                 static_cast<ssl::HashAlgorithm>(cursor.readBE<uint8_t>());
1634             ssl::SignatureAlgorithm sigAlg =
1635                 static_cast<ssl::SignatureAlgorithm>(cursor.readBE<uint8_t>());
1636             extensionDataLength -= 2;
1637             sock->clientHelloInfo_->
1638               clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg);
1639           }
1640         } else {
1641           cursor.skip(extensionDataLength);
1642           extensionsLength -= extensionDataLength;
1643         }
1644       }
1645     }
1646   } catch (std::out_of_range& e) {
1647     // we'll use what we found and cleanup below.
1648     VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
1649       << "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock;
1650   }
1651
1652   sock->resetClientHelloParsing(ssl);
1653 }
1654
1655 } // namespace