Use Baton (again) in EventBase::runInEventBaseThreadAndWait
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
index c7c1c357a9518ed6c84cc21789fd7828a55f036f..e6513f8259db83e1e18890554294e75e459c5d3b 100644 (file)
@@ -406,7 +406,7 @@ size_t AsyncSSLSocket::getRawBytesWritten() const {
     return 0;
   }
   BIO* next = BIO_next(b);
-  while (next != NULL) {
+  while (next != nullptr) {
     b = next;
     next = BIO_next(b);
   }
@@ -450,20 +450,20 @@ void AsyncSSLSocket::sslAccept(
     std::chrono::milliseconds timeout,
     const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
   verifyPeer_ = verifyPeer;
 
   // Make sure we're in the uninitialized state
-  if (!server_ || (sslState_ != STATE_UNINIT &&
-                   sslState_ != STATE_UNENCRYPTED) ||
+  if (!server_ ||
+      (sslState_ != STATE_UNINIT && sslState_ != STATE_UNENCRYPTED) ||
       handshakeCallback_ != nullptr) {
     return invalidState(callback);
   }
 
   // Cache local and remote socket addresses to keep them available
   // after socket file descriptor is closed.
-  if (cacheAddrOnFailure_ && -1 != getFd()) {
-    cacheLocalPeerAddr();
+  if (cacheAddrOnFailure_) {
+    cacheAddresses();
   }
 
   handshakeStartTime_ = std::chrono::steady_clock::now();
@@ -503,9 +503,6 @@ void AsyncSSLSocket::attachSSLContext(
   // In order to call attachSSLContext, detachSSLContext must have been
   // previously called.
   // We need to update the initial_ctx if necessary
-  auto sslCtx = ctx->getSSLCtx();
-  SSL_CTX_up_ref(sslCtx);
-
   // The 'initial_ctx' inside an SSL* points to the context that it was created
   // with, which is also where session callbacks and servername callbacks
   // happen.
@@ -514,6 +511,7 @@ void AsyncSSLSocket::attachSSLContext(
   // NOTE: this will only work if we have access to ssl_ internals, so it may
   // not work on
   // OpenSSL version >= 1.1.0
+  auto sslCtx = ctx->getSSLCtx();
   OpenSSLUtils::setSSLInitialCtx(ssl_, sslCtx);
   // Detach sets the socket's context to the dummy context. Thus we must acquire
   // this lock.
@@ -667,19 +665,6 @@ void AsyncSSLSocket::invokeHandshakeCB() {
   }
 }
 
-void AsyncSSLSocket::cacheLocalPeerAddr() {
-  SocketAddress address;
-  try {
-    getLocalAddress(&address);
-    getPeerAddress(&address);
-  } catch (const std::system_error& e) {
-    // The handle can be still valid while the connection is already closed.
-    if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
-      throw;
-    }
-  }
-}
-
 void AsyncSSLSocket::connect(
     ConnectCallback* callback,
     const folly::SocketAddress& address,
@@ -699,13 +684,15 @@ void AsyncSSLSocket::connect(
     const folly::SocketAddress& bindAddr) noexcept {
   assert(!server_);
   assert(state_ == StateEnum::UNINIT);
-  assert(sslState_ == STATE_UNINIT);
+  assert(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED);
   noTransparentTls_ = true;
   totalConnectTimeout_ = totalConnectTimeout;
-  AsyncSSLSocketConnector* connector =
-      new AsyncSSLSocketConnector(this, callback, totalConnectTimeout.count());
+  if (sslState_ != STATE_UNENCRYPTED) {
+    callback = new AsyncSSLSocketConnector(
+        this, callback, int(totalConnectTimeout.count()));
+  }
   AsyncSocket::connect(
-      connector, address, connectTimeout.count(), options, bindAddr);
+      callback, address, int(connectTimeout.count()), options, bindAddr);
 }
 
 bool AsyncSSLSocket::needsPeerVerification() const {
@@ -751,12 +738,12 @@ void AsyncSSLSocket::sslConn(
     std::chrono::milliseconds timeout,
     const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   // Cache local and remote socket addresses to keep them available
   // after socket file descriptor is closed.
-  if (cacheAddrOnFailure_ && -1 != getFd()) {
-    cacheLocalPeerAddr();
+  if (cacheAddrOnFailure_) {
+    cacheAddresses();
   }
 
   verifyPeer_ = verifyPeer;
@@ -1338,7 +1325,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
                << "): client intitiated SSL renegotiation not permitted";
     return ReadResult(
         READ_ERROR,
-        folly::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
+        std::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
   }
   if (bytes <= 0) {
     int error = SSL_get_error(ssl_, bytes);
@@ -1358,7 +1345,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
                  << "): unsupported SSL renegotiation during read";
       return ReadResult(
           READ_ERROR,
-          folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
+          std::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
     } else {
       if (zero_return(error, bytes)) {
         return ReadResult(bytes);
@@ -1375,7 +1362,7 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
               << "reason: " << ERR_reason_error_string(errError);
       return ReadResult(
           READ_ERROR,
-          folly::make_unique<SSLException>(error, errError, bytes, errno));
+          std::make_unique<SSLException>(error, errError, bytes, errno));
     }
   } else {
     appBytesReceived_ += bytes;
@@ -1418,7 +1405,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
                << "unsupported SSL renegotiation during write";
     return WriteResult(
         WRITE_ERROR,
-        folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
+        std::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
   } else {
     if (zero_return(error, rc)) {
       return WriteResult(0);
@@ -1431,7 +1418,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
             << ", reason: " << ERR_reason_error_string(errError);
     return WriteResult(
         WRITE_ERROR,
-        folly::make_unique<SSLException>(error, errError, rc, errno));
+        std::make_unique<SSLException>(error, errError, rc, errno));
   }
 }
 
@@ -1452,7 +1439,7 @@ AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
                << "TODO: AsyncSSLSocket currently does not support calling "
                << "write() before the handshake has fully completed";
     return WriteResult(
-        WRITE_ERROR, folly::make_unique<SSLException>(SSLError::EARLY_WRITE));
+        WRITE_ERROR, std::make_unique<SSLException>(SSLError::EARLY_WRITE));
   }
 
   // Declare a buffer used to hold small write requests.  It could point to a
@@ -1701,9 +1688,9 @@ int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) {
     queue.append(std::move(sslSock->preReceivedData_));
     queue.trimStart(len);
     sslSock->preReceivedData_ = queue.move();
-    return len;
+    return static_cast<int>(len);
   } else {
-    auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0);
+    auto result = int(recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0));
     if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
       BIO_set_retry_read(b);
     }