Add a const getter for X509 used in handshake (server-side)
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSSLSocket.h>
18
19 #include <folly/io/async/EventBase.h>
20 #include <folly/portability/Sockets.h>
21
22 #include <boost/noncopyable.hpp>
23 #include <errno.h>
24 #include <fcntl.h>
25 #include <openssl/err.h>
26 #include <openssl/asn1.h>
27 #include <openssl/ssl.h>
28 #include <sys/types.h>
29 #include <chrono>
30
31 #include <folly/Bits.h>
32 #include <folly/SocketAddress.h>
33 #include <folly/SpinLock.h>
34 #include <folly/io/IOBuf.h>
35 #include <folly/io/Cursor.h>
36 #include <folly/portability/Unistd.h>
37
38 using folly::SocketAddress;
39 using folly::SSLContext;
40 using std::string;
41 using std::shared_ptr;
42
43 using folly::Endian;
44 using folly::IOBuf;
45 using folly::SpinLock;
46 using folly::SpinLockGuard;
47 using folly::io::Cursor;
48 using std::unique_ptr;
49 using std::bind;
50
51 namespace {
52 using folly::AsyncSocket;
53 using folly::AsyncSocketException;
54 using folly::AsyncSSLSocket;
55 using folly::Optional;
56 using folly::SSLContext;
57 using folly::ssl::OpenSSLUtils;
58
59 // We have one single dummy SSL context so that we can implement attach
60 // and detach methods in a thread safe fashion without modifying opnessl.
61 static SSLContext *dummyCtx = nullptr;
62 static SpinLock dummyCtxLock;
63
64 // If given min write size is less than this, buffer will be allocated on
65 // stack, otherwise it is allocated on heap
66 const size_t MAX_STACK_BUF_SIZE = 2048;
67
68 // This converts "illegal" shutdowns into ZERO_RETURN
69 inline bool zero_return(int error, int rc) {
70   return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0));
71 }
72
73 class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
74                                 public AsyncSSLSocket::HandshakeCB {
75
76  private:
77   AsyncSSLSocket *sslSocket_;
78   AsyncSSLSocket::ConnectCallback *callback_;
79   int timeout_;
80   int64_t startTime_;
81
82  protected:
83   ~AsyncSSLSocketConnector() override {}
84
85  public:
86   AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
87                            AsyncSocket::ConnectCallback *callback,
88                            int timeout) :
89       sslSocket_(sslSocket),
90       callback_(callback),
91       timeout_(timeout),
92       startTime_(std::chrono::duration_cast<std::chrono::milliseconds>(
93                    std::chrono::steady_clock::now().time_since_epoch()).count()) {
94   }
95
96   void connectSuccess() noexcept override {
97     VLOG(7) << "client socket connected";
98
99     int64_t timeoutLeft = 0;
100     if (timeout_ > 0) {
101       auto curTime = std::chrono::duration_cast<std::chrono::milliseconds>(
102         std::chrono::steady_clock::now().time_since_epoch()).count();
103
104       timeoutLeft = timeout_ - (curTime - startTime_);
105       if (timeoutLeft <= 0) {
106         AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
107                                 "SSL connect timed out");
108         fail(ex);
109         delete this;
110         return;
111       }
112     }
113     sslSocket_->sslConn(this, timeoutLeft);
114   }
115
116   void connectErr(const AsyncSocketException& ex) noexcept override {
117     LOG(ERROR) << "TCP connect failed: " <<  ex.what();
118     fail(ex);
119     delete this;
120   }
121
122   void handshakeSuc(AsyncSSLSocket* /* sock */) noexcept override {
123     VLOG(7) << "client handshake success";
124     if (callback_) {
125       callback_->connectSuccess();
126     }
127     delete this;
128   }
129
130   void handshakeErr(AsyncSSLSocket* /* socket */,
131                     const AsyncSocketException& ex) noexcept override {
132     LOG(ERROR) << "client handshakeErr: " << ex.what();
133     fail(ex);
134     delete this;
135   }
136
137   void fail(const AsyncSocketException &ex) {
138     // fail is a noop if called twice
139     if (callback_) {
140       AsyncSSLSocket::ConnectCallback *cb = callback_;
141       callback_ = nullptr;
142
143       cb->connectErr(ex);
144       sslSocket_->closeNow();
145       // closeNow can call handshakeErr if it hasn't been called already.
146       // So this may have been deleted, no member variable access beyond this
147       // point
148       // Note that closeNow may invoke writeError callbacks if the socket had
149       // write data pending connection completion.
150     }
151   }
152 };
153
154 // XXX: implement an equivalent to corking for platforms with TCP_NOPUSH?
155 #ifdef TCP_CORK // Linux-only
156 /**
157  * Utility class that corks a TCP socket upon construction or uncorks
158  * the socket upon destruction
159  */
160 class CorkGuard : private boost::noncopyable {
161  public:
162   CorkGuard(int fd, bool multipleWrites, bool haveMore, bool* corked):
163     fd_(fd), haveMore_(haveMore), corked_(corked) {
164     if (*corked_) {
165       // socket is already corked; nothing to do
166       return;
167     }
168     if (multipleWrites || haveMore) {
169       // We are performing multiple writes in this performWrite() call,
170       // and/or there are more calls to performWrite() that will be invoked
171       // later, so enable corking
172       int flag = 1;
173       setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
174       *corked_ = true;
175     }
176   }
177
178   ~CorkGuard() {
179     if (haveMore_) {
180       // more data to come; don't uncork yet
181       return;
182     }
183     if (!*corked_) {
184       // socket isn't corked; nothing to do
185       return;
186     }
187
188     int flag = 0;
189     setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
190     *corked_ = false;
191   }
192
193  private:
194   int fd_;
195   bool haveMore_;
196   bool* corked_;
197 };
198 #else
199 class CorkGuard : private boost::noncopyable {
200  public:
201   CorkGuard(int, bool, bool, bool*) {}
202 };
203 #endif
204
205 void setup_SSL_CTX(SSL_CTX *ctx) {
206 #ifdef SSL_MODE_RELEASE_BUFFERS
207   SSL_CTX_set_mode(ctx,
208                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
209                    SSL_MODE_ENABLE_PARTIAL_WRITE
210                    | SSL_MODE_RELEASE_BUFFERS
211                    );
212 #else
213   SSL_CTX_set_mode(ctx,
214                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
215                    SSL_MODE_ENABLE_PARTIAL_WRITE
216                    );
217 #endif
218 // SSL_CTX_set_mode is a Macro
219 #ifdef SSL_MODE_WRITE_IOVEC
220   SSL_CTX_set_mode(ctx,
221                    SSL_CTX_get_mode(ctx)
222                    | SSL_MODE_WRITE_IOVEC);
223 #endif
224
225 }
226
227 BIO_METHOD sslWriteBioMethod;
228
229 void* initsslWriteBioMethod(void) {
230   memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod));
231   // override the bwrite method for MSG_EOR support
232   OpenSSLUtils::setCustomBioWriteMethod(
233       &sslWriteBioMethod, AsyncSSLSocket::bioWrite);
234
235   // Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not
236   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
237   // then have specific handlings. The sslWriteBioWrite should be compatible
238   // with the one in openssl.
239
240   // Return something here to enable AsyncSSLSocket to call this method using
241   // a function-scoped static.
242   return nullptr;
243 }
244
245 } // anonymous namespace
246
247 namespace folly {
248
249 /**
250  * Create a client AsyncSSLSocket
251  */
252 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
253                                EventBase* evb, bool deferSecurityNegotiation) :
254     AsyncSocket(evb),
255     ctx_(ctx),
256     handshakeTimeout_(this, evb) {
257   init();
258   if (deferSecurityNegotiation) {
259     sslState_ = STATE_UNENCRYPTED;
260   }
261 }
262
263 /**
264  * Create a server/client AsyncSSLSocket
265  */
266 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
267                                EventBase* evb, int fd, bool server,
268                                bool deferSecurityNegotiation) :
269     AsyncSocket(evb, fd),
270     server_(server),
271     ctx_(ctx),
272     handshakeTimeout_(this, evb) {
273   init();
274   if (server) {
275     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
276                               AsyncSSLSocket::sslInfoCallback);
277   }
278   if (deferSecurityNegotiation) {
279     sslState_ = STATE_UNENCRYPTED;
280   }
281 }
282
283 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
284 /**
285  * Create a client AsyncSSLSocket and allow tlsext_hostname
286  * to be sent in Client Hello.
287  */
288 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
289                                  EventBase* evb,
290                                const std::string& serverName,
291                                bool deferSecurityNegotiation) :
292     AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
293   tlsextHostname_ = serverName;
294 }
295
296 /**
297  * Create a client AsyncSSLSocket from an already connected fd
298  * and allow tlsext_hostname to be sent in Client Hello.
299  */
300 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
301                                  EventBase* evb, int fd,
302                                const std::string& serverName,
303                                bool deferSecurityNegotiation) :
304     AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
305   tlsextHostname_ = serverName;
306 }
307 #endif
308
309 AsyncSSLSocket::~AsyncSSLSocket() {
310   VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
311           << ", evb=" << eventBase_ << ", fd=" << fd_
312           << ", state=" << int(state_) << ", sslState="
313           << sslState_ << ", events=" << eventFlags_ << ")";
314 }
315
316 void AsyncSSLSocket::init() {
317   // Do this here to ensure we initialize this once before any use of
318   // AsyncSSLSocket instances and not as part of library load.
319   static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod();
320   (void)sslWriteBioMethodInitializer;
321
322   setup_SSL_CTX(ctx_->getSSLCtx());
323 }
324
325 void AsyncSSLSocket::closeNow() {
326   // Close the SSL connection.
327   if (ssl_ != nullptr && fd_ != -1) {
328     int rc = SSL_shutdown(ssl_);
329     if (rc == 0) {
330       rc = SSL_shutdown(ssl_);
331     }
332     if (rc < 0) {
333       ERR_clear_error();
334     }
335   }
336
337   if (sslSession_ != nullptr) {
338     SSL_SESSION_free(sslSession_);
339     sslSession_ = nullptr;
340   }
341
342   sslState_ = STATE_CLOSED;
343
344   if (handshakeTimeout_.isScheduled()) {
345     handshakeTimeout_.cancelTimeout();
346   }
347
348   DestructorGuard dg(this);
349
350   invokeHandshakeErr(
351       AsyncSocketException(
352         AsyncSocketException::END_OF_FILE,
353         "SSL connection closed locally"));
354
355   if (ssl_ != nullptr) {
356     SSL_free(ssl_);
357     ssl_ = nullptr;
358   }
359
360   // Close the socket.
361   AsyncSocket::closeNow();
362 }
363
364 void AsyncSSLSocket::shutdownWrite() {
365   // SSL sockets do not support half-shutdown, so just perform a full shutdown.
366   //
367   // (Performing a full shutdown here is more desirable than doing nothing at
368   // all.  The purpose of shutdownWrite() is normally to notify the other end
369   // of the connection that no more data will be sent.  If we do nothing, the
370   // other end will never know that no more data is coming, and this may result
371   // in protocol deadlock.)
372   close();
373 }
374
375 void AsyncSSLSocket::shutdownWriteNow() {
376   closeNow();
377 }
378
379 bool AsyncSSLSocket::good() const {
380   return (AsyncSocket::good() &&
381           (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
382            sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
383 }
384
385 // The TAsyncTransport definition of 'good' states that the transport is
386 // ready to perform reads and writes, so sslState_ == UNINIT must report !good.
387 // connecting can be true when the sslState_ == UNINIT because the AsyncSocket
388 // is connected but we haven't initiated the call to SSL_connect.
389 bool AsyncSSLSocket::connecting() const {
390   return (!server_ &&
391           (AsyncSocket::connecting() ||
392            (AsyncSocket::good() && (sslState_ == STATE_UNINIT ||
393                                      sslState_ == STATE_CONNECTING))));
394 }
395
396 std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
397   const unsigned char* protoName = nullptr;
398   unsigned protoLength;
399   if (getSelectedNextProtocolNoThrow(&protoName, &protoLength)) {
400     return std::string(reinterpret_cast<const char*>(protoName), protoLength);
401   }
402   return "";
403 }
404
405 bool AsyncSSLSocket::isEorTrackingEnabled() const {
406   return trackEor_;
407 }
408
409 void AsyncSSLSocket::setEorTracking(bool track) {
410   if (trackEor_ != track) {
411     trackEor_ = track;
412     appEorByteNo_ = 0;
413     minEorRawByteNo_ = 0;
414   }
415 }
416
417 size_t AsyncSSLSocket::getRawBytesWritten() const {
418   BIO *b;
419   if (!ssl_ || !(b = SSL_get_wbio(ssl_))) {
420     return 0;
421   }
422
423   return BIO_number_written(b);
424 }
425
426 size_t AsyncSSLSocket::getRawBytesReceived() const {
427   BIO *b;
428   if (!ssl_ || !(b = SSL_get_rbio(ssl_))) {
429     return 0;
430   }
431
432   return BIO_number_read(b);
433 }
434
435
436 void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
437   LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_
438              << ", state=" << int(state_) << ", sslState=" << sslState_ << ", "
439              << "events=" << eventFlags_ << ", server=" << short(server_)
440              << "): " << "sslAccept/Connect() called in invalid "
441              << "state, handshake callback " << handshakeCallback_
442              << ", new callback " << callback;
443   assert(!handshakeTimeout_.isScheduled());
444   sslState_ = STATE_ERROR;
445
446   AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
447                          "sslAccept() called with socket in invalid state");
448
449   handshakeEndTime_ = std::chrono::steady_clock::now();
450   if (callback) {
451     callback->handshakeErr(this, ex);
452   }
453
454   // Check the socket state not the ssl state here.
455   if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) {
456     failHandshake(__func__, ex);
457   }
458 }
459
460 void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
461       const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
462   DestructorGuard dg(this);
463   assert(eventBase_->isInEventBaseThread());
464   verifyPeer_ = verifyPeer;
465
466   // Make sure we're in the uninitialized state
467   if (!server_ || (sslState_ != STATE_UNINIT &&
468                    sslState_ != STATE_UNENCRYPTED) ||
469       handshakeCallback_ != nullptr) {
470     return invalidState(callback);
471   }
472
473   // Cache local and remote socket addresses to keep them available
474   // after socket file descriptor is closed.
475   if (cacheAddrOnFailure_ && -1 != getFd()) {
476     cacheLocalPeerAddr();
477   }
478
479   handshakeStartTime_ = std::chrono::steady_clock::now();
480   // Make end time at least >= start time.
481   handshakeEndTime_ = handshakeStartTime_;
482
483   sslState_ = STATE_ACCEPTING;
484   handshakeCallback_ = callback;
485
486   if (timeout > 0) {
487     handshakeTimeout_.scheduleTimeout(timeout);
488   }
489
490   /* register for a read operation (waiting for CLIENT HELLO) */
491   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
492 }
493
494 #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
495 void AsyncSSLSocket::attachSSLContext(
496   const std::shared_ptr<SSLContext>& ctx) {
497
498   // Check to ensure we are in client mode. Changing a server's ssl
499   // context doesn't make sense since clients of that server would likely
500   // become confused when the server's context changes.
501   DCHECK(!server_);
502   DCHECK(!ctx_);
503   DCHECK(ctx);
504   DCHECK(ctx->getSSLCtx());
505   ctx_ = ctx;
506
507   // In order to call attachSSLContext, detachSSLContext must have been
508   // previously called which sets the socket's context to the dummy
509   // context. Thus we must acquire this lock.
510   SpinLockGuard guard(dummyCtxLock);
511   SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx());
512 }
513
514 void AsyncSSLSocket::detachSSLContext() {
515   DCHECK(ctx_);
516   ctx_.reset();
517   // We aren't using the initial_ctx for now, and it can introduce race
518   // conditions in the destructor of the SSL object.
519 #ifndef OPENSSL_NO_TLSEXT
520   if (ssl_->initial_ctx) {
521     SSL_CTX_free(ssl_->initial_ctx);
522     ssl_->initial_ctx = nullptr;
523   }
524 #endif
525   SpinLockGuard guard(dummyCtxLock);
526   if (nullptr == dummyCtx) {
527     // We need to lazily initialize the dummy context so we don't
528     // accidentally override any programmatic settings to openssl
529     dummyCtx = new SSLContext;
530   }
531   // We must remove this socket's references to its context right now
532   // since this socket could get passed to any thread. If the context has
533   // had its locking disabled, just doing a set in attachSSLContext()
534   // would not be thread safe.
535   SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
536 }
537 #endif
538
539 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
540 void AsyncSSLSocket::switchServerSSLContext(
541   const std::shared_ptr<SSLContext>& handshakeCtx) {
542   CHECK(server_);
543   if (sslState_ != STATE_ACCEPTING) {
544     // We log it here and allow the switch.
545     // It should not affect our re-negotiation support (which
546     // is not supported now).
547     VLOG(6) << "fd=" << getFd()
548             << " renegotation detected when switching SSL_CTX";
549   }
550
551   setup_SSL_CTX(handshakeCtx->getSSLCtx());
552   SSL_CTX_set_info_callback(handshakeCtx->getSSLCtx(),
553                             AsyncSSLSocket::sslInfoCallback);
554   handshakeCtx_ = handshakeCtx;
555   SSL_set_SSL_CTX(ssl_, handshakeCtx->getSSLCtx());
556 }
557
558 bool AsyncSSLSocket::isServerNameMatch() const {
559   CHECK(!server_);
560
561   if (!ssl_) {
562     return false;
563   }
564
565   SSL_SESSION *ss = SSL_get_session(ssl_);
566   if (!ss) {
567     return false;
568   }
569
570   if(!ss->tlsext_hostname) {
571     return false;
572   }
573   return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
574 }
575
576 void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
577   tlsextHostname_ = std::move(serverName);
578 }
579
580 #endif
581
582 void AsyncSSLSocket::timeoutExpired() noexcept {
583   if (state_ == StateEnum::ESTABLISHED &&
584       (sslState_ == STATE_CACHE_LOOKUP ||
585        sslState_ == STATE_ASYNC_PENDING)) {
586     sslState_ = STATE_ERROR;
587     // We are expecting a callback in restartSSLAccept.  The cache lookup
588     // and rsa-call necessarily have pointers to this ssl socket, so delay
589     // the cleanup until he calls us back.
590   } else {
591     assert(state_ == StateEnum::ESTABLISHED &&
592            (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
593     DestructorGuard dg(this);
594     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
595                            (sslState_ == STATE_CONNECTING) ?
596                            "SSL connect timed out" : "SSL accept timed out");
597     failHandshake(__func__, ex);
598   }
599 }
600
601 int AsyncSSLSocket::getSSLExDataIndex() {
602   static auto index = SSL_get_ex_new_index(
603       0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
604   return index;
605 }
606
607 AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
608   return static_cast<AsyncSSLSocket *>(SSL_get_ex_data(ssl,
609       getSSLExDataIndex()));
610 }
611
612 void AsyncSSLSocket::failHandshake(const char* /* fn */,
613                                    const AsyncSocketException& ex) {
614   startFail();
615   if (handshakeTimeout_.isScheduled()) {
616     handshakeTimeout_.cancelTimeout();
617   }
618   invokeHandshakeErr(ex);
619   finishFail();
620 }
621
622 void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
623   handshakeEndTime_ = std::chrono::steady_clock::now();
624   if (handshakeCallback_ != nullptr) {
625     HandshakeCB* callback = handshakeCallback_;
626     handshakeCallback_ = nullptr;
627     callback->handshakeErr(this, ex);
628   }
629 }
630
631 void AsyncSSLSocket::invokeHandshakeCB() {
632   handshakeEndTime_ = std::chrono::steady_clock::now();
633   if (handshakeTimeout_.isScheduled()) {
634     handshakeTimeout_.cancelTimeout();
635   }
636   if (handshakeCallback_) {
637     HandshakeCB* callback = handshakeCallback_;
638     handshakeCallback_ = nullptr;
639     callback->handshakeSuc(this);
640   }
641 }
642
643 void AsyncSSLSocket::cacheLocalPeerAddr() {
644   SocketAddress address;
645   try {
646     getLocalAddress(&address);
647     getPeerAddress(&address);
648   } catch (const std::system_error& e) {
649     // The handle can be still valid while the connection is already closed.
650     if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
651       throw;
652     }
653   }
654 }
655
656 void AsyncSSLSocket::connect(ConnectCallback* callback,
657                               const folly::SocketAddress& address,
658                               int timeout,
659                               const OptionMap &options,
660                               const folly::SocketAddress& bindAddr)
661                               noexcept {
662   assert(!server_);
663   assert(state_ == StateEnum::UNINIT);
664   assert(sslState_ == STATE_UNINIT);
665   AsyncSSLSocketConnector *connector =
666     new AsyncSSLSocketConnector(this, callback, timeout);
667   AsyncSocket::connect(connector, address, timeout, options, bindAddr);
668 }
669
670 void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
671   // apply the settings specified in verifyPeer_
672   if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
673     if(ctx_->needsPeerVerification()) {
674       SSL_set_verify(ssl, ctx_->getVerificationMode(),
675         AsyncSSLSocket::sslVerifyCallback);
676     }
677   } else {
678     if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
679         verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT) {
680       SSL_set_verify(ssl, SSLContext::getVerificationMode(verifyPeer_),
681         AsyncSSLSocket::sslVerifyCallback);
682     }
683   }
684 }
685
686 bool AsyncSSLSocket::setupSSLBio() {
687   auto wb = BIO_new(&sslWriteBioMethod);
688
689   if (!wb) {
690     return false;
691   }
692
693   OpenSSLUtils::setBioAppData(wb, this);
694   BIO_set_fd(wb, fd_, BIO_NOCLOSE);
695   SSL_set_bio(ssl_, wb, wb);
696   return true;
697 }
698
699 void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
700         const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
701   DestructorGuard dg(this);
702   assert(eventBase_->isInEventBaseThread());
703
704   // Cache local and remote socket addresses to keep them available
705   // after socket file descriptor is closed.
706   if (cacheAddrOnFailure_ && -1 != getFd()) {
707     cacheLocalPeerAddr();
708   }
709
710   verifyPeer_ = verifyPeer;
711
712   // Make sure we're in the uninitialized state
713   if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
714                   STATE_UNENCRYPTED) ||
715       handshakeCallback_ != nullptr) {
716     return invalidState(callback);
717   }
718
719   handshakeStartTime_ = std::chrono::steady_clock::now();
720   // Make end time at least >= start time.
721   handshakeEndTime_ = handshakeStartTime_;
722
723   sslState_ = STATE_CONNECTING;
724   handshakeCallback_ = callback;
725
726   try {
727     ssl_ = ctx_->createSSL();
728   } catch (std::exception &e) {
729     sslState_ = STATE_ERROR;
730     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
731                            "error calling SSLContext::createSSL()");
732     LOG(ERROR) << "AsyncSSLSocket::sslConn(this=" << this << ", fd="
733             << fd_ << "): " << e.what();
734     return failHandshake(__func__, ex);
735   }
736
737   if (!setupSSLBio()) {
738     sslState_ = STATE_ERROR;
739     AsyncSocketException ex(
740         AsyncSocketException::INTERNAL_ERROR, "error creating SSL bio");
741     return failHandshake(__func__, ex);
742   }
743
744   applyVerificationOptions(ssl_);
745
746   if (sslSession_ != nullptr) {
747     SSL_set_session(ssl_, sslSession_);
748     SSL_SESSION_free(sslSession_);
749     sslSession_ = nullptr;
750   }
751 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
752   if (tlsextHostname_.size()) {
753     SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
754   }
755 #endif
756
757   SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
758
759   if (timeout > 0) {
760     handshakeTimeout_.scheduleTimeout(timeout);
761   }
762
763   handleConnect();
764 }
765
766 SSL_SESSION *AsyncSSLSocket::getSSLSession() {
767   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
768     return SSL_get1_session(ssl_);
769   }
770
771   return sslSession_;
772 }
773
774 const SSL* AsyncSSLSocket::getSSL() const {
775   return ssl_;
776 }
777
778 void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
779   sslSession_ = session;
780   if (!takeOwnership && session != nullptr) {
781     // Increment the reference count
782     CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
783   }
784 }
785
786 void AsyncSSLSocket::getSelectedNextProtocol(
787     const unsigned char** protoName,
788     unsigned* protoLen,
789     SSLContext::NextProtocolType* protoType) const {
790   if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) {
791     throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
792                               "NPN not supported");
793   }
794 }
795
796 bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
797     const unsigned char** protoName,
798     unsigned* protoLen,
799     SSLContext::NextProtocolType* protoType) const {
800   *protoName = nullptr;
801   *protoLen = 0;
802 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
803   SSL_get0_alpn_selected(ssl_, protoName, protoLen);
804   if (*protoLen > 0) {
805     if (protoType) {
806       *protoType = SSLContext::NextProtocolType::ALPN;
807     }
808     return true;
809   }
810 #endif
811 #ifdef OPENSSL_NPN_NEGOTIATED
812   SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
813   if (protoType) {
814     *protoType = SSLContext::NextProtocolType::NPN;
815   }
816   return true;
817 #else
818   (void)protoType;
819   return false;
820 #endif
821 }
822
823 bool AsyncSSLSocket::getSSLSessionReused() const {
824   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
825     return SSL_session_reused(ssl_);
826   }
827   return false;
828 }
829
830 const char *AsyncSSLSocket::getNegotiatedCipherName() const {
831   return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr;
832 }
833
834 /* static */
835 const char* AsyncSSLSocket::getSSLServerNameFromSSL(SSL* ssl) {
836   if (ssl == nullptr) {
837     return nullptr;
838   }
839 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
840   return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
841 #else
842   return nullptr;
843 #endif
844 }
845
846 const char *AsyncSSLSocket::getSSLServerName() const {
847 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
848   return getSSLServerNameFromSSL(ssl_);
849 #else
850   throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
851                              "SNI not supported");
852 #endif
853 }
854
855 const char *AsyncSSLSocket::getSSLServerNameNoThrow() const {
856   return getSSLServerNameFromSSL(ssl_);
857 }
858
859 int AsyncSSLSocket::getSSLVersion() const {
860   return (ssl_ != nullptr) ? SSL_version(ssl_) : 0;
861 }
862
863 const char *AsyncSSLSocket::getSSLCertSigAlgName() const {
864   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
865   if (cert) {
866     int nid = OBJ_obj2nid(cert->sig_alg->algorithm);
867     return OBJ_nid2ln(nid);
868   }
869   return nullptr;
870 }
871
872 int AsyncSSLSocket::getSSLCertSize() const {
873   int certSize = 0;
874   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
875   if (cert) {
876     EVP_PKEY *key = X509_get_pubkey(cert);
877     certSize = EVP_PKEY_bits(key);
878     EVP_PKEY_free(key);
879   }
880   return certSize;
881 }
882
883 const X509* AsyncSSLSocket::getSelfCert() const {
884   return (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
885 }
886
887 bool AsyncSSLSocket::willBlock(int ret,
888                                int* sslErrorOut,
889                                unsigned long* errErrorOut) noexcept {
890   *errErrorOut = 0;
891   int error = *sslErrorOut = SSL_get_error(ssl_, ret);
892   if (error == SSL_ERROR_WANT_READ) {
893     // Register for read event if not already.
894     updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
895     return true;
896   } else if (error == SSL_ERROR_WANT_WRITE) {
897     VLOG(3) << "AsyncSSLSocket(fd=" << fd_
898             << ", state=" << int(state_) << ", sslState="
899             << sslState_ << ", events=" << eventFlags_ << "): "
900             << "SSL_ERROR_WANT_WRITE";
901     // Register for write event if not already.
902     updateEventRegistration(EventHandler::WRITE, EventHandler::READ);
903     return true;
904 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
905   } else if (error == SSL_ERROR_WANT_SESS_CACHE_LOOKUP) {
906     // We will block but we can't register our own socket.  The callback that
907     // triggered this code will re-call handleAccept at the appropriate time.
908
909     // We can only get here if the linked libssl.so has support for this feature
910     // as well, otherwise SSL_get_error cannot return our error code.
911     sslState_ = STATE_CACHE_LOOKUP;
912
913     // Unregister for all events while blocked here
914     updateEventRegistration(EventHandler::NONE,
915                             EventHandler::READ | EventHandler::WRITE);
916
917     // The timeout (if set) keeps running here
918     return true;
919 #endif
920   } else if (0
921 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
922       || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING
923 #endif
924 #ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
925       || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
926 #endif
927       ) {
928     // Our custom openssl function has kicked off an async request to do
929     // rsa/ecdsa private key operation.  When that call returns, a callback will
930     // be invoked that will re-call handleAccept.
931     sslState_ = STATE_ASYNC_PENDING;
932
933     // Unregister for all events while blocked here
934     updateEventRegistration(
935       EventHandler::NONE,
936       EventHandler::READ | EventHandler::WRITE
937     );
938
939     // The timeout (if set) keeps running here
940     return true;
941   } else {
942     unsigned long lastError = *errErrorOut = ERR_get_error();
943     VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
944             << "state=" << state_ << ", "
945             << "sslState=" << sslState_ << ", "
946             << "events=" << std::hex << eventFlags_ << "): "
947             << "SSL error: " << error << ", "
948             << "errno: " << errno << ", "
949             << "ret: " << ret << ", "
950             << "read: " << BIO_number_read(SSL_get_rbio(ssl_)) << ", "
951             << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
952             << "func: " << ERR_func_error_string(lastError) << ", "
953             << "reason: " << ERR_reason_error_string(lastError);
954     return false;
955   }
956 }
957
958 void AsyncSSLSocket::checkForImmediateRead() noexcept {
959   // openssl may have buffered data that it read from the socket already.
960   // In this case we have to process it immediately, rather than waiting for
961   // the socket to become readable again.
962   if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
963     AsyncSocket::handleRead();
964   }
965 }
966
967 void
968 AsyncSSLSocket::restartSSLAccept()
969 {
970   VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this
971           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
972           << "sslState=" << sslState_ << ", events=" << eventFlags_;
973   DestructorGuard dg(this);
974   assert(
975     sslState_ == STATE_CACHE_LOOKUP ||
976     sslState_ == STATE_ASYNC_PENDING ||
977     sslState_ == STATE_ERROR ||
978     sslState_ == STATE_CLOSED);
979   if (sslState_ == STATE_CLOSED) {
980     // I sure hope whoever closed this socket didn't delete it already,
981     // but this is not strictly speaking an error
982     return;
983   }
984   if (sslState_ == STATE_ERROR) {
985     // go straight to fail if timeout expired during lookup
986     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
987                            "SSL accept timed out");
988     failHandshake(__func__, ex);
989     return;
990   }
991   sslState_ = STATE_ACCEPTING;
992   this->handleAccept();
993 }
994
995 void
996 AsyncSSLSocket::handleAccept() noexcept {
997   VLOG(3) << "AsyncSSLSocket::handleAccept() this=" << this
998           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
999           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1000   assert(server_);
1001   assert(state_ == StateEnum::ESTABLISHED &&
1002          sslState_ == STATE_ACCEPTING);
1003   if (!ssl_) {
1004     /* lazily create the SSL structure */
1005     try {
1006       ssl_ = ctx_->createSSL();
1007     } catch (std::exception &e) {
1008       sslState_ = STATE_ERROR;
1009       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1010                              "error calling SSLContext::createSSL()");
1011       LOG(ERROR) << "AsyncSSLSocket::handleAccept(this=" << this
1012                  << ", fd=" << fd_ << "): " << e.what();
1013       return failHandshake(__func__, ex);
1014     }
1015
1016     if (!setupSSLBio()) {
1017       sslState_ = STATE_ERROR;
1018       AsyncSocketException ex(
1019           AsyncSocketException::INTERNAL_ERROR, "error creating write bio");
1020       return failHandshake(__func__, ex);
1021     }
1022
1023     SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
1024
1025     applyVerificationOptions(ssl_);
1026   }
1027
1028   if (server_ && parseClientHello_) {
1029     SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback);
1030     SSL_set_msg_callback_arg(ssl_, this);
1031   }
1032
1033   int ret = SSL_accept(ssl_);
1034   if (ret <= 0) {
1035     int sslError;
1036     unsigned long errError;
1037     int errnoCopy = errno;
1038     if (willBlock(ret, &sslError, &errError)) {
1039       return;
1040     } else {
1041       sslState_ = STATE_ERROR;
1042       SSLException ex(sslError, errError, ret, errnoCopy);
1043       return failHandshake(__func__, ex);
1044     }
1045   }
1046
1047   handshakeComplete_ = true;
1048   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1049
1050   // Move into STATE_ESTABLISHED in the normal case that we are in
1051   // STATE_ACCEPTING.
1052   sslState_ = STATE_ESTABLISHED;
1053
1054   VLOG(3) << "AsyncSSLSocket " << this << ": fd " << fd_
1055           << " successfully accepted; state=" << int(state_)
1056           << ", sslState=" << sslState_ << ", events=" << eventFlags_;
1057
1058   // Remember the EventBase we are attached to, before we start invoking any
1059   // callbacks (since the callbacks may call detachEventBase()).
1060   EventBase* originalEventBase = eventBase_;
1061
1062   // Call the accept callback.
1063   invokeHandshakeCB();
1064
1065   // Note that the accept callback may have changed our state.
1066   // (set or unset the read callback, called write(), closed the socket, etc.)
1067   // The following code needs to handle these situations correctly.
1068   //
1069   // If the socket has been closed, readCallback_ and writeReqHead_ will
1070   // always be nullptr, so that will prevent us from trying to read or write.
1071   //
1072   // The main thing to check for is if eventBase_ is still originalEventBase.
1073   // If not, we have been detached from this event base, so we shouldn't
1074   // perform any more operations.
1075   if (eventBase_ != originalEventBase) {
1076     return;
1077   }
1078
1079   AsyncSocket::handleInitialReadWrite();
1080 }
1081
1082 void
1083 AsyncSSLSocket::handleConnect() noexcept {
1084   VLOG(3) <<  "AsyncSSLSocket::handleConnect() this=" << this
1085           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
1086           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1087   assert(!server_);
1088   if (state_ < StateEnum::ESTABLISHED) {
1089     return AsyncSocket::handleConnect();
1090   }
1091
1092   assert(
1093       (state_ == StateEnum::FAST_OPEN || state_ == StateEnum::ESTABLISHED) &&
1094       sslState_ == STATE_CONNECTING);
1095   assert(ssl_);
1096
1097   int ret = SSL_connect(ssl_);
1098   if (ret <= 0) {
1099     int sslError;
1100     unsigned long errError;
1101     int errnoCopy = errno;
1102     if (willBlock(ret, &sslError, &errError)) {
1103       return;
1104     } else {
1105       sslState_ = STATE_ERROR;
1106       SSLException ex(sslError, errError, ret, errnoCopy);
1107       return failHandshake(__func__, ex);
1108     }
1109   }
1110
1111   handshakeComplete_ = true;
1112   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1113
1114   // Move into STATE_ESTABLISHED in the normal case that we are in
1115   // STATE_CONNECTING.
1116   sslState_ = STATE_ESTABLISHED;
1117
1118   VLOG(3) << "AsyncSSLSocket " << this << ": "
1119           << "fd " << fd_ << " successfully connected; "
1120           << "state=" << int(state_) << ", sslState=" << sslState_
1121           << ", events=" << eventFlags_;
1122
1123   // Remember the EventBase we are attached to, before we start invoking any
1124   // callbacks (since the callbacks may call detachEventBase()).
1125   EventBase* originalEventBase = eventBase_;
1126
1127   // Call the handshake callback.
1128   invokeHandshakeCB();
1129
1130   // Note that the connect callback may have changed our state.
1131   // (set or unset the read callback, called write(), closed the socket, etc.)
1132   // The following code needs to handle these situations correctly.
1133   //
1134   // If the socket has been closed, readCallback_ and writeReqHead_ will
1135   // always be nullptr, so that will prevent us from trying to read or write.
1136   //
1137   // The main thing to check for is if eventBase_ is still originalEventBase.
1138   // If not, we have been detached from this event base, so we shouldn't
1139   // perform any more operations.
1140   if (eventBase_ != originalEventBase) {
1141     return;
1142   }
1143
1144   AsyncSocket::handleInitialReadWrite();
1145 }
1146
1147 void AsyncSSLSocket::invokeConnectSuccess() {
1148   if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
1149     // If we failed TFO, we'd fall back to trying to connect the socket,
1150     // when we succeed we should handle the writes that caused us to start
1151     // TFO.
1152     handleWrite();
1153   }
1154   AsyncSocket::invokeConnectSuccess();
1155 }
1156
1157 void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
1158 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1159   // turn on the buffer movable in openssl
1160   if (bufferMovableEnabled_ && ssl_ != nullptr && !isBufferMovable_ &&
1161       callback != nullptr && callback->isBufferMovable()) {
1162     SSL_set_mode(ssl_, SSL_get_mode(ssl_) | SSL_MODE_MOVE_BUFFER_OWNERSHIP);
1163     isBufferMovable_ = true;
1164   }
1165 #endif
1166
1167   AsyncSocket::setReadCB(callback);
1168 }
1169
1170 void AsyncSSLSocket::setBufferMovableEnabled(bool enabled) {
1171   bufferMovableEnabled_ = enabled;
1172 }
1173
1174 void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
1175   CHECK(readCallback_);
1176   if (isBufferMovable_) {
1177     *buf = nullptr;
1178     *buflen = 0;
1179   } else {
1180     // buf is necessary for SSLSocket without SSL_MODE_MOVE_BUFFER_OWNERSHIP
1181     readCallback_->getReadBuffer(buf, buflen);
1182   }
1183 }
1184
1185 void
1186 AsyncSSLSocket::handleRead() noexcept {
1187   VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
1188           << ", state=" << int(state_) << ", "
1189           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1190   if (state_ < StateEnum::ESTABLISHED) {
1191     return AsyncSocket::handleRead();
1192   }
1193
1194
1195   if (sslState_ == STATE_ACCEPTING) {
1196     assert(server_);
1197     handleAccept();
1198     return;
1199   }
1200   else if (sslState_ == STATE_CONNECTING) {
1201     assert(!server_);
1202     handleConnect();
1203     return;
1204   }
1205
1206   // Normal read
1207   AsyncSocket::handleRead();
1208 }
1209
1210 AsyncSocket::ReadResult
1211 AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
1212   VLOG(4) << "AsyncSSLSocket::performRead() this=" << this << ", buf=" << *buf
1213           << ", buflen=" << *buflen;
1214
1215   if (sslState_ == STATE_UNENCRYPTED) {
1216     return AsyncSocket::performRead(buf, buflen, offset);
1217   }
1218
1219   ssize_t bytes = 0;
1220   if (!isBufferMovable_) {
1221     bytes = SSL_read(ssl_, *buf, *buflen);
1222   }
1223 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1224   else {
1225     bytes = SSL_read_buf(ssl_, buf, (int *) offset, (int *) buflen);
1226   }
1227 #endif
1228
1229   if (server_ && renegotiateAttempted_) {
1230     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1231                << ", sslstate=" << sslState_ << ", events=" << eventFlags_
1232                << "): client intitiated SSL renegotiation not permitted";
1233     return ReadResult(
1234         READ_ERROR,
1235         folly::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
1236   }
1237   if (bytes <= 0) {
1238     int error = SSL_get_error(ssl_, bytes);
1239     if (error == SSL_ERROR_WANT_READ) {
1240       // The caller will register for read event if not already.
1241       if (errno == EWOULDBLOCK || errno == EAGAIN) {
1242         return ReadResult(READ_BLOCKING);
1243       } else {
1244         return ReadResult(READ_ERROR);
1245       }
1246     } else if (error == SSL_ERROR_WANT_WRITE) {
1247       // TODO: Even though we are attempting to read data, SSL_read() may
1248       // need to write data if renegotiation is being performed.  We currently
1249       // don't support this and just fail the read.
1250       LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1251                  << ", sslState=" << sslState_ << ", events=" << eventFlags_
1252                  << "): unsupported SSL renegotiation during read";
1253       return ReadResult(
1254           READ_ERROR,
1255           folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
1256     } else {
1257       if (zero_return(error, bytes)) {
1258         return ReadResult(bytes);
1259       }
1260       long errError = ERR_get_error();
1261       VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
1262               << "state=" << state_ << ", "
1263               << "sslState=" << sslState_ << ", "
1264               << "events=" << std::hex << eventFlags_ << "): "
1265               << "bytes: " << bytes << ", "
1266               << "error: " << error << ", "
1267               << "errno: " << errno << ", "
1268               << "func: " << ERR_func_error_string(errError) << ", "
1269               << "reason: " << ERR_reason_error_string(errError);
1270       return ReadResult(
1271           READ_ERROR,
1272           folly::make_unique<SSLException>(error, errError, bytes, errno));
1273     }
1274   } else {
1275     appBytesReceived_ += bytes;
1276     return ReadResult(bytes);
1277   }
1278 }
1279
1280 void AsyncSSLSocket::handleWrite() noexcept {
1281   VLOG(5) << "AsyncSSLSocket::handleWrite() this=" << this << ", fd=" << fd_
1282           << ", state=" << int(state_) << ", "
1283           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1284   if (state_ < StateEnum::ESTABLISHED) {
1285     return AsyncSocket::handleWrite();
1286   }
1287
1288   if (sslState_ == STATE_ACCEPTING) {
1289     assert(server_);
1290     handleAccept();
1291     return;
1292   }
1293
1294   if (sslState_ == STATE_CONNECTING) {
1295     assert(!server_);
1296     handleConnect();
1297     return;
1298   }
1299
1300   // Normal write
1301   AsyncSocket::handleWrite();
1302 }
1303
1304 AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
1305   if (error == SSL_ERROR_WANT_READ) {
1306     // Even though we are attempting to write data, SSL_write() may
1307     // need to read data if renegotiation is being performed.  We currently
1308     // don't support this and just fail the write.
1309     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1310                << ", sslState=" << sslState_ << ", events=" << eventFlags_
1311                << "): "
1312                << "unsupported SSL renegotiation during write";
1313     return WriteResult(
1314         WRITE_ERROR,
1315         folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
1316   } else {
1317     if (zero_return(error, rc)) {
1318       return WriteResult(0);
1319     }
1320     auto errError = ERR_get_error();
1321     VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1322             << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
1323             << "SSL error: " << error << ", errno: " << errno
1324             << ", func: " << ERR_func_error_string(errError)
1325             << ", reason: " << ERR_reason_error_string(errError);
1326     return WriteResult(
1327         WRITE_ERROR,
1328         folly::make_unique<SSLException>(error, errError, rc, errno));
1329   }
1330 }
1331
1332 AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
1333     const iovec* vec,
1334     uint32_t count,
1335     WriteFlags flags,
1336     uint32_t* countWritten,
1337     uint32_t* partialWritten) {
1338   if (sslState_ == STATE_UNENCRYPTED) {
1339     return AsyncSocket::performWrite(
1340       vec, count, flags, countWritten, partialWritten);
1341   }
1342   if (sslState_ != STATE_ESTABLISHED) {
1343     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1344                << ", sslState=" << sslState_
1345                << ", events=" << eventFlags_ << "): "
1346                << "TODO: AsyncSSLSocket currently does not support calling "
1347                << "write() before the handshake has fully completed";
1348     return WriteResult(
1349         WRITE_ERROR, folly::make_unique<SSLException>(SSLError::EARLY_WRITE));
1350   }
1351
1352   bool cork = isSet(flags, WriteFlags::CORK);
1353   CorkGuard guard(fd_, count > 1, cork, &corked_);
1354
1355   // Declare a buffer used to hold small write requests.  It could point to a
1356   // memory block either on stack or on heap. If it is on heap, we release it
1357   // manually when scope exits
1358   char* combinedBuf{nullptr};
1359   SCOPE_EXIT {
1360     // Note, always keep this check consistent with what we do below
1361     if (combinedBuf != nullptr && minWriteSize_ > MAX_STACK_BUF_SIZE) {
1362       delete[] combinedBuf;
1363     }
1364   };
1365
1366   *countWritten = 0;
1367   *partialWritten = 0;
1368   ssize_t totalWritten = 0;
1369   size_t bytesStolenFromNextBuffer = 0;
1370   for (uint32_t i = 0; i < count; i++) {
1371     const iovec* v = vec + i;
1372     size_t offset = bytesStolenFromNextBuffer;
1373     bytesStolenFromNextBuffer = 0;
1374     size_t len = v->iov_len - offset;
1375     const void* buf;
1376     if (len == 0) {
1377       (*countWritten)++;
1378       continue;
1379     }
1380     buf = ((const char*)v->iov_base) + offset;
1381
1382     ssize_t bytes;
1383     uint32_t buffersStolen = 0;
1384     if ((len < minWriteSize_) && ((i + 1) < count)) {
1385       // Combine this buffer with part or all of the next buffers in
1386       // order to avoid really small-grained calls to SSL_write().
1387       // Each call to SSL_write() produces a separate record in
1388       // the egress SSL stream, and we've found that some low-end
1389       // mobile clients can't handle receiving an HTTP response
1390       // header and the first part of the response body in two
1391       // separate SSL records (even if those two records are in
1392       // the same TCP packet).
1393
1394       if (combinedBuf == nullptr) {
1395         if (minWriteSize_ > MAX_STACK_BUF_SIZE) {
1396           // Allocate the buffer on heap
1397           combinedBuf = new char[minWriteSize_];
1398         } else {
1399           // Allocate the buffer on stack
1400           combinedBuf = (char*)alloca(minWriteSize_);
1401         }
1402       }
1403       assert(combinedBuf != nullptr);
1404
1405       memcpy(combinedBuf, buf, len);
1406       do {
1407         // INVARIANT: i + buffersStolen == complete chunks serialized
1408         uint32_t nextIndex = i + buffersStolen + 1;
1409         bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len,
1410                                              minWriteSize_ - len);
1411         memcpy(combinedBuf + len, vec[nextIndex].iov_base,
1412                bytesStolenFromNextBuffer);
1413         len += bytesStolenFromNextBuffer;
1414         if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
1415           // couldn't steal the whole buffer
1416           break;
1417         } else {
1418           bytesStolenFromNextBuffer = 0;
1419           buffersStolen++;
1420         }
1421       } while ((i + buffersStolen + 1) < count && (len < minWriteSize_));
1422       bytes = eorAwareSSLWrite(
1423         ssl_, combinedBuf, len,
1424         (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
1425
1426     } else {
1427       bytes = eorAwareSSLWrite(ssl_, buf, len,
1428                            (isSet(flags, WriteFlags::EOR) && i + 1 == count));
1429     }
1430
1431     if (bytes <= 0) {
1432       int error = SSL_get_error(ssl_, bytes);
1433       if (error == SSL_ERROR_WANT_WRITE) {
1434         // The caller will register for write event if not already.
1435         *partialWritten = offset;
1436         return WriteResult(totalWritten);
1437       }
1438       auto writeResult = interpretSSLError(bytes, error);
1439       if (writeResult.writeReturn < 0) {
1440         return writeResult;
1441       } // else fall through to below to correctly record totalWritten
1442     }
1443
1444     totalWritten += bytes;
1445
1446     if (bytes == (ssize_t)len) {
1447       // The full iovec is written.
1448       (*countWritten) += 1 + buffersStolen;
1449       i += buffersStolen;
1450       // continue
1451     } else {
1452       bytes += offset; // adjust bytes to account for all of v
1453       while (bytes >= (ssize_t)v->iov_len) {
1454         // We combined this buf with part or all of the next one, and
1455         // we managed to write all of this buf but not all of the bytes
1456         // from the next one that we'd hoped to write.
1457         bytes -= v->iov_len;
1458         (*countWritten)++;
1459         v = &(vec[++i]);
1460       }
1461       *partialWritten = bytes;
1462       return WriteResult(totalWritten);
1463     }
1464   }
1465
1466   return WriteResult(totalWritten);
1467 }
1468
1469 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
1470                                       bool eor) {
1471   if (eor && trackEor_) {
1472     if (appEorByteNo_) {
1473       // cannot track for more than one app byte EOR
1474       CHECK(appEorByteNo_ == appBytesWritten_ + n);
1475     } else {
1476       appEorByteNo_ = appBytesWritten_ + n;
1477     }
1478
1479     // 1. It is fine to keep updating minEorRawByteNo_.
1480     // 2. It is _min_ in the sense that SSL record will add some overhead.
1481     minEorRawByteNo_ = getRawBytesWritten() + n;
1482   }
1483
1484   n = sslWriteImpl(ssl, buf, n);
1485   if (n > 0) {
1486     appBytesWritten_ += n;
1487     if (appEorByteNo_) {
1488       if (getRawBytesWritten() >= minEorRawByteNo_) {
1489         minEorRawByteNo_ = 0;
1490       }
1491       if(appBytesWritten_ == appEorByteNo_) {
1492         appEorByteNo_ = 0;
1493       } else {
1494         CHECK(appBytesWritten_ < appEorByteNo_);
1495       }
1496     }
1497   }
1498   return n;
1499 }
1500
1501 void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
1502   AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl);
1503   if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
1504     sslSocket->renegotiateAttempted_ = true;
1505   }
1506   if (where & SSL_CB_READ_ALERT) {
1507     const char* type = SSL_alert_type_string(ret);
1508     if (type) {
1509       const char* desc = SSL_alert_desc_string(ret);
1510       sslSocket->alertsReceived_.emplace_back(
1511           *type, StringPiece(desc, std::strlen(desc)));
1512     }
1513   }
1514 }
1515
1516 int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
1517   struct msghdr msg;
1518   struct iovec iov;
1519   int flags = 0;
1520   AsyncSSLSocket* tsslSock;
1521
1522   iov.iov_base = const_cast<char*>(in);
1523   iov.iov_len = inl;
1524   memset(&msg, 0, sizeof(msg));
1525   msg.msg_iov = &iov;
1526   msg.msg_iovlen = 1;
1527
1528   auto appData = OpenSSLUtils::getBioAppData(b);
1529   CHECK(appData);
1530
1531   tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
1532   CHECK(tsslSock);
1533
1534   if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ &&
1535       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
1536     flags = MSG_EOR;
1537   }
1538
1539   auto result =
1540       tsslSock->sendSocketMessage(BIO_get_fd(b, nullptr), &msg, flags);
1541   BIO_clear_retry_flags(b);
1542   if (!result.exception && result.writeReturn <= 0) {
1543     if (OpenSSLUtils::getBioShouldRetryWrite(result.writeReturn)) {
1544       BIO_set_retry_write(b);
1545     }
1546   }
1547   return result.writeReturn;
1548 }
1549
1550 int AsyncSSLSocket::sslVerifyCallback(
1551     int preverifyOk,
1552     X509_STORE_CTX* x509Ctx) {
1553   SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
1554     x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
1555   AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
1556
1557   VLOG(3) <<  "AsyncSSLSocket::sslVerifyCallback() this=" << self << ", "
1558           << "fd=" << self->fd_ << ", preverifyOk=" << preverifyOk;
1559   return (self->handshakeCallback_) ?
1560     self->handshakeCallback_->handshakeVer(self, preverifyOk, x509Ctx) :
1561     preverifyOk;
1562 }
1563
1564 void AsyncSSLSocket::enableClientHelloParsing()  {
1565     parseClientHello_ = true;
1566     clientHelloInfo_.reset(new ssl::ClientHelloInfo());
1567 }
1568
1569 void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl)  {
1570   SSL_set_msg_callback(ssl, nullptr);
1571   SSL_set_msg_callback_arg(ssl, nullptr);
1572   clientHelloInfo_->clientHelloBuf_.clear();
1573 }
1574
1575 void AsyncSSLSocket::clientHelloParsingCallback(int written,
1576                                                 int /* version */,
1577                                                 int contentType,
1578                                                 const void* buf,
1579                                                 size_t len,
1580                                                 SSL* ssl,
1581                                                 void* arg) {
1582   AsyncSSLSocket *sock = static_cast<AsyncSSLSocket*>(arg);
1583   if (written != 0) {
1584     sock->resetClientHelloParsing(ssl);
1585     return;
1586   }
1587   if (contentType != SSL3_RT_HANDSHAKE) {
1588     return;
1589   }
1590   if (len == 0) {
1591     return;
1592   }
1593
1594   auto& clientHelloBuf = sock->clientHelloInfo_->clientHelloBuf_;
1595   clientHelloBuf.append(IOBuf::wrapBuffer(buf, len));
1596   try {
1597     Cursor cursor(clientHelloBuf.front());
1598     if (cursor.read<uint8_t>() != SSL3_MT_CLIENT_HELLO) {
1599       sock->resetClientHelloParsing(ssl);
1600       return;
1601     }
1602
1603     if (cursor.totalLength() < 3) {
1604       clientHelloBuf.trimEnd(len);
1605       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1606       return;
1607     }
1608
1609     uint32_t messageLength = cursor.read<uint8_t>();
1610     messageLength <<= 8;
1611     messageLength |= cursor.read<uint8_t>();
1612     messageLength <<= 8;
1613     messageLength |= cursor.read<uint8_t>();
1614     if (cursor.totalLength() < messageLength) {
1615       clientHelloBuf.trimEnd(len);
1616       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1617       return;
1618     }
1619
1620     sock->clientHelloInfo_->clientHelloMajorVersion_ = cursor.read<uint8_t>();
1621     sock->clientHelloInfo_->clientHelloMinorVersion_ = cursor.read<uint8_t>();
1622
1623     cursor.skip(4); // gmt_unix_time
1624     cursor.skip(28); // random_bytes
1625
1626     cursor.skip(cursor.read<uint8_t>()); // session_id
1627
1628     uint16_t cipherSuitesLength = cursor.readBE<uint16_t>();
1629     for (int i = 0; i < cipherSuitesLength; i += 2) {
1630       sock->clientHelloInfo_->
1631         clientHelloCipherSuites_.push_back(cursor.readBE<uint16_t>());
1632     }
1633
1634     uint8_t compressionMethodsLength = cursor.read<uint8_t>();
1635     for (int i = 0; i < compressionMethodsLength; ++i) {
1636       sock->clientHelloInfo_->
1637         clientHelloCompressionMethods_.push_back(cursor.readBE<uint8_t>());
1638     }
1639
1640     if (cursor.totalLength() > 0) {
1641       uint16_t extensionsLength = cursor.readBE<uint16_t>();
1642       while (extensionsLength) {
1643         ssl::TLSExtension extensionType =
1644             static_cast<ssl::TLSExtension>(cursor.readBE<uint16_t>());
1645         sock->clientHelloInfo_->
1646           clientHelloExtensions_.push_back(extensionType);
1647         extensionsLength -= 2;
1648         uint16_t extensionDataLength = cursor.readBE<uint16_t>();
1649         extensionsLength -= 2;
1650
1651         if (extensionType == ssl::TLSExtension::SIGNATURE_ALGORITHMS) {
1652           cursor.skip(2);
1653           extensionDataLength -= 2;
1654           while (extensionDataLength) {
1655             ssl::HashAlgorithm hashAlg =
1656                 static_cast<ssl::HashAlgorithm>(cursor.readBE<uint8_t>());
1657             ssl::SignatureAlgorithm sigAlg =
1658                 static_cast<ssl::SignatureAlgorithm>(cursor.readBE<uint8_t>());
1659             extensionDataLength -= 2;
1660             sock->clientHelloInfo_->
1661               clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg);
1662           }
1663         } else {
1664           cursor.skip(extensionDataLength);
1665           extensionsLength -= extensionDataLength;
1666         }
1667       }
1668     }
1669   } catch (std::out_of_range& e) {
1670     // we'll use what we found and cleanup below.
1671     VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
1672       << "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock;
1673   }
1674
1675   sock->resetClientHelloParsing(ssl);
1676 }
1677
1678 } // namespace