f4a443214cff1cc4965be1509324603aaabeda5e
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSSLSocket.h>
18
19 #include <folly/io/async/EventBase.h>
20
21 #include <boost/noncopyable.hpp>
22 #include <errno.h>
23 #include <fcntl.h>
24 #include <netinet/in.h>
25 #include <netinet/tcp.h>
26 #include <openssl/err.h>
27 #include <openssl/asn1.h>
28 #include <openssl/ssl.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31 #include <unistd.h>
32 #include <chrono>
33
34 #include <folly/Bits.h>
35 #include <folly/SocketAddress.h>
36 #include <folly/SpinLock.h>
37 #include <folly/io/IOBuf.h>
38 #include <folly/io/Cursor.h>
39
40 using folly::SocketAddress;
41 using folly::SSLContext;
42 using std::string;
43 using std::shared_ptr;
44
45 using folly::Endian;
46 using folly::IOBuf;
47 using folly::SpinLock;
48 using folly::SpinLockGuard;
49 using folly::io::Cursor;
50 using std::unique_ptr;
51 using std::bind;
52
53 namespace {
54 using folly::AsyncSocket;
55 using folly::AsyncSocketException;
56 using folly::AsyncSSLSocket;
57 using folly::Optional;
58
59 // We have one single dummy SSL context so that we can implement attach
60 // and detach methods in a thread safe fashion without modifying opnessl.
61 static SSLContext *dummyCtx = nullptr;
62 static SpinLock dummyCtxLock;
63
64 // Numbers chosen as to not collide with functions in ssl.h
65 const uint8_t TASYNCSSLSOCKET_F_PERFORM_READ = 90;
66 const uint8_t TASYNCSSLSOCKET_F_PERFORM_WRITE = 91;
67
68 // If given min write size is less than this, buffer will be allocated on
69 // stack, otherwise it is allocated on heap
70 const size_t MAX_STACK_BUF_SIZE = 2048;
71
72 // This converts "illegal" shutdowns into ZERO_RETURN
73 inline bool zero_return(int error, int rc) {
74   return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0));
75 }
76
77 class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
78                                 public AsyncSSLSocket::HandshakeCB {
79
80  private:
81   AsyncSSLSocket *sslSocket_;
82   AsyncSSLSocket::ConnectCallback *callback_;
83   int timeout_;
84   int64_t startTime_;
85
86  protected:
87   virtual ~AsyncSSLSocketConnector() {
88   }
89
90  public:
91   AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
92                            AsyncSocket::ConnectCallback *callback,
93                            int timeout) :
94       sslSocket_(sslSocket),
95       callback_(callback),
96       timeout_(timeout),
97       startTime_(std::chrono::duration_cast<std::chrono::milliseconds>(
98                    std::chrono::steady_clock::now().time_since_epoch()).count()) {
99   }
100
101   virtual void connectSuccess() noexcept {
102     VLOG(7) << "client socket connected";
103
104     int64_t timeoutLeft = 0;
105     if (timeout_ > 0) {
106       auto curTime = std::chrono::duration_cast<std::chrono::milliseconds>(
107         std::chrono::steady_clock::now().time_since_epoch()).count();
108
109       timeoutLeft = timeout_ - (curTime - startTime_);
110       if (timeoutLeft <= 0) {
111         AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
112                                 "SSL connect timed out");
113         fail(ex);
114         delete this;
115         return;
116       }
117     }
118     sslSocket_->sslConn(this, timeoutLeft);
119   }
120
121   virtual void connectErr(const AsyncSocketException& ex) noexcept {
122     LOG(ERROR) << "TCP connect failed: " <<  ex.what();
123     fail(ex);
124     delete this;
125   }
126
127   virtual void handshakeSuc(AsyncSSLSocket *sock) noexcept {
128     VLOG(7) << "client handshake success";
129     if (callback_) {
130       callback_->connectSuccess();
131     }
132     delete this;
133   }
134
135   virtual void handshakeErr(AsyncSSLSocket *socket,
136                               const AsyncSocketException& ex) noexcept {
137     LOG(ERROR) << "client handshakeErr: " << ex.what();
138     fail(ex);
139     delete this;
140   }
141
142   void fail(const AsyncSocketException &ex) {
143     // fail is a noop if called twice
144     if (callback_) {
145       AsyncSSLSocket::ConnectCallback *cb = callback_;
146       callback_ = nullptr;
147
148       cb->connectErr(ex);
149       sslSocket_->closeNow();
150       // closeNow can call handshakeErr if it hasn't been called already.
151       // So this may have been deleted, no member variable access beyond this
152       // point
153       // Note that closeNow may invoke writeError callbacks if the socket had
154       // write data pending connection completion.
155     }
156   }
157 };
158
159 // XXX: implement an equivalent to corking for platforms with TCP_NOPUSH?
160 #ifdef TCP_CORK // Linux-only
161 /**
162  * Utility class that corks a TCP socket upon construction or uncorks
163  * the socket upon destruction
164  */
165 class CorkGuard : private boost::noncopyable {
166  public:
167   CorkGuard(int fd, bool multipleWrites, bool haveMore, bool* corked):
168     fd_(fd), haveMore_(haveMore), corked_(corked) {
169     if (*corked_) {
170       // socket is already corked; nothing to do
171       return;
172     }
173     if (multipleWrites || haveMore) {
174       // We are performing multiple writes in this performWrite() call,
175       // and/or there are more calls to performWrite() that will be invoked
176       // later, so enable corking
177       int flag = 1;
178       setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
179       *corked_ = true;
180     }
181   }
182
183   ~CorkGuard() {
184     if (haveMore_) {
185       // more data to come; don't uncork yet
186       return;
187     }
188     if (!*corked_) {
189       // socket isn't corked; nothing to do
190       return;
191     }
192
193     int flag = 0;
194     setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
195     *corked_ = false;
196   }
197
198  private:
199   int fd_;
200   bool haveMore_;
201   bool* corked_;
202 };
203 #else
204 class CorkGuard : private boost::noncopyable {
205  public:
206   CorkGuard(int, bool, bool, bool*) {}
207 };
208 #endif
209
210 void setup_SSL_CTX(SSL_CTX *ctx) {
211 #ifdef SSL_MODE_RELEASE_BUFFERS
212   SSL_CTX_set_mode(ctx,
213                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
214                    SSL_MODE_ENABLE_PARTIAL_WRITE
215                    | SSL_MODE_RELEASE_BUFFERS
216                    );
217 #else
218   SSL_CTX_set_mode(ctx,
219                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
220                    SSL_MODE_ENABLE_PARTIAL_WRITE
221                    );
222 #endif
223 // SSL_CTX_set_mode is a Macro
224 #ifdef SSL_MODE_WRITE_IOVEC
225   SSL_CTX_set_mode(ctx,
226                    SSL_CTX_get_mode(ctx)
227                    | SSL_MODE_WRITE_IOVEC);
228 #endif
229
230 }
231
232 BIO_METHOD eorAwareBioMethod;
233
234 void* initEorBioMethod(void) {
235   memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod));
236   // override the bwrite method for MSG_EOR support
237   eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite;
238
239   // Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not
240   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
241   // then have specific handlings. The eorAwareBioWrite should be compatible
242   // with the one in openssl.
243
244   // Return something here to enable AsyncSSLSocket to call this method using
245   // a function-scoped static.
246   return nullptr;
247 }
248
249 } // anonymous namespace
250
251 namespace folly {
252
253 SSLException::SSLException(int sslError, int errno_copy):
254     AsyncSocketException(
255       AsyncSocketException::SSL_ERROR,
256       ERR_error_string(sslError, msg_),
257       sslError == SSL_ERROR_SYSCALL ? errno_copy : 0), error_(sslError) {}
258
259 /**
260  * Create a client AsyncSSLSocket
261  */
262 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
263                                EventBase* evb, bool deferSecurityNegotiation) :
264     AsyncSocket(evb),
265     ctx_(ctx),
266     handshakeTimeout_(this, evb) {
267   init();
268   if (deferSecurityNegotiation) {
269     sslState_ = STATE_UNENCRYPTED;
270   }
271 }
272
273 /**
274  * Create a server/client AsyncSSLSocket
275  */
276 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
277                                EventBase* evb, int fd, bool server,
278                                bool deferSecurityNegotiation) :
279     AsyncSocket(evb, fd),
280     server_(server),
281     ctx_(ctx),
282     handshakeTimeout_(this, evb) {
283   init();
284   if (server) {
285     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
286                               AsyncSSLSocket::sslInfoCallback);
287   }
288   if (deferSecurityNegotiation) {
289     sslState_ = STATE_UNENCRYPTED;
290   }
291 }
292
293 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
294 /**
295  * Create a client AsyncSSLSocket and allow tlsext_hostname
296  * to be sent in Client Hello.
297  */
298 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
299                                  EventBase* evb,
300                                const std::string& serverName,
301                                bool deferSecurityNegotiation) :
302     AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
303   tlsextHostname_ = serverName;
304 }
305
306 /**
307  * Create a client AsyncSSLSocket from an already connected fd
308  * and allow tlsext_hostname to be sent in Client Hello.
309  */
310 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
311                                  EventBase* evb, int fd,
312                                const std::string& serverName,
313                                bool deferSecurityNegotiation) :
314     AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
315   tlsextHostname_ = serverName;
316 }
317 #endif
318
319 AsyncSSLSocket::~AsyncSSLSocket() {
320   VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
321           << ", evb=" << eventBase_ << ", fd=" << fd_
322           << ", state=" << int(state_) << ", sslState="
323           << sslState_ << ", events=" << eventFlags_ << ")";
324 }
325
326 void AsyncSSLSocket::init() {
327   // Do this here to ensure we initialize this once before any use of
328   // AsyncSSLSocket instances and not as part of library load.
329   static const auto eorAwareBioMethodInitializer = initEorBioMethod();
330   setup_SSL_CTX(ctx_->getSSLCtx());
331 }
332
333 void AsyncSSLSocket::closeNow() {
334   // Close the SSL connection.
335   if (ssl_ != nullptr && fd_ != -1) {
336     int rc = SSL_shutdown(ssl_);
337     if (rc == 0) {
338       rc = SSL_shutdown(ssl_);
339     }
340     if (rc < 0) {
341       ERR_clear_error();
342     }
343   }
344
345   if (sslSession_ != nullptr) {
346     SSL_SESSION_free(sslSession_);
347     sslSession_ = nullptr;
348   }
349
350   sslState_ = STATE_CLOSED;
351
352   if (handshakeTimeout_.isScheduled()) {
353     handshakeTimeout_.cancelTimeout();
354   }
355
356   DestructorGuard dg(this);
357
358   if (handshakeCallback_) {
359     AsyncSocketException ex(AsyncSocketException::END_OF_FILE,
360                            "SSL connection closed locally");
361     HandshakeCB* callback = handshakeCallback_;
362     handshakeCallback_ = nullptr;
363     callback->handshakeErr(this, ex);
364   }
365
366   if (ssl_ != nullptr) {
367     SSL_free(ssl_);
368     ssl_ = nullptr;
369   }
370
371   // Close the socket.
372   AsyncSocket::closeNow();
373 }
374
375 void AsyncSSLSocket::shutdownWrite() {
376   // SSL sockets do not support half-shutdown, so just perform a full shutdown.
377   //
378   // (Performing a full shutdown here is more desirable than doing nothing at
379   // all.  The purpose of shutdownWrite() is normally to notify the other end
380   // of the connection that no more data will be sent.  If we do nothing, the
381   // other end will never know that no more data is coming, and this may result
382   // in protocol deadlock.)
383   close();
384 }
385
386 void AsyncSSLSocket::shutdownWriteNow() {
387   closeNow();
388 }
389
390 bool AsyncSSLSocket::good() const {
391   return (AsyncSocket::good() &&
392           (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
393            sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
394 }
395
396 // The TAsyncTransport definition of 'good' states that the transport is
397 // ready to perform reads and writes, so sslState_ == UNINIT must report !good.
398 // connecting can be true when the sslState_ == UNINIT because the AsyncSocket
399 // is connected but we haven't initiated the call to SSL_connect.
400 bool AsyncSSLSocket::connecting() const {
401   return (!server_ &&
402           (AsyncSocket::connecting() ||
403            (AsyncSocket::good() && (sslState_ == STATE_UNINIT ||
404                                      sslState_ == STATE_CONNECTING))));
405 }
406
407 bool AsyncSSLSocket::isEorTrackingEnabled() const {
408   const BIO *wb = SSL_get_wbio(ssl_);
409   return wb && wb->method == &eorAwareBioMethod;
410 }
411
412 void AsyncSSLSocket::setEorTracking(bool track) {
413   BIO *wb = SSL_get_wbio(ssl_);
414   if (!wb) {
415     throw AsyncSocketException(AsyncSocketException::INVALID_STATE,
416                               "setting EOR tracking without an initialized "
417                               "BIO");
418   }
419
420   if (track) {
421     if (wb->method != &eorAwareBioMethod) {
422       // only do this if we didn't
423       wb->method = &eorAwareBioMethod;
424       BIO_set_app_data(wb, this);
425       appEorByteNo_ = 0;
426       minEorRawByteNo_ = 0;
427     }
428   } else if (wb->method == &eorAwareBioMethod) {
429     wb->method = BIO_s_socket();
430     BIO_set_app_data(wb, nullptr);
431     appEorByteNo_ = 0;
432     minEorRawByteNo_ = 0;
433   } else {
434     CHECK(wb->method == BIO_s_socket());
435   }
436 }
437
438 size_t AsyncSSLSocket::getRawBytesWritten() const {
439   BIO *b;
440   if (!ssl_ || !(b = SSL_get_wbio(ssl_))) {
441     return 0;
442   }
443
444   return BIO_number_written(b);
445 }
446
447 size_t AsyncSSLSocket::getRawBytesReceived() const {
448   BIO *b;
449   if (!ssl_ || !(b = SSL_get_rbio(ssl_))) {
450     return 0;
451   }
452
453   return BIO_number_read(b);
454 }
455
456
457 void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
458   LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_
459              << ", state=" << int(state_) << ", sslState=" << sslState_ << ", "
460              << "events=" << eventFlags_ << ", server=" << short(server_) << "): "
461              << "sslAccept/Connect() called in invalid "
462              << "state, handshake callback " << handshakeCallback_ << ", new callback "
463              << callback;
464   assert(!handshakeTimeout_.isScheduled());
465   sslState_ = STATE_ERROR;
466
467   AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
468                          "sslAccept() called with socket in invalid state");
469
470   if (callback) {
471     callback->handshakeErr(this, ex);
472   }
473
474   // Check the socket state not the ssl state here.
475   if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) {
476     failHandshake(__func__, ex);
477   }
478 }
479
480 void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
481       const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
482   DestructorGuard dg(this);
483   assert(eventBase_->isInEventBaseThread());
484   verifyPeer_ = verifyPeer;
485
486   // Make sure we're in the uninitialized state
487   if (!server_ || (sslState_ != STATE_UNINIT &&
488                    sslState_ != STATE_UNENCRYPTED) ||
489       handshakeCallback_ != nullptr) {
490     return invalidState(callback);
491   }
492
493   sslState_ = STATE_ACCEPTING;
494   handshakeCallback_ = callback;
495
496   if (timeout > 0) {
497     handshakeTimeout_.scheduleTimeout(timeout);
498   }
499
500   /* register for a read operation (waiting for CLIENT HELLO) */
501   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
502 }
503
504 #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
505 void AsyncSSLSocket::attachSSLContext(
506   const std::shared_ptr<SSLContext>& ctx) {
507
508   // Check to ensure we are in client mode. Changing a server's ssl
509   // context doesn't make sense since clients of that server would likely
510   // become confused when the server's context changes.
511   DCHECK(!server_);
512   DCHECK(!ctx_);
513   DCHECK(ctx);
514   DCHECK(ctx->getSSLCtx());
515   ctx_ = ctx;
516
517   // In order to call attachSSLContext, detachSSLContext must have been
518   // previously called which sets the socket's context to the dummy
519   // context. Thus we must acquire this lock.
520   SpinLockGuard guard(dummyCtxLock);
521   SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx());
522 }
523
524 void AsyncSSLSocket::detachSSLContext() {
525   DCHECK(ctx_);
526   ctx_.reset();
527   // We aren't using the initial_ctx for now, and it can introduce race
528   // conditions in the destructor of the SSL object.
529 #ifndef OPENSSL_NO_TLSEXT
530   if (ssl_->initial_ctx) {
531     SSL_CTX_free(ssl_->initial_ctx);
532     ssl_->initial_ctx = nullptr;
533   }
534 #endif
535   SpinLockGuard guard(dummyCtxLock);
536   if (nullptr == dummyCtx) {
537     // We need to lazily initialize the dummy context so we don't
538     // accidentally override any programmatic settings to openssl
539     dummyCtx = new SSLContext;
540   }
541   // We must remove this socket's references to its context right now
542   // since this socket could get passed to any thread. If the context has
543   // had its locking disabled, just doing a set in attachSSLContext()
544   // would not be thread safe.
545   SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
546 }
547 #endif
548
549 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
550 void AsyncSSLSocket::switchServerSSLContext(
551   const std::shared_ptr<SSLContext>& handshakeCtx) {
552   CHECK(server_);
553   if (sslState_ != STATE_ACCEPTING) {
554     // We log it here and allow the switch.
555     // It should not affect our re-negotiation support (which
556     // is not supported now).
557     VLOG(6) << "fd=" << getFd()
558             << " renegotation detected when switching SSL_CTX";
559   }
560
561   setup_SSL_CTX(handshakeCtx->getSSLCtx());
562   SSL_CTX_set_info_callback(handshakeCtx->getSSLCtx(),
563                             AsyncSSLSocket::sslInfoCallback);
564   handshakeCtx_ = handshakeCtx;
565   SSL_set_SSL_CTX(ssl_, handshakeCtx->getSSLCtx());
566 }
567
568 bool AsyncSSLSocket::isServerNameMatch() const {
569   CHECK(!server_);
570
571   if (!ssl_) {
572     return false;
573   }
574
575   SSL_SESSION *ss = SSL_get_session(ssl_);
576   if (!ss) {
577     return false;
578   }
579
580   if(!ss->tlsext_hostname) {
581     return false;
582   }
583   return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
584 }
585
586 void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
587   tlsextHostname_ = std::move(serverName);
588 }
589
590 #endif
591
592 void AsyncSSLSocket::timeoutExpired() noexcept {
593   if (state_ == StateEnum::ESTABLISHED &&
594       (sslState_ == STATE_CACHE_LOOKUP ||
595        sslState_ == STATE_RSA_ASYNC_PENDING)) {
596     sslState_ = STATE_ERROR;
597     // We are expecting a callback in restartSSLAccept.  The cache lookup
598     // and rsa-call necessarily have pointers to this ssl socket, so delay
599     // the cleanup until he calls us back.
600   } else {
601     assert(state_ == StateEnum::ESTABLISHED &&
602            (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
603     DestructorGuard dg(this);
604     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
605                            (sslState_ == STATE_CONNECTING) ?
606                            "SSL connect timed out" : "SSL accept timed out");
607     failHandshake(__func__, ex);
608   }
609 }
610
611 int AsyncSSLSocket::sslExDataIndex_ = -1;
612 std::mutex AsyncSSLSocket::mutex_;
613
614 int AsyncSSLSocket::getSSLExDataIndex() {
615   if (sslExDataIndex_ < 0) {
616     std::lock_guard<std::mutex> g(mutex_);
617     if (sslExDataIndex_ < 0) {
618       sslExDataIndex_ = SSL_get_ex_new_index(0,
619           (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
620     }
621   }
622   return sslExDataIndex_;
623 }
624
625 AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
626   return static_cast<AsyncSSLSocket *>(SSL_get_ex_data(ssl,
627       getSSLExDataIndex()));
628 }
629
630 void AsyncSSLSocket::failHandshake(const char* fn,
631                                     const AsyncSocketException& ex) {
632   startFail();
633
634   if (handshakeTimeout_.isScheduled()) {
635     handshakeTimeout_.cancelTimeout();
636   }
637   if (handshakeCallback_ != nullptr) {
638     HandshakeCB* callback = handshakeCallback_;
639     handshakeCallback_ = nullptr;
640     callback->handshakeErr(this, ex);
641   }
642
643   finishFail();
644 }
645
646 void AsyncSSLSocket::invokeHandshakeCB() {
647   if (handshakeTimeout_.isScheduled()) {
648     handshakeTimeout_.cancelTimeout();
649   }
650   if (handshakeCallback_) {
651     HandshakeCB* callback = handshakeCallback_;
652     handshakeCallback_ = nullptr;
653     callback->handshakeSuc(this);
654   }
655 }
656
657 void AsyncSSLSocket::connect(ConnectCallback* callback,
658                               const folly::SocketAddress& address,
659                               int timeout,
660                               const OptionMap &options,
661                               const folly::SocketAddress& bindAddr)
662                               noexcept {
663   assert(!server_);
664   assert(state_ == StateEnum::UNINIT);
665   assert(sslState_ == STATE_UNINIT);
666   AsyncSSLSocketConnector *connector =
667     new AsyncSSLSocketConnector(this, callback, timeout);
668   AsyncSocket::connect(connector, address, timeout, options, bindAddr);
669 }
670
671 void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
672   // apply the settings specified in verifyPeer_
673   if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
674     if(ctx_->needsPeerVerification()) {
675       SSL_set_verify(ssl, ctx_->getVerificationMode(),
676         AsyncSSLSocket::sslVerifyCallback);
677     }
678   } else {
679     if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
680         verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT) {
681       SSL_set_verify(ssl, SSLContext::getVerificationMode(verifyPeer_),
682         AsyncSSLSocket::sslVerifyCallback);
683     }
684   }
685 }
686
687 void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
688         const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
689   DestructorGuard dg(this);
690   assert(eventBase_->isInEventBaseThread());
691
692   verifyPeer_ = verifyPeer;
693
694   // Make sure we're in the uninitialized state
695   if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
696                   STATE_UNENCRYPTED) ||
697       handshakeCallback_ != nullptr) {
698     return invalidState(callback);
699   }
700
701   sslState_ = STATE_CONNECTING;
702   handshakeCallback_ = callback;
703
704   try {
705     ssl_ = ctx_->createSSL();
706   } catch (std::exception &e) {
707     sslState_ = STATE_ERROR;
708     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
709                            "error calling SSLContext::createSSL()");
710     LOG(ERROR) << "AsyncSSLSocket::sslConn(this=" << this << ", fd="
711             << fd_ << "): " << e.what();
712     return failHandshake(__func__, ex);
713   }
714
715   applyVerificationOptions(ssl_);
716
717   SSL_set_fd(ssl_, fd_);
718   if (sslSession_ != nullptr) {
719     SSL_set_session(ssl_, sslSession_);
720     SSL_SESSION_free(sslSession_);
721     sslSession_ = nullptr;
722   }
723 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
724   if (tlsextHostname_.size()) {
725     SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
726   }
727 #endif
728
729   SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
730
731   if (timeout > 0) {
732     handshakeTimeout_.scheduleTimeout(timeout);
733   }
734
735   handleConnect();
736 }
737
738 SSL_SESSION *AsyncSSLSocket::getSSLSession() {
739   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
740     return SSL_get1_session(ssl_);
741   }
742
743   return sslSession_;
744 }
745
746 void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
747   sslSession_ = session;
748   if (!takeOwnership && session != nullptr) {
749     // Increment the reference count
750     CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
751   }
752 }
753
754 void AsyncSSLSocket::getSelectedNextProtocol(const unsigned char** protoName,
755     unsigned* protoLen) const {
756   if (!getSelectedNextProtocolNoThrow(protoName, protoLen)) {
757     throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
758                               "NPN not supported");
759   }
760 }
761
762 bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
763   const unsigned char** protoName,
764   unsigned* protoLen) const {
765   *protoName = nullptr;
766   *protoLen = 0;
767 #ifdef OPENSSL_NPN_NEGOTIATED
768   SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
769   return true;
770 #else
771   return false;
772 #endif
773 }
774
775 bool AsyncSSLSocket::getSSLSessionReused() const {
776   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
777     return SSL_session_reused(ssl_);
778   }
779   return false;
780 }
781
782 const char *AsyncSSLSocket::getNegotiatedCipherName() const {
783   return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr;
784 }
785
786 const char *AsyncSSLSocket::getSSLServerName() const {
787 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
788   return (ssl_ != nullptr) ? SSL_get_servername(ssl_, TLSEXT_NAMETYPE_host_name)
789         : nullptr;
790 #else
791   throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
792                             "SNI not supported");
793 #endif
794 }
795
796 const char *AsyncSSLSocket::getSSLServerNameNoThrow() const {
797   try {
798     return getSSLServerName();
799   } catch (AsyncSocketException& ex) {
800     return nullptr;
801   }
802 }
803
804 int AsyncSSLSocket::getSSLVersion() const {
805   return (ssl_ != nullptr) ? SSL_version(ssl_) : 0;
806 }
807
808 int AsyncSSLSocket::getSSLCertSize() const {
809   int certSize = 0;
810   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
811   if (cert) {
812     EVP_PKEY *key = X509_get_pubkey(cert);
813     certSize = EVP_PKEY_bits(key);
814     EVP_PKEY_free(key);
815   }
816   return certSize;
817 }
818
819 bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
820   int error = *errorOut = SSL_get_error(ssl_, ret);
821   if (error == SSL_ERROR_WANT_READ) {
822     // Register for read event if not already.
823     updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
824     return true;
825   } else if (error == SSL_ERROR_WANT_WRITE) {
826     VLOG(3) << "AsyncSSLSocket(fd=" << fd_
827             << ", state=" << int(state_) << ", sslState="
828             << sslState_ << ", events=" << eventFlags_ << "): "
829             << "SSL_ERROR_WANT_WRITE";
830     // Register for write event if not already.
831     updateEventRegistration(EventHandler::WRITE, EventHandler::READ);
832     return true;
833 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
834   } else if (error == SSL_ERROR_WANT_SESS_CACHE_LOOKUP) {
835     // We will block but we can't register our own socket.  The callback that
836     // triggered this code will re-call handleAccept at the appropriate time.
837
838     // We can only get here if the linked libssl.so has support for this feature
839     // as well, otherwise SSL_get_error cannot return our error code.
840     sslState_ = STATE_CACHE_LOOKUP;
841
842     // Unregister for all events while blocked here
843     updateEventRegistration(EventHandler::NONE,
844                             EventHandler::READ | EventHandler::WRITE);
845
846     // The timeout (if set) keeps running here
847     return true;
848 #endif
849 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
850   } else if (error == SSL_ERROR_WANT_RSA_ASYNC_PENDING) {
851     // Our custom openssl function has kicked off an async request to do
852     // modular exponentiation.  When that call returns, a callback will
853     // be invoked that will re-call handleAccept.
854     sslState_ = STATE_RSA_ASYNC_PENDING;
855
856     // Unregister for all events while blocked here
857     updateEventRegistration(
858       EventHandler::NONE,
859       EventHandler::READ | EventHandler::WRITE
860     );
861
862     // The timeout (if set) keeps running here
863     return true;
864 #endif
865   } else {
866     // SSL_ERROR_ZERO_RETURN is processed here so we can get some detail
867     // in the log
868     long lastError = ERR_get_error();
869     VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
870             << "state=" << state_ << ", "
871             << "sslState=" << sslState_ << ", "
872             << "events=" << std::hex << eventFlags_ << "): "
873             << "SSL error: " << error << ", "
874             << "errno: " << errno << ", "
875             << "ret: " << ret << ", "
876             << "read: " << BIO_number_read(SSL_get_rbio(ssl_)) << ", "
877             << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
878             << "func: " << ERR_func_error_string(lastError) << ", "
879             << "reason: " << ERR_reason_error_string(lastError);
880     if (error != SSL_ERROR_SYSCALL) {
881       if (error == SSL_ERROR_SSL) {
882         *errorOut = lastError;
883       }
884       if ((unsigned long)lastError < 0x8000) {
885         errno = ENOSYS;
886       } else {
887         errno = lastError;
888       }
889     }
890     ERR_clear_error();
891     return false;
892   }
893 }
894
895 void AsyncSSLSocket::checkForImmediateRead() noexcept {
896   // openssl may have buffered data that it read from the socket already.
897   // In this case we have to process it immediately, rather than waiting for
898   // the socket to become readable again.
899   if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
900     AsyncSocket::handleRead();
901   }
902 }
903
904 void
905 AsyncSSLSocket::restartSSLAccept()
906 {
907   VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this << ", fd=" << fd_
908           << ", state=" << int(state_) << ", "
909           << "sslState=" << sslState_ << ", events=" << eventFlags_;
910   DestructorGuard dg(this);
911   assert(
912     sslState_ == STATE_CACHE_LOOKUP ||
913     sslState_ == STATE_RSA_ASYNC_PENDING ||
914     sslState_ == STATE_ERROR ||
915     sslState_ == STATE_CLOSED
916   );
917   if (sslState_ == STATE_CLOSED) {
918     // I sure hope whoever closed this socket didn't delete it already,
919     // but this is not strictly speaking an error
920     return;
921   }
922   if (sslState_ == STATE_ERROR) {
923     // go straight to fail if timeout expired during lookup
924     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
925                            "SSL accept timed out");
926     failHandshake(__func__, ex);
927     return;
928   }
929   sslState_ = STATE_ACCEPTING;
930   this->handleAccept();
931 }
932
933 void
934 AsyncSSLSocket::handleAccept() noexcept {
935   VLOG(3) << "AsyncSSLSocket::handleAccept() this=" << this
936           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
937           << "sslState=" << sslState_ << ", events=" << eventFlags_;
938   assert(server_);
939   assert(state_ == StateEnum::ESTABLISHED &&
940          sslState_ == STATE_ACCEPTING);
941   if (!ssl_) {
942     /* lazily create the SSL structure */
943     try {
944       ssl_ = ctx_->createSSL();
945     } catch (std::exception &e) {
946       sslState_ = STATE_ERROR;
947       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
948                              "error calling SSLContext::createSSL()");
949       LOG(ERROR) << "AsyncSSLSocket::handleAccept(this=" << this
950                  << ", fd=" << fd_ << "): " << e.what();
951       return failHandshake(__func__, ex);
952     }
953     SSL_set_fd(ssl_, fd_);
954     SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
955
956     applyVerificationOptions(ssl_);
957   }
958
959   if (server_ && parseClientHello_) {
960     SSL_set_msg_callback_arg(ssl_, this);
961     SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback);
962   }
963
964   errno = 0;
965   int ret = SSL_accept(ssl_);
966   if (ret <= 0) {
967     int error;
968     if (willBlock(ret, &error)) {
969       return;
970     } else {
971       sslState_ = STATE_ERROR;
972       SSLException ex(error, errno);
973       return failHandshake(__func__, ex);
974     }
975   }
976
977   handshakeComplete_ = true;
978   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
979
980   // Move into STATE_ESTABLISHED in the normal case that we are in
981   // STATE_ACCEPTING.
982   sslState_ = STATE_ESTABLISHED;
983
984   VLOG(3) << "AsyncSSLSocket " << this << ": fd " << fd_
985           << " successfully accepted; state=" << int(state_)
986           << ", sslState=" << sslState_ << ", events=" << eventFlags_;
987
988   // Remember the EventBase we are attached to, before we start invoking any
989   // callbacks (since the callbacks may call detachEventBase()).
990   EventBase* originalEventBase = eventBase_;
991
992   // Call the accept callback.
993   invokeHandshakeCB();
994
995   // Note that the accept callback may have changed our state.
996   // (set or unset the read callback, called write(), closed the socket, etc.)
997   // The following code needs to handle these situations correctly.
998   //
999   // If the socket has been closed, readCallback_ and writeReqHead_ will
1000   // always be nullptr, so that will prevent us from trying to read or write.
1001   //
1002   // The main thing to check for is if eventBase_ is still originalEventBase.
1003   // If not, we have been detached from this event base, so we shouldn't
1004   // perform any more operations.
1005   if (eventBase_ != originalEventBase) {
1006     return;
1007   }
1008
1009   AsyncSocket::handleInitialReadWrite();
1010 }
1011
1012 void
1013 AsyncSSLSocket::handleConnect() noexcept {
1014   VLOG(3) <<  "AsyncSSLSocket::handleConnect() this=" << this
1015           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
1016           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1017   assert(!server_);
1018   if (state_ < StateEnum::ESTABLISHED) {
1019     return AsyncSocket::handleConnect();
1020   }
1021
1022   assert(state_ == StateEnum::ESTABLISHED &&
1023          sslState_ == STATE_CONNECTING);
1024   assert(ssl_);
1025
1026   errno = 0;
1027   int ret = SSL_connect(ssl_);
1028   if (ret <= 0) {
1029     int error;
1030     if (willBlock(ret, &error)) {
1031       return;
1032     } else {
1033       sslState_ = STATE_ERROR;
1034       SSLException ex(error, errno);
1035       return failHandshake(__func__, ex);
1036     }
1037   }
1038
1039   handshakeComplete_ = true;
1040   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1041
1042   // Move into STATE_ESTABLISHED in the normal case that we are in
1043   // STATE_CONNECTING.
1044   sslState_ = STATE_ESTABLISHED;
1045
1046   VLOG(3) << "AsyncSSLSocket %p: fd %d successfully connected; "
1047           << "state=" << int(state_) << ", sslState=" << sslState_
1048           << ", events=" << eventFlags_;
1049
1050   // Remember the EventBase we are attached to, before we start invoking any
1051   // callbacks (since the callbacks may call detachEventBase()).
1052   EventBase* originalEventBase = eventBase_;
1053
1054   // Call the handshake callback.
1055   invokeHandshakeCB();
1056
1057   // Note that the connect callback may have changed our state.
1058   // (set or unset the read callback, called write(), closed the socket, etc.)
1059   // The following code needs to handle these situations correctly.
1060   //
1061   // If the socket has been closed, readCallback_ and writeReqHead_ will
1062   // always be nullptr, so that will prevent us from trying to read or write.
1063   //
1064   // The main thing to check for is if eventBase_ is still originalEventBase.
1065   // If not, we have been detached from this event base, so we shouldn't
1066   // perform any more operations.
1067   if (eventBase_ != originalEventBase) {
1068     return;
1069   }
1070
1071   AsyncSocket::handleInitialReadWrite();
1072 }
1073
1074 void
1075 AsyncSSLSocket::handleRead() noexcept {
1076   VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
1077           << ", state=" << int(state_) << ", "
1078           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1079   if (state_ < StateEnum::ESTABLISHED) {
1080     return AsyncSocket::handleRead();
1081   }
1082
1083
1084   if (sslState_ == STATE_ACCEPTING) {
1085     assert(server_);
1086     handleAccept();
1087     return;
1088   }
1089   else if (sslState_ == STATE_CONNECTING) {
1090     assert(!server_);
1091     handleConnect();
1092     return;
1093   }
1094
1095   // Normal read
1096   AsyncSocket::handleRead();
1097 }
1098
1099 ssize_t
1100 AsyncSSLSocket::performRead(void* buf, size_t buflen) {
1101   if (sslState_ == STATE_UNENCRYPTED) {
1102     return AsyncSocket::performRead(buf, buflen);
1103   }
1104
1105   errno = 0;
1106   ssize_t bytes = SSL_read(ssl_, buf, buflen);
1107   if (server_ && renegotiateAttempted_) {
1108     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1109                << ", sslstate=" << sslState_ << ", events=" << eventFlags_
1110                << "): client intitiated SSL renegotiation not permitted";
1111     // We pack our own SSLerr here with a dummy function
1112     errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
1113                      SSL_CLIENT_RENEGOTIATION_ATTEMPT);
1114     ERR_clear_error();
1115     return READ_ERROR;
1116   }
1117   if (bytes <= 0) {
1118     int error = SSL_get_error(ssl_, bytes);
1119     if (error == SSL_ERROR_WANT_READ) {
1120       // The caller will register for read event if not already.
1121       return READ_BLOCKING;
1122     } else if (error == SSL_ERROR_WANT_WRITE) {
1123       // TODO: Even though we are attempting to read data, SSL_read() may
1124       // need to write data if renegotiation is being performed.  We currently
1125       // don't support this and just fail the read.
1126       LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1127                  << ", sslState=" << sslState_ << ", events=" << eventFlags_
1128                  << "): unsupported SSL renegotiation during read",
1129       errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
1130                        SSL_INVALID_RENEGOTIATION);
1131       ERR_clear_error();
1132       return READ_ERROR;
1133     } else {
1134       // TODO: Fix this code so that it can return a proper error message
1135       // to the callback, rather than relying on AsyncSocket code which
1136       // can't handle SSL errors.
1137       long lastError = ERR_get_error();
1138
1139       VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
1140               << "state=" << state_ << ", "
1141               << "sslState=" << sslState_ << ", "
1142               << "events=" << std::hex << eventFlags_ << "): "
1143               << "bytes: " << bytes << ", "
1144               << "error: " << error << ", "
1145               << "errno: " << errno << ", "
1146               << "func: " << ERR_func_error_string(lastError) << ", "
1147               << "reason: " << ERR_reason_error_string(lastError);
1148       ERR_clear_error();
1149       if (zero_return(error, bytes)) {
1150         return bytes;
1151       }
1152       if (error != SSL_ERROR_SYSCALL) {
1153         if ((unsigned long)lastError < 0x8000) {
1154           errno = ENOSYS;
1155         } else {
1156           errno = lastError;
1157         }
1158       }
1159       return READ_ERROR;
1160     }
1161   } else {
1162     appBytesReceived_ += bytes;
1163     return bytes;
1164   }
1165 }
1166
1167 void AsyncSSLSocket::handleWrite() noexcept {
1168   VLOG(5) << "AsyncSSLSocket::handleWrite() this=" << this << ", fd=" << fd_
1169           << ", state=" << int(state_) << ", "
1170           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1171   if (state_ < StateEnum::ESTABLISHED) {
1172     return AsyncSocket::handleWrite();
1173   }
1174
1175   if (sslState_ == STATE_ACCEPTING) {
1176     assert(server_);
1177     handleAccept();
1178     return;
1179   }
1180
1181   if (sslState_ == STATE_CONNECTING) {
1182     assert(!server_);
1183     handleConnect();
1184     return;
1185   }
1186
1187   // Normal write
1188   AsyncSocket::handleWrite();
1189 }
1190
1191 int AsyncSSLSocket::interpretSSLError(int rc, int error) {
1192   if (error == SSL_ERROR_WANT_READ) {
1193     // TODO: Even though we are attempting to write data, SSL_write() may
1194     // need to read data if renegotiation is being performed.  We currently
1195     // don't support this and just fail the write.
1196     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1197                << ", sslState=" << sslState_ << ", events=" << eventFlags_
1198                << "): " << "unsupported SSL renegotiation during write",
1199       errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
1200                        SSL_INVALID_RENEGOTIATION);
1201     ERR_clear_error();
1202     return -1;
1203   } else {
1204     // TODO: Fix this code so that it can return a proper error message
1205     // to the callback, rather than relying on AsyncSocket code which
1206     // can't handle SSL errors.
1207     long lastError = ERR_get_error();
1208     VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1209             << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
1210             << "SSL error: " << error << ", errno: " << errno
1211             << ", func: " << ERR_func_error_string(lastError)
1212             << ", reason: " << ERR_reason_error_string(lastError);
1213     if (error != SSL_ERROR_SYSCALL) {
1214       if ((unsigned long)lastError < 0x8000) {
1215         errno = ENOSYS;
1216       } else {
1217         errno = lastError;
1218       }
1219     }
1220     ERR_clear_error();
1221     if (!zero_return(error, rc)) {
1222       return -1;
1223     } else {
1224       return 0;
1225     }
1226   }
1227 }
1228
1229 ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
1230                                       uint32_t count,
1231                                       WriteFlags flags,
1232                                       uint32_t* countWritten,
1233                                       uint32_t* partialWritten) {
1234   if (sslState_ == STATE_UNENCRYPTED) {
1235     return AsyncSocket::performWrite(
1236       vec, count, flags, countWritten, partialWritten);
1237   }
1238   if (sslState_ != STATE_ESTABLISHED) {
1239     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1240                << ", sslState=" << sslState_
1241                << ", events=" << eventFlags_ << "): "
1242                << "TODO: AsyncSSLSocket currently does not support calling "
1243                << "write() before the handshake has fully completed";
1244       errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
1245                        SSL_EARLY_WRITE);
1246       return -1;
1247   }
1248
1249   bool cork = isSet(flags, WriteFlags::CORK);
1250   CorkGuard guard(fd_, count > 1, cork, &corked_);
1251
1252 #if 0
1253 //#ifdef SSL_MODE_WRITE_IOVEC
1254   if (ssl_->expand == nullptr &&
1255       ssl_->compress == nullptr &&
1256       (ssl_->mode & SSL_MODE_WRITE_IOVEC)) {
1257     return performWriteIovec(vec, count, flags, countWritten, partialWritten);
1258   }
1259 #endif
1260
1261   // Declare a buffer used to hold small write requests.  It could point to a
1262   // memory block either on stack or on heap. If it is on heap, we release it
1263   // manually when scope exits
1264   char* combinedBuf{nullptr};
1265   SCOPE_EXIT {
1266     // Note, always keep this check consistent with what we do below
1267     if (combinedBuf != nullptr && minWriteSize_ > MAX_STACK_BUF_SIZE) {
1268       delete[] combinedBuf;
1269     }
1270   };
1271
1272   *countWritten = 0;
1273   *partialWritten = 0;
1274   ssize_t totalWritten = 0;
1275   size_t bytesStolenFromNextBuffer = 0;
1276   for (uint32_t i = 0; i < count; i++) {
1277     const iovec* v = vec + i;
1278     size_t offset = bytesStolenFromNextBuffer;
1279     bytesStolenFromNextBuffer = 0;
1280     size_t len = v->iov_len - offset;
1281     const void* buf;
1282     if (len == 0) {
1283       (*countWritten)++;
1284       continue;
1285     }
1286     buf = ((const char*)v->iov_base) + offset;
1287
1288     ssize_t bytes;
1289     errno = 0;
1290     uint32_t buffersStolen = 0;
1291     if ((len < minWriteSize_) && ((i + 1) < count)) {
1292       // Combine this buffer with part or all of the next buffers in
1293       // order to avoid really small-grained calls to SSL_write().
1294       // Each call to SSL_write() produces a separate record in
1295       // the egress SSL stream, and we've found that some low-end
1296       // mobile clients can't handle receiving an HTTP response
1297       // header and the first part of the response body in two
1298       // separate SSL records (even if those two records are in
1299       // the same TCP packet).
1300
1301       if (combinedBuf == nullptr) {
1302         if (minWriteSize_ > MAX_STACK_BUF_SIZE) {
1303           // Allocate the buffer on heap
1304           combinedBuf = new char[minWriteSize_];
1305         } else {
1306           // Allocate the buffer on stack
1307           combinedBuf = (char*)alloca(minWriteSize_);
1308         }
1309       }
1310       assert(combinedBuf != nullptr);
1311
1312       memcpy(combinedBuf, buf, len);
1313       do {
1314         // INVARIANT: i + buffersStolen == complete chunks serialized
1315         uint32_t nextIndex = i + buffersStolen + 1;
1316         bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len,
1317                                              minWriteSize_ - len);
1318         memcpy(combinedBuf + len, vec[nextIndex].iov_base,
1319                bytesStolenFromNextBuffer);
1320         len += bytesStolenFromNextBuffer;
1321         if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
1322           // couldn't steal the whole buffer
1323           break;
1324         } else {
1325           bytesStolenFromNextBuffer = 0;
1326           buffersStolen++;
1327         }
1328       } while ((i + buffersStolen + 1) < count && (len < minWriteSize_));
1329       bytes = eorAwareSSLWrite(
1330         ssl_, combinedBuf, len,
1331         (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
1332
1333     } else {
1334       bytes = eorAwareSSLWrite(ssl_, buf, len,
1335                            (isSet(flags, WriteFlags::EOR) && i + 1 == count));
1336     }
1337
1338     if (bytes <= 0) {
1339       int error = SSL_get_error(ssl_, bytes);
1340       if (error == SSL_ERROR_WANT_WRITE) {
1341         // The caller will register for write event if not already.
1342         *partialWritten = offset;
1343         return totalWritten;
1344       }
1345       int rc = interpretSSLError(bytes, error);
1346       if (rc < 0) {
1347         return rc;
1348       } // else fall through to below to correctly record totalWritten
1349     }
1350
1351     totalWritten += bytes;
1352
1353     if (bytes == (ssize_t)len) {
1354       // The full iovec is written.
1355       (*countWritten) += 1 + buffersStolen;
1356       i += buffersStolen;
1357       // continue
1358     } else {
1359       bytes += offset; // adjust bytes to account for all of v
1360       while (bytes >= (ssize_t)v->iov_len) {
1361         // We combined this buf with part or all of the next one, and
1362         // we managed to write all of this buf but not all of the bytes
1363         // from the next one that we'd hoped to write.
1364         bytes -= v->iov_len;
1365         (*countWritten)++;
1366         v = &(vec[++i]);
1367       }
1368       *partialWritten = bytes;
1369       return totalWritten;
1370     }
1371   }
1372
1373   return totalWritten;
1374 }
1375
1376 #if 0
1377 //#ifdef SSL_MODE_WRITE_IOVEC
1378 ssize_t AsyncSSLSocket::performWriteIovec(const iovec* vec,
1379                                           uint32_t count,
1380                                           WriteFlags flags,
1381                                           uint32_t* countWritten,
1382                                           uint32_t* partialWritten) {
1383   size_t tot = 0;
1384   for (uint32_t j = 0; j < count; j++) {
1385     tot += vec[j].iov_len;
1386   }
1387
1388   ssize_t totalWritten = SSL_write_iovec(ssl_, vec, count);
1389
1390   *countWritten = 0;
1391   *partialWritten = 0;
1392   if (totalWritten <= 0) {
1393     return interpretSSLError(totalWritten, SSL_get_error(ssl_, totalWritten));
1394   } else {
1395     ssize_t bytes = totalWritten, i = 0;
1396     while (i < count && bytes >= (ssize_t)vec[i].iov_len) {
1397       // we managed to write all of this buf
1398       bytes -= vec[i].iov_len;
1399       (*countWritten)++;
1400       i++;
1401     }
1402     *partialWritten = bytes;
1403
1404     VLOG(4) << "SSL_write_iovec() writes " << tot
1405             << ", returns " << totalWritten << " bytes"
1406             << ", max_send_fragment=" << ssl_->max_send_fragment
1407             << ", count=" << count << ", countWritten=" << *countWritten;
1408
1409     return totalWritten;
1410   }
1411 }
1412 #endif
1413
1414 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
1415                                       bool eor) {
1416   if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) {
1417     if (appEorByteNo_) {
1418       // cannot track for more than one app byte EOR
1419       CHECK(appEorByteNo_ == appBytesWritten_ + n);
1420     } else {
1421       appEorByteNo_ = appBytesWritten_ + n;
1422     }
1423
1424     // 1. It is fine to keep updating minEorRawByteNo_.
1425     // 2. It is _min_ in the sense that SSL record will add some overhead.
1426     minEorRawByteNo_ = getRawBytesWritten() + n;
1427   }
1428
1429   n = sslWriteImpl(ssl, buf, n);
1430   if (n > 0) {
1431     appBytesWritten_ += n;
1432     if (appEorByteNo_) {
1433       if (getRawBytesWritten() >= minEorRawByteNo_) {
1434         minEorRawByteNo_ = 0;
1435       }
1436       if(appBytesWritten_ == appEorByteNo_) {
1437         appEorByteNo_ = 0;
1438       } else {
1439         CHECK(appBytesWritten_ < appEorByteNo_);
1440       }
1441     }
1442   }
1443   return n;
1444 }
1445
1446 void
1447 AsyncSSLSocket::sslInfoCallback(const SSL *ssl, int where, int ret) {
1448   AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl);
1449   if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
1450     sslSocket->renegotiateAttempted_ = true;
1451   }
1452 }
1453
1454 int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) {
1455   int ret;
1456   struct msghdr msg;
1457   struct iovec iov;
1458   int flags = 0;
1459   AsyncSSLSocket *tsslSock;
1460
1461   iov.iov_base = const_cast<char *>(in);
1462   iov.iov_len = inl;
1463   memset(&msg, 0, sizeof(msg));
1464   msg.msg_iov = &iov;
1465   msg.msg_iovlen = 1;
1466
1467   tsslSock =
1468     reinterpret_cast<AsyncSSLSocket*>(BIO_get_app_data(b));
1469   if (tsslSock &&
1470       tsslSock->minEorRawByteNo_ &&
1471       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
1472     flags = MSG_EOR;
1473   }
1474
1475   errno = 0;
1476   ret = sendmsg(b->num, &msg, flags);
1477   BIO_clear_retry_flags(b);
1478   if (ret <= 0) {
1479     if (BIO_sock_should_retry(ret))
1480       BIO_set_retry_write(b);
1481   }
1482   return(ret);
1483 }
1484
1485 int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
1486                                        X509_STORE_CTX* x509Ctx) {
1487   SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
1488     x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
1489   AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
1490
1491   VLOG(3) <<  "AsyncSSLSocket::sslVerifyCallback() this=" << self << ", "
1492           << "fd=" << self->fd_ << ", preverifyOk=" << preverifyOk;
1493   return (self->handshakeCallback_) ?
1494     self->handshakeCallback_->handshakeVer(self, preverifyOk, x509Ctx) :
1495     preverifyOk;
1496 }
1497
1498 void AsyncSSLSocket::enableClientHelloParsing()  {
1499     parseClientHello_ = true;
1500     clientHelloInfo_.reset(new ClientHelloInfo());
1501 }
1502
1503 void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl)  {
1504   SSL_set_msg_callback(ssl, nullptr);
1505   SSL_set_msg_callback_arg(ssl, nullptr);
1506   clientHelloInfo_->clientHelloBuf_.clear();
1507 }
1508
1509 void
1510 AsyncSSLSocket::clientHelloParsingCallback(int written, int version,
1511     int contentType, const void *buf, size_t len, SSL *ssl, void *arg)
1512 {
1513   AsyncSSLSocket *sock = static_cast<AsyncSSLSocket*>(arg);
1514   if (written != 0) {
1515     sock->resetClientHelloParsing(ssl);
1516     return;
1517   }
1518   if (contentType != SSL3_RT_HANDSHAKE) {
1519     sock->resetClientHelloParsing(ssl);
1520     return;
1521   }
1522   if (len == 0) {
1523     return;
1524   }
1525
1526   auto& clientHelloBuf = sock->clientHelloInfo_->clientHelloBuf_;
1527   clientHelloBuf.append(IOBuf::wrapBuffer(buf, len));
1528   try {
1529     Cursor cursor(clientHelloBuf.front());
1530     if (cursor.read<uint8_t>() != SSL3_MT_CLIENT_HELLO) {
1531       sock->resetClientHelloParsing(ssl);
1532       return;
1533     }
1534
1535     if (cursor.totalLength() < 3) {
1536       clientHelloBuf.trimEnd(len);
1537       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1538       return;
1539     }
1540
1541     uint32_t messageLength = cursor.read<uint8_t>();
1542     messageLength <<= 8;
1543     messageLength |= cursor.read<uint8_t>();
1544     messageLength <<= 8;
1545     messageLength |= cursor.read<uint8_t>();
1546     if (cursor.totalLength() < messageLength) {
1547       clientHelloBuf.trimEnd(len);
1548       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1549       return;
1550     }
1551
1552     sock->clientHelloInfo_->clientHelloMajorVersion_ = cursor.read<uint8_t>();
1553     sock->clientHelloInfo_->clientHelloMinorVersion_ = cursor.read<uint8_t>();
1554
1555     cursor.skip(4); // gmt_unix_time
1556     cursor.skip(28); // random_bytes
1557
1558     cursor.skip(cursor.read<uint8_t>()); // session_id
1559
1560     uint16_t cipherSuitesLength = cursor.readBE<uint16_t>();
1561     for (int i = 0; i < cipherSuitesLength; i += 2) {
1562       sock->clientHelloInfo_->
1563         clientHelloCipherSuites_.push_back(cursor.readBE<uint16_t>());
1564     }
1565
1566     uint8_t compressionMethodsLength = cursor.read<uint8_t>();
1567     for (int i = 0; i < compressionMethodsLength; ++i) {
1568       sock->clientHelloInfo_->
1569         clientHelloCompressionMethods_.push_back(cursor.readBE<uint8_t>());
1570     }
1571
1572     if (cursor.totalLength() > 0) {
1573       uint16_t extensionsLength = cursor.readBE<uint16_t>();
1574       while (extensionsLength) {
1575         sock->clientHelloInfo_->
1576           clientHelloExtensions_.push_back(cursor.readBE<uint16_t>());
1577         extensionsLength -= 2;
1578         uint16_t extensionDataLength = cursor.readBE<uint16_t>();
1579         extensionsLength -= 2;
1580         cursor.skip(extensionDataLength);
1581         extensionsLength -= extensionDataLength;
1582       }
1583     }
1584   } catch (std::out_of_range& e) {
1585     // we'll use what we found and cleanup below.
1586     VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
1587       << "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock;
1588   }
1589
1590   sock->resetClientHelloParsing(ssl);
1591 }
1592
1593 } // namespace