Move address caching logic from AsyncSSLSocket to AsyncSocket.
[folly.git] / folly / io / async / AsyncSocket.cpp
index ddbab294ab94152eb6a5c4f273e109b79a70f197..f32da3a75ad8e25372cd75854ba222b9c48ff9c4 100644 (file)
@@ -266,7 +266,9 @@ AsyncSocket::AsyncSocket(AsyncSocket::UniquePtr oldAsyncSocket)
 // init() method, since constructor forwarding isn't supported in most
 // compilers yet.
 void AsyncSocket::init() {
-  assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+  if (eventBase_) {
+    eventBase_->dcheckIsInEventBaseThread();
+  }
   shutdownFlags_ = 0;
   state_ = StateEnum::UNINIT;
   eventFlags_ = EventHandler::NONE;
@@ -356,7 +358,7 @@ void AsyncSocket::connect(ConnectCallback* callback,
                            const OptionMap &options,
                            const folly::SocketAddress& bindAddr) noexcept {
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   addr_ = address;
 
@@ -520,6 +522,11 @@ int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
     // Ignore return value, errors are ok
     setsockopt(fd_, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
   }
+  if (noTSocks_) {
+    VLOG(4) << "Disabling TSOCKS for fd " << fd_;
+    // Ignore return value, errors are ok
+    setsockopt(fd_, SOL_SOCKET, SO_NO_TSOCKS, nullptr, 0);
+  }
 #endif
   int rv = fsp::connect(fd_, saddr, len);
   if (rv < 0) {
@@ -587,7 +594,9 @@ void AsyncSocket::cancelConnect() {
 
 void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
   sendTimeout_ = milliseconds;
-  assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+  if (eventBase_) {
+    eventBase_->dcheckIsInEventBaseThread();
+  }
 
   // If we are currently pending on write requests, immediately update
   // writeTimeout_ with the new value.
@@ -612,7 +621,7 @@ void AsyncSocket::setErrMessageCB(ErrMessageCallback* callback) {
           << ", fd=" << fd_ << ", callback=" << callback
           << ", state=" << state_;
 
-  // Short circuit if callback is the same as the existing timestampCallback_.
+  // Short circuit if callback is the same as the existing errMessageCallback_.
   if (callback == errMessageCallback_) {
     return;
   }
@@ -623,7 +632,15 @@ void AsyncSocket::setErrMessageCB(ErrMessageCallback* callback) {
   }
 
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
+
+  if (callback == nullptr) {
+    // We should be able to reset the callback regardless of the
+    // socket state. It's important to have a reliable callback
+    // cancellation mechanism.
+    errMessageCallback_ = callback;
+    return;
+  }
 
   switch ((StateEnum)state_) {
     case StateEnum::CONNECTING:
@@ -701,7 +718,7 @@ void AsyncSocket::setReadCB(ReadCallback *callback) {
   }
 
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   switch ((StateEnum)state_) {
     case StateEnum::CONNECTING:
@@ -776,10 +793,10 @@ void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
   size_t count = buf->countChainElements();
   if (count <= kSmallSizeMax) {
     // suppress "warning: variable length array 'vec' is used [-Wvla]"
-    FOLLY_PUSH_WARNING;
-    FOLLY_GCC_DISABLE_WARNING(vla);
+    FOLLY_PUSH_WARNING
+    FOLLY_GCC_DISABLE_WARNING("-Wvla")
     iovec vec[BOOST_PP_IF(FOLLY_HAVE_VLA, count, kSmallSizeMax)];
-    FOLLY_POP_WARNING;
+    FOLLY_POP_WARNING
 
     writeChainImpl(callback, vec, count, std::move(buf), flags);
   } else {
@@ -803,7 +820,7 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
           << ", state=" << state_;
   DestructorGuard dg(this);
   unique_ptr<IOBuf>ioBuf(std::move(buf));
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
     // No new writes may be performed after the write side of the socket has
@@ -953,7 +970,7 @@ void AsyncSocket::close() {
   // Declare a DestructorGuard to ensure that the AsyncSocket cannot be
   // destroyed until close() returns.
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   // Since there are write requests pending, we have to set the
   // SHUT_WRITE_PENDING flag, and wait to perform the real close until the
@@ -985,7 +1002,9 @@ void AsyncSocket::closeNow() {
           << ", state=" << state_ << ", shutdownFlags="
           << std::hex << (int) shutdownFlags_;
   DestructorGuard dg(this);
-  assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+  if (eventBase_) {
+    eventBase_->dcheckIsInEventBaseThread();
+  }
 
   switch (state_) {
     case StateEnum::ESTABLISHED:
@@ -1077,7 +1096,7 @@ void AsyncSocket::shutdownWrite() {
     return;
   }
 
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   // There are pending writes.  Set SHUT_WRITE_PENDING so that the actual
   // shutdown will be performed once all writes complete.
@@ -1104,7 +1123,9 @@ void AsyncSocket::shutdownWriteNow() {
   }
 
   DestructorGuard dg(this);
-  assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
+  if (eventBase_) {
+    eventBase_->dcheckIsInEventBaseThread();
+  }
 
   switch (static_cast<StateEnum>(state_)) {
     case StateEnum::ESTABLISHED:
@@ -1181,6 +1202,18 @@ bool AsyncSocket::readable() const {
   return rc == 1;
 }
 
+bool AsyncSocket::writable() const {
+  if (fd_ == -1) {
+    return false;
+  }
+  struct pollfd fds[1];
+  fds[0].fd = fd_;
+  fds[0].events = POLLOUT;
+  fds[0].revents = 0;
+  int rc = poll(fds, 1, 0);
+  return rc == 1;
+}
+
 bool AsyncSocket::isPending() const {
   return ioHandler_.isPending();
 }
@@ -1220,7 +1253,7 @@ void AsyncSocket::attachEventBase(EventBase* eventBase) {
           << ", state=" << state_ << ", events="
           << std::hex << eventFlags_ << ")";
   assert(eventBase_ == nullptr);
-  assert(eventBase->isInEventBaseThread());
+  eventBase->dcheckIsInEventBaseThread();
 
   eventBase_ = eventBase;
   ioHandler_.attachEventBase(eventBase);
@@ -1235,7 +1268,7 @@ void AsyncSocket::detachEventBase() {
           << ", old evb=" << eventBase_ << ", state=" << state_
           << ", events=" << std::hex << eventFlags_ << ")";
   assert(eventBase_ != nullptr);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   eventBase_ = nullptr;
   ioHandler_.detachEventBase();
@@ -1247,22 +1280,44 @@ void AsyncSocket::detachEventBase() {
 
 bool AsyncSocket::isDetachable() const {
   DCHECK(eventBase_ != nullptr);
-  DCHECK(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   return !ioHandler_.isHandlerRegistered() && !writeTimeout_.isScheduled();
 }
 
-void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
+void AsyncSocket::cacheAddresses() {
+  if (fd_ >= 0) {
+    try {
+      cacheLocalAddress();
+      cachePeerAddress();
+    } catch (const std::system_error& e) {
+      if (e.code() != std::error_code(ENOTCONN, std::system_category())) {
+        VLOG(1) << "Error caching addresses: " << e.code().value() << ", "
+                << e.code().message();
+      }
+    }
+  }
+}
+
+void AsyncSocket::cacheLocalAddress() const {
   if (!localAddr_.isInitialized()) {
     localAddr_.setFromLocalAddress(fd_);
   }
-  *address = localAddr_;
 }
 
-void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
+void AsyncSocket::cachePeerAddress() const {
   if (!addr_.isInitialized()) {
     addr_.setFromPeerAddress(fd_);
   }
+}
+
+void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
+  cacheLocalAddress();
+  *address = localAddr_;
+}
+
+void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
+  cachePeerAddress();
   *address = addr_;
 }
 
@@ -1403,7 +1458,7 @@ void AsyncSocket::ioReady(uint16_t events) noexcept {
           << ", events=" << std::hex << events << ", state=" << state_;
   DestructorGuard dg(this);
   assert(events & EventHandler::READ_WRITE);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   uint16_t relevantEvents = uint16_t(events & EventHandler::READ_WRITE);
   EventBase* originalEventBase = eventBase_;
@@ -1532,7 +1587,9 @@ void AsyncSocket::handleErrMessages() noexcept {
     }
 
     for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
-         cmsg != nullptr && cmsg->cmsg_len != 0;
+         cmsg != nullptr &&
+           cmsg->cmsg_len != 0 &&
+           errMessageCallback_ != nullptr;
          cmsg = CMSG_NXTHDR(&msg, cmsg)) {
       errMessageCallback_->errMessage(*cmsg);
     }
@@ -1946,7 +2003,7 @@ void AsyncSocket::timeoutExpired() noexcept {
   VLOG(7) << "AsyncSocket " << this << ", fd " << fd_ << ": timeout expired: "
           << "state=" << state_ << ", events=" << std::hex << eventFlags_;
   DestructorGuard dg(this);
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
 
   if (state_ == StateEnum::CONNECTING) {
     // connect() timed out
@@ -2003,7 +2060,7 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
         registerForConnectEvents();
       } catch (const AsyncSocketException& ex) {
         return WriteResult(
-            WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
+            WRITE_ERROR, std::make_unique<AsyncSocketException>(ex));
       }
       // Let's fake it that no bytes were written and return an errno.
       errno = EAGAIN;
@@ -2026,7 +2083,7 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
         totalWritten = -1;
       } catch (const AsyncSocketException& ex) {
         return WriteResult(
-            WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
+            WRITE_ERROR, std::make_unique<AsyncSocketException>(ex));
       }
     } else if (errno == EAGAIN) {
       // Normally sendmsg would indicate that the write would block.
@@ -2035,7 +2092,7 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
       // instead, and is an error condition indicating no fds available.
       return WriteResult(
           WRITE_ERROR,
-          folly::make_unique<AsyncSocketException>(
+          std::make_unique<AsyncSocketException>(
               AsyncSocketException::UNKNOWN, "No more free local ports"));
     }
   } else {
@@ -2124,13 +2181,13 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
  * and call all currently installed callbacks.  After an error, the
  * AsyncSocket is completely unregistered.
  *
- * @return Returns true on succcess, or false on error.
+ * @return Returns true on success, or false on error.
  */
 bool AsyncSocket::updateEventRegistration() {
   VLOG(5) << "AsyncSocket::updateEventRegistration(this=" << this
           << ", fd=" << fd_ << ", evb=" << eventBase_ << ", state=" << state_
           << ", events=" << std::hex << eventFlags_;
-  assert(eventBase_->isInEventBaseThread());
+  eventBase_->dcheckIsInEventBaseThread();
   if (eventFlags_ == EventHandler::NONE) {
     ioHandler_.unregisterHandler();
     return true;