Fix AsyncSSLSocket handshake error reporting.
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSSLSocket.h>
18
19 #include <folly/io/async/EventBase.h>
20
21 #include <boost/noncopyable.hpp>
22 #include <errno.h>
23 #include <fcntl.h>
24 #include <netinet/in.h>
25 #include <netinet/tcp.h>
26 #include <openssl/err.h>
27 #include <openssl/asn1.h>
28 #include <openssl/ssl.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31 #include <unistd.h>
32 #include <chrono>
33
34 #include <folly/Bits.h>
35 #include <folly/SocketAddress.h>
36 #include <folly/SpinLock.h>
37 #include <folly/io/IOBuf.h>
38 #include <folly/io/Cursor.h>
39
40 using folly::SocketAddress;
41 using folly::SSLContext;
42 using std::string;
43 using std::shared_ptr;
44
45 using folly::Endian;
46 using folly::IOBuf;
47 using folly::SpinLock;
48 using folly::SpinLockGuard;
49 using folly::io::Cursor;
50 using std::unique_ptr;
51 using std::bind;
52
53 namespace {
54 using folly::AsyncSocket;
55 using folly::AsyncSocketException;
56 using folly::AsyncSSLSocket;
57 using folly::Optional;
58 using folly::SSLContext;
59
60 // We have one single dummy SSL context so that we can implement attach
61 // and detach methods in a thread safe fashion without modifying opnessl.
62 static SSLContext *dummyCtx = nullptr;
63 static SpinLock dummyCtxLock;
64
65 // Numbers chosen as to not collide with functions in ssl.h
66 const uint8_t TASYNCSSLSOCKET_F_PERFORM_READ = 90;
67 const uint8_t TASYNCSSLSOCKET_F_PERFORM_WRITE = 91;
68
69 // If given min write size is less than this, buffer will be allocated on
70 // stack, otherwise it is allocated on heap
71 const size_t MAX_STACK_BUF_SIZE = 2048;
72
73 // This converts "illegal" shutdowns into ZERO_RETURN
74 inline bool zero_return(int error, int rc) {
75   return (error == SSL_ERROR_ZERO_RETURN || (rc == 0 && errno == 0));
76 }
77
78 class AsyncSSLSocketConnector: public AsyncSocket::ConnectCallback,
79                                 public AsyncSSLSocket::HandshakeCB {
80
81  private:
82   AsyncSSLSocket *sslSocket_;
83   AsyncSSLSocket::ConnectCallback *callback_;
84   int timeout_;
85   int64_t startTime_;
86
87  protected:
88   ~AsyncSSLSocketConnector() override {}
89
90  public:
91   AsyncSSLSocketConnector(AsyncSSLSocket *sslSocket,
92                            AsyncSocket::ConnectCallback *callback,
93                            int timeout) :
94       sslSocket_(sslSocket),
95       callback_(callback),
96       timeout_(timeout),
97       startTime_(std::chrono::duration_cast<std::chrono::milliseconds>(
98                    std::chrono::steady_clock::now().time_since_epoch()).count()) {
99   }
100
101   void connectSuccess() noexcept override {
102     VLOG(7) << "client socket connected";
103
104     int64_t timeoutLeft = 0;
105     if (timeout_ > 0) {
106       auto curTime = std::chrono::duration_cast<std::chrono::milliseconds>(
107         std::chrono::steady_clock::now().time_since_epoch()).count();
108
109       timeoutLeft = timeout_ - (curTime - startTime_);
110       if (timeoutLeft <= 0) {
111         AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
112                                 "SSL connect timed out");
113         fail(ex);
114         delete this;
115         return;
116       }
117     }
118     sslSocket_->sslConn(this, timeoutLeft);
119   }
120
121   void connectErr(const AsyncSocketException& ex) noexcept override {
122     LOG(ERROR) << "TCP connect failed: " <<  ex.what();
123     fail(ex);
124     delete this;
125   }
126
127   void handshakeSuc(AsyncSSLSocket* /* sock */) noexcept override {
128     VLOG(7) << "client handshake success";
129     if (callback_) {
130       callback_->connectSuccess();
131     }
132     delete this;
133   }
134
135   void handshakeErr(AsyncSSLSocket* /* socket */,
136                     const AsyncSocketException& ex) noexcept override {
137     LOG(ERROR) << "client handshakeErr: " << ex.what();
138     fail(ex);
139     delete this;
140   }
141
142   void fail(const AsyncSocketException &ex) {
143     // fail is a noop if called twice
144     if (callback_) {
145       AsyncSSLSocket::ConnectCallback *cb = callback_;
146       callback_ = nullptr;
147
148       cb->connectErr(ex);
149       sslSocket_->closeNow();
150       // closeNow can call handshakeErr if it hasn't been called already.
151       // So this may have been deleted, no member variable access beyond this
152       // point
153       // Note that closeNow may invoke writeError callbacks if the socket had
154       // write data pending connection completion.
155     }
156   }
157 };
158
159 // XXX: implement an equivalent to corking for platforms with TCP_NOPUSH?
160 #ifdef TCP_CORK // Linux-only
161 /**
162  * Utility class that corks a TCP socket upon construction or uncorks
163  * the socket upon destruction
164  */
165 class CorkGuard : private boost::noncopyable {
166  public:
167   CorkGuard(int fd, bool multipleWrites, bool haveMore, bool* corked):
168     fd_(fd), haveMore_(haveMore), corked_(corked) {
169     if (*corked_) {
170       // socket is already corked; nothing to do
171       return;
172     }
173     if (multipleWrites || haveMore) {
174       // We are performing multiple writes in this performWrite() call,
175       // and/or there are more calls to performWrite() that will be invoked
176       // later, so enable corking
177       int flag = 1;
178       setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
179       *corked_ = true;
180     }
181   }
182
183   ~CorkGuard() {
184     if (haveMore_) {
185       // more data to come; don't uncork yet
186       return;
187     }
188     if (!*corked_) {
189       // socket isn't corked; nothing to do
190       return;
191     }
192
193     int flag = 0;
194     setsockopt(fd_, IPPROTO_TCP, TCP_CORK, &flag, sizeof(flag));
195     *corked_ = false;
196   }
197
198  private:
199   int fd_;
200   bool haveMore_;
201   bool* corked_;
202 };
203 #else
204 class CorkGuard : private boost::noncopyable {
205  public:
206   CorkGuard(int, bool, bool, bool*) {}
207 };
208 #endif
209
210 void setup_SSL_CTX(SSL_CTX *ctx) {
211 #ifdef SSL_MODE_RELEASE_BUFFERS
212   SSL_CTX_set_mode(ctx,
213                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
214                    SSL_MODE_ENABLE_PARTIAL_WRITE
215                    | SSL_MODE_RELEASE_BUFFERS
216                    );
217 #else
218   SSL_CTX_set_mode(ctx,
219                    SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
220                    SSL_MODE_ENABLE_PARTIAL_WRITE
221                    );
222 #endif
223 // SSL_CTX_set_mode is a Macro
224 #ifdef SSL_MODE_WRITE_IOVEC
225   SSL_CTX_set_mode(ctx,
226                    SSL_CTX_get_mode(ctx)
227                    | SSL_MODE_WRITE_IOVEC);
228 #endif
229
230 }
231
232 BIO_METHOD eorAwareBioMethod;
233
234 void* initEorBioMethod(void) {
235   memcpy(&eorAwareBioMethod, BIO_s_socket(), sizeof(eorAwareBioMethod));
236   // override the bwrite method for MSG_EOR support
237   eorAwareBioMethod.bwrite = AsyncSSLSocket::eorAwareBioWrite;
238
239   // Note that the eorAwareBioMethod.type and eorAwareBioMethod.name are not
240   // set here. openssl code seems to be checking ".type == BIO_TYPE_SOCKET" and
241   // then have specific handlings. The eorAwareBioWrite should be compatible
242   // with the one in openssl.
243
244   // Return something here to enable AsyncSSLSocket to call this method using
245   // a function-scoped static.
246   return nullptr;
247 }
248
249 std::string decodeOpenSSLError(int sslError,
250                                unsigned long errError,
251                                int sslOperationReturnValue) {
252   if (sslError == SSL_ERROR_SYSCALL && errError == 0) {
253     if (sslOperationReturnValue == 0) {
254       return "SSL_ERROR_SYSCALL: EOF";
255     } else {
256       // In this case errno is set, AsyncSocketException will add it.
257       return "SSL_ERROR_SYSCALL";
258     }
259   } else if (sslError == SSL_ERROR_ZERO_RETURN) {
260     // This signifies a TLS closure alert.
261     return "SSL_ERROR_ZERO_RETURN";
262   } else {
263     char buf[256];
264     std::string msg(ERR_error_string(errError, buf));
265     return msg;
266   }
267 }
268
269 } // anonymous namespace
270
271 namespace folly {
272
273 SSLException::SSLException(int sslError,
274                            unsigned long errError,
275                            int sslOperationReturnValue,
276                            int errno_copy)
277     : AsyncSocketException(
278           AsyncSocketException::SSL_ERROR,
279           decodeOpenSSLError(sslError, errError, sslOperationReturnValue),
280           sslError == SSL_ERROR_SYSCALL ? errno_copy : 0) {}
281
282 /**
283  * Create a client AsyncSSLSocket
284  */
285 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
286                                EventBase* evb, bool deferSecurityNegotiation) :
287     AsyncSocket(evb),
288     ctx_(ctx),
289     handshakeTimeout_(this, evb) {
290   init();
291   if (deferSecurityNegotiation) {
292     sslState_ = STATE_UNENCRYPTED;
293   }
294 }
295
296 /**
297  * Create a server/client AsyncSSLSocket
298  */
299 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
300                                EventBase* evb, int fd, bool server,
301                                bool deferSecurityNegotiation) :
302     AsyncSocket(evb, fd),
303     server_(server),
304     ctx_(ctx),
305     handshakeTimeout_(this, evb) {
306   init();
307   if (server) {
308     SSL_CTX_set_info_callback(ctx_->getSSLCtx(),
309                               AsyncSSLSocket::sslInfoCallback);
310   }
311   if (deferSecurityNegotiation) {
312     sslState_ = STATE_UNENCRYPTED;
313   }
314 }
315
316 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
317 /**
318  * Create a client AsyncSSLSocket and allow tlsext_hostname
319  * to be sent in Client Hello.
320  */
321 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext> &ctx,
322                                  EventBase* evb,
323                                const std::string& serverName,
324                                bool deferSecurityNegotiation) :
325     AsyncSSLSocket(ctx, evb, deferSecurityNegotiation) {
326   tlsextHostname_ = serverName;
327 }
328
329 /**
330  * Create a client AsyncSSLSocket from an already connected fd
331  * and allow tlsext_hostname to be sent in Client Hello.
332  */
333 AsyncSSLSocket::AsyncSSLSocket(const shared_ptr<SSLContext>& ctx,
334                                  EventBase* evb, int fd,
335                                const std::string& serverName,
336                                bool deferSecurityNegotiation) :
337     AsyncSSLSocket(ctx, evb, fd, false, deferSecurityNegotiation) {
338   tlsextHostname_ = serverName;
339 }
340 #endif
341
342 AsyncSSLSocket::~AsyncSSLSocket() {
343   VLOG(3) << "actual destruction of AsyncSSLSocket(this=" << this
344           << ", evb=" << eventBase_ << ", fd=" << fd_
345           << ", state=" << int(state_) << ", sslState="
346           << sslState_ << ", events=" << eventFlags_ << ")";
347 }
348
349 void AsyncSSLSocket::init() {
350   // Do this here to ensure we initialize this once before any use of
351   // AsyncSSLSocket instances and not as part of library load.
352   static const auto eorAwareBioMethodInitializer = initEorBioMethod();
353   (void)eorAwareBioMethodInitializer;
354
355   setup_SSL_CTX(ctx_->getSSLCtx());
356 }
357
358 void AsyncSSLSocket::closeNow() {
359   // Close the SSL connection.
360   if (ssl_ != nullptr && fd_ != -1) {
361     int rc = SSL_shutdown(ssl_);
362     if (rc == 0) {
363       rc = SSL_shutdown(ssl_);
364     }
365     if (rc < 0) {
366       ERR_clear_error();
367     }
368   }
369
370   if (sslSession_ != nullptr) {
371     SSL_SESSION_free(sslSession_);
372     sslSession_ = nullptr;
373   }
374
375   sslState_ = STATE_CLOSED;
376
377   if (handshakeTimeout_.isScheduled()) {
378     handshakeTimeout_.cancelTimeout();
379   }
380
381   DestructorGuard dg(this);
382
383   invokeHandshakeErr(
384       AsyncSocketException(
385         AsyncSocketException::END_OF_FILE,
386         "SSL connection closed locally"));
387
388   if (ssl_ != nullptr) {
389     SSL_free(ssl_);
390     ssl_ = nullptr;
391   }
392
393   // Close the socket.
394   AsyncSocket::closeNow();
395 }
396
397 void AsyncSSLSocket::shutdownWrite() {
398   // SSL sockets do not support half-shutdown, so just perform a full shutdown.
399   //
400   // (Performing a full shutdown here is more desirable than doing nothing at
401   // all.  The purpose of shutdownWrite() is normally to notify the other end
402   // of the connection that no more data will be sent.  If we do nothing, the
403   // other end will never know that no more data is coming, and this may result
404   // in protocol deadlock.)
405   close();
406 }
407
408 void AsyncSSLSocket::shutdownWriteNow() {
409   closeNow();
410 }
411
412 bool AsyncSSLSocket::good() const {
413   return (AsyncSocket::good() &&
414           (sslState_ == STATE_ACCEPTING || sslState_ == STATE_CONNECTING ||
415            sslState_ == STATE_ESTABLISHED || sslState_ == STATE_UNENCRYPTED));
416 }
417
418 // The TAsyncTransport definition of 'good' states that the transport is
419 // ready to perform reads and writes, so sslState_ == UNINIT must report !good.
420 // connecting can be true when the sslState_ == UNINIT because the AsyncSocket
421 // is connected but we haven't initiated the call to SSL_connect.
422 bool AsyncSSLSocket::connecting() const {
423   return (!server_ &&
424           (AsyncSocket::connecting() ||
425            (AsyncSocket::good() && (sslState_ == STATE_UNINIT ||
426                                      sslState_ == STATE_CONNECTING))));
427 }
428
429 std::string AsyncSSLSocket::getApplicationProtocol() noexcept {
430   const unsigned char* protoName = nullptr;
431   unsigned protoLength;
432   if (getSelectedNextProtocolNoThrow(&protoName, &protoLength)) {
433     return std::string(reinterpret_cast<const char*>(protoName), protoLength);
434   }
435   return "";
436 }
437
438 bool AsyncSSLSocket::isEorTrackingEnabled() const {
439   if (ssl_ == nullptr) {
440     return false;
441   }
442   const BIO *wb = SSL_get_wbio(ssl_);
443   return wb && wb->method == &eorAwareBioMethod;
444 }
445
446 void AsyncSSLSocket::setEorTracking(bool track) {
447   BIO *wb = SSL_get_wbio(ssl_);
448   if (!wb) {
449     throw AsyncSocketException(AsyncSocketException::INVALID_STATE,
450                               "setting EOR tracking without an initialized "
451                               "BIO");
452   }
453
454   if (track) {
455     if (wb->method != &eorAwareBioMethod) {
456       // only do this if we didn't
457       wb->method = &eorAwareBioMethod;
458       BIO_set_app_data(wb, this);
459       appEorByteNo_ = 0;
460       minEorRawByteNo_ = 0;
461     }
462   } else if (wb->method == &eorAwareBioMethod) {
463     wb->method = BIO_s_socket();
464     BIO_set_app_data(wb, nullptr);
465     appEorByteNo_ = 0;
466     minEorRawByteNo_ = 0;
467   } else {
468     CHECK(wb->method == BIO_s_socket());
469   }
470 }
471
472 size_t AsyncSSLSocket::getRawBytesWritten() const {
473   BIO *b;
474   if (!ssl_ || !(b = SSL_get_wbio(ssl_))) {
475     return 0;
476   }
477
478   return BIO_number_written(b);
479 }
480
481 size_t AsyncSSLSocket::getRawBytesReceived() const {
482   BIO *b;
483   if (!ssl_ || !(b = SSL_get_rbio(ssl_))) {
484     return 0;
485   }
486
487   return BIO_number_read(b);
488 }
489
490
491 void AsyncSSLSocket::invalidState(HandshakeCB* callback) {
492   LOG(ERROR) << "AsyncSSLSocket(this=" << this << ", fd=" << fd_
493              << ", state=" << int(state_) << ", sslState=" << sslState_ << ", "
494              << "events=" << eventFlags_ << ", server=" << short(server_) << "): "
495              << "sslAccept/Connect() called in invalid "
496              << "state, handshake callback " << handshakeCallback_ << ", new callback "
497              << callback;
498   assert(!handshakeTimeout_.isScheduled());
499   sslState_ = STATE_ERROR;
500
501   AsyncSocketException ex(AsyncSocketException::INVALID_STATE,
502                          "sslAccept() called with socket in invalid state");
503
504   handshakeEndTime_ = std::chrono::steady_clock::now();
505   if (callback) {
506     callback->handshakeErr(this, ex);
507   }
508
509   // Check the socket state not the ssl state here.
510   if (state_ != StateEnum::CLOSED || state_ != StateEnum::ERROR) {
511     failHandshake(__func__, ex);
512   }
513 }
514
515 void AsyncSSLSocket::sslAccept(HandshakeCB* callback, uint32_t timeout,
516       const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
517   DestructorGuard dg(this);
518   assert(eventBase_->isInEventBaseThread());
519   verifyPeer_ = verifyPeer;
520
521   // Make sure we're in the uninitialized state
522   if (!server_ || (sslState_ != STATE_UNINIT &&
523                    sslState_ != STATE_UNENCRYPTED) ||
524       handshakeCallback_ != nullptr) {
525     return invalidState(callback);
526   }
527
528   // Cache local and remote socket addresses to keep them available
529   // after socket file descriptor is closed.
530   if (cacheAddrOnFailure_ && -1 != getFd()) {
531     cacheLocalPeerAddr();
532   }
533
534   handshakeStartTime_ = std::chrono::steady_clock::now();
535   // Make end time at least >= start time.
536   handshakeEndTime_ = handshakeStartTime_;
537
538   sslState_ = STATE_ACCEPTING;
539   handshakeCallback_ = callback;
540
541   if (timeout > 0) {
542     handshakeTimeout_.scheduleTimeout(timeout);
543   }
544
545   /* register for a read operation (waiting for CLIENT HELLO) */
546   updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
547 }
548
549 #if OPENSSL_VERSION_NUMBER >= 0x009080bfL
550 void AsyncSSLSocket::attachSSLContext(
551   const std::shared_ptr<SSLContext>& ctx) {
552
553   // Check to ensure we are in client mode. Changing a server's ssl
554   // context doesn't make sense since clients of that server would likely
555   // become confused when the server's context changes.
556   DCHECK(!server_);
557   DCHECK(!ctx_);
558   DCHECK(ctx);
559   DCHECK(ctx->getSSLCtx());
560   ctx_ = ctx;
561
562   // In order to call attachSSLContext, detachSSLContext must have been
563   // previously called which sets the socket's context to the dummy
564   // context. Thus we must acquire this lock.
565   SpinLockGuard guard(dummyCtxLock);
566   SSL_set_SSL_CTX(ssl_, ctx->getSSLCtx());
567 }
568
569 void AsyncSSLSocket::detachSSLContext() {
570   DCHECK(ctx_);
571   ctx_.reset();
572   // We aren't using the initial_ctx for now, and it can introduce race
573   // conditions in the destructor of the SSL object.
574 #ifndef OPENSSL_NO_TLSEXT
575   if (ssl_->initial_ctx) {
576     SSL_CTX_free(ssl_->initial_ctx);
577     ssl_->initial_ctx = nullptr;
578   }
579 #endif
580   SpinLockGuard guard(dummyCtxLock);
581   if (nullptr == dummyCtx) {
582     // We need to lazily initialize the dummy context so we don't
583     // accidentally override any programmatic settings to openssl
584     dummyCtx = new SSLContext;
585   }
586   // We must remove this socket's references to its context right now
587   // since this socket could get passed to any thread. If the context has
588   // had its locking disabled, just doing a set in attachSSLContext()
589   // would not be thread safe.
590   SSL_set_SSL_CTX(ssl_, dummyCtx->getSSLCtx());
591 }
592 #endif
593
594 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
595 void AsyncSSLSocket::switchServerSSLContext(
596   const std::shared_ptr<SSLContext>& handshakeCtx) {
597   CHECK(server_);
598   if (sslState_ != STATE_ACCEPTING) {
599     // We log it here and allow the switch.
600     // It should not affect our re-negotiation support (which
601     // is not supported now).
602     VLOG(6) << "fd=" << getFd()
603             << " renegotation detected when switching SSL_CTX";
604   }
605
606   setup_SSL_CTX(handshakeCtx->getSSLCtx());
607   SSL_CTX_set_info_callback(handshakeCtx->getSSLCtx(),
608                             AsyncSSLSocket::sslInfoCallback);
609   handshakeCtx_ = handshakeCtx;
610   SSL_set_SSL_CTX(ssl_, handshakeCtx->getSSLCtx());
611 }
612
613 bool AsyncSSLSocket::isServerNameMatch() const {
614   CHECK(!server_);
615
616   if (!ssl_) {
617     return false;
618   }
619
620   SSL_SESSION *ss = SSL_get_session(ssl_);
621   if (!ss) {
622     return false;
623   }
624
625   if(!ss->tlsext_hostname) {
626     return false;
627   }
628   return (tlsextHostname_.compare(ss->tlsext_hostname) ? false : true);
629 }
630
631 void AsyncSSLSocket::setServerName(std::string serverName) noexcept {
632   tlsextHostname_ = std::move(serverName);
633 }
634
635 #endif
636
637 void AsyncSSLSocket::timeoutExpired() noexcept {
638   if (state_ == StateEnum::ESTABLISHED &&
639       (sslState_ == STATE_CACHE_LOOKUP ||
640        sslState_ == STATE_ASYNC_PENDING)) {
641     sslState_ = STATE_ERROR;
642     // We are expecting a callback in restartSSLAccept.  The cache lookup
643     // and rsa-call necessarily have pointers to this ssl socket, so delay
644     // the cleanup until he calls us back.
645   } else {
646     assert(state_ == StateEnum::ESTABLISHED &&
647            (sslState_ == STATE_CONNECTING || sslState_ == STATE_ACCEPTING));
648     DestructorGuard dg(this);
649     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
650                            (sslState_ == STATE_CONNECTING) ?
651                            "SSL connect timed out" : "SSL accept timed out");
652     failHandshake(__func__, ex);
653   }
654 }
655
656 int AsyncSSLSocket::getSSLExDataIndex() {
657   static auto index = SSL_get_ex_new_index(
658       0, (void*)"AsyncSSLSocket data index", nullptr, nullptr, nullptr);
659   return index;
660 }
661
662 AsyncSSLSocket* AsyncSSLSocket::getFromSSL(const SSL *ssl) {
663   return static_cast<AsyncSSLSocket *>(SSL_get_ex_data(ssl,
664       getSSLExDataIndex()));
665 }
666
667 void AsyncSSLSocket::failHandshake(const char* /* fn */,
668                                    const AsyncSocketException& ex) {
669   startFail();
670   if (handshakeTimeout_.isScheduled()) {
671     handshakeTimeout_.cancelTimeout();
672   }
673   invokeHandshakeErr(ex);
674   finishFail();
675 }
676
677 void AsyncSSLSocket::invokeHandshakeErr(const AsyncSocketException& ex) {
678   handshakeEndTime_ = std::chrono::steady_clock::now();
679   if (handshakeCallback_ != nullptr) {
680     HandshakeCB* callback = handshakeCallback_;
681     handshakeCallback_ = nullptr;
682     callback->handshakeErr(this, ex);
683   }
684 }
685
686 void AsyncSSLSocket::invokeHandshakeCB() {
687   handshakeEndTime_ = std::chrono::steady_clock::now();
688   if (handshakeTimeout_.isScheduled()) {
689     handshakeTimeout_.cancelTimeout();
690   }
691   if (handshakeCallback_) {
692     HandshakeCB* callback = handshakeCallback_;
693     handshakeCallback_ = nullptr;
694     callback->handshakeSuc(this);
695   }
696 }
697
698 void AsyncSSLSocket::cacheLocalPeerAddr() {
699   SocketAddress address;
700   try {
701     getLocalAddress(&address);
702     getPeerAddress(&address);
703   } catch (const std::system_error& e) {
704     // The handle can be still valid while the connection is already closed.
705     if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
706       throw;
707     }
708   }
709 }
710
711 void AsyncSSLSocket::connect(ConnectCallback* callback,
712                               const folly::SocketAddress& address,
713                               int timeout,
714                               const OptionMap &options,
715                               const folly::SocketAddress& bindAddr)
716                               noexcept {
717   assert(!server_);
718   assert(state_ == StateEnum::UNINIT);
719   assert(sslState_ == STATE_UNINIT);
720   AsyncSSLSocketConnector *connector =
721     new AsyncSSLSocketConnector(this, callback, timeout);
722   AsyncSocket::connect(connector, address, timeout, options, bindAddr);
723 }
724
725 void AsyncSSLSocket::applyVerificationOptions(SSL * ssl) {
726   // apply the settings specified in verifyPeer_
727   if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::USE_CTX) {
728     if(ctx_->needsPeerVerification()) {
729       SSL_set_verify(ssl, ctx_->getVerificationMode(),
730         AsyncSSLSocket::sslVerifyCallback);
731     }
732   } else {
733     if (verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY ||
734         verifyPeer_ == SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT) {
735       SSL_set_verify(ssl, SSLContext::getVerificationMode(verifyPeer_),
736         AsyncSSLSocket::sslVerifyCallback);
737     }
738   }
739 }
740
741 void AsyncSSLSocket::sslConn(HandshakeCB* callback, uint64_t timeout,
742         const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
743   DestructorGuard dg(this);
744   assert(eventBase_->isInEventBaseThread());
745
746   // Cache local and remote socket addresses to keep them available
747   // after socket file descriptor is closed.
748   if (cacheAddrOnFailure_ && -1 != getFd()) {
749     cacheLocalPeerAddr();
750   }
751
752   verifyPeer_ = verifyPeer;
753
754   // Make sure we're in the uninitialized state
755   if (server_ || (sslState_ != STATE_UNINIT && sslState_ !=
756                   STATE_UNENCRYPTED) ||
757       handshakeCallback_ != nullptr) {
758     return invalidState(callback);
759   }
760
761   handshakeStartTime_ = std::chrono::steady_clock::now();
762   // Make end time at least >= start time.
763   handshakeEndTime_ = handshakeStartTime_;
764
765   sslState_ = STATE_CONNECTING;
766   handshakeCallback_ = callback;
767
768   try {
769     ssl_ = ctx_->createSSL();
770   } catch (std::exception &e) {
771     sslState_ = STATE_ERROR;
772     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
773                            "error calling SSLContext::createSSL()");
774     LOG(ERROR) << "AsyncSSLSocket::sslConn(this=" << this << ", fd="
775             << fd_ << "): " << e.what();
776     return failHandshake(__func__, ex);
777   }
778
779   applyVerificationOptions(ssl_);
780
781   SSL_set_fd(ssl_, fd_);
782   if (sslSession_ != nullptr) {
783     SSL_set_session(ssl_, sslSession_);
784     SSL_SESSION_free(sslSession_);
785     sslSession_ = nullptr;
786   }
787 #if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
788   if (tlsextHostname_.size()) {
789     SSL_set_tlsext_host_name(ssl_, tlsextHostname_.c_str());
790   }
791 #endif
792
793   SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
794
795   if (timeout > 0) {
796     handshakeTimeout_.scheduleTimeout(timeout);
797   }
798
799   handleConnect();
800 }
801
802 SSL_SESSION *AsyncSSLSocket::getSSLSession() {
803   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
804     return SSL_get1_session(ssl_);
805   }
806
807   return sslSession_;
808 }
809
810 void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
811   sslSession_ = session;
812   if (!takeOwnership && session != nullptr) {
813     // Increment the reference count
814     CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
815   }
816 }
817
818 void AsyncSSLSocket::getSelectedNextProtocol(
819     const unsigned char** protoName,
820     unsigned* protoLen,
821     SSLContext::NextProtocolType* protoType) const {
822   if (!getSelectedNextProtocolNoThrow(protoName, protoLen, protoType)) {
823     throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
824                               "NPN not supported");
825   }
826 }
827
828 bool AsyncSSLSocket::getSelectedNextProtocolNoThrow(
829     const unsigned char** protoName,
830     unsigned* protoLen,
831     SSLContext::NextProtocolType* protoType) const {
832   *protoName = nullptr;
833   *protoLen = 0;
834 #if OPENSSL_VERSION_NUMBER >= 0x1000200fL && !defined(OPENSSL_NO_TLSEXT)
835   SSL_get0_alpn_selected(ssl_, protoName, protoLen);
836   if (*protoLen > 0) {
837     if (protoType) {
838       *protoType = SSLContext::NextProtocolType::ALPN;
839     }
840     return true;
841   }
842 #endif
843 #ifdef OPENSSL_NPN_NEGOTIATED
844   SSL_get0_next_proto_negotiated(ssl_, protoName, protoLen);
845   if (protoType) {
846     *protoType = SSLContext::NextProtocolType::NPN;
847   }
848   return true;
849 #else
850   (void)protoType;
851   return false;
852 #endif
853 }
854
855 bool AsyncSSLSocket::getSSLSessionReused() const {
856   if (ssl_ != nullptr && sslState_ == STATE_ESTABLISHED) {
857     return SSL_session_reused(ssl_);
858   }
859   return false;
860 }
861
862 const char *AsyncSSLSocket::getNegotiatedCipherName() const {
863   return (ssl_ != nullptr) ? SSL_get_cipher_name(ssl_) : nullptr;
864 }
865
866 /* static */
867 const char* AsyncSSLSocket::getSSLServerNameFromSSL(SSL* ssl) {
868   if (ssl == nullptr) {
869     return nullptr;
870   }
871 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
872   return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name);
873 #else
874   return nullptr;
875 #endif
876 }
877
878 const char *AsyncSSLSocket::getSSLServerName() const {
879 #ifdef SSL_CTRL_SET_TLSEXT_SERVERNAME_CB
880   return getSSLServerNameFromSSL(ssl_);
881 #else
882   throw AsyncSocketException(AsyncSocketException::NOT_SUPPORTED,
883                              "SNI not supported");
884 #endif
885 }
886
887 const char *AsyncSSLSocket::getSSLServerNameNoThrow() const {
888   return getSSLServerNameFromSSL(ssl_);
889 }
890
891 int AsyncSSLSocket::getSSLVersion() const {
892   return (ssl_ != nullptr) ? SSL_version(ssl_) : 0;
893 }
894
895 const char *AsyncSSLSocket::getSSLCertSigAlgName() const {
896   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
897   if (cert) {
898     int nid = OBJ_obj2nid(cert->sig_alg->algorithm);
899     return OBJ_nid2ln(nid);
900   }
901   return nullptr;
902 }
903
904 int AsyncSSLSocket::getSSLCertSize() const {
905   int certSize = 0;
906   X509 *cert = (ssl_ != nullptr) ? SSL_get_certificate(ssl_) : nullptr;
907   if (cert) {
908     EVP_PKEY *key = X509_get_pubkey(cert);
909     certSize = EVP_PKEY_bits(key);
910     EVP_PKEY_free(key);
911   }
912   return certSize;
913 }
914
915 bool AsyncSSLSocket::willBlock(int ret,
916                                int* sslErrorOut,
917                                unsigned long* errErrorOut) noexcept {
918   *errErrorOut = 0;
919   int error = *sslErrorOut = SSL_get_error(ssl_, ret);
920   if (error == SSL_ERROR_WANT_READ) {
921     // Register for read event if not already.
922     updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
923     return true;
924   } else if (error == SSL_ERROR_WANT_WRITE) {
925     VLOG(3) << "AsyncSSLSocket(fd=" << fd_
926             << ", state=" << int(state_) << ", sslState="
927             << sslState_ << ", events=" << eventFlags_ << "): "
928             << "SSL_ERROR_WANT_WRITE";
929     // Register for write event if not already.
930     updateEventRegistration(EventHandler::WRITE, EventHandler::READ);
931     return true;
932 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
933   } else if (error == SSL_ERROR_WANT_SESS_CACHE_LOOKUP) {
934     // We will block but we can't register our own socket.  The callback that
935     // triggered this code will re-call handleAccept at the appropriate time.
936
937     // We can only get here if the linked libssl.so has support for this feature
938     // as well, otherwise SSL_get_error cannot return our error code.
939     sslState_ = STATE_CACHE_LOOKUP;
940
941     // Unregister for all events while blocked here
942     updateEventRegistration(EventHandler::NONE,
943                             EventHandler::READ | EventHandler::WRITE);
944
945     // The timeout (if set) keeps running here
946     return true;
947 #endif
948   } else if (0
949 #ifdef SSL_ERROR_WANT_RSA_ASYNC_PENDING
950       || error == SSL_ERROR_WANT_RSA_ASYNC_PENDING
951 #endif
952 #ifdef SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
953       || error == SSL_ERROR_WANT_ECDSA_ASYNC_PENDING
954 #endif
955       ) {
956     // Our custom openssl function has kicked off an async request to do
957     // rsa/ecdsa private key operation.  When that call returns, a callback will
958     // be invoked that will re-call handleAccept.
959     sslState_ = STATE_ASYNC_PENDING;
960
961     // Unregister for all events while blocked here
962     updateEventRegistration(
963       EventHandler::NONE,
964       EventHandler::READ | EventHandler::WRITE
965     );
966
967     // The timeout (if set) keeps running here
968     return true;
969   } else {
970     // SSL_ERROR_ZERO_RETURN is processed here so we can get some detail
971     // in the log
972     unsigned long lastError = *errErrorOut = ERR_get_error();
973     VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
974             << "state=" << state_ << ", "
975             << "sslState=" << sslState_ << ", "
976             << "events=" << std::hex << eventFlags_ << "): "
977             << "SSL error: " << error << ", "
978             << "errno: " << errno << ", "
979             << "ret: " << ret << ", "
980             << "read: " << BIO_number_read(SSL_get_rbio(ssl_)) << ", "
981             << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
982             << "func: " << ERR_func_error_string(lastError) << ", "
983             << "reason: " << ERR_reason_error_string(lastError);
984     ERR_clear_error();
985     return false;
986   }
987 }
988
989 void AsyncSSLSocket::checkForImmediateRead() noexcept {
990   // openssl may have buffered data that it read from the socket already.
991   // In this case we have to process it immediately, rather than waiting for
992   // the socket to become readable again.
993   if (ssl_ != nullptr && SSL_pending(ssl_) > 0) {
994     AsyncSocket::handleRead();
995   }
996 }
997
998 void
999 AsyncSSLSocket::restartSSLAccept()
1000 {
1001   VLOG(3) << "AsyncSSLSocket::restartSSLAccept() this=" << this << ", fd=" << fd_
1002           << ", state=" << int(state_) << ", "
1003           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1004   DestructorGuard dg(this);
1005   assert(
1006     sslState_ == STATE_CACHE_LOOKUP ||
1007     sslState_ == STATE_ASYNC_PENDING ||
1008     sslState_ == STATE_ERROR ||
1009     sslState_ == STATE_CLOSED
1010   );
1011   if (sslState_ == STATE_CLOSED) {
1012     // I sure hope whoever closed this socket didn't delete it already,
1013     // but this is not strictly speaking an error
1014     return;
1015   }
1016   if (sslState_ == STATE_ERROR) {
1017     // go straight to fail if timeout expired during lookup
1018     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
1019                            "SSL accept timed out");
1020     failHandshake(__func__, ex);
1021     return;
1022   }
1023   sslState_ = STATE_ACCEPTING;
1024   this->handleAccept();
1025 }
1026
1027 void
1028 AsyncSSLSocket::handleAccept() noexcept {
1029   VLOG(3) << "AsyncSSLSocket::handleAccept() this=" << this
1030           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
1031           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1032   assert(server_);
1033   assert(state_ == StateEnum::ESTABLISHED &&
1034          sslState_ == STATE_ACCEPTING);
1035   if (!ssl_) {
1036     /* lazily create the SSL structure */
1037     try {
1038       ssl_ = ctx_->createSSL();
1039     } catch (std::exception &e) {
1040       sslState_ = STATE_ERROR;
1041       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1042                              "error calling SSLContext::createSSL()");
1043       LOG(ERROR) << "AsyncSSLSocket::handleAccept(this=" << this
1044                  << ", fd=" << fd_ << "): " << e.what();
1045       return failHandshake(__func__, ex);
1046     }
1047     SSL_set_fd(ssl_, fd_);
1048     SSL_set_ex_data(ssl_, getSSLExDataIndex(), this);
1049
1050     applyVerificationOptions(ssl_);
1051   }
1052
1053   if (server_ && parseClientHello_) {
1054     SSL_set_msg_callback(ssl_, &AsyncSSLSocket::clientHelloParsingCallback);
1055     SSL_set_msg_callback_arg(ssl_, this);
1056   }
1057
1058   errno = 0;
1059   int ret = SSL_accept(ssl_);
1060   if (ret <= 0) {
1061     int sslError;
1062     unsigned long errError;
1063     int errnoCopy = errno;
1064     if (willBlock(ret, &sslError, &errError)) {
1065       return;
1066     } else {
1067       sslState_ = STATE_ERROR;
1068       SSLException ex(sslError, errError, ret, errnoCopy);
1069       return failHandshake(__func__, ex);
1070     }
1071   }
1072
1073   handshakeComplete_ = true;
1074   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1075
1076   // Move into STATE_ESTABLISHED in the normal case that we are in
1077   // STATE_ACCEPTING.
1078   sslState_ = STATE_ESTABLISHED;
1079
1080   VLOG(3) << "AsyncSSLSocket " << this << ": fd " << fd_
1081           << " successfully accepted; state=" << int(state_)
1082           << ", sslState=" << sslState_ << ", events=" << eventFlags_;
1083
1084   // Remember the EventBase we are attached to, before we start invoking any
1085   // callbacks (since the callbacks may call detachEventBase()).
1086   EventBase* originalEventBase = eventBase_;
1087
1088   // Call the accept callback.
1089   invokeHandshakeCB();
1090
1091   // Note that the accept callback may have changed our state.
1092   // (set or unset the read callback, called write(), closed the socket, etc.)
1093   // The following code needs to handle these situations correctly.
1094   //
1095   // If the socket has been closed, readCallback_ and writeReqHead_ will
1096   // always be nullptr, so that will prevent us from trying to read or write.
1097   //
1098   // The main thing to check for is if eventBase_ is still originalEventBase.
1099   // If not, we have been detached from this event base, so we shouldn't
1100   // perform any more operations.
1101   if (eventBase_ != originalEventBase) {
1102     return;
1103   }
1104
1105   AsyncSocket::handleInitialReadWrite();
1106 }
1107
1108 void
1109 AsyncSSLSocket::handleConnect() noexcept {
1110   VLOG(3) <<  "AsyncSSLSocket::handleConnect() this=" << this
1111           << ", fd=" << fd_ << ", state=" << int(state_) << ", "
1112           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1113   assert(!server_);
1114   if (state_ < StateEnum::ESTABLISHED) {
1115     return AsyncSocket::handleConnect();
1116   }
1117
1118   assert(state_ == StateEnum::ESTABLISHED &&
1119          sslState_ == STATE_CONNECTING);
1120   assert(ssl_);
1121
1122   errno = 0;
1123   int ret = SSL_connect(ssl_);
1124   if (ret <= 0) {
1125     int sslError;
1126     unsigned long errError;
1127     int errnoCopy = errno;
1128     if (willBlock(ret, &sslError, &errError)) {
1129       return;
1130     } else {
1131       sslState_ = STATE_ERROR;
1132       SSLException ex(sslError, errError, ret, errnoCopy);
1133       return failHandshake(__func__, ex);
1134     }
1135   }
1136
1137   handshakeComplete_ = true;
1138   updateEventRegistration(0, EventHandler::READ | EventHandler::WRITE);
1139
1140   // Move into STATE_ESTABLISHED in the normal case that we are in
1141   // STATE_CONNECTING.
1142   sslState_ = STATE_ESTABLISHED;
1143
1144   VLOG(3) << "AsyncSSLSocket %p: fd %d successfully connected; "
1145           << "state=" << int(state_) << ", sslState=" << sslState_
1146           << ", events=" << eventFlags_;
1147
1148   // Remember the EventBase we are attached to, before we start invoking any
1149   // callbacks (since the callbacks may call detachEventBase()).
1150   EventBase* originalEventBase = eventBase_;
1151
1152   // Call the handshake callback.
1153   invokeHandshakeCB();
1154
1155   // Note that the connect callback may have changed our state.
1156   // (set or unset the read callback, called write(), closed the socket, etc.)
1157   // The following code needs to handle these situations correctly.
1158   //
1159   // If the socket has been closed, readCallback_ and writeReqHead_ will
1160   // always be nullptr, so that will prevent us from trying to read or write.
1161   //
1162   // The main thing to check for is if eventBase_ is still originalEventBase.
1163   // If not, we have been detached from this event base, so we shouldn't
1164   // perform any more operations.
1165   if (eventBase_ != originalEventBase) {
1166     return;
1167   }
1168
1169   AsyncSocket::handleInitialReadWrite();
1170 }
1171
1172 void AsyncSSLSocket::setReadCB(ReadCallback *callback) {
1173 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1174   // turn on the buffer movable in openssl
1175   if (ssl_ != nullptr && !isBufferMovable_ &&
1176       callback != nullptr && callback->isBufferMovable()) {
1177     SSL_set_mode(ssl_, SSL_get_mode(ssl_) | SSL_MODE_MOVE_BUFFER_OWNERSHIP);
1178     isBufferMovable_ = true;
1179   }
1180 #endif
1181
1182   AsyncSocket::setReadCB(callback);
1183 }
1184
1185 void AsyncSSLSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
1186   CHECK(readCallback_);
1187   if (isBufferMovable_) {
1188     *buf = nullptr;
1189     *buflen = 0;
1190   } else {
1191     // buf is necessary for SSLSocket without SSL_MODE_MOVE_BUFFER_OWNERSHIP
1192     readCallback_->getReadBuffer(buf, buflen);
1193   }
1194 }
1195
1196 void
1197 AsyncSSLSocket::handleRead() noexcept {
1198   VLOG(5) << "AsyncSSLSocket::handleRead() this=" << this << ", fd=" << fd_
1199           << ", state=" << int(state_) << ", "
1200           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1201   if (state_ < StateEnum::ESTABLISHED) {
1202     return AsyncSocket::handleRead();
1203   }
1204
1205
1206   if (sslState_ == STATE_ACCEPTING) {
1207     assert(server_);
1208     handleAccept();
1209     return;
1210   }
1211   else if (sslState_ == STATE_CONNECTING) {
1212     assert(!server_);
1213     handleConnect();
1214     return;
1215   }
1216
1217   // Normal read
1218   AsyncSocket::handleRead();
1219 }
1220
1221 ssize_t
1222 AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
1223   VLOG(4) << "AsyncSSLSocket::performRead() this=" << this
1224           << ", buf=" << *buf << ", buflen=" << *buflen;
1225
1226   if (sslState_ == STATE_UNENCRYPTED) {
1227     return AsyncSocket::performRead(buf, buflen, offset);
1228   }
1229
1230   errno = 0;
1231   ssize_t bytes = 0;
1232   if (!isBufferMovable_) {
1233     bytes = SSL_read(ssl_, *buf, *buflen);
1234   }
1235 #ifdef SSL_MODE_MOVE_BUFFER_OWNERSHIP
1236   else {
1237     bytes = SSL_read_buf(ssl_, buf, (int *) offset, (int *) buflen);
1238   }
1239 #endif
1240
1241   if (server_ && renegotiateAttempted_) {
1242     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1243                << ", sslstate=" << sslState_ << ", events=" << eventFlags_
1244                << "): client intitiated SSL renegotiation not permitted";
1245     // We pack our own SSLerr here with a dummy function
1246     errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
1247                      SSL_CLIENT_RENEGOTIATION_ATTEMPT);
1248     ERR_clear_error();
1249     return READ_ERROR;
1250   }
1251   if (bytes <= 0) {
1252     int error = SSL_get_error(ssl_, bytes);
1253     if (error == SSL_ERROR_WANT_READ) {
1254       // The caller will register for read event if not already.
1255       if (errno == EWOULDBLOCK || errno == EAGAIN) {
1256         return READ_BLOCKING;
1257       } else {
1258         return READ_ERROR;
1259       }
1260     } else if (error == SSL_ERROR_WANT_WRITE) {
1261       // TODO: Even though we are attempting to read data, SSL_read() may
1262       // need to write data if renegotiation is being performed.  We currently
1263       // don't support this and just fail the read.
1264       LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1265                  << ", sslState=" << sslState_ << ", events=" << eventFlags_
1266                  << "): unsupported SSL renegotiation during read",
1267       errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
1268                        SSL_INVALID_RENEGOTIATION);
1269       ERR_clear_error();
1270       return READ_ERROR;
1271     } else {
1272       // TODO: Fix this code so that it can return a proper error message
1273       // to the callback, rather than relying on AsyncSocket code which
1274       // can't handle SSL errors.
1275       long lastError = ERR_get_error();
1276
1277       VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
1278               << "state=" << state_ << ", "
1279               << "sslState=" << sslState_ << ", "
1280               << "events=" << std::hex << eventFlags_ << "): "
1281               << "bytes: " << bytes << ", "
1282               << "error: " << error << ", "
1283               << "errno: " << errno << ", "
1284               << "func: " << ERR_func_error_string(lastError) << ", "
1285               << "reason: " << ERR_reason_error_string(lastError);
1286       ERR_clear_error();
1287       if (zero_return(error, bytes)) {
1288         return bytes;
1289       }
1290       if (error != SSL_ERROR_SYSCALL) {
1291         if ((unsigned long)lastError < 0x8000) {
1292           errno = ENOSYS;
1293         } else {
1294           errno = lastError;
1295         }
1296       }
1297       return READ_ERROR;
1298     }
1299   } else {
1300     appBytesReceived_ += bytes;
1301     return bytes;
1302   }
1303 }
1304
1305 void AsyncSSLSocket::handleWrite() noexcept {
1306   VLOG(5) << "AsyncSSLSocket::handleWrite() this=" << this << ", fd=" << fd_
1307           << ", state=" << int(state_) << ", "
1308           << "sslState=" << sslState_ << ", events=" << eventFlags_;
1309   if (state_ < StateEnum::ESTABLISHED) {
1310     return AsyncSocket::handleWrite();
1311   }
1312
1313   if (sslState_ == STATE_ACCEPTING) {
1314     assert(server_);
1315     handleAccept();
1316     return;
1317   }
1318
1319   if (sslState_ == STATE_CONNECTING) {
1320     assert(!server_);
1321     handleConnect();
1322     return;
1323   }
1324
1325   // Normal write
1326   AsyncSocket::handleWrite();
1327 }
1328
1329 int AsyncSSLSocket::interpretSSLError(int rc, int error) {
1330   if (error == SSL_ERROR_WANT_READ) {
1331     // TODO: Even though we are attempting to write data, SSL_write() may
1332     // need to read data if renegotiation is being performed.  We currently
1333     // don't support this and just fail the write.
1334     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1335                << ", sslState=" << sslState_ << ", events=" << eventFlags_
1336                << "): " << "unsupported SSL renegotiation during write",
1337       errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
1338                        SSL_INVALID_RENEGOTIATION);
1339     ERR_clear_error();
1340     return -1;
1341   } else {
1342     // TODO: Fix this code so that it can return a proper error message
1343     // to the callback, rather than relying on AsyncSocket code which
1344     // can't handle SSL errors.
1345     long lastError = ERR_get_error();
1346     VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1347             << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
1348             << "SSL error: " << error << ", errno: " << errno
1349             << ", func: " << ERR_func_error_string(lastError)
1350             << ", reason: " << ERR_reason_error_string(lastError);
1351     if (error != SSL_ERROR_SYSCALL) {
1352       if ((unsigned long)lastError < 0x8000) {
1353         errno = ENOSYS;
1354       } else {
1355         errno = lastError;
1356       }
1357     }
1358     ERR_clear_error();
1359     if (!zero_return(error, rc)) {
1360       return -1;
1361     } else {
1362       return 0;
1363     }
1364   }
1365 }
1366
1367 ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
1368                                       uint32_t count,
1369                                       WriteFlags flags,
1370                                       uint32_t* countWritten,
1371                                       uint32_t* partialWritten) {
1372   if (sslState_ == STATE_UNENCRYPTED) {
1373     return AsyncSocket::performWrite(
1374       vec, count, flags, countWritten, partialWritten);
1375   }
1376   if (sslState_ != STATE_ESTABLISHED) {
1377     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
1378                << ", sslState=" << sslState_
1379                << ", events=" << eventFlags_ << "): "
1380                << "TODO: AsyncSSLSocket currently does not support calling "
1381                << "write() before the handshake has fully completed";
1382       errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
1383                        SSL_EARLY_WRITE);
1384       return -1;
1385   }
1386
1387   bool cork = isSet(flags, WriteFlags::CORK) || persistentCork_;
1388   CorkGuard guard(fd_, count > 1, cork, &corked_);
1389
1390 #if 0
1391 //#ifdef SSL_MODE_WRITE_IOVEC
1392   if (ssl_->expand == nullptr &&
1393       ssl_->compress == nullptr &&
1394       (ssl_->mode & SSL_MODE_WRITE_IOVEC)) {
1395     return performWriteIovec(vec, count, flags, countWritten, partialWritten);
1396   }
1397 #endif
1398
1399   // Declare a buffer used to hold small write requests.  It could point to a
1400   // memory block either on stack or on heap. If it is on heap, we release it
1401   // manually when scope exits
1402   char* combinedBuf{nullptr};
1403   SCOPE_EXIT {
1404     // Note, always keep this check consistent with what we do below
1405     if (combinedBuf != nullptr && minWriteSize_ > MAX_STACK_BUF_SIZE) {
1406       delete[] combinedBuf;
1407     }
1408   };
1409
1410   *countWritten = 0;
1411   *partialWritten = 0;
1412   ssize_t totalWritten = 0;
1413   size_t bytesStolenFromNextBuffer = 0;
1414   for (uint32_t i = 0; i < count; i++) {
1415     const iovec* v = vec + i;
1416     size_t offset = bytesStolenFromNextBuffer;
1417     bytesStolenFromNextBuffer = 0;
1418     size_t len = v->iov_len - offset;
1419     const void* buf;
1420     if (len == 0) {
1421       (*countWritten)++;
1422       continue;
1423     }
1424     buf = ((const char*)v->iov_base) + offset;
1425
1426     ssize_t bytes;
1427     errno = 0;
1428     uint32_t buffersStolen = 0;
1429     if ((len < minWriteSize_) && ((i + 1) < count)) {
1430       // Combine this buffer with part or all of the next buffers in
1431       // order to avoid really small-grained calls to SSL_write().
1432       // Each call to SSL_write() produces a separate record in
1433       // the egress SSL stream, and we've found that some low-end
1434       // mobile clients can't handle receiving an HTTP response
1435       // header and the first part of the response body in two
1436       // separate SSL records (even if those two records are in
1437       // the same TCP packet).
1438
1439       if (combinedBuf == nullptr) {
1440         if (minWriteSize_ > MAX_STACK_BUF_SIZE) {
1441           // Allocate the buffer on heap
1442           combinedBuf = new char[minWriteSize_];
1443         } else {
1444           // Allocate the buffer on stack
1445           combinedBuf = (char*)alloca(minWriteSize_);
1446         }
1447       }
1448       assert(combinedBuf != nullptr);
1449
1450       memcpy(combinedBuf, buf, len);
1451       do {
1452         // INVARIANT: i + buffersStolen == complete chunks serialized
1453         uint32_t nextIndex = i + buffersStolen + 1;
1454         bytesStolenFromNextBuffer = std::min(vec[nextIndex].iov_len,
1455                                              minWriteSize_ - len);
1456         memcpy(combinedBuf + len, vec[nextIndex].iov_base,
1457                bytesStolenFromNextBuffer);
1458         len += bytesStolenFromNextBuffer;
1459         if (bytesStolenFromNextBuffer < vec[nextIndex].iov_len) {
1460           // couldn't steal the whole buffer
1461           break;
1462         } else {
1463           bytesStolenFromNextBuffer = 0;
1464           buffersStolen++;
1465         }
1466       } while ((i + buffersStolen + 1) < count && (len < minWriteSize_));
1467       bytes = eorAwareSSLWrite(
1468         ssl_, combinedBuf, len,
1469         (isSet(flags, WriteFlags::EOR) && i + buffersStolen + 1 == count));
1470
1471     } else {
1472       bytes = eorAwareSSLWrite(ssl_, buf, len,
1473                            (isSet(flags, WriteFlags::EOR) && i + 1 == count));
1474     }
1475
1476     if (bytes <= 0) {
1477       int error = SSL_get_error(ssl_, bytes);
1478       if (error == SSL_ERROR_WANT_WRITE) {
1479         // The caller will register for write event if not already.
1480         *partialWritten = offset;
1481         return totalWritten;
1482       }
1483       int rc = interpretSSLError(bytes, error);
1484       if (rc < 0) {
1485         return rc;
1486       } // else fall through to below to correctly record totalWritten
1487     }
1488
1489     totalWritten += bytes;
1490
1491     if (bytes == (ssize_t)len) {
1492       // The full iovec is written.
1493       (*countWritten) += 1 + buffersStolen;
1494       i += buffersStolen;
1495       // continue
1496     } else {
1497       bytes += offset; // adjust bytes to account for all of v
1498       while (bytes >= (ssize_t)v->iov_len) {
1499         // We combined this buf with part or all of the next one, and
1500         // we managed to write all of this buf but not all of the bytes
1501         // from the next one that we'd hoped to write.
1502         bytes -= v->iov_len;
1503         (*countWritten)++;
1504         v = &(vec[++i]);
1505       }
1506       *partialWritten = bytes;
1507       return totalWritten;
1508     }
1509   }
1510
1511   return totalWritten;
1512 }
1513
1514 #if 0
1515 //#ifdef SSL_MODE_WRITE_IOVEC
1516 ssize_t AsyncSSLSocket::performWriteIovec(const iovec* vec,
1517                                           uint32_t count,
1518                                           WriteFlags flags,
1519                                           uint32_t* countWritten,
1520                                           uint32_t* partialWritten) {
1521   size_t tot = 0;
1522   for (uint32_t j = 0; j < count; j++) {
1523     tot += vec[j].iov_len;
1524   }
1525
1526   ssize_t totalWritten = SSL_write_iovec(ssl_, vec, count);
1527
1528   *countWritten = 0;
1529   *partialWritten = 0;
1530   if (totalWritten <= 0) {
1531     return interpretSSLError(totalWritten, SSL_get_error(ssl_, totalWritten));
1532   } else {
1533     ssize_t bytes = totalWritten, i = 0;
1534     while (i < count && bytes >= (ssize_t)vec[i].iov_len) {
1535       // we managed to write all of this buf
1536       bytes -= vec[i].iov_len;
1537       (*countWritten)++;
1538       i++;
1539     }
1540     *partialWritten = bytes;
1541
1542     VLOG(4) << "SSL_write_iovec() writes " << tot
1543             << ", returns " << totalWritten << " bytes"
1544             << ", max_send_fragment=" << ssl_->max_send_fragment
1545             << ", count=" << count << ", countWritten=" << *countWritten;
1546
1547     return totalWritten;
1548   }
1549 }
1550 #endif
1551
1552 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
1553                                       bool eor) {
1554   if (eor && SSL_get_wbio(ssl)->method == &eorAwareBioMethod) {
1555     if (appEorByteNo_) {
1556       // cannot track for more than one app byte EOR
1557       CHECK(appEorByteNo_ == appBytesWritten_ + n);
1558     } else {
1559       appEorByteNo_ = appBytesWritten_ + n;
1560     }
1561
1562     // 1. It is fine to keep updating minEorRawByteNo_.
1563     // 2. It is _min_ in the sense that SSL record will add some overhead.
1564     minEorRawByteNo_ = getRawBytesWritten() + n;
1565   }
1566
1567   n = sslWriteImpl(ssl, buf, n);
1568   if (n > 0) {
1569     appBytesWritten_ += n;
1570     if (appEorByteNo_) {
1571       if (getRawBytesWritten() >= minEorRawByteNo_) {
1572         minEorRawByteNo_ = 0;
1573       }
1574       if(appBytesWritten_ == appEorByteNo_) {
1575         appEorByteNo_ = 0;
1576       } else {
1577         CHECK(appBytesWritten_ < appEorByteNo_);
1578       }
1579     }
1580   }
1581   return n;
1582 }
1583
1584 void AsyncSSLSocket::sslInfoCallback(const SSL* ssl, int where, int /* ret */) {
1585   AsyncSSLSocket *sslSocket = AsyncSSLSocket::getFromSSL(ssl);
1586   if (sslSocket->handshakeComplete_ && (where & SSL_CB_HANDSHAKE_START)) {
1587     sslSocket->renegotiateAttempted_ = true;
1588   }
1589 }
1590
1591 int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) {
1592   int ret;
1593   struct msghdr msg;
1594   struct iovec iov;
1595   int flags = 0;
1596   AsyncSSLSocket *tsslSock;
1597
1598   iov.iov_base = const_cast<char *>(in);
1599   iov.iov_len = inl;
1600   memset(&msg, 0, sizeof(msg));
1601   msg.msg_iov = &iov;
1602   msg.msg_iovlen = 1;
1603
1604   tsslSock =
1605     reinterpret_cast<AsyncSSLSocket*>(BIO_get_app_data(b));
1606   if (tsslSock &&
1607       tsslSock->minEorRawByteNo_ &&
1608       tsslSock->minEorRawByteNo_ <= BIO_number_written(b) + inl) {
1609     flags = MSG_EOR;
1610   }
1611
1612   errno = 0;
1613   ret = sendmsg(b->num, &msg, flags);
1614   BIO_clear_retry_flags(b);
1615   if (ret <= 0) {
1616     if (BIO_sock_should_retry(ret))
1617       BIO_set_retry_write(b);
1618   }
1619   return(ret);
1620 }
1621
1622 int AsyncSSLSocket::sslVerifyCallback(int preverifyOk,
1623                                        X509_STORE_CTX* x509Ctx) {
1624   SSL* ssl = (SSL*) X509_STORE_CTX_get_ex_data(
1625     x509Ctx, SSL_get_ex_data_X509_STORE_CTX_idx());
1626   AsyncSSLSocket* self = AsyncSSLSocket::getFromSSL(ssl);
1627
1628   VLOG(3) <<  "AsyncSSLSocket::sslVerifyCallback() this=" << self << ", "
1629           << "fd=" << self->fd_ << ", preverifyOk=" << preverifyOk;
1630   return (self->handshakeCallback_) ?
1631     self->handshakeCallback_->handshakeVer(self, preverifyOk, x509Ctx) :
1632     preverifyOk;
1633 }
1634
1635 void AsyncSSLSocket::enableClientHelloParsing()  {
1636     parseClientHello_ = true;
1637     clientHelloInfo_.reset(new ClientHelloInfo());
1638 }
1639
1640 void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl)  {
1641   SSL_set_msg_callback(ssl, nullptr);
1642   SSL_set_msg_callback_arg(ssl, nullptr);
1643   clientHelloInfo_->clientHelloBuf_.clear();
1644 }
1645
1646 void AsyncSSLSocket::clientHelloParsingCallback(int written,
1647                                                 int /* version */,
1648                                                 int contentType,
1649                                                 const void* buf,
1650                                                 size_t len,
1651                                                 SSL* ssl,
1652                                                 void* arg) {
1653   AsyncSSLSocket *sock = static_cast<AsyncSSLSocket*>(arg);
1654   if (written != 0) {
1655     sock->resetClientHelloParsing(ssl);
1656     return;
1657   }
1658   if (contentType != SSL3_RT_HANDSHAKE) {
1659     return;
1660   }
1661   if (len == 0) {
1662     return;
1663   }
1664
1665   auto& clientHelloBuf = sock->clientHelloInfo_->clientHelloBuf_;
1666   clientHelloBuf.append(IOBuf::wrapBuffer(buf, len));
1667   try {
1668     Cursor cursor(clientHelloBuf.front());
1669     if (cursor.read<uint8_t>() != SSL3_MT_CLIENT_HELLO) {
1670       sock->resetClientHelloParsing(ssl);
1671       return;
1672     }
1673
1674     if (cursor.totalLength() < 3) {
1675       clientHelloBuf.trimEnd(len);
1676       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1677       return;
1678     }
1679
1680     uint32_t messageLength = cursor.read<uint8_t>();
1681     messageLength <<= 8;
1682     messageLength |= cursor.read<uint8_t>();
1683     messageLength <<= 8;
1684     messageLength |= cursor.read<uint8_t>();
1685     if (cursor.totalLength() < messageLength) {
1686       clientHelloBuf.trimEnd(len);
1687       clientHelloBuf.append(IOBuf::copyBuffer(buf, len));
1688       return;
1689     }
1690
1691     sock->clientHelloInfo_->clientHelloMajorVersion_ = cursor.read<uint8_t>();
1692     sock->clientHelloInfo_->clientHelloMinorVersion_ = cursor.read<uint8_t>();
1693
1694     cursor.skip(4); // gmt_unix_time
1695     cursor.skip(28); // random_bytes
1696
1697     cursor.skip(cursor.read<uint8_t>()); // session_id
1698
1699     uint16_t cipherSuitesLength = cursor.readBE<uint16_t>();
1700     for (int i = 0; i < cipherSuitesLength; i += 2) {
1701       sock->clientHelloInfo_->
1702         clientHelloCipherSuites_.push_back(cursor.readBE<uint16_t>());
1703     }
1704
1705     uint8_t compressionMethodsLength = cursor.read<uint8_t>();
1706     for (int i = 0; i < compressionMethodsLength; ++i) {
1707       sock->clientHelloInfo_->
1708         clientHelloCompressionMethods_.push_back(cursor.readBE<uint8_t>());
1709     }
1710
1711     if (cursor.totalLength() > 0) {
1712       uint16_t extensionsLength = cursor.readBE<uint16_t>();
1713       while (extensionsLength) {
1714         TLSExtension extensionType = static_cast<TLSExtension>(
1715             cursor.readBE<uint16_t>());
1716         sock->clientHelloInfo_->
1717           clientHelloExtensions_.push_back(extensionType);
1718         extensionsLength -= 2;
1719         uint16_t extensionDataLength = cursor.readBE<uint16_t>();
1720         extensionsLength -= 2;
1721
1722         if (extensionType == TLSExtension::SIGNATURE_ALGORITHMS) {
1723           cursor.skip(2);
1724           extensionDataLength -= 2;
1725           while (extensionDataLength) {
1726             HashAlgorithm hashAlg = static_cast<HashAlgorithm>(
1727                 cursor.readBE<uint8_t>());
1728             SignatureAlgorithm sigAlg = static_cast<SignatureAlgorithm>(
1729                 cursor.readBE<uint8_t>());
1730             extensionDataLength -= 2;
1731             sock->clientHelloInfo_->
1732               clientHelloSigAlgs_.emplace_back(hashAlg, sigAlg);
1733           }
1734         } else {
1735           cursor.skip(extensionDataLength);
1736           extensionsLength -= extensionDataLength;
1737         }
1738       }
1739     }
1740   } catch (std::out_of_range& e) {
1741     // we'll use what we found and cleanup below.
1742     VLOG(4) << "AsyncSSLSocket::clientHelloParsingCallback(): "
1743       << "buffer finished unexpectedly." << " AsyncSSLSocket socket=" << sock;
1744   }
1745
1746   sock->resetClientHelloParsing(ssl);
1747 }
1748
1749 } // namespace