e72894cd9ac9bae46f0ad59e65713da6f8085e22
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSSLSocket.h>
18
19 #include <folly/io/async/EventBase.h>
20
21 #include <boost/noncopyable.hpp>
22 #include <errno.h>
23 #include <fcntl.h>
24 #include <netinet/in.h>
25 #include <netinet/tcp.h>
26 #include <openssl/err.h>
27 #include <openssl/asn1.h>
28 #include <openssl/ssl.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31 #include <unistd.h>
32 #include <chrono>
33
34 #include <folly/Bits.h>
35 #include <folly/SocketAddress.h>
36 #include <folly/SpinLock.h>
37 #include <folly/io/IOBuf.h>
38 #include <folly/io/Cursor.h>
39
40 using folly::SocketAddress;
41 using folly::SSLContext;
42 using std::string;
43 using std::shared_ptr;
44
45 using folly::Endian;
46 using folly::IOBuf;
47 using folly::SpinLock;
48 using folly::SpinLockGuard;
49 using folly::io::Cursor;
50 using std::unique_ptr;
51 using std::bind;
52
53 namespace {
54 using folly::AsyncSocket;
55 using folly::AsyncSocketException;
56 using folly::AsyncSSLSocket;
57 using folly::Optional;
58 using folly::SSLContext;
59
60 // We have one single dummy SSL context so that we can implement attach
61 // and detach methods in a thread safe fashion without modifying opnessl.
62 static SSLContext *dummyCtx = nullptr;
63 static SpinLock dummyCtxLock;
64
65 // Numbers chosen as to not collide with functions in ssl.h
66 const uint8_t TASYNCSSLSOCKET_F_PERFORM_READ = 90;
67 const uint8_t TASYNCSSLSOCKET_F_PERFORM_WRITE = 91;
68
69 // If given min write size is less than this, buffer will be allocated on
70 // stack, otherwise it is allocated on heap
71 const size_t MAX_STACK_BUF_SIZE = 2048;
72
73 // This converts "illegal" shutdowns into ZERO_RETURN
74 inline bool zero_return(int error, int rc) {
75   return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0));
76 }
77
78 class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
79                                 public AsyncSSLSocket::HandshakeCB {
80
81  private:
82   AsyncSSLSocket *sslSocket_;
83   AsyncSSLSocket::ConnectCallback *callback_;
84   int timeout_;
85   int64_t startTime_;
86
87  protected:
88   ~AsyncSSLSocketConnector() override {}
89
90  public:
91   AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
92                            AsyncSocket::ConnectCallback *callback,
93                            int timeout) :
94       sslSocket_(sslSocket),
95       callback_(callback),
96       timeout_(timeout),
97       startTime_(std::chrono::duration_cast<std::chrono::milliseconds>(
98                    std::chrono::steady_clock::now().time_since_epoch()).count()) {
99   }
100
101   void connectSuccess() noexcept override {
102     VLOG(7) << "client socket connected";
103
104     int64_t timeoutLeft = 0;
105     if (timeout_ > 0) {
106       auto curTime = std::chrono::duration_cast<std::chrono::milliseconds>(
107         std::chrono::steady_clock::now().time_since_epoch()).count();
108
109       timeoutLeft = timeout_ - (curTime - startTime_);
110       if (timeoutLeft <= 0) {
111         AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
112                                 "SSL connect timed out");
113         fail(ex);
114         delete this;
115         return;
116       }
117     }
118     sslSocket_->sslConn(this, timeoutLeft);
119   }
120
121   void connectErr(const AsyncSocketException& ex) noexcept override {
122     LOG(ERROR) << "TCP connect failed: " <<  ex.what();
123     fail(ex);
124     delete this;
125   }
126
127   void handshakeSuc(AsyncSSLSocket* sock) noexcept override {
128     VLOG(7) << "client handshake success";
129     if (callback_) {
130       callback_->connectSuccess();
131     }
132     delete this;
133   }
134
135   void handshakeErr(AsyncSSLSocket* socket,
136                     const AsyncSocketException& ex) noexcept override {
137     LOG(ERROR) << "client handshakeErr: " << ex.what();
138     fail(ex);
139     delete this;
140   }
141
142   void fail(const AsyncSocketException &ex) {
143     // fail is a noop if called twice
144     if (callback_) {
145       AsyncSSLSocket::ConnectCallback *cb = callback_;
146       callback_ = nullptr;
147
148       cb->connectErr(ex);
149       sslSocket_->closeNow();
150       // closeNow can call handshakeErr if it hasn't been called already.
151       // So this may have been deleted, no member variable access beyond this
152       // point
153       // Note that closeNow may invoke writeError callbacks if the socket had
154       // write data pending connection completion.
155     }
156   }
157 };
158
159 // XXX: implement an equivalent to corking for platforms with TCP_NOPUSH?
160 #ifdef TCP_CORK // Linux-only
161 /**
162  * Utility class that corks a TCP socket upon construction or uncorks
163  * the socket upon destruction
164  */
165 class CorkGuard : private boost::noncopyable {
166  public:
167   CorkGuard(int fd, bool multipleWrites, bool haveMore, bool* corked):
168     fd_(fd), haveMore_(haveMore), corked_(corked) {
169     if (*corked_) {
170       // socket is already corked; nothing to do
171       return;
172     }
173     if (multipleWrites || haveMore) {
174       // We are performing multiple writes in this performWrite() call,
175       // and/or there are more calls to performWrite() that will be invoked
176       // later, so enable corking
177       int flag = 1;
178       setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
179       *corked_ = true;
180     }
181   }
182
183   ~CorkGuard() {
184     if (haveMore_) {
185       // more data to come; don't uncork yet
186       return;
187     }
188     if (!*corked_) {
189       // socket isn't corked; nothing to do
190       return;
191     }
192
193     int flag = 0;
194     setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
195     *corked_ = false;
196   }
197
198  private:
199   int fd_;
200   bool haveMore_;
201   bool* corked_;
202 };
203 #else
204 class CorkGuard : private boost::noncopyable {
205  public:
206   CorkGuard(int, bool, bool, bool*) {}
207 };
208 #endif
209
210 void setup_SSL_CTX(SSL_CTX *ctx) {
211 #ifdef SSL_MODE_RELEASE_BUFFERS
212   SSL_CTX_set_mode(ctx,
213                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
214                    SSL_MODE_ENABLE_PARTIAL_WRITE
215                    | SSL_MODE_RELEASE_BUFFERS
216                    );
217 #else
218   SSL_CTX_set_mode(ctx,
219                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
220                    SSL_MODE_ENABLE_PARTIAL_WRITE
221                    );
222 #endif
223 // SSL_CTX_set_mode is a Macro
224 #ifdef SSL_MODE_WRITE_IOVEC
225   SSL_CTX_set_mode(ctx,
226                    SSL_CTX_get_mode(ctx)
227                    | SSL_MODE_WRITE_IOVEC);
228 #endif
229
230 }
231
232 BIO_METHOD eorAwareBioMethod;
233
234 void* initEorBioMethod(void) {
235   memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod));
236   // override the bwrite method for MSG_EOR support
237   eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite;
238
239   // Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not
240   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
241   // then have specific handlings. The eorAwareBioWrite should be compatible
242   // with the one in openssl.
243
244   // Return something here to enable AsyncSSLSocket to call this method using
245   // a function-scoped static.
246   return nullptr;
247 }
248
249 } // anonymous namespace
250
251 namespace folly {
252
253 SSLException::SSLException(int sslError, int errno_copy):
254     AsyncSocketException(
255       AsyncSocketException::SSL_ERROR,
256       ERR_error_string(sslError, msg_),
257       sslError == SSL_ERROR_SYSCALL ? errno_copy : 0), error_(sslError) {}
258
259 /**
260  * Create a client AsyncSSLSocket
261  */
262 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
263                                EventBase* evb, bool deferSecurityNegotiation) :
264     AsyncSocket(evb),
265     ctx_(ctx),
266     handshakeTimeout_(this, evb) {
267   init();
268   if (deferSecurityNegotiation) {
269     sslState_ = STATE_UNENCRYPTED;
270   }
271 }
272
273 /**
274  * Create a server/client AsyncSSLSocket
275  */
276 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
277                                EventBase* evb, int fd, bool server,
278                                bool deferSecurityNegotiation) :
279     AsyncSocket(evb, fd),
280     server_(server),
281     ctx_(ctx),
282     handshakeTimeout_(this, evb) {
283   init();
284   if (server) {
285     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
286                               AsyncSSLSocket::sslInfoCallback);
287   }
288   if (deferSecurityNegotiation) {
289     sslState_ = STATE_UNENCRYPTED;
290   }
291 }
292
293 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
294 /**
295  * Create a client AsyncSSLSocket and allow tlsext_hostname
296  * to be sent in Client Hello.
297  */
298 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
299                                  EventBase* evb,
300                                const std::string& serverName,
301                                bool deferSecurityNegotiation) :
302     AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
303   tlsextHostname_ = serverName;
304 }
305
306 /**
307  * Create a client AsyncSSLSocket from an already connected fd
308  * and allow tlsext_hostname to be sent in Client Hello.
309  */
310 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
311                                  EventBase* evb, int fd,
312                                const std::string& serverName,
313                                bool deferSecurityNegotiation) :
314     AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
315   tlsextHostname_ = serverName;
316 }
317 #endif
318
319 AsyncSSLSocket::~AsyncSSLSocket() {
320   VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
321           << ", evb=" << eventBase_ << ", fd=" << fd_
322           << ", state=" << int(state_) << ", sslState="
323           << sslState_ << ", events=" << eventFlags_ << ")";
324 }
325
326 void AsyncSSLSocket::init() {
327   // Do this here to ensure we initialize this once before any use of
328   // AsyncSSLSocket instances and not as part of library load.
329   static const auto eorAwareBioMethodInitializer = initEorBioMethod();
330   (void)eorAwareBioMethodInitializer;
331
332   setup_SSL_CTX(ctx_->getSSLCtx());
333 }
334
335 void AsyncSSLSocket::closeNow() {
336   // Close the SSL connection.
337   if (ssl_ != nullptr && fd_ != -1) {
338     int rc = SSL_shutdown(ssl_);
339     if (rc == 0) {
340       rc = SSL_shutdown(ssl_);
341     }
342     if (rc < 0) {
343       ERR_clear_error();
344     }
345   }
346
347   if (sslSession_ != nullptr) {
348     SSL_SESSION_free(sslSession_);
349     sslSession_ = nullptr;
350   }
351
352   sslState_ = STATE_CLOSED;
353
354   if (handshakeTimeout_.isScheduled()) {
355     handshakeTimeout_.cancelTimeout();
356   }
357
358   DestructorGuard dg(this);
359
360   invokeHandshakeErr(
361       AsyncSocketException(
362         AsyncSocketException::END_OF_FILE,
363         "SSL connection closed locally"));
364
365   if (ssl_ != nullptr) {
366     SSL_free(ssl_);
367     ssl_ = nullptr;
368   }
369
370   // Close the socket.
371   AsyncSocket::closeNow();
372 }
373
374 void AsyncSSLSocket::shutdownWrite() {
375   // SSL sockets do not support half-shutdown, so just perform a full shutdown.
376   //
377   // (Performing a full shutdown here is more desirable than doing nothing at
378   // all.  The purpose of shutdownWrite() is normally to notify the other end
379   // of the connection that no more data will be sent.  If we do nothing, the
380   // other end will never know that no more data is coming, and this may result
381   // in protocol deadlock.)
382   close();
383 }
384
385 void AsyncSSLSocket::shutdownWriteNow() {
386   closeNow();
387 }
388
389 bool AsyncSSLSocket::good() const {
390   return (AsyncSocket::good() &&
391           (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
392            sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
393 }
394
395 // The TAsyncTransport definition of 'good' states that the transport is
396 // ready to perform reads and writes, so sslState_ == UNINIT must report !good.
397 // connecting can be true when the sslState_ == UNINIT because the AsyncSocket
398 // is connected but we haven't initiated the call to SSL_connect.
399 bool AsyncSSLSocket::connecting() const {
400   return (!server_ &&
401           (AsyncSocket::connecting() ||
402            (AsyncSocket::good() && (sslState_ == STATE_UNINIT ||
403                                      sslState_ == STATE_CONNECTING))));
404 }
405
406 std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
407   const unsigned char* protoName = nullptr;
408   unsigned protoLength;
409   if (getSelectedNextProtocolNoThrow(&protoName, &protoLength)) {
410     return std::string(reinterpret_cast<const char*>(protoName), protoLength);
411   }
412   return "";
413 }
414
415 bool AsyncSSLSocket::isEorTrackingEnabled() const {
416   if (ssl_ == nullptr) {
417     return false;
418   }
419   const BIO *wb = SSL_get_wbio(ssl_);
420   return wb && wb->method == &eorAwareBioMethod;
421 }
422
423 void AsyncSSLSocket::setEorTracking(bool track) {
424   BIO *wb = SSL_get_wbio(ssl_);
425   if (!wb) {
426     throw AsyncSocketException(AsyncSocketException::INVALID_STATE,
427                               "setting EOR tracking without an initialized "
428                               "BIO");
429   }
430
431   if (track) {
432     if (wb->method != &eorAwareBioMethod) {
433       // only do this if we didn't
434       wb->method = &eorAwareBioMethod;
435       BIO_set_app_data(wb, this);
436       appEorByteNo_ = 0;
437       minEorRawByteNo_ = 0;
438     }
439   } else if (wb->method == &eorAwareBioMethod) {
440     wb->method = BIO_s_socket();
441     BIO_set_app_data(wb, nullptr);
442     appEorByteNo_ = 0;
443     minEorRawByteNo_ = 0;
444   } else {
445     CHECK(wb->method == BIO_s_socket());
446   }
447 }
448
449 size_t AsyncSSLSocket::getRawBytesWritten() const {
450   BIO *b;
451   if (!ssl_ || !(b = SSL_get_wbio(ssl_))) {
452     return 0;
453   }
454
455   return BIO_number_written(b);
456 }
457
458 size_t AsyncSSLSocket::getRawBytesReceived() const {
459   BIO *b;
460   if (!ssl_ || !(b = SSL_get_rbio(ssl_))) {
461     return 0;
462   }
463
464   return BIO_number_read(b);
465 }
466
467
468 void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
469   LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_
470              << ", state=" << int(state_) << ", sslState=" << sslState_ << ", "
471              << "events=" << eventFlags_ << ", server=" << short(server_) << "): "
472              << "sslAccept/Connect() called in invalid "
473              << "state, handshake callback " << handshakeCallback_ << ", new callback "
474              << callback;
475   assert(!handshakeTimeout_.isScheduled());
476   sslState_ = STATE_ERROR;
477
478   AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
479                          "sslAccept() called with socket in invalid state");
480
481   handshakeEndTime_ = std::chrono::steady_clock::now();
482   if (callback) {
483     callback->handshakeErr(this, ex);
484   }
485
486   // Check the socket state not the ssl state here.
487   if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) {
488     failHandshake(__func__, ex);
489   }
490 }
491
492 void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
493       const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
494   DestructorGuard dg(this);
495   assert(eventBase_->isInEventBaseThread());
496   verifyPeer_ = verifyPeer;
497
498   // Make sure we're in the uninitialized state
499   if (!server_ || (sslState_ != STATE_UNINIT &&
500                    sslState_ != STATE_UNENCRYPTED) ||
501       handshakeCallback_ != nullptr) {
502     return invalidState(callback);
503   }
504   handshakeStartTime_ = std::chrono::steady_clock::now();
505   // Make end time at least >= start time.
506   handshakeEndTime_ = handshakeStartTime_;
507
508   sslState_ = STATE_ACCEPTING;
509   handshakeCallback_ = callback;
510
511   if (timeout > 0) {
512     handshakeTimeout_.scheduleTimeout(timeout);
513   }
514
515   /* register for a read operation (waiting for CLIENT HELLO) */
516   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
517 }
518
519 #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
520 void AsyncSSLSocket::attachSSLContext(
521   const std::shared_ptr<SSLContext>& ctx) {
522
523   // Check to ensure we are in client mode. Changing a server's ssl
524   // context doesn't make sense since clients of that server would likely
525   // become confused when the server's context changes.
526   DCHECK(!server_);
527   DCHECK(!ctx_);
528   DCHECK(ctx);
529   DCHECK(ctx->getSSLCtx());
530   ctx_ = ctx;
531
532   // In order to call attachSSLContext, detachSSLContext must have been
533   // previously called which sets the socket's context to the dummy
534   // context. Thus we must acquire this lock.
535   SpinLockGuard guard(dummyCtxLock);
536   SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx());
537 }
538
539 void AsyncSSLSocket::detachSSLContext() {
540   DCHECK(ctx_);
541   ctx_.reset();
542   // We aren't using the initial_ctx for now, and it can introduce race
543   // conditions in the destructor of the SSL object.
544 #ifndef OPENSSL_NO_TLSEXT
545   if (ssl_->initial_ctx) {
546     SSL_CTX_free(ssl_->initial_ctx);
547     ssl_->initial_ctx = nullptr;
548   }
549 #endif
550   SpinLockGuard guard(dummyCtxLock);
551   if (nullptr == dummyCtx) {
552     // We need to lazily initialize the dummy context so we don't
553     // accidentally override any programmatic settings to openssl
554     dummyCtx = new SSLContext;
555   }
556   // We must remove this socket's references to its context right now
557   // since this socket could get passed to any thread. If the context has
558   // had its locking disabled, just doing a set in attachSSLContext()
559   // would not be thread safe.
560   SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
561 }
562 #endif
563
564 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
565 void AsyncSSLSocket::switchServerSSLContext(
566   const std::shared_ptr<SSLContext>& handshakeCtx) {
567   CHECK(server_);
568   if (sslState_ != STATE_ACCEPTING) {
569     // We log it here and allow the switch.
570     // It should not affect our re-negotiation support (which
571     // is not supported now).
572     VLOG(6) << "fd=" << getFd()
573             << " renegotation detected when switching SSL_CTX";
574   }
575
576   setup_SSL_CTX(handshakeCtx->getSSLCtx());
577   SSL_CTX_set_info_callback(handshakeCtx->getSSLCtx(),
578                             AsyncSSLSocket::sslInfoCallback);
579   handshakeCtx_ = handshakeCtx;
580   SSL_set_SSL_CTX(ssl_, handshakeCtx->getSSLCtx());
581 }
582
583 bool AsyncSSLSocket::isServerNameMatch() const {
584   CHECK(!server_);
585
586   if (!ssl_) {
587     return false;
588   }
589
590   SSL_SESSION *ss = SSL_get_session(ssl_);
591   if (!ss) {
592     return false;
593   }
594
595   if(!ss->tlsext_hostname) {
596     return false;
597   }
598   return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
599 }
600
601 void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
602   tlsextHostname_ = std::move(serverName);
603 }
604
605 #endif
606
607 void AsyncSSLSocket::timeoutExpired() noexcept {
608   if (state_ == StateEnum::ESTABLISHED &&
609       (sslState_ == STATE_CACHE_LOOKUP ||
610        sslState_ == STATE_RSA_ASYNC_PENDING)) {
611     sslState_ = STATE_ERROR;
612     // We are expecting a callback in restartSSLAccept.  The cache lookup
613     // and rsa-call necessarily have pointers to this ssl socket, so delay
614     // the cleanup until he calls us back.
615   } else {
616     assert(state_ == StateEnum::ESTABLISHED &&
617            (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
618     DestructorGuard dg(this);
619     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
620                            (sslState_ == STATE_CONNECTING) ?
621                            "SSL connect timed out" : "SSL accept timed out");
622     failHandshake(__func__, ex);
623   }
624 }
625
626 int AsyncSSLSocket::getSSLExDataIndex() {
627   static auto index = SSL_get_ex_new_index(
628       0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
629   return index;
630 }
631
632 AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
633   return static_cast<AsyncSSLSocket *>(SSL_get_ex_data(ssl,
634       getSSLExDataIndex()));
635 }
636
637 void AsyncSSLSocket::failHandshake(const char* fn,
638                                     const AsyncSocketException& ex) {
639   startFail();
640   if (handshakeTimeout_.isScheduled()) {
641     handshakeTimeout_.cancelTimeout();
642   }
643   invokeHandshakeErr(ex);
644   finishFail();
645 }
646
647 void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
648   handshakeEndTime_ = std::chrono::steady_clock::now();
649   if (handshakeCallback_ != nullptr) {
650     HandshakeCB* callback = handshakeCallback_;
651     handshakeCallback_ = nullptr;
652     callback->handshakeErr(this, ex);
653   }
654 }
655
656 void AsyncSSLSocket::invokeHandshakeCB() {
657   handshakeEndTime_ = std::chrono::steady_clock::now();
658   if (handshakeTimeout_.isScheduled()) {
659     handshakeTimeout_.cancelTimeout();
660   }
661   if (handshakeCallback_) {
662     HandshakeCB* callback = handshakeCallback_;
663     handshakeCallback_ = nullptr;
664     callback->handshakeSuc(this);
665   }
666 }
667
668 void AsyncSSLSocket::connect(ConnectCallback* callback,
669                               const folly::SocketAddress& address,
670                               int timeout,
671                               const OptionMap &options,
672                               const folly::SocketAddress& bindAddr)
673                               noexcept {
674   assert(!server_);
675   assert(state_ == StateEnum::UNINIT);
676   assert(sslState_ == STATE_UNINIT);
677   AsyncSSLSocketConnector *connector =
678     new AsyncSSLSocketConnector(this, callback, timeout);
679   AsyncSocket::connect(connector, address, timeout, options, bindAddr);
680 }
681
682 void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
683   // apply the settings specified in verifyPeer_
684   if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
685     if(ctx_->needsPeerVerification()) {
686       SSL_set_verify(ssl, ctx_->getVerificationMode(),
687         AsyncSSLSocket::sslVerifyCallback);
688     }
689   } else {
690     if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
691         verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT) {
692       SSL_set_verify(ssl, SSLContext::getVerificationMode(verifyPeer_),
693         AsyncSSLSocket::sslVerifyCallback);
694     }
695   }
696 }
697
698 void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
699         const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
700   DestructorGuard dg(this);
701   assert(eventBase_->isInEventBaseThread());
702
703   verifyPeer_ = verifyPeer;
704
705   // Make sure we're in the uninitialized state
706   if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
707                   STATE_UNENCRYPTED) ||
708       handshakeCallback_ != nullptr) {
709     return invalidState(callback);
710   }
711
712   handshakeStartTime_ = std::chrono::steady_clock::now();
713   // Make end time at least >= start time.
714   handshakeEndTime_ = handshakeStartTime_;
715
716   sslState_ = STATE_CONNECTING;
717   handshakeCallback_ = callback;
718
719   try {
720     ssl_ = ctx_->createSSL();
721   } catch (std::exception &e) {
722     sslState_ = STATE_ERROR;
723     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
724                            "error calling SSLContext::createSSL()");
725     LOG(ERROR) << "AsyncSSLSocket::sslConn(this=" << this << ", fd="
726             << fd_ << "): " << e.what();
727     return failHandshake(__func__, ex);
728   }
729
730   applyVerificationOptions(ssl_);
731
732   SSL_set_fd(ssl_, fd_);
733   if (sslSession_ != nullptr) {
734     SSL_set_session(ssl_, sslSession_);
735     SSL_SESSION_free(sslSession_);
736     sslSession_ = nullptr;
737   }
738 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
739   if (tlsextHostname_.size()) {
740     SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
741   }
742 #endif
743
744   SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
745
746   if (timeout > 0) {
747     handshakeTimeout_.scheduleTimeout(timeout);
748   }
749
750   handleConnect();
751 }
752
753 SSL_SESSION *AsyncSSLSocket::getSSLSession() {
754   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
755     return SSL_get1_session(ssl_);
756   }
757
758   return sslSession_;
759 }
760
761 void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
762   sslSession_ = session;
763   if (!takeOwnership && session != nullptr) {
764     // Increment the reference count
765     CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
766   }
767 }
768
769 void AsyncSSLSocket::getSelectedNextProtocol(
770     const unsigned char** protoName,
771     unsigned* protoLen,
772     SSLContext::NextProtocolType* protoType) const {
773   if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) {
774     throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
775                               "NPN not supported");
776   }
777 }
778
779 bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
780     const unsigned char** protoName,
781     unsigned* protoLen,
782     SSLContext::NextProtocolType* protoType) const {
783   *protoName = nullptr;
784   *protoLen = 0;
785 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
786   SSL_get0_alpn_selected(ssl_, protoName, protoLen);
787   if (*protoLen > 0) {
788     if (protoType) {
789       *protoType = SSLContext::NextProtocolType::ALPN;
790     }
791     return true;
792   }
793 #endif
794 #ifdef OPENSSL_NPN_NEGOTIATED
795   SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
796   if (protoType) {
797     *protoType = SSLContext::NextProtocolType::NPN;
798   }
799   return true;
800 #else
801   return false;
802 #endif
803 }
804
805 bool AsyncSSLSocket::getSSLSessionReused() const {
806   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
807     return SSL_session_reused(ssl_);
808   }
809   return false;
810 }
811
812 const char *AsyncSSLSocket::getNegotiatedCipherName() const {
813   return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr;
814 }
815
816 const char *AsyncSSLSocket::getSSLServerName() const {
817 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
818   return (ssl_ != nullptr) ? SSL_get_servername(ssl_, TLSEXT_NAMETYPE_host_name)
819         : nullptr;
820 #else
821   throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
822                             "SNI not supported");
823 #endif
824 }
825
826 const char *AsyncSSLSocket::getSSLServerNameNoThrow() const {
827   try {
828     return getSSLServerName();
829   } catch (AsyncSocketException& ex) {
830     return nullptr;
831   }
832 }
833
834 int AsyncSSLSocket::getSSLVersion() const {
835   return (ssl_ != nullptr) ? SSL_version(ssl_) : 0;
836 }
837
838 const char *AsyncSSLSocket::getSSLCertSigAlgName() const {
839   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
840   if (cert) {
841     int nid = OBJ_obj2nid(cert->sig_alg->algorithm);
842     return OBJ_nid2ln(nid);
843   }
844   return nullptr;
845 }
846
847 int AsyncSSLSocket::getSSLCertSize() const {
848   int certSize = 0;
849   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
850   if (cert) {
851     EVP_PKEY *key = X509_get_pubkey(cert);
852     certSize = EVP_PKEY_bits(key);
853     EVP_PKEY_free(key);
854   }
855   return certSize;
856 }
857
858 bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
859   int error = *errorOut = SSL_get_error(ssl_, ret);
860   if (error == SSL_ERROR_WANT_READ) {
861     // Register for read event if not already.
862     updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
863     return true;
864   } else if (error == SSL_ERROR_WANT_WRITE) {
865     VLOG(3) << "AsyncSSLSocket(fd=" << fd_
866             << ", state=" << int(state_) << ", sslState="
867             << sslState_ << ", events=" << eventFlags_ << "): "
868             << "SSL_ERROR_WANT_WRITE";
869     // Register for write event if not already.
870     updateEventRegistration(EventHandler::WRITE, EventHandler::READ);
871     return true;
872 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
873   } else if (error == SSL_ERROR_WANT_SESS_CACHE_LOOKUP) {
874     // We will block but we can't register our own socket.  The callback that
875     // triggered this code will re-call handleAccept at the appropriate time.
876
877     // We can only get here if the linked libssl.so has support for this feature
878     // as well, otherwise SSL_get_error cannot return our error code.
879     sslState_ = STATE_CACHE_LOOKUP;
880
881     // Unregister for all events while blocked here
882     updateEventRegistration(EventHandler::NONE,
883                             EventHandler::READ | EventHandler::WRITE);
884
885     // The timeout (if set) keeps running here
886     return true;
887 #endif
888 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
889   } else if (error == SSL_ERROR_WANT_RSA_ASYNC_PENDING) {
890     // Our custom openssl function has kicked off an async request to do
891     // modular exponentiation.  When that call returns, a callback will
892     // be invoked that will re-call handleAccept.
893     sslState_ = STATE_RSA_ASYNC_PENDING;
894
895     // Unregister for all events while blocked here
896     updateEventRegistration(
897       EventHandler::NONE,
898       EventHandler::READ | EventHandler::WRITE
899     );
900
901     // The timeout (if set) keeps running here
902     return true;
903 #endif
904   } else {
905     // SSL_ERROR_ZERO_RETURN is processed here so we can get some detail
906     // in the log
907     long lastError = ERR_get_error();
908     VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
909             << "state=" << state_ << ", "
910             << "sslState=" << sslState_ << ", "
911             << "events=" << std::hex << eventFlags_ << "): "
912             << "SSL error: " << error << ", "
913             << "errno: " << errno << ", "
914             << "ret: " << ret << ", "
915             << "read: " << BIO_number_read(SSL_get_rbio(ssl_)) << ", "
916             << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
917             << "func: " << ERR_func_error_string(lastError) << ", "
918             << "reason: " << ERR_reason_error_string(lastError);
919     if (error != SSL_ERROR_SYSCALL) {
920       if (error == SSL_ERROR_SSL) {
921         *errorOut = lastError;
922       }
923       if ((unsigned long)lastError < 0x8000) {
924         errno = ENOSYS;
925       } else {
926         errno = lastError;
927       }
928     }
929     ERR_clear_error();
930     return false;
931   }
932 }
933
934 void AsyncSSLSocket::checkForImmediateRead() noexcept {
935   // openssl may have buffered data that it read from the socket already.
936   // In this case we have to process it immediately, rather than waiting for
937   // the socket to become readable again.
938   if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
939     AsyncSocket::handleRead();
940   }
941 }
942
943 void
944 AsyncSSLSocket::restartSSLAccept()
945 {
946   VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this << ", fd=" << fd_
947           << ", state=" << int(state_) << ", "
948           << "sslState=" << sslState_ << ", events=" << eventFlags_;
949   DestructorGuard dg(this);
950   assert(
951     sslState_ == STATE_CACHE_LOOKUP ||
952     sslState_ == STATE_RSA_ASYNC_PENDING ||
953     sslState_ == STATE_ERROR ||
954     sslState_ == STATE_CLOSED
955   );
956   if (sslState_ == STATE_CLOSED) {
957     // I sure hope whoever closed this socket didn't delete it already,
958     // but this is not strictly speaking an error
959     return;
960   }
961   if (sslState_ == STATE_ERROR) {
962     // go straight to fail if timeout expired during lookup
963     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
964                            "SSL accept timed out");
965     failHandshake(__func__, ex);
966     return;
967   }
968   sslState_ = STATE_ACCEPTING;
969   this->handleAccept();
970 }
971
972 void
973 AsyncSSLSocket::handleAccept() noexcept {
974   VLOG(3) << "AsyncSSLSocket::handleAccept() this=" << this
975           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
976           << "sslState=" << sslState_ << ", events=" << eventFlags_;
977   assert(server_);
978   assert(state_ == StateEnum::ESTABLISHED &&
979          sslState_ == STATE_ACCEPTING);
980   if (!ssl_) {
981     /* lazily create the SSL structure */
982     try {
983       ssl_ = ctx_->createSSL();
984     } catch (std::exception &e) {
985       sslState_ = STATE_ERROR;
986       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
987                              "error calling SSLContext::createSSL()");
988       LOG(ERROR) << "AsyncSSLSocket::handleAccept(this=" << this
989                  << ", fd=" << fd_ << "): " << e.what();
990       return failHandshake(__func__, ex);
991     }
992     SSL_set_fd(ssl_, fd_);
993     SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
994
995     applyVerificationOptions(ssl_);
996   }
997
998   if (server_ && parseClientHello_) {
999     SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback);
1000     SSL_set_msg_callback_arg(ssl_, this);
1001   }
1002
1003   errno = 0;
1004   int ret = SSL_accept(ssl_);
1005   if (ret <= 0) {
1006     int error;
1007     if (willBlock(ret, &error)) {
1008       return;
1009     } else {
1010       sslState_ = STATE_ERROR;
1011       SSLException ex(error, errno);
1012       return failHandshake(__func__, ex);
1013     }
1014   }
1015
1016   handshakeComplete_ = true;
1017   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1018
1019   // Move into STATE_ESTABLISHED in the normal case that we are in
1020   // STATE_ACCEPTING.
1021   sslState_ = STATE_ESTABLISHED;
1022
1023   VLOG(3) << "AsyncSSLSocket " << this << ": fd " << fd_
1024           << " successfully accepted; state=" << int(state_)
1025           << ", sslState=" << sslState_ << ", events=" << eventFlags_;
1026
1027   // Remember the EventBase we are attached to, before we start invoking any
1028   // callbacks (since the callbacks may call detachEventBase()).
1029   EventBase* originalEventBase = eventBase_;
1030
1031   // Call the accept callback.
1032   invokeHandshakeCB();
1033
1034   // Note that the accept callback may have changed our state.
1035   // (set or unset the read callback, called write(), closed the socket, etc.)
1036   // The following code needs to handle these situations correctly.
1037   //
1038   // If the socket has been closed, readCallback_ and writeReqHead_ will
1039   // always be nullptr, so that will prevent us from trying to read or write.
1040   //
1041   // The main thing to check for is if eventBase_ is still originalEventBase.
1042   // If not, we have been detached from this event base, so we shouldn't
1043   // perform any more operations.
1044   if (eventBase_ != originalEventBase) {
1045     return;
1046   }
1047
1048   AsyncSocket::handleInitialReadWrite();
1049 }
1050
1051 void
1052 AsyncSSLSocket::handleConnect() noexcept {
1053   VLOG(3) <<  "AsyncSSLSocket::handleConnect() this=" << this
1054           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
1055           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1056   assert(!server_);
1057   if (state_ < StateEnum::ESTABLISHED) {
1058     return AsyncSocket::handleConnect();
1059   }
1060
1061   assert(state_ == StateEnum::ESTABLISHED &&
1062          sslState_ == STATE_CONNECTING);
1063   assert(ssl_);
1064
1065   errno = 0;
1066   int ret = SSL_connect(ssl_);
1067   if (ret <= 0) {
1068     int error;
1069     if (willBlock(ret, &error)) {
1070       return;
1071     } else {
1072       sslState_ = STATE_ERROR;
1073       SSLException ex(error, errno);
1074       return failHandshake(__func__, ex);
1075     }
1076   }
1077
1078   handshakeComplete_ = true;
1079   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1080
1081   // Move into STATE_ESTABLISHED in the normal case that we are in
1082   // STATE_CONNECTING.
1083   sslState_ = STATE_ESTABLISHED;
1084
1085   VLOG(3) << "AsyncSSLSocket %p: fd %d successfully connected; "
1086           << "state=" << int(state_) << ", sslState=" << sslState_
1087           << ", events=" << eventFlags_;
1088
1089   // Remember the EventBase we are attached to, before we start invoking any
1090   // callbacks (since the callbacks may call detachEventBase()).
1091   EventBase* originalEventBase = eventBase_;
1092
1093   // Call the handshake callback.
1094   invokeHandshakeCB();
1095
1096   // Note that the connect callback may have changed our state.
1097   // (set or unset the read callback, called write(), closed the socket, etc.)
1098   // The following code needs to handle these situations correctly.
1099   //
1100   // If the socket has been closed, readCallback_ and writeReqHead_ will
1101   // always be nullptr, so that will prevent us from trying to read or write.
1102   //
1103   // The main thing to check for is if eventBase_ is still originalEventBase.
1104   // If not, we have been detached from this event base, so we shouldn't
1105   // perform any more operations.
1106   if (eventBase_ != originalEventBase) {
1107     return;
1108   }
1109
1110   AsyncSocket::handleInitialReadWrite();
1111 }
1112
1113 void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
1114 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1115   // turn on the buffer movable in openssl
1116   if (ssl_ != nullptr && !isBufferMovable_ &&
1117       callback != nullptr && callback->isBufferMovable()) {
1118     SSL_set_mode(ssl_, SSL_get_mode(ssl_) | SSL_MODE_MOVE_BUFFER_OWNERSHIP);
1119     isBufferMovable_ = true;
1120   }
1121 #endif
1122
1123   AsyncSocket::setReadCB(callback);
1124 }
1125
1126 void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
1127   CHECK(readCallback_);
1128   if (isBufferMovable_) {
1129     *buf = nullptr;
1130     *buflen = 0;
1131   } else {
1132     // buf is necessary for SSLSocket without SSL_MODE_MOVE_BUFFER_OWNERSHIP
1133     readCallback_->getReadBuffer(buf, buflen);
1134   }
1135 }
1136
1137 void
1138 AsyncSSLSocket::handleRead() noexcept {
1139   VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
1140           << ", state=" << int(state_) << ", "
1141           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1142   if (state_ < StateEnum::ESTABLISHED) {
1143     return AsyncSocket::handleRead();
1144   }
1145
1146
1147   if (sslState_ == STATE_ACCEPTING) {
1148     assert(server_);
1149     handleAccept();
1150     return;
1151   }
1152   else if (sslState_ == STATE_CONNECTING) {
1153     assert(!server_);
1154     handleConnect();
1155     return;
1156   }
1157
1158   // Normal read
1159   AsyncSocket::handleRead();
1160 }
1161
1162 ssize_t
1163 AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
1164   VLOG(4) << "AsyncSSLSocket::performRead() this=" << this
1165           << ", buf=" << *buf << ", buflen=" << *buflen;
1166
1167   if (sslState_ == STATE_UNENCRYPTED) {
1168     return AsyncSocket::performRead(buf, buflen, offset);
1169   }
1170
1171   errno = 0;
1172   ssize_t bytes = 0;
1173   if (!isBufferMovable_) {
1174     bytes = SSL_read(ssl_, *buf, *buflen);
1175   }
1176 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1177   else {
1178     bytes = SSL_read_buf(ssl_, buf, (int *) offset, (int *) buflen);
1179   }
1180 #endif
1181
1182   if (server_ && renegotiateAttempted_) {
1183     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1184                << ", sslstate=" << sslState_ << ", events=" << eventFlags_
1185                << "): client intitiated SSL renegotiation not permitted";
1186     // We pack our own SSLerr here with a dummy function
1187     errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
1188                      SSL_CLIENT_RENEGOTIATION_ATTEMPT);
1189     ERR_clear_error();
1190     return READ_ERROR;
1191   }
1192   if (bytes <= 0) {
1193     int error = SSL_get_error(ssl_, bytes);
1194     if (error == SSL_ERROR_WANT_READ) {
1195       // The caller will register for read event if not already.
1196       if (errno == EWOULDBLOCK || errno == EAGAIN) {
1197         return READ_BLOCKING;
1198       } else {
1199         return READ_ERROR;
1200       }
1201     } else if (error == SSL_ERROR_WANT_WRITE) {
1202       // TODO: Even though we are attempting to read data, SSL_read() may
1203       // need to write data if renegotiation is being performed.  We currently
1204       // don't support this and just fail the read.
1205       LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1206                  << ", sslState=" << sslState_ << ", events=" << eventFlags_
1207                  << "): unsupported SSL renegotiation during read",
1208       errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
1209                        SSL_INVALID_RENEGOTIATION);
1210       ERR_clear_error();
1211       return READ_ERROR;
1212     } else {
1213       // TODO: Fix this code so that it can return a proper error message
1214       // to the callback, rather than relying on AsyncSocket code which
1215       // can't handle SSL errors.
1216       long lastError = ERR_get_error();
1217
1218       VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
1219               << "state=" << state_ << ", "
1220               << "sslState=" << sslState_ << ", "
1221               << "events=" << std::hex << eventFlags_ << "): "
1222               << "bytes: " << bytes << ", "
1223               << "error: " << error << ", "
1224               << "errno: " << errno << ", "
1225               << "func: " << ERR_func_error_string(lastError) << ", "
1226               << "reason: " << ERR_reason_error_string(lastError);
1227       ERR_clear_error();
1228       if (zero_return(error, bytes)) {
1229         return bytes;
1230       }
1231       if (error != SSL_ERROR_SYSCALL) {
1232         if ((unsigned long)lastError < 0x8000) {
1233           errno = ENOSYS;
1234         } else {
1235           errno = lastError;
1236         }
1237       }
1238       return READ_ERROR;
1239     }
1240   } else {
1241     appBytesReceived_ += bytes;
1242     return bytes;
1243   }
1244 }
1245
1246 void AsyncSSLSocket::handleWrite() noexcept {
1247   VLOG(5) << "AsyncSSLSocket::handleWrite() this=" << this << ", fd=" << fd_
1248           << ", state=" << int(state_) << ", "
1249           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1250   if (state_ < StateEnum::ESTABLISHED) {
1251     return AsyncSocket::handleWrite();
1252   }
1253
1254   if (sslState_ == STATE_ACCEPTING) {
1255     assert(server_);
1256     handleAccept();
1257     return;
1258   }
1259
1260   if (sslState_ == STATE_CONNECTING) {
1261     assert(!server_);
1262     handleConnect();
1263     return;
1264   }
1265
1266   // Normal write
1267   AsyncSocket::handleWrite();
1268 }
1269
1270 int AsyncSSLSocket::interpretSSLError(int rc, int error) {
1271   if (error == SSL_ERROR_WANT_READ) {
1272     // TODO: Even though we are attempting to write data, SSL_write() may
1273     // need to read data if renegotiation is being performed.  We currently
1274     // don't support this and just fail the write.
1275     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1276                << ", sslState=" << sslState_ << ", events=" << eventFlags_
1277                << "): " << "unsupported SSL renegotiation during write",
1278       errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
1279                        SSL_INVALID_RENEGOTIATION);
1280     ERR_clear_error();
1281     return -1;
1282   } else {
1283     // TODO: Fix this code so that it can return a proper error message
1284     // to the callback, rather than relying on AsyncSocket code which
1285     // can't handle SSL errors.
1286     long lastError = ERR_get_error();
1287     VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1288             << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
1289             << "SSL error: " << error << ", errno: " << errno
1290             << ", func: " << ERR_func_error_string(lastError)
1291             << ", reason: " << ERR_reason_error_string(lastError);
1292     if (error != SSL_ERROR_SYSCALL) {
1293       if ((unsigned long)lastError < 0x8000) {
1294         errno = ENOSYS;
1295       } else {
1296         errno = lastError;
1297       }
1298     }
1299     ERR_clear_error();
1300     if (!zero_return(error, rc)) {
1301       return -1;
1302     } else {
1303       return 0;
1304     }
1305   }
1306 }
1307
1308 ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
1309                                       uint32_t count,
1310                                       WriteFlags flags,
1311                                       uint32_t* countWritten,
1312                                       uint32_t* partialWritten) {
1313   if (sslState_ == STATE_UNENCRYPTED) {
1314     return AsyncSocket::performWrite(
1315       vec, count, flags, countWritten, partialWritten);
1316   }
1317   if (sslState_ != STATE_ESTABLISHED) {
1318     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1319                << ", sslState=" << sslState_
1320                << ", events=" << eventFlags_ << "): "
1321                << "TODO: AsyncSSLSocket currently does not support calling "
1322                << "write() before the handshake has fully completed";
1323       errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
1324                        SSL_EARLY_WRITE);
1325       return -1;
1326   }
1327
1328   bool cork = isSet(flags, WriteFlags::CORK) || persistentCork_;
1329   CorkGuard guard(fd_, count > 1, cork, &corked_);
1330
1331 #if 0
1332 //#ifdef SSL_MODE_WRITE_IOVEC
1333   if (ssl_->expand == nullptr &&
1334       ssl_->compress == nullptr &&
1335       (ssl_->mode & SSL_MODE_WRITE_IOVEC)) {
1336     return performWriteIovec(vec, count, flags, countWritten, partialWritten);
1337   }
1338 #endif
1339
1340   // Declare a buffer used to hold small write requests.  It could point to a
1341   // memory block either on stack or on heap. If it is on heap, we release it
1342   // manually when scope exits
1343   char* combinedBuf{nullptr};
1344   SCOPE_EXIT {
1345     // Note, always keep this check consistent with what we do below
1346     if (combinedBuf != nullptr && minWriteSize_ > MAX_STACK_BUF_SIZE) {
1347       delete[] combinedBuf;
1348     }
1349   };
1350
1351   *countWritten = 0;
1352   *partialWritten = 0;
1353   ssize_t totalWritten = 0;
1354   size_t bytesStolenFromNextBuffer = 0;
1355   for (uint32_t i = 0; i < count; i++) {
1356     const iovec* v = vec + i;
1357     size_t offset = bytesStolenFromNextBuffer;
1358     bytesStolenFromNextBuffer = 0;
1359     size_t len = v->iov_len - offset;
1360     const void* buf;
1361     if (len == 0) {
1362       (*countWritten)++;
1363       continue;
1364     }
1365     buf = ((const char*)v->iov_base) + offset;
1366
1367     ssize_t bytes;
1368     errno = 0;
1369     uint32_t buffersStolen = 0;
1370     if ((len < minWriteSize_) && ((i + 1) < count)) {
1371       // Combine this buffer with part or all of the next buffers in
1372       // order to avoid really small-grained calls to SSL_write().
1373       // Each call to SSL_write() produces a separate record in
1374       // the egress SSL stream, and we've found that some low-end
1375       // mobile clients can't handle receiving an HTTP response
1376       // header and the first part of the response body in two
1377       // separate SSL records (even if those two records are in
1378       // the same TCP packet).
1379
1380       if (combinedBuf == nullptr) {
1381         if (minWriteSize_ > MAX_STACK_BUF_SIZE) {
1382           // Allocate the buffer on heap
1383           combinedBuf = new char[minWriteSize_];
1384         } else {
1385           // Allocate the buffer on stack
1386           combinedBuf = (char*)alloca(minWriteSize_);
1387         }
1388       }
1389       assert(combinedBuf != nullptr);
1390
1391       memcpy(combinedBuf, buf, len);
1392       do {
1393         // INVARIANT: i + buffersStolen == complete chunks serialized
1394         uint32_t nextIndex = i + buffersStolen + 1;
1395         bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len,
1396                                              minWriteSize_ - len);
1397         memcpy(combinedBuf + len, vec[nextIndex].iov_base,
1398                bytesStolenFromNextBuffer);
1399         len += bytesStolenFromNextBuffer;
1400         if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
1401           // couldn't steal the whole buffer
1402           break;
1403         } else {
1404           bytesStolenFromNextBuffer = 0;
1405           buffersStolen++;
1406         }
1407       } while ((i + buffersStolen + 1) < count && (len < minWriteSize_));
1408       bytes = eorAwareSSLWrite(
1409         ssl_, combinedBuf, len,
1410         (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
1411
1412     } else {
1413       bytes = eorAwareSSLWrite(ssl_, buf, len,
1414                            (isSet(flags, WriteFlags::EOR) && i + 1 == count));
1415     }
1416
1417     if (bytes <= 0) {
1418       int error = SSL_get_error(ssl_, bytes);
1419       if (error == SSL_ERROR_WANT_WRITE) {
1420         // The caller will register for write event if not already.
1421         *partialWritten = offset;
1422         return totalWritten;
1423       }
1424       int rc = interpretSSLError(bytes, error);
1425       if (rc < 0) {
1426         return rc;
1427       } // else fall through to below to correctly record totalWritten
1428     }
1429
1430     totalWritten += bytes;
1431
1432     if (bytes == (ssize_t)len) {
1433       // The full iovec is written.
1434       (*countWritten) += 1 + buffersStolen;
1435       i += buffersStolen;
1436       // continue
1437     } else {
1438       bytes += offset; // adjust bytes to account for all of v
1439       while (bytes >= (ssize_t)v->iov_len) {
1440         // We combined this buf with part or all of the next one, and
1441         // we managed to write all of this buf but not all of the bytes
1442         // from the next one that we'd hoped to write.
1443         bytes -= v->iov_len;
1444         (*countWritten)++;
1445         v = &(vec[++i]);
1446       }
1447       *partialWritten = bytes;
1448       return totalWritten;
1449     }
1450   }
1451
1452   return totalWritten;
1453 }
1454
1455 #if 0
1456 //#ifdef SSL_MODE_WRITE_IOVEC
1457 ssize_t AsyncSSLSocket::performWriteIovec(const iovec* vec,
1458                                           uint32_t count,
1459                                           WriteFlags flags,
1460                                           uint32_t* countWritten,
1461                                           uint32_t* partialWritten) {
1462   size_t tot = 0;
1463   for (uint32_t j = 0; j < count; j++) {
1464     tot += vec[j].iov_len;
1465   }
1466
1467   ssize_t totalWritten = SSL_write_iovec(ssl_, vec, count);
1468
1469   *countWritten = 0;
1470   *partialWritten = 0;
1471   if (totalWritten <= 0) {
1472     return interpretSSLError(totalWritten, SSL_get_error(ssl_, totalWritten));
1473   } else {
1474     ssize_t bytes = totalWritten, i = 0;
1475     while (i < count && bytes >= (ssize_t)vec[i].iov_len) {
1476       // we managed to write all of this buf
1477       bytes -= vec[i].iov_len;
1478       (*countWritten)++;
1479       i++;
1480     }
1481     *partialWritten = bytes;
1482
1483     VLOG(4) << "SSL_write_iovec() writes " << tot
1484             << ", returns " << totalWritten << " bytes"
1485             << ", max_send_fragment=" << ssl_->max_send_fragment
1486             << ", count=" << count << ", countWritten=" << *countWritten;
1487
1488     return totalWritten;
1489   }
1490 }
1491 #endif
1492
1493 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
1494                                       bool eor) {
1495   if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) {
1496     if (appEorByteNo_) {
1497       // cannot track for more than one app byte EOR
1498       CHECK(appEorByteNo_ == appBytesWritten_ + n);
1499     } else {
1500       appEorByteNo_ = appBytesWritten_ + n;
1501     }
1502
1503     // 1. It is fine to keep updating minEorRawByteNo_.
1504     // 2. It is _min_ in the sense that SSL record will add some overhead.
1505     minEorRawByteNo_ = getRawBytesWritten() + n;
1506   }
1507
1508   n = sslWriteImpl(ssl, buf, n);
1509   if (n > 0) {
1510     appBytesWritten_ += n;
1511     if (appEorByteNo_) {
1512       if (getRawBytesWritten() >= minEorRawByteNo_) {
1513         minEorRawByteNo_ = 0;
1514       }
1515       if(appBytesWritten_ == appEorByteNo_) {
1516         appEorByteNo_ = 0;
1517       } else {
1518         CHECK(appBytesWritten_ < appEorByteNo_);
1519       }
1520     }
1521   }
1522   return n;
1523 }
1524
1525 void
1526 AsyncSSLSocket::sslInfoCallback(const SSL *ssl, int where, int ret) {
1527   AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl);
1528   if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
1529     sslSocket->renegotiateAttempted_ = true;
1530   }
1531 }
1532
1533 int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) {
1534   int ret;
1535   struct msghdr msg;
1536   struct iovec iov;
1537   int flags = 0;
1538   AsyncSSLSocket *tsslSock;
1539
1540   iov.iov_base = const_cast<char *>(in);
1541   iov.iov_len = inl;
1542   memset(&msg, 0, sizeof(msg));
1543   msg.msg_iov = &iov;
1544   msg.msg_iovlen = 1;
1545
1546   tsslSock =
1547     reinterpret_cast<AsyncSSLSocket*>(BIO_get_app_data(b));
1548   if (tsslSock &&
1549       tsslSock->minEorRawByteNo_ &&
1550       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
1551     flags = MSG_EOR;
1552   }
1553
1554   errno = 0;
1555   ret = sendmsg(b->num, &msg, flags);
1556   BIO_clear_retry_flags(b);
1557   if (ret <= 0) {
1558     if (BIO_sock_should_retry(ret))
1559       BIO_set_retry_write(b);
1560   }
1561   return(ret);
1562 }
1563
1564 int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
1565                                        X509_STORE_CTX* x509Ctx) {
1566   SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
1567     x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
1568   AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
1569
1570   VLOG(3) <<  "AsyncSSLSocket::sslVerifyCallback() this=" << self << ", "
1571           << "fd=" << self->fd_ << ", preverifyOk=" << preverifyOk;
1572   return (self->handshakeCallback_) ?
1573     self->handshakeCallback_->handshakeVer(self, preverifyOk, x509Ctx) :
1574     preverifyOk;
1575 }
1576
1577 void AsyncSSLSocket::enableClientHelloParsing()  {
1578     parseClientHello_ = true;
1579     clientHelloInfo_.reset(new ClientHelloInfo());
1580 }
1581
1582 void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl)  {
1583   SSL_set_msg_callback(ssl, nullptr);
1584   SSL_set_msg_callback_arg(ssl, nullptr);
1585   clientHelloInfo_->clientHelloBuf_.clear();
1586 }
1587
1588 void
1589 AsyncSSLSocket::clientHelloParsingCallback(int written, int version,
1590     int contentType, const void *buf, size_t len, SSL *ssl, void *arg)
1591 {
1592   AsyncSSLSocket *sock = static_cast<AsyncSSLSocket*>(arg);
1593   if (written != 0) {
1594     sock->resetClientHelloParsing(ssl);
1595     return;
1596   }
1597   if (contentType != SSL3_RT_HANDSHAKE) {
1598     return;
1599   }
1600   if (len == 0) {
1601     return;
1602   }
1603
1604   auto& clientHelloBuf = sock->clientHelloInfo_->clientHelloBuf_;
1605   clientHelloBuf.append(IOBuf::wrapBuffer(buf, len));
1606   try {
1607     Cursor cursor(clientHelloBuf.front());
1608     if (cursor.read<uint8_t>() != SSL3_MT_CLIENT_HELLO) {
1609       sock->resetClientHelloParsing(ssl);
1610       return;
1611     }
1612
1613     if (cursor.totalLength() < 3) {
1614       clientHelloBuf.trimEnd(len);
1615       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1616       return;
1617     }
1618
1619     uint32_t messageLength = cursor.read<uint8_t>();
1620     messageLength <<= 8;
1621     messageLength |= cursor.read<uint8_t>();
1622     messageLength <<= 8;
1623     messageLength |= cursor.read<uint8_t>();
1624     if (cursor.totalLength() < messageLength) {
1625       clientHelloBuf.trimEnd(len);
1626       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1627       return;
1628     }
1629
1630     sock->clientHelloInfo_->clientHelloMajorVersion_ = cursor.read<uint8_t>();
1631     sock->clientHelloInfo_->clientHelloMinorVersion_ = cursor.read<uint8_t>();
1632
1633     cursor.skip(4); // gmt_unix_time
1634     cursor.skip(28); // random_bytes
1635
1636     cursor.skip(cursor.read<uint8_t>()); // session_id
1637
1638     uint16_t cipherSuitesLength = cursor.readBE<uint16_t>();
1639     for (int i = 0; i < cipherSuitesLength; i += 2) {
1640       sock->clientHelloInfo_->
1641         clientHelloCipherSuites_.push_back(cursor.readBE<uint16_t>());
1642     }
1643
1644     uint8_t compressionMethodsLength = cursor.read<uint8_t>();
1645     for (int i = 0; i < compressionMethodsLength; ++i) {
1646       sock->clientHelloInfo_->
1647         clientHelloCompressionMethods_.push_back(cursor.readBE<uint8_t>());
1648     }
1649
1650     if (cursor.totalLength() > 0) {
1651       uint16_t extensionsLength = cursor.readBE<uint16_t>();
1652       while (extensionsLength) {
1653         TLSExtension extensionType = static_cast<TLSExtension>(
1654             cursor.readBE<uint16_t>());
1655         sock->clientHelloInfo_->
1656           clientHelloExtensions_.push_back(extensionType);
1657         extensionsLength -= 2;
1658         uint16_t extensionDataLength = cursor.readBE<uint16_t>();
1659         extensionsLength -= 2;
1660
1661         if (extensionType == TLSExtension::SIGNATURE_ALGORITHMS) {
1662           cursor.skip(2);
1663           extensionDataLength -= 2;
1664           while (extensionDataLength) {
1665             HashAlgorithm hashAlg = static_cast<HashAlgorithm>(
1666                 cursor.readBE<uint8_t>());
1667             SignatureAlgorithm sigAlg = static_cast<SignatureAlgorithm>(
1668                 cursor.readBE<uint8_t>());
1669             extensionDataLength -= 2;
1670             sock->clientHelloInfo_->
1671               clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg);
1672           }
1673         } else {
1674           cursor.skip(extensionDataLength);
1675           extensionsLength -= extensionDataLength;
1676         }
1677       }
1678     }
1679   } catch (std::out_of_range& e) {
1680     // we'll use what we found and cleanup below.
1681     VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
1682       << "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock;
1683   }
1684
1685   sock->resetClientHelloParsing(ssl);
1686 }
1687
1688 } // namespace