Cleanup of how we use BIO/BIO_METHODs
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSSLSocket.h>
18
19 #include <folly/io/async/EventBase.h>
20 #include <folly/portability/Sockets.h>
21
22 #include <boost/noncopyable.hpp>
23 #include <errno.h>
24 #include <fcntl.h>
25 #include <openssl/err.h>
26 #include <openssl/asn1.h>
27 #include <openssl/ssl.h>
28 #include <sys/types.h>
29 #include <chrono>
30
31 #include <folly/Bits.h>
32 #include <folly/SocketAddress.h>
33 #include <folly/SpinLock.h>
34 #include <folly/io/IOBuf.h>
35 #include <folly/io/Cursor.h>
36 #include <folly/portability/Unistd.h>
37
38 using folly::SocketAddress;
39 using folly::SSLContext;
40 using std::string;
41 using std::shared_ptr;
42
43 using folly::Endian;
44 using folly::IOBuf;
45 using folly::SpinLock;
46 using folly::SpinLockGuard;
47 using folly::io::Cursor;
48 using std::unique_ptr;
49 using std::bind;
50
51 namespace {
52 using folly::AsyncSocket;
53 using folly::AsyncSocketException;
54 using folly::AsyncSSLSocket;
55 using folly::Optional;
56 using folly::SSLContext;
57 using folly::ssl::OpenSSLUtils;
58
59 // We have one single dummy SSL context so that we can implement attach
60 // and detach methods in a thread safe fashion without modifying opnessl.
61 static SSLContext *dummyCtx = nullptr;
62 static SpinLock dummyCtxLock;
63
64 // If given min write size is less than this, buffer will be allocated on
65 // stack, otherwise it is allocated on heap
66 const size_t MAX_STACK_BUF_SIZE = 2048;
67
68 // This converts "illegal" shutdowns into ZERO_RETURN
69 inline bool zero_return(int error, int rc) {
70   return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0));
71 }
72
73 class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
74                                 public AsyncSSLSocket::HandshakeCB {
75
76  private:
77   AsyncSSLSocket *sslSocket_;
78   AsyncSSLSocket::ConnectCallback *callback_;
79   int timeout_;
80   int64_t startTime_;
81
82  protected:
83   ~AsyncSSLSocketConnector() override {}
84
85  public:
86   AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
87                            AsyncSocket::ConnectCallback *callback,
88                            int timeout) :
89       sslSocket_(sslSocket),
90       callback_(callback),
91       timeout_(timeout),
92       startTime_(std::chrono::duration_cast<std::chrono::milliseconds>(
93                    std::chrono::steady_clock::now().time_since_epoch()).count()) {
94   }
95
96   void connectSuccess() noexcept override {
97     VLOG(7) << "client socket connected";
98
99     int64_t timeoutLeft = 0;
100     if (timeout_ > 0) {
101       auto curTime = std::chrono::duration_cast<std::chrono::milliseconds>(
102         std::chrono::steady_clock::now().time_since_epoch()).count();
103
104       timeoutLeft = timeout_ - (curTime - startTime_);
105       if (timeoutLeft <= 0) {
106         AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
107                                 "SSL connect timed out");
108         fail(ex);
109         delete this;
110         return;
111       }
112     }
113     sslSocket_->sslConn(this, timeoutLeft);
114   }
115
116   void connectErr(const AsyncSocketException& ex) noexcept override {
117     LOG(ERROR) << "TCP connect failed: " <<  ex.what();
118     fail(ex);
119     delete this;
120   }
121
122   void handshakeSuc(AsyncSSLSocket* /* sock */) noexcept override {
123     VLOG(7) << "client handshake success";
124     if (callback_) {
125       callback_->connectSuccess();
126     }
127     delete this;
128   }
129
130   void handshakeErr(AsyncSSLSocket* /* socket */,
131                     const AsyncSocketException& ex) noexcept override {
132     LOG(ERROR) << "client handshakeErr: " << ex.what();
133     fail(ex);
134     delete this;
135   }
136
137   void fail(const AsyncSocketException &ex) {
138     // fail is a noop if called twice
139     if (callback_) {
140       AsyncSSLSocket::ConnectCallback *cb = callback_;
141       callback_ = nullptr;
142
143       cb->connectErr(ex);
144       sslSocket_->closeNow();
145       // closeNow can call handshakeErr if it hasn't been called already.
146       // So this may have been deleted, no member variable access beyond this
147       // point
148       // Note that closeNow may invoke writeError callbacks if the socket had
149       // write data pending connection completion.
150     }
151   }
152 };
153
154 // XXX: implement an equivalent to corking for platforms with TCP_NOPUSH?
155 #ifdef TCP_CORK // Linux-only
156 /**
157  * Utility class that corks a TCP socket upon construction or uncorks
158  * the socket upon destruction
159  */
160 class CorkGuard : private boost::noncopyable {
161  public:
162   CorkGuard(int fd, bool multipleWrites, bool haveMore, bool* corked):
163     fd_(fd), haveMore_(haveMore), corked_(corked) {
164     if (*corked_) {
165       // socket is already corked; nothing to do
166       return;
167     }
168     if (multipleWrites || haveMore) {
169       // We are performing multiple writes in this performWrite() call,
170       // and/or there are more calls to performWrite() that will be invoked
171       // later, so enable corking
172       int flag = 1;
173       setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
174       *corked_ = true;
175     }
176   }
177
178   ~CorkGuard() {
179     if (haveMore_) {
180       // more data to come; don't uncork yet
181       return;
182     }
183     if (!*corked_) {
184       // socket isn't corked; nothing to do
185       return;
186     }
187
188     int flag = 0;
189     setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
190     *corked_ = false;
191   }
192
193  private:
194   int fd_;
195   bool haveMore_;
196   bool* corked_;
197 };
198 #else
199 class CorkGuard : private boost::noncopyable {
200  public:
201   CorkGuard(int, bool, bool, bool*) {}
202 };
203 #endif
204
205 void setup_SSL_CTX(SSL_CTX *ctx) {
206 #ifdef SSL_MODE_RELEASE_BUFFERS
207   SSL_CTX_set_mode(ctx,
208                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
209                    SSL_MODE_ENABLE_PARTIAL_WRITE
210                    | SSL_MODE_RELEASE_BUFFERS
211                    );
212 #else
213   SSL_CTX_set_mode(ctx,
214                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
215                    SSL_MODE_ENABLE_PARTIAL_WRITE
216                    );
217 #endif
218 // SSL_CTX_set_mode is a Macro
219 #ifdef SSL_MODE_WRITE_IOVEC
220   SSL_CTX_set_mode(ctx,
221                    SSL_CTX_get_mode(ctx)
222                    | SSL_MODE_WRITE_IOVEC);
223 #endif
224
225 }
226
227 BIO_METHOD sslWriteBioMethod;
228
229 void* initsslWriteBioMethod(void) {
230   memcpy(&sslWriteBioMethod, BIO_s_socket(), sizeof(sslWriteBioMethod));
231   // override the bwrite method for MSG_EOR support
232   OpenSSLUtils::setCustomBioWriteMethod(
233       &sslWriteBioMethod, AsyncSSLSocket::bioWrite);
234
235   // Note that the sslWriteBioMethod.type and sslWriteBioMethod.name are not
236   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
237   // then have specific handlings. The sslWriteBioWrite should be compatible
238   // with the one in openssl.
239
240   // Return something here to enable AsyncSSLSocket to call this method using
241   // a function-scoped static.
242   return nullptr;
243 }
244
245 } // anonymous namespace
246
247 namespace folly {
248
249 /**
250  * Create a client AsyncSSLSocket
251  */
252 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
253                                EventBase* evb, bool deferSecurityNegotiation) :
254     AsyncSocket(evb),
255     ctx_(ctx),
256     handshakeTimeout_(this, evb) {
257   init();
258   if (deferSecurityNegotiation) {
259     sslState_ = STATE_UNENCRYPTED;
260   }
261 }
262
263 /**
264  * Create a server/client AsyncSSLSocket
265  */
266 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
267                                EventBase* evb, int fd, bool server,
268                                bool deferSecurityNegotiation) :
269     AsyncSocket(evb, fd),
270     server_(server),
271     ctx_(ctx),
272     handshakeTimeout_(this, evb) {
273   init();
274   if (server) {
275     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
276                               AsyncSSLSocket::sslInfoCallback);
277   }
278   if (deferSecurityNegotiation) {
279     sslState_ = STATE_UNENCRYPTED;
280   }
281 }
282
283 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
284 /**
285  * Create a client AsyncSSLSocket and allow tlsext_hostname
286  * to be sent in Client Hello.
287  */
288 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
289                                  EventBase* evb,
290                                const std::string& serverName,
291                                bool deferSecurityNegotiation) :
292     AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
293   tlsextHostname_ = serverName;
294 }
295
296 /**
297  * Create a client AsyncSSLSocket from an already connected fd
298  * and allow tlsext_hostname to be sent in Client Hello.
299  */
300 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
301                                  EventBase* evb, int fd,
302                                const std::string& serverName,
303                                bool deferSecurityNegotiation) :
304     AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
305   tlsextHostname_ = serverName;
306 }
307 #endif
308
309 AsyncSSLSocket::~AsyncSSLSocket() {
310   VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
311           << ", evb=" << eventBase_ << ", fd=" << fd_
312           << ", state=" << int(state_) << ", sslState="
313           << sslState_ << ", events=" << eventFlags_ << ")";
314 }
315
316 void AsyncSSLSocket::init() {
317   // Do this here to ensure we initialize this once before any use of
318   // AsyncSSLSocket instances and not as part of library load.
319   static const auto sslWriteBioMethodInitializer = initsslWriteBioMethod();
320   (void)sslWriteBioMethodInitializer;
321
322   setup_SSL_CTX(ctx_->getSSLCtx());
323 }
324
325 void AsyncSSLSocket::closeNow() {
326   // Close the SSL connection.
327   if (ssl_ != nullptr && fd_ != -1) {
328     int rc = SSL_shutdown(ssl_);
329     if (rc == 0) {
330       rc = SSL_shutdown(ssl_);
331     }
332     if (rc < 0) {
333       ERR_clear_error();
334     }
335   }
336
337   if (sslSession_ != nullptr) {
338     SSL_SESSION_free(sslSession_);
339     sslSession_ = nullptr;
340   }
341
342   sslState_ = STATE_CLOSED;
343
344   if (handshakeTimeout_.isScheduled()) {
345     handshakeTimeout_.cancelTimeout();
346   }
347
348   DestructorGuard dg(this);
349
350   invokeHandshakeErr(
351       AsyncSocketException(
352         AsyncSocketException::END_OF_FILE,
353         "SSL connection closed locally"));
354
355   if (ssl_ != nullptr) {
356     SSL_free(ssl_);
357     ssl_ = nullptr;
358   }
359
360   // Close the socket.
361   AsyncSocket::closeNow();
362 }
363
364 void AsyncSSLSocket::shutdownWrite() {
365   // SSL sockets do not support half-shutdown, so just perform a full shutdown.
366   //
367   // (Performing a full shutdown here is more desirable than doing nothing at
368   // all.  The purpose of shutdownWrite() is normally to notify the other end
369   // of the connection that no more data will be sent.  If we do nothing, the
370   // other end will never know that no more data is coming, and this may result
371   // in protocol deadlock.)
372   close();
373 }
374
375 void AsyncSSLSocket::shutdownWriteNow() {
376   closeNow();
377 }
378
379 bool AsyncSSLSocket::good() const {
380   return (AsyncSocket::good() &&
381           (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
382            sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
383 }
384
385 // The TAsyncTransport definition of 'good' states that the transport is
386 // ready to perform reads and writes, so sslState_ == UNINIT must report !good.
387 // connecting can be true when the sslState_ == UNINIT because the AsyncSocket
388 // is connected but we haven't initiated the call to SSL_connect.
389 bool AsyncSSLSocket::connecting() const {
390   return (!server_ &&
391           (AsyncSocket::connecting() ||
392            (AsyncSocket::good() && (sslState_ == STATE_UNINIT ||
393                                      sslState_ == STATE_CONNECTING))));
394 }
395
396 std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
397   const unsigned char* protoName = nullptr;
398   unsigned protoLength;
399   if (getSelectedNextProtocolNoThrow(&protoName, &protoLength)) {
400     return std::string(reinterpret_cast<const char*>(protoName), protoLength);
401   }
402   return "";
403 }
404
405 bool AsyncSSLSocket::isEorTrackingEnabled() const {
406   return trackEor_;
407 }
408
409 void AsyncSSLSocket::setEorTracking(bool track) {
410   if (trackEor_ != track) {
411     trackEor_ = track;
412     appEorByteNo_ = 0;
413     minEorRawByteNo_ = 0;
414   }
415 }
416
417 size_t AsyncSSLSocket::getRawBytesWritten() const {
418   BIO *b;
419   if (!ssl_ || !(b = SSL_get_wbio(ssl_))) {
420     return 0;
421   }
422
423   return BIO_number_written(b);
424 }
425
426 size_t AsyncSSLSocket::getRawBytesReceived() const {
427   BIO *b;
428   if (!ssl_ || !(b = SSL_get_rbio(ssl_))) {
429     return 0;
430   }
431
432   return BIO_number_read(b);
433 }
434
435
436 void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
437   LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_
438              << ", state=" << int(state_) << ", sslState=" << sslState_ << ", "
439              << "events=" << eventFlags_ << ", server=" << short(server_)
440              << "): " << "sslAccept/Connect() called in invalid "
441              << "state, handshake callback " << handshakeCallback_
442              << ", new callback " << callback;
443   assert(!handshakeTimeout_.isScheduled());
444   sslState_ = STATE_ERROR;
445
446   AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
447                          "sslAccept() called with socket in invalid state");
448
449   handshakeEndTime_ = std::chrono::steady_clock::now();
450   if (callback) {
451     callback->handshakeErr(this, ex);
452   }
453
454   // Check the socket state not the ssl state here.
455   if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) {
456     failHandshake(__func__, ex);
457   }
458 }
459
460 void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
461       const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
462   DestructorGuard dg(this);
463   assert(eventBase_->isInEventBaseThread());
464   verifyPeer_ = verifyPeer;
465
466   // Make sure we're in the uninitialized state
467   if (!server_ || (sslState_ != STATE_UNINIT &&
468                    sslState_ != STATE_UNENCRYPTED) ||
469       handshakeCallback_ != nullptr) {
470     return invalidState(callback);
471   }
472
473   // Cache local and remote socket addresses to keep them available
474   // after socket file descriptor is closed.
475   if (cacheAddrOnFailure_ && -1 != getFd()) {
476     cacheLocalPeerAddr();
477   }
478
479   handshakeStartTime_ = std::chrono::steady_clock::now();
480   // Make end time at least >= start time.
481   handshakeEndTime_ = handshakeStartTime_;
482
483   sslState_ = STATE_ACCEPTING;
484   handshakeCallback_ = callback;
485
486   if (timeout > 0) {
487     handshakeTimeout_.scheduleTimeout(timeout);
488   }
489
490   /* register for a read operation (waiting for CLIENT HELLO) */
491   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
492 }
493
494 #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
495 void AsyncSSLSocket::attachSSLContext(
496   const std::shared_ptr<SSLContext>& ctx) {
497
498   // Check to ensure we are in client mode. Changing a server's ssl
499   // context doesn't make sense since clients of that server would likely
500   // become confused when the server's context changes.
501   DCHECK(!server_);
502   DCHECK(!ctx_);
503   DCHECK(ctx);
504   DCHECK(ctx->getSSLCtx());
505   ctx_ = ctx;
506
507   // In order to call attachSSLContext, detachSSLContext must have been
508   // previously called which sets the socket's context to the dummy
509   // context. Thus we must acquire this lock.
510   SpinLockGuard guard(dummyCtxLock);
511   SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx());
512 }
513
514 void AsyncSSLSocket::detachSSLContext() {
515   DCHECK(ctx_);
516   ctx_.reset();
517   // We aren't using the initial_ctx for now, and it can introduce race
518   // conditions in the destructor of the SSL object.
519 #ifndef OPENSSL_NO_TLSEXT
520   if (ssl_->initial_ctx) {
521     SSL_CTX_free(ssl_->initial_ctx);
522     ssl_->initial_ctx = nullptr;
523   }
524 #endif
525   SpinLockGuard guard(dummyCtxLock);
526   if (nullptr == dummyCtx) {
527     // We need to lazily initialize the dummy context so we don't
528     // accidentally override any programmatic settings to openssl
529     dummyCtx = new SSLContext;
530   }
531   // We must remove this socket's references to its context right now
532   // since this socket could get passed to any thread. If the context has
533   // had its locking disabled, just doing a set in attachSSLContext()
534   // would not be thread safe.
535   SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
536 }
537 #endif
538
539 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
540 void AsyncSSLSocket::switchServerSSLContext(
541   const std::shared_ptr<SSLContext>& handshakeCtx) {
542   CHECK(server_);
543   if (sslState_ != STATE_ACCEPTING) {
544     // We log it here and allow the switch.
545     // It should not affect our re-negotiation support (which
546     // is not supported now).
547     VLOG(6) << "fd=" << getFd()
548             << " renegotation detected when switching SSL_CTX";
549   }
550
551   setup_SSL_CTX(handshakeCtx->getSSLCtx());
552   SSL_CTX_set_info_callback(handshakeCtx->getSSLCtx(),
553                             AsyncSSLSocket::sslInfoCallback);
554   handshakeCtx_ = handshakeCtx;
555   SSL_set_SSL_CTX(ssl_, handshakeCtx->getSSLCtx());
556 }
557
558 bool AsyncSSLSocket::isServerNameMatch() const {
559   CHECK(!server_);
560
561   if (!ssl_) {
562     return false;
563   }
564
565   SSL_SESSION *ss = SSL_get_session(ssl_);
566   if (!ss) {
567     return false;
568   }
569
570   if(!ss->tlsext_hostname) {
571     return false;
572   }
573   return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
574 }
575
576 void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
577   tlsextHostname_ = std::move(serverName);
578 }
579
580 #endif
581
582 void AsyncSSLSocket::timeoutExpired() noexcept {
583   if (state_ == StateEnum::ESTABLISHED &&
584       (sslState_ == STATE_CACHE_LOOKUP ||
585        sslState_ == STATE_ASYNC_PENDING)) {
586     sslState_ = STATE_ERROR;
587     // We are expecting a callback in restartSSLAccept.  The cache lookup
588     // and rsa-call necessarily have pointers to this ssl socket, so delay
589     // the cleanup until he calls us back.
590   } else {
591     assert(state_ == StateEnum::ESTABLISHED &&
592            (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
593     DestructorGuard dg(this);
594     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
595                            (sslState_ == STATE_CONNECTING) ?
596                            "SSL connect timed out" : "SSL accept timed out");
597     failHandshake(__func__, ex);
598   }
599 }
600
601 int AsyncSSLSocket::getSSLExDataIndex() {
602   static auto index = SSL_get_ex_new_index(
603       0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
604   return index;
605 }
606
607 AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
608   return static_cast<AsyncSSLSocket *>(SSL_get_ex_data(ssl,
609       getSSLExDataIndex()));
610 }
611
612 void AsyncSSLSocket::failHandshake(const char* /* fn */,
613                                    const AsyncSocketException& ex) {
614   startFail();
615   if (handshakeTimeout_.isScheduled()) {
616     handshakeTimeout_.cancelTimeout();
617   }
618   invokeHandshakeErr(ex);
619   finishFail();
620 }
621
622 void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
623   handshakeEndTime_ = std::chrono::steady_clock::now();
624   if (handshakeCallback_ != nullptr) {
625     HandshakeCB* callback = handshakeCallback_;
626     handshakeCallback_ = nullptr;
627     callback->handshakeErr(this, ex);
628   }
629 }
630
631 void AsyncSSLSocket::invokeHandshakeCB() {
632   handshakeEndTime_ = std::chrono::steady_clock::now();
633   if (handshakeTimeout_.isScheduled()) {
634     handshakeTimeout_.cancelTimeout();
635   }
636   if (handshakeCallback_) {
637     HandshakeCB* callback = handshakeCallback_;
638     handshakeCallback_ = nullptr;
639     callback->handshakeSuc(this);
640   }
641 }
642
643 void AsyncSSLSocket::cacheLocalPeerAddr() {
644   SocketAddress address;
645   try {
646     getLocalAddress(&address);
647     getPeerAddress(&address);
648   } catch (const std::system_error& e) {
649     // The handle can be still valid while the connection is already closed.
650     if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
651       throw;
652     }
653   }
654 }
655
656 void AsyncSSLSocket::connect(ConnectCallback* callback,
657                               const folly::SocketAddress& address,
658                               int timeout,
659                               const OptionMap &options,
660                               const folly::SocketAddress& bindAddr)
661                               noexcept {
662   assert(!server_);
663   assert(state_ == StateEnum::UNINIT);
664   assert(sslState_ == STATE_UNINIT);
665   AsyncSSLSocketConnector *connector =
666     new AsyncSSLSocketConnector(this, callback, timeout);
667   AsyncSocket::connect(connector, address, timeout, options, bindAddr);
668 }
669
670 void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
671   // apply the settings specified in verifyPeer_
672   if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
673     if(ctx_->needsPeerVerification()) {
674       SSL_set_verify(ssl, ctx_->getVerificationMode(),
675         AsyncSSLSocket::sslVerifyCallback);
676     }
677   } else {
678     if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
679         verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT) {
680       SSL_set_verify(ssl, SSLContext::getVerificationMode(verifyPeer_),
681         AsyncSSLSocket::sslVerifyCallback);
682     }
683   }
684 }
685
686 bool AsyncSSLSocket::setupSSLBio() {
687   auto wb = BIO_new(&sslWriteBioMethod);
688
689   if (!wb) {
690     return false;
691   }
692
693   OpenSSLUtils::setBioAppData(wb, this);
694   BIO_set_fd(wb, fd_, BIO_NOCLOSE);
695   SSL_set_bio(ssl_, wb, wb);
696   return true;
697 }
698
699 void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
700         const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
701   DestructorGuard dg(this);
702   assert(eventBase_->isInEventBaseThread());
703
704   // Cache local and remote socket addresses to keep them available
705   // after socket file descriptor is closed.
706   if (cacheAddrOnFailure_ && -1 != getFd()) {
707     cacheLocalPeerAddr();
708   }
709
710   verifyPeer_ = verifyPeer;
711
712   // Make sure we're in the uninitialized state
713   if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
714                   STATE_UNENCRYPTED) ||
715       handshakeCallback_ != nullptr) {
716     return invalidState(callback);
717   }
718
719   handshakeStartTime_ = std::chrono::steady_clock::now();
720   // Make end time at least >= start time.
721   handshakeEndTime_ = handshakeStartTime_;
722
723   sslState_ = STATE_CONNECTING;
724   handshakeCallback_ = callback;
725
726   try {
727     ssl_ = ctx_->createSSL();
728   } catch (std::exception &e) {
729     sslState_ = STATE_ERROR;
730     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
731                            "error calling SSLContext::createSSL()");
732     LOG(ERROR) << "AsyncSSLSocket::sslConn(this=" << this << ", fd="
733             << fd_ << "): " << e.what();
734     return failHandshake(__func__, ex);
735   }
736
737   if (!setupSSLBio()) {
738     sslState_ = STATE_ERROR;
739     AsyncSocketException ex(
740         AsyncSocketException::INTERNAL_ERROR, "error creating SSL bio");
741     return failHandshake(__func__, ex);
742   }
743
744   applyVerificationOptions(ssl_);
745
746   if (sslSession_ != nullptr) {
747     SSL_set_session(ssl_, sslSession_);
748     SSL_SESSION_free(sslSession_);
749     sslSession_ = nullptr;
750   }
751 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
752   if (tlsextHostname_.size()) {
753     SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
754   }
755 #endif
756
757   SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
758
759   if (timeout > 0) {
760     handshakeTimeout_.scheduleTimeout(timeout);
761   }
762
763   handleConnect();
764 }
765
766 SSL_SESSION *AsyncSSLSocket::getSSLSession() {
767   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
768     return SSL_get1_session(ssl_);
769   }
770
771   return sslSession_;
772 }
773
774 const SSL* AsyncSSLSocket::getSSL() const {
775   return ssl_;
776 }
777
778 void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
779   sslSession_ = session;
780   if (!takeOwnership && session != nullptr) {
781     // Increment the reference count
782     CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
783   }
784 }
785
786 void AsyncSSLSocket::getSelectedNextProtocol(
787     const unsigned char** protoName,
788     unsigned* protoLen,
789     SSLContext::NextProtocolType* protoType) const {
790   if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) {
791     throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
792                               "NPN not supported");
793   }
794 }
795
796 bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
797     const unsigned char** protoName,
798     unsigned* protoLen,
799     SSLContext::NextProtocolType* protoType) const {
800   *protoName = nullptr;
801   *protoLen = 0;
802 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
803   SSL_get0_alpn_selected(ssl_, protoName, protoLen);
804   if (*protoLen > 0) {
805     if (protoType) {
806       *protoType = SSLContext::NextProtocolType::ALPN;
807     }
808     return true;
809   }
810 #endif
811 #ifdef OPENSSL_NPN_NEGOTIATED
812   SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
813   if (protoType) {
814     *protoType = SSLContext::NextProtocolType::NPN;
815   }
816   return true;
817 #else
818   (void)protoType;
819   return false;
820 #endif
821 }
822
823 bool AsyncSSLSocket::getSSLSessionReused() const {
824   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
825     return SSL_session_reused(ssl_);
826   }
827   return false;
828 }
829
830 const char *AsyncSSLSocket::getNegotiatedCipherName() const {
831   return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr;
832 }
833
834 /* static */
835 const char* AsyncSSLSocket::getSSLServerNameFromSSL(SSL* ssl) {
836   if (ssl == nullptr) {
837     return nullptr;
838   }
839 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
840   return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
841 #else
842   return nullptr;
843 #endif
844 }
845
846 const char *AsyncSSLSocket::getSSLServerName() const {
847 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
848   return getSSLServerNameFromSSL(ssl_);
849 #else
850   throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
851                              "SNI not supported");
852 #endif
853 }
854
855 const char *AsyncSSLSocket::getSSLServerNameNoThrow() const {
856   return getSSLServerNameFromSSL(ssl_);
857 }
858
859 int AsyncSSLSocket::getSSLVersion() const {
860   return (ssl_ != nullptr) ? SSL_version(ssl_) : 0;
861 }
862
863 const char *AsyncSSLSocket::getSSLCertSigAlgName() const {
864   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
865   if (cert) {
866     int nid = OBJ_obj2nid(cert->sig_alg->algorithm);
867     return OBJ_nid2ln(nid);
868   }
869   return nullptr;
870 }
871
872 int AsyncSSLSocket::getSSLCertSize() const {
873   int certSize = 0;
874   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
875   if (cert) {
876     EVP_PKEY *key = X509_get_pubkey(cert);
877     certSize = EVP_PKEY_bits(key);
878     EVP_PKEY_free(key);
879   }
880   return certSize;
881 }
882
883 bool AsyncSSLSocket::willBlock(int ret,
884                                int* sslErrorOut,
885                                unsigned long* errErrorOut) noexcept {
886   *errErrorOut = 0;
887   int error = *sslErrorOut = SSL_get_error(ssl_, ret);
888   if (error == SSL_ERROR_WANT_READ) {
889     // Register for read event if not already.
890     updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
891     return true;
892   } else if (error == SSL_ERROR_WANT_WRITE) {
893     VLOG(3) << "AsyncSSLSocket(fd=" << fd_
894             << ", state=" << int(state_) << ", sslState="
895             << sslState_ << ", events=" << eventFlags_ << "): "
896             << "SSL_ERROR_WANT_WRITE";
897     // Register for write event if not already.
898     updateEventRegistration(EventHandler::WRITE, EventHandler::READ);
899     return true;
900 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
901   } else if (error == SSL_ERROR_WANT_SESS_CACHE_LOOKUP) {
902     // We will block but we can't register our own socket.  The callback that
903     // triggered this code will re-call handleAccept at the appropriate time.
904
905     // We can only get here if the linked libssl.so has support for this feature
906     // as well, otherwise SSL_get_error cannot return our error code.
907     sslState_ = STATE_CACHE_LOOKUP;
908
909     // Unregister for all events while blocked here
910     updateEventRegistration(EventHandler::NONE,
911                             EventHandler::READ | EventHandler::WRITE);
912
913     // The timeout (if set) keeps running here
914     return true;
915 #endif
916   } else if (0
917 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
918       || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING
919 #endif
920 #ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
921       || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
922 #endif
923       ) {
924     // Our custom openssl function has kicked off an async request to do
925     // rsa/ecdsa private key operation.  When that call returns, a callback will
926     // be invoked that will re-call handleAccept.
927     sslState_ = STATE_ASYNC_PENDING;
928
929     // Unregister for all events while blocked here
930     updateEventRegistration(
931       EventHandler::NONE,
932       EventHandler::READ | EventHandler::WRITE
933     );
934
935     // The timeout (if set) keeps running here
936     return true;
937   } else {
938     unsigned long lastError = *errErrorOut = ERR_get_error();
939     VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
940             << "state=" << state_ << ", "
941             << "sslState=" << sslState_ << ", "
942             << "events=" << std::hex << eventFlags_ << "): "
943             << "SSL error: " << error << ", "
944             << "errno: " << errno << ", "
945             << "ret: " << ret << ", "
946             << "read: " << BIO_number_read(SSL_get_rbio(ssl_)) << ", "
947             << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
948             << "func: " << ERR_func_error_string(lastError) << ", "
949             << "reason: " << ERR_reason_error_string(lastError);
950     return false;
951   }
952 }
953
954 void AsyncSSLSocket::checkForImmediateRead() noexcept {
955   // openssl may have buffered data that it read from the socket already.
956   // In this case we have to process it immediately, rather than waiting for
957   // the socket to become readable again.
958   if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
959     AsyncSocket::handleRead();
960   }
961 }
962
963 void
964 AsyncSSLSocket::restartSSLAccept()
965 {
966   VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this
967           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
968           << "sslState=" << sslState_ << ", events=" << eventFlags_;
969   DestructorGuard dg(this);
970   assert(
971     sslState_ == STATE_CACHE_LOOKUP ||
972     sslState_ == STATE_ASYNC_PENDING ||
973     sslState_ == STATE_ERROR ||
974     sslState_ == STATE_CLOSED);
975   if (sslState_ == STATE_CLOSED) {
976     // I sure hope whoever closed this socket didn't delete it already,
977     // but this is not strictly speaking an error
978     return;
979   }
980   if (sslState_ == STATE_ERROR) {
981     // go straight to fail if timeout expired during lookup
982     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
983                            "SSL accept timed out");
984     failHandshake(__func__, ex);
985     return;
986   }
987   sslState_ = STATE_ACCEPTING;
988   this->handleAccept();
989 }
990
991 void
992 AsyncSSLSocket::handleAccept() noexcept {
993   VLOG(3) << "AsyncSSLSocket::handleAccept() this=" << this
994           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
995           << "sslState=" << sslState_ << ", events=" << eventFlags_;
996   assert(server_);
997   assert(state_ == StateEnum::ESTABLISHED &&
998          sslState_ == STATE_ACCEPTING);
999   if (!ssl_) {
1000     /* lazily create the SSL structure */
1001     try {
1002       ssl_ = ctx_->createSSL();
1003     } catch (std::exception &e) {
1004       sslState_ = STATE_ERROR;
1005       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1006                              "error calling SSLContext::createSSL()");
1007       LOG(ERROR) << "AsyncSSLSocket::handleAccept(this=" << this
1008                  << ", fd=" << fd_ << "): " << e.what();
1009       return failHandshake(__func__, ex);
1010     }
1011
1012     if (!setupSSLBio()) {
1013       sslState_ = STATE_ERROR;
1014       AsyncSocketException ex(
1015           AsyncSocketException::INTERNAL_ERROR, "error creating write bio");
1016       return failHandshake(__func__, ex);
1017     }
1018
1019     SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
1020
1021     applyVerificationOptions(ssl_);
1022   }
1023
1024   if (server_ && parseClientHello_) {
1025     SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback);
1026     SSL_set_msg_callback_arg(ssl_, this);
1027   }
1028
1029   int ret = SSL_accept(ssl_);
1030   if (ret <= 0) {
1031     int sslError;
1032     unsigned long errError;
1033     int errnoCopy = errno;
1034     if (willBlock(ret, &sslError, &errError)) {
1035       return;
1036     } else {
1037       sslState_ = STATE_ERROR;
1038       SSLException ex(sslError, errError, ret, errnoCopy);
1039       return failHandshake(__func__, ex);
1040     }
1041   }
1042
1043   handshakeComplete_ = true;
1044   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1045
1046   // Move into STATE_ESTABLISHED in the normal case that we are in
1047   // STATE_ACCEPTING.
1048   sslState_ = STATE_ESTABLISHED;
1049
1050   VLOG(3) << "AsyncSSLSocket " << this << ": fd " << fd_
1051           << " successfully accepted; state=" << int(state_)
1052           << ", sslState=" << sslState_ << ", events=" << eventFlags_;
1053
1054   // Remember the EventBase we are attached to, before we start invoking any
1055   // callbacks (since the callbacks may call detachEventBase()).
1056   EventBase* originalEventBase = eventBase_;
1057
1058   // Call the accept callback.
1059   invokeHandshakeCB();
1060
1061   // Note that the accept callback may have changed our state.
1062   // (set or unset the read callback, called write(), closed the socket, etc.)
1063   // The following code needs to handle these situations correctly.
1064   //
1065   // If the socket has been closed, readCallback_ and writeReqHead_ will
1066   // always be nullptr, so that will prevent us from trying to read or write.
1067   //
1068   // The main thing to check for is if eventBase_ is still originalEventBase.
1069   // If not, we have been detached from this event base, so we shouldn't
1070   // perform any more operations.
1071   if (eventBase_ != originalEventBase) {
1072     return;
1073   }
1074
1075   AsyncSocket::handleInitialReadWrite();
1076 }
1077
1078 void
1079 AsyncSSLSocket::handleConnect() noexcept {
1080   VLOG(3) <<  "AsyncSSLSocket::handleConnect() this=" << this
1081           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
1082           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1083   assert(!server_);
1084   if (state_ < StateEnum::ESTABLISHED) {
1085     return AsyncSocket::handleConnect();
1086   }
1087
1088   assert(
1089       (state_ == StateEnum::FAST_OPEN || state_ == StateEnum::ESTABLISHED) &&
1090       sslState_ == STATE_CONNECTING);
1091   assert(ssl_);
1092
1093   int ret = SSL_connect(ssl_);
1094   if (ret <= 0) {
1095     int sslError;
1096     unsigned long errError;
1097     int errnoCopy = errno;
1098     if (willBlock(ret, &sslError, &errError)) {
1099       return;
1100     } else {
1101       sslState_ = STATE_ERROR;
1102       SSLException ex(sslError, errError, ret, errnoCopy);
1103       return failHandshake(__func__, ex);
1104     }
1105   }
1106
1107   handshakeComplete_ = true;
1108   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1109
1110   // Move into STATE_ESTABLISHED in the normal case that we are in
1111   // STATE_CONNECTING.
1112   sslState_ = STATE_ESTABLISHED;
1113
1114   VLOG(3) << "AsyncSSLSocket " << this << ": "
1115           << "fd " << fd_ << " successfully connected; "
1116           << "state=" << int(state_) << ", sslState=" << sslState_
1117           << ", events=" << eventFlags_;
1118
1119   // Remember the EventBase we are attached to, before we start invoking any
1120   // callbacks (since the callbacks may call detachEventBase()).
1121   EventBase* originalEventBase = eventBase_;
1122
1123   // Call the handshake callback.
1124   invokeHandshakeCB();
1125
1126   // Note that the connect callback may have changed our state.
1127   // (set or unset the read callback, called write(), closed the socket, etc.)
1128   // The following code needs to handle these situations correctly.
1129   //
1130   // If the socket has been closed, readCallback_ and writeReqHead_ will
1131   // always be nullptr, so that will prevent us from trying to read or write.
1132   //
1133   // The main thing to check for is if eventBase_ is still originalEventBase.
1134   // If not, we have been detached from this event base, so we shouldn't
1135   // perform any more operations.
1136   if (eventBase_ != originalEventBase) {
1137     return;
1138   }
1139
1140   AsyncSocket::handleInitialReadWrite();
1141 }
1142
1143 void AsyncSSLSocket::invokeConnectSuccess() {
1144   if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
1145     // If we failed TFO, we'd fall back to trying to connect the socket,
1146     // when we succeed we should handle the writes that caused us to start
1147     // TFO.
1148     handleWrite();
1149   }
1150   AsyncSocket::invokeConnectSuccess();
1151 }
1152
1153 void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
1154 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1155   // turn on the buffer movable in openssl
1156   if (bufferMovableEnabled_ && ssl_ != nullptr && !isBufferMovable_ &&
1157       callback != nullptr && callback->isBufferMovable()) {
1158     SSL_set_mode(ssl_, SSL_get_mode(ssl_) | SSL_MODE_MOVE_BUFFER_OWNERSHIP);
1159     isBufferMovable_ = true;
1160   }
1161 #endif
1162
1163   AsyncSocket::setReadCB(callback);
1164 }
1165
1166 void AsyncSSLSocket::setBufferMovableEnabled(bool enabled) {
1167   bufferMovableEnabled_ = enabled;
1168 }
1169
1170 void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
1171   CHECK(readCallback_);
1172   if (isBufferMovable_) {
1173     *buf = nullptr;
1174     *buflen = 0;
1175   } else {
1176     // buf is necessary for SSLSocket without SSL_MODE_MOVE_BUFFER_OWNERSHIP
1177     readCallback_->getReadBuffer(buf, buflen);
1178   }
1179 }
1180
1181 void
1182 AsyncSSLSocket::handleRead() noexcept {
1183   VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
1184           << ", state=" << int(state_) << ", "
1185           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1186   if (state_ < StateEnum::ESTABLISHED) {
1187     return AsyncSocket::handleRead();
1188   }
1189
1190
1191   if (sslState_ == STATE_ACCEPTING) {
1192     assert(server_);
1193     handleAccept();
1194     return;
1195   }
1196   else if (sslState_ == STATE_CONNECTING) {
1197     assert(!server_);
1198     handleConnect();
1199     return;
1200   }
1201
1202   // Normal read
1203   AsyncSocket::handleRead();
1204 }
1205
1206 AsyncSocket::ReadResult
1207 AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
1208   VLOG(4) << "AsyncSSLSocket::performRead() this=" << this << ", buf=" << *buf
1209           << ", buflen=" << *buflen;
1210
1211   if (sslState_ == STATE_UNENCRYPTED) {
1212     return AsyncSocket::performRead(buf, buflen, offset);
1213   }
1214
1215   ssize_t bytes = 0;
1216   if (!isBufferMovable_) {
1217     bytes = SSL_read(ssl_, *buf, *buflen);
1218   }
1219 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1220   else {
1221     bytes = SSL_read_buf(ssl_, buf, (int *) offset, (int *) buflen);
1222   }
1223 #endif
1224
1225   if (server_ && renegotiateAttempted_) {
1226     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1227                << ", sslstate=" << sslState_ << ", events=" << eventFlags_
1228                << "): client intitiated SSL renegotiation not permitted";
1229     return ReadResult(
1230         READ_ERROR,
1231         folly::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
1232   }
1233   if (bytes <= 0) {
1234     int error = SSL_get_error(ssl_, bytes);
1235     if (error == SSL_ERROR_WANT_READ) {
1236       // The caller will register for read event if not already.
1237       if (errno == EWOULDBLOCK || errno == EAGAIN) {
1238         return ReadResult(READ_BLOCKING);
1239       } else {
1240         return ReadResult(READ_ERROR);
1241       }
1242     } else if (error == SSL_ERROR_WANT_WRITE) {
1243       // TODO: Even though we are attempting to read data, SSL_read() may
1244       // need to write data if renegotiation is being performed.  We currently
1245       // don't support this and just fail the read.
1246       LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1247                  << ", sslState=" << sslState_ << ", events=" << eventFlags_
1248                  << "): unsupported SSL renegotiation during read";
1249       return ReadResult(
1250           READ_ERROR,
1251           folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
1252     } else {
1253       if (zero_return(error, bytes)) {
1254         return ReadResult(bytes);
1255       }
1256       long errError = ERR_get_error();
1257       VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
1258               << "state=" << state_ << ", "
1259               << "sslState=" << sslState_ << ", "
1260               << "events=" << std::hex << eventFlags_ << "): "
1261               << "bytes: " << bytes << ", "
1262               << "error: " << error << ", "
1263               << "errno: " << errno << ", "
1264               << "func: " << ERR_func_error_string(errError) << ", "
1265               << "reason: " << ERR_reason_error_string(errError);
1266       return ReadResult(
1267           READ_ERROR,
1268           folly::make_unique<SSLException>(error, errError, bytes, errno));
1269     }
1270   } else {
1271     appBytesReceived_ += bytes;
1272     return ReadResult(bytes);
1273   }
1274 }
1275
1276 void AsyncSSLSocket::handleWrite() noexcept {
1277   VLOG(5) << "AsyncSSLSocket::handleWrite() this=" << this << ", fd=" << fd_
1278           << ", state=" << int(state_) << ", "
1279           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1280   if (state_ < StateEnum::ESTABLISHED) {
1281     return AsyncSocket::handleWrite();
1282   }
1283
1284   if (sslState_ == STATE_ACCEPTING) {
1285     assert(server_);
1286     handleAccept();
1287     return;
1288   }
1289
1290   if (sslState_ == STATE_CONNECTING) {
1291     assert(!server_);
1292     handleConnect();
1293     return;
1294   }
1295
1296   // Normal write
1297   AsyncSocket::handleWrite();
1298 }
1299
1300 AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
1301   if (error == SSL_ERROR_WANT_READ) {
1302     // Even though we are attempting to write data, SSL_write() may
1303     // need to read data if renegotiation is being performed.  We currently
1304     // don't support this and just fail the write.
1305     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1306                << ", sslState=" << sslState_ << ", events=" << eventFlags_
1307                << "): "
1308                << "unsupported SSL renegotiation during write";
1309     return WriteResult(
1310         WRITE_ERROR,
1311         folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
1312   } else {
1313     if (zero_return(error, rc)) {
1314       return WriteResult(0);
1315     }
1316     auto errError = ERR_get_error();
1317     VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1318             << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
1319             << "SSL error: " << error << ", errno: " << errno
1320             << ", func: " << ERR_func_error_string(errError)
1321             << ", reason: " << ERR_reason_error_string(errError);
1322     return WriteResult(
1323         WRITE_ERROR,
1324         folly::make_unique<SSLException>(error, errError, rc, errno));
1325   }
1326 }
1327
1328 AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
1329     const iovec* vec,
1330     uint32_t count,
1331     WriteFlags flags,
1332     uint32_t* countWritten,
1333     uint32_t* partialWritten) {
1334   if (sslState_ == STATE_UNENCRYPTED) {
1335     return AsyncSocket::performWrite(
1336       vec, count, flags, countWritten, partialWritten);
1337   }
1338   if (sslState_ != STATE_ESTABLISHED) {
1339     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1340                << ", sslState=" << sslState_
1341                << ", events=" << eventFlags_ << "): "
1342                << "TODO: AsyncSSLSocket currently does not support calling "
1343                << "write() before the handshake has fully completed";
1344     return WriteResult(
1345         WRITE_ERROR, folly::make_unique<SSLException>(SSLError::EARLY_WRITE));
1346   }
1347
1348   bool cork = isSet(flags, WriteFlags::CORK);
1349   CorkGuard guard(fd_, count > 1, cork, &corked_);
1350
1351   // Declare a buffer used to hold small write requests.  It could point to a
1352   // memory block either on stack or on heap. If it is on heap, we release it
1353   // manually when scope exits
1354   char* combinedBuf{nullptr};
1355   SCOPE_EXIT {
1356     // Note, always keep this check consistent with what we do below
1357     if (combinedBuf != nullptr && minWriteSize_ > MAX_STACK_BUF_SIZE) {
1358       delete[] combinedBuf;
1359     }
1360   };
1361
1362   *countWritten = 0;
1363   *partialWritten = 0;
1364   ssize_t totalWritten = 0;
1365   size_t bytesStolenFromNextBuffer = 0;
1366   for (uint32_t i = 0; i < count; i++) {
1367     const iovec* v = vec + i;
1368     size_t offset = bytesStolenFromNextBuffer;
1369     bytesStolenFromNextBuffer = 0;
1370     size_t len = v->iov_len - offset;
1371     const void* buf;
1372     if (len == 0) {
1373       (*countWritten)++;
1374       continue;
1375     }
1376     buf = ((const char*)v->iov_base) + offset;
1377
1378     ssize_t bytes;
1379     uint32_t buffersStolen = 0;
1380     if ((len < minWriteSize_) && ((i + 1) < count)) {
1381       // Combine this buffer with part or all of the next buffers in
1382       // order to avoid really small-grained calls to SSL_write().
1383       // Each call to SSL_write() produces a separate record in
1384       // the egress SSL stream, and we've found that some low-end
1385       // mobile clients can't handle receiving an HTTP response
1386       // header and the first part of the response body in two
1387       // separate SSL records (even if those two records are in
1388       // the same TCP packet).
1389
1390       if (combinedBuf == nullptr) {
1391         if (minWriteSize_ > MAX_STACK_BUF_SIZE) {
1392           // Allocate the buffer on heap
1393           combinedBuf = new char[minWriteSize_];
1394         } else {
1395           // Allocate the buffer on stack
1396           combinedBuf = (char*)alloca(minWriteSize_);
1397         }
1398       }
1399       assert(combinedBuf != nullptr);
1400
1401       memcpy(combinedBuf, buf, len);
1402       do {
1403         // INVARIANT: i + buffersStolen == complete chunks serialized
1404         uint32_t nextIndex = i + buffersStolen + 1;
1405         bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len,
1406                                              minWriteSize_ - len);
1407         memcpy(combinedBuf + len, vec[nextIndex].iov_base,
1408                bytesStolenFromNextBuffer);
1409         len += bytesStolenFromNextBuffer;
1410         if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
1411           // couldn't steal the whole buffer
1412           break;
1413         } else {
1414           bytesStolenFromNextBuffer = 0;
1415           buffersStolen++;
1416         }
1417       } while ((i + buffersStolen + 1) < count && (len < minWriteSize_));
1418       bytes = eorAwareSSLWrite(
1419         ssl_, combinedBuf, len,
1420         (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
1421
1422     } else {
1423       bytes = eorAwareSSLWrite(ssl_, buf, len,
1424                            (isSet(flags, WriteFlags::EOR) && i + 1 == count));
1425     }
1426
1427     if (bytes <= 0) {
1428       int error = SSL_get_error(ssl_, bytes);
1429       if (error == SSL_ERROR_WANT_WRITE) {
1430         // The caller will register for write event if not already.
1431         *partialWritten = offset;
1432         return WriteResult(totalWritten);
1433       }
1434       auto writeResult = interpretSSLError(bytes, error);
1435       if (writeResult.writeReturn < 0) {
1436         return writeResult;
1437       } // else fall through to below to correctly record totalWritten
1438     }
1439
1440     totalWritten += bytes;
1441
1442     if (bytes == (ssize_t)len) {
1443       // The full iovec is written.
1444       (*countWritten) += 1 + buffersStolen;
1445       i += buffersStolen;
1446       // continue
1447     } else {
1448       bytes += offset; // adjust bytes to account for all of v
1449       while (bytes >= (ssize_t)v->iov_len) {
1450         // We combined this buf with part or all of the next one, and
1451         // we managed to write all of this buf but not all of the bytes
1452         // from the next one that we'd hoped to write.
1453         bytes -= v->iov_len;
1454         (*countWritten)++;
1455         v = &(vec[++i]);
1456       }
1457       *partialWritten = bytes;
1458       return WriteResult(totalWritten);
1459     }
1460   }
1461
1462   return WriteResult(totalWritten);
1463 }
1464
1465 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
1466                                       bool eor) {
1467   if (eor && trackEor_) {
1468     if (appEorByteNo_) {
1469       // cannot track for more than one app byte EOR
1470       CHECK(appEorByteNo_ == appBytesWritten_ + n);
1471     } else {
1472       appEorByteNo_ = appBytesWritten_ + n;
1473     }
1474
1475     // 1. It is fine to keep updating minEorRawByteNo_.
1476     // 2. It is _min_ in the sense that SSL record will add some overhead.
1477     minEorRawByteNo_ = getRawBytesWritten() + n;
1478   }
1479
1480   n = sslWriteImpl(ssl, buf, n);
1481   if (n > 0) {
1482     appBytesWritten_ += n;
1483     if (appEorByteNo_) {
1484       if (getRawBytesWritten() >= minEorRawByteNo_) {
1485         minEorRawByteNo_ = 0;
1486       }
1487       if(appBytesWritten_ == appEorByteNo_) {
1488         appEorByteNo_ = 0;
1489       } else {
1490         CHECK(appBytesWritten_ < appEorByteNo_);
1491       }
1492     }
1493   }
1494   return n;
1495 }
1496
1497 void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int ret) {
1498   AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl);
1499   if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
1500     sslSocket->renegotiateAttempted_ = true;
1501   }
1502   if (where & SSL_CB_READ_ALERT) {
1503     const char* type = SSL_alert_type_string(ret);
1504     if (type) {
1505       const char* desc = SSL_alert_desc_string(ret);
1506       sslSocket->alertsReceived_.emplace_back(
1507           *type, StringPiece(desc, std::strlen(desc)));
1508     }
1509   }
1510 }
1511
1512 int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
1513   struct msghdr msg;
1514   struct iovec iov;
1515   int flags = 0;
1516   AsyncSSLSocket* tsslSock;
1517
1518   iov.iov_base = const_cast<char*>(in);
1519   iov.iov_len = inl;
1520   memset(&msg, 0, sizeof(msg));
1521   msg.msg_iov = &iov;
1522   msg.msg_iovlen = 1;
1523
1524   auto appData = OpenSSLUtils::getBioAppData(b);
1525   CHECK(appData);
1526
1527   tsslSock = reinterpret_cast<AsyncSSLSocket*>(appData);
1528   CHECK(tsslSock);
1529
1530   if (tsslSock->trackEor_ && tsslSock->minEorRawByteNo_ &&
1531       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
1532     flags = MSG_EOR;
1533   }
1534
1535   auto result =
1536       tsslSock->sendSocketMessage(BIO_get_fd(b, nullptr), &msg, flags);
1537   BIO_clear_retry_flags(b);
1538   if (!result.exception && result.writeReturn <= 0) {
1539     if (OpenSSLUtils::getBioShouldRetryWrite(result.writeReturn)) {
1540       BIO_set_retry_write(b);
1541     }
1542   }
1543   return result.writeReturn;
1544 }
1545
1546 int AsyncSSLSocket::sslVerifyCallback(
1547     int preverifyOk,
1548     X509_STORE_CTX* x509Ctx) {
1549   SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
1550     x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
1551   AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
1552
1553   VLOG(3) <<  "AsyncSSLSocket::sslVerifyCallback() this=" << self << ", "
1554           << "fd=" << self->fd_ << ", preverifyOk=" << preverifyOk;
1555   return (self->handshakeCallback_) ?
1556     self->handshakeCallback_->handshakeVer(self, preverifyOk, x509Ctx) :
1557     preverifyOk;
1558 }
1559
1560 void AsyncSSLSocket::enableClientHelloParsing()  {
1561     parseClientHello_ = true;
1562     clientHelloInfo_.reset(new ssl::ClientHelloInfo());
1563 }
1564
1565 void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl)  {
1566   SSL_set_msg_callback(ssl, nullptr);
1567   SSL_set_msg_callback_arg(ssl, nullptr);
1568   clientHelloInfo_->clientHelloBuf_.clear();
1569 }
1570
1571 void AsyncSSLSocket::clientHelloParsingCallback(int written,
1572                                                 int /* version */,
1573                                                 int contentType,
1574                                                 const void* buf,
1575                                                 size_t len,
1576                                                 SSL* ssl,
1577                                                 void* arg) {
1578   AsyncSSLSocket *sock = static_cast<AsyncSSLSocket*>(arg);
1579   if (written != 0) {
1580     sock->resetClientHelloParsing(ssl);
1581     return;
1582   }
1583   if (contentType != SSL3_RT_HANDSHAKE) {
1584     return;
1585   }
1586   if (len == 0) {
1587     return;
1588   }
1589
1590   auto& clientHelloBuf = sock->clientHelloInfo_->clientHelloBuf_;
1591   clientHelloBuf.append(IOBuf::wrapBuffer(buf, len));
1592   try {
1593     Cursor cursor(clientHelloBuf.front());
1594     if (cursor.read<uint8_t>() != SSL3_MT_CLIENT_HELLO) {
1595       sock->resetClientHelloParsing(ssl);
1596       return;
1597     }
1598
1599     if (cursor.totalLength() < 3) {
1600       clientHelloBuf.trimEnd(len);
1601       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1602       return;
1603     }
1604
1605     uint32_t messageLength = cursor.read<uint8_t>();
1606     messageLength <<= 8;
1607     messageLength |= cursor.read<uint8_t>();
1608     messageLength <<= 8;
1609     messageLength |= cursor.read<uint8_t>();
1610     if (cursor.totalLength() < messageLength) {
1611       clientHelloBuf.trimEnd(len);
1612       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1613       return;
1614     }
1615
1616     sock->clientHelloInfo_->clientHelloMajorVersion_ = cursor.read<uint8_t>();
1617     sock->clientHelloInfo_->clientHelloMinorVersion_ = cursor.read<uint8_t>();
1618
1619     cursor.skip(4); // gmt_unix_time
1620     cursor.skip(28); // random_bytes
1621
1622     cursor.skip(cursor.read<uint8_t>()); // session_id
1623
1624     uint16_t cipherSuitesLength = cursor.readBE<uint16_t>();
1625     for (int i = 0; i < cipherSuitesLength; i += 2) {
1626       sock->clientHelloInfo_->
1627         clientHelloCipherSuites_.push_back(cursor.readBE<uint16_t>());
1628     }
1629
1630     uint8_t compressionMethodsLength = cursor.read<uint8_t>();
1631     for (int i = 0; i < compressionMethodsLength; ++i) {
1632       sock->clientHelloInfo_->
1633         clientHelloCompressionMethods_.push_back(cursor.readBE<uint8_t>());
1634     }
1635
1636     if (cursor.totalLength() > 0) {
1637       uint16_t extensionsLength = cursor.readBE<uint16_t>();
1638       while (extensionsLength) {
1639         ssl::TLSExtension extensionType =
1640             static_cast<ssl::TLSExtension>(cursor.readBE<uint16_t>());
1641         sock->clientHelloInfo_->
1642           clientHelloExtensions_.push_back(extensionType);
1643         extensionsLength -= 2;
1644         uint16_t extensionDataLength = cursor.readBE<uint16_t>();
1645         extensionsLength -= 2;
1646
1647         if (extensionType == ssl::TLSExtension::SIGNATURE_ALGORITHMS) {
1648           cursor.skip(2);
1649           extensionDataLength -= 2;
1650           while (extensionDataLength) {
1651             ssl::HashAlgorithm hashAlg =
1652                 static_cast<ssl::HashAlgorithm>(cursor.readBE<uint8_t>());
1653             ssl::SignatureAlgorithm sigAlg =
1654                 static_cast<ssl::SignatureAlgorithm>(cursor.readBE<uint8_t>());
1655             extensionDataLength -= 2;
1656             sock->clientHelloInfo_->
1657               clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg);
1658           }
1659         } else {
1660           cursor.skip(extensionDataLength);
1661           extensionsLength -= extensionDataLength;
1662         }
1663       }
1664     }
1665   } catch (std::out_of_range& e) {
1666     // we'll use what we found and cleanup below.
1667     VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
1668       << "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock;
1669   }
1670
1671   sock->resetClientHelloParsing(ssl);
1672 }
1673
1674 } // namespace