Modernize use of std::make_unique
[folly.git] / folly / io / async / AsyncSSLSocket.cpp
index 3791db7dc4efc95593594f45365fd705ac6715a3..4bb8a9b43525a37d0f80d95d20b4f357f365003b 100644 (file)
 #include <fcntl.h>
 #include <sys/types.h>
 #include <chrono>
+#include <memory>
 
 #include <folly/Bits.h>
+#include <folly/Format.h>
 #include <folly/SocketAddress.h>
 #include <folly/SpinLock.h>
 #include <folly/io/Cursor.h>
@@ -183,7 +185,7 @@ static BIO_METHOD* getSSLBioMethod() {
   return instance;
 }
 
-void* initsslBioMethod(void) {
+void* initsslBioMethod() {
   auto sslBioMethod = getSSLBioMethod();
   // override the bwrite method for MSG_EOR support
   OpenSSLUtils::setCustomBioWriteMethod(sslBioMethod, AsyncSSLSocket::bioWrite);
@@ -199,7 +201,7 @@ void* initsslBioMethod(void) {
   return nullptr;
 }
 
-} // anonymous namespace
+} // namespace
 
 namespace folly {
 
@@ -406,7 +408,7 @@ size_t AsyncSSLSocket::getRawBytesWritten() const {
     return 0;
   }
   BIO* next = BIO_next(b);
-  while (next != NULL) {
+  while (next != nullptr) {
     b = next;
     next = BIO_next(b);
   }
@@ -450,20 +452,20 @@ void AsyncSSLSocket::sslAccept(
     std::chrono::milliseconds timeout,
     const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
   verifyPeer_ = verifyPeer;
 
   // Make sure we're in the uninitialized state
-  if (!server_ || (sslState_ != STATE_UNINIT &&
-                   sslState_ != STATE_UNENCRYPTED) ||
+  if (!server_ ||
+      (sslState_ != STATE_UNINIT && sslState_ != STATE_UNENCRYPTED) ||
       handshakeCallback_ != nullptr) {
     return invalidState(callback);
   }
 
   // Cache local and remote socket addresses to keep them available
   // after socket file descriptor is closed.
-  if (cacheAddrOnFailure_ && -1 != getFd()) {
-    cacheLocalPeerAddr();
+  if (cacheAddrOnFailure_) {
+    cacheAddresses();
   }
 
   handshakeStartTime_ = std::chrono::steady_clock::now();
@@ -503,9 +505,6 @@ void AsyncSSLSocket::attachSSLContext(
   // In order to call attachSSLContext, detachSSLContext must have been
   // previously called.
   // We need to update the initial_ctx if necessary
-  auto sslCtx = ctx->getSSLCtx();
-  SSL_CTX_up_ref(sslCtx);
-
   // The 'initial_ctx' inside an SSL* points to the context that it was created
   // with, which is also where session callbacks and servername callbacks
   // happen.
@@ -514,6 +513,7 @@ void AsyncSSLSocket::attachSSLContext(
   // NOTE: this will only work if we have access to ssl_ internals, so it may
   // not work on
   // OpenSSL version >= 1.1.0
+  auto sslCtx = ctx->getSSLCtx();
   OpenSSLUtils::setSSLInitialCtx(ssl_, sslCtx);
   // Detach sets the socket's context to the dummy context. Thus we must acquire
   // this lock.
@@ -667,19 +667,6 @@ void AsyncSSLSocket::invokeHandshakeCB() {
   }
 }
 
-void AsyncSSLSocket::cacheLocalPeerAddr() {
-  SocketAddress address;
-  try {
-    getLocalAddress(&address);
-    getPeerAddress(&address);
-  } catch (const std::system_error& e) {
-    // The handle can be still valid while the connection is already closed.
-    if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
-      throw;
-    }
-  }
-}
-
 void AsyncSSLSocket::connect(
     ConnectCallback* callback,
     const folly::SocketAddress& address,
@@ -699,13 +686,15 @@ void AsyncSSLSocket::connect(
     const folly::SocketAddress& bindAddr) noexcept {
   assert(!server_);
   assert(state_ == StateEnum::UNINIT);
-  assert(sslState_ == STATE_UNINIT);
+  assert(sslState_ == STATE_UNINIT || sslState_ == STATE_UNENCRYPTED);
   noTransparentTls_ = true;
   totalConnectTimeout_ = totalConnectTimeout;
-  AsyncSSLSocketConnector* connector =
-      new AsyncSSLSocketConnector(this, callback, totalConnectTimeout.count());
+  if (sslState_ != STATE_UNENCRYPTED) {
+    callback = new AsyncSSLSocketConnector(
+        this, callback, int(totalConnectTimeout.count()));
+  }
   AsyncSocket::connect(
-      connector, address, connectTimeout.count(), options, bindAddr);
+      callback, address, int(connectTimeout.count()), options, bindAddr);
 }
 
 bool AsyncSSLSocket::needsPeerVerification() const {
@@ -751,12 +740,12 @@ void AsyncSSLSocket::sslConn(
     std::chrono::milliseconds timeout,
     const SSLContext::SSLVerifyPeerEnum& verifyPeer) {
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   // Cache local and remote socket addresses to keep them available
   // after socket file descriptor is closed.
-  if (cacheAddrOnFailure_ && -1 != getFd()) {
-    cacheLocalPeerAddr();
+  if (cacheAddrOnFailure_) {
+    cacheAddresses();
   }
 
   verifyPeer_ = verifyPeer;
@@ -1659,7 +1648,8 @@ int AsyncSSLSocket::bioWrite(BIO* b, const char* in, int inl) {
     flags |= WriteFlags::CORK;
   }
 
-  int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(flags);
+  int msg_flags = tsslSock->getSendMsgParamsCB()->getFlags(
+      flags, false /*zeroCopyEnabled*/);
   msg.msg_controllen =
       tsslSock->getSendMsgParamsCB()->getAncillaryDataSize(flags);
   CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
@@ -1701,9 +1691,9 @@ int AsyncSSLSocket::bioRead(BIO* b, char* out, int outl) {
     queue.append(std::move(sslSock->preReceivedData_));
     queue.trimStart(len);
     sslSock->preReceivedData_ = queue.move();
-    return len;
+    return static_cast<int>(len);
   } else {
-    auto result = recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0);
+    auto result = int(recv(OpenSSLUtils::getBioFd(b, nullptr), out, outl, 0));
     if (result <= 0 && OpenSSLUtils::getBioShouldRetryWrite(result)) {
       BIO_set_retry_read(b);
     }
@@ -1727,7 +1717,7 @@ int AsyncSSLSocket::sslVerifyCallback(
 
 void AsyncSSLSocket::enableClientHelloParsing()  {
     parseClientHello_ = true;
-    clientHelloInfo_.reset(new ssl::ClientHelloInfo());
+    clientHelloInfo_ = std::make_unique<ssl::ClientHelloInfo>();
 }
 
 void AsyncSSLSocket::resetClientHelloParsing(SSL *ssl)  {