Adding 'compound assignment union operator' for folly::WriteFlags enum class
[folly.git] / folly / io / async / AsyncSocket.cpp
index 0fe8e88613c013b759771b476ed6b7db33344d3c..db6231291cbb55989836fb68c0d34691b09cfedf 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 
 #include <folly/io/async/AsyncSocket.h>
 
-#include <folly/io/async/EventBase.h>
-#include <folly/io/async/EventHandler.h>
+#include <folly/ExceptionWrapper.h>
 #include <folly/SocketAddress.h>
 #include <folly/io/IOBuf.h>
+#include <folly/Portability.h>
 #include <folly/portability/Fcntl.h>
 #include <folly/portability/Sockets.h>
 #include <folly/portability/SysUio.h>
@@ -34,6 +34,8 @@
 using std::string;
 using std::unique_ptr;
 
+namespace fsp = folly::portability::sockets;
+
 namespace folly {
 
 // static members initializers
@@ -91,10 +93,12 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
   WriteResult performWrite() override {
     WriteFlags writeFlags = flags_;
     if (getNext() != nullptr) {
-      writeFlags = writeFlags | WriteFlags::CORK;
+      writeFlags |= WriteFlags::CORK;
     }
-    return socket_->performWrite(
+    auto writeResult = socket_->performWrite(
         getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
+    bytesWritten_ = writeResult.writeReturn > 0 ? writeResult.writeReturn : 0;
+    return writeResult;
   }
 
   bool isComplete() override {
@@ -122,7 +126,8 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
     currentOp->iov_len -= partialBytes_;
 
     // Increment the totalBytesWritten_ count by bytesWritten_;
-    totalBytesWritten_ += bytesWritten_;
+    assert(bytesWritten_ >= 0);
+    totalBytesWritten_ += uint32_t(bytesWritten_);
   }
 
  private:
@@ -172,19 +177,19 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
 };
 
 AsyncSocket::AsyncSocket()
-  : eventBase_(nullptr)
-  , writeTimeout_(this, nullptr)
-  , ioHandler_(this, nullptr)
-  , immediateReadHandler_(this) {
+    : eventBase_(nullptr),
+      writeTimeout_(this, nullptr),
+      ioHandler_(this, nullptr),
+      immediateReadHandler_(this) {
   VLOG(5) << "new AsyncSocket()";
   init();
 }
 
 AsyncSocket::AsyncSocket(EventBase* evb)
-  : eventBase_(evb)
-  , writeTimeout_(this, evb)
-  , ioHandler_(this, evb)
-  , immediateReadHandler_(this) {
+    : eventBase_(evb),
+      writeTimeout_(this, evb),
+      ioHandler_(this, evb),
+      immediateReadHandler_(this) {
   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
   init();
 }
@@ -205,10 +210,10 @@ AsyncSocket::AsyncSocket(EventBase* evb,
 }
 
 AsyncSocket::AsyncSocket(EventBase* evb, int fd)
-  : eventBase_(evb)
-  , writeTimeout_(this, evb)
-  , ioHandler_(this, evb, fd)
-  , immediateReadHandler_(this) {
+    : eventBase_(evb),
+      writeTimeout_(this, evb),
+      ioHandler_(this, evb, fd),
+      immediateReadHandler_(this) {
   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd="
           << fd << ")";
   init();
@@ -335,7 +340,7 @@ void AsyncSocket::connect(ConnectCallback* callback,
     // constant (PF_xxx) rather than an address family (AF_xxx), but the
     // distinction is mainly just historical.  In pretty much all
     // implementations the PF_foo and AF_foo constants are identical.
-    fd_ = socket(address.getFamily(), SOCK_STREAM, 0);
+    fd_ = fsp::socket(address.getFamily(), SOCK_STREAM, 0);
     if (fd_ < 0) {
       auto errnoCopy = errno;
       throw AsyncSocketException(
@@ -393,7 +398,7 @@ void AsyncSocket::connect(ConnectCallback* callback,
     // bind the socket
     if (bindAddr != anyAddress()) {
       int one = 1;
-      if (::setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
+      if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
         auto errnoCopy = errno;
         doClose();
         throw AsyncSocketException(
@@ -404,7 +409,7 @@ void AsyncSocket::connect(ConnectCallback* callback,
 
       bindAddr.getAddress(&addrStorage);
 
-      if (::bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
+      if (bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
         auto errnoCopy = errno;
         doClose();
         throw AsyncSocketException(
@@ -416,7 +421,7 @@ void AsyncSocket::connect(ConnectCallback* callback,
 
     // Apply the additional options if any.
     for (const auto& opt: options) {
-      int rv = opt.first.apply(fd_, opt.second);
+      rv = opt.first.apply(fd_, opt.second);
       if (rv != 0) {
         auto errnoCopy = errno;
         throw AsyncSocketException(
@@ -429,34 +434,12 @@ void AsyncSocket::connect(ConnectCallback* callback,
     // Perform the connect()
     address.getAddress(&addrStorage);
 
-    rv = ::connect(fd_, saddr, address.getActualSize());
-    if (rv < 0) {
-      auto errnoCopy = errno;
-      if (errnoCopy == EINPROGRESS) {
-        // Connection in progress.
-        if (timeout > 0) {
-          // Start a timer in case the connection takes too long.
-          if (!writeTimeout_.scheduleTimeout(timeout)) {
-            throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-                withAddr("failed to schedule AsyncSocket connect timeout"));
-          }
-        }
-
-        // Register for write events, so we'll
-        // be notified when the connection finishes/fails.
-        // Note that we don't register for a persistent event here.
-        assert(eventFlags_ == EventHandler::NONE);
-        eventFlags_ = EventHandler::WRITE;
-        if (!ioHandler_.registerHandler(eventFlags_)) {
-          throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-              withAddr("failed to register AsyncSocket connect handler"));
-        }
+    if (tfoEnabled_) {
+      state_ = StateEnum::FAST_OPEN;
+      tfoAttempted_ = true;
+    } else {
+      if (socketConnect(saddr, addr_.getActualSize()) < 0) {
         return;
-      } else {
-        throw AsyncSocketException(
-            AsyncSocketException::NOT_OPEN,
-            "connect failed (immediately)",
-            errnoCopy);
       }
     }
 
@@ -481,10 +464,61 @@ void AsyncSocket::connect(ConnectCallback* callback,
   VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
   assert(readCallback_ == nullptr);
   assert(writeReqHead_ == nullptr);
-  state_ = StateEnum::ESTABLISHED;
+  if (state_ != StateEnum::FAST_OPEN) {
+    state_ = StateEnum::ESTABLISHED;
+  }
   invokeConnectSuccess();
 }
 
+int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
+#if __linux__
+  if (noTransparentTls_) {
+    // Ignore return value, errors are ok
+    setsockopt(fd_, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
+  }
+#endif
+  int rv = fsp::connect(fd_, saddr, len);
+  if (rv < 0) {
+    auto errnoCopy = errno;
+    if (errnoCopy == EINPROGRESS) {
+      scheduleConnectTimeout();
+      registerForConnectEvents();
+    } else {
+      throw AsyncSocketException(
+          AsyncSocketException::NOT_OPEN,
+          "connect failed (immediately)",
+          errnoCopy);
+    }
+  }
+  return rv;
+}
+
+void AsyncSocket::scheduleConnectTimeout() {
+  // Connection in progress.
+  auto timeout = connectTimeout_.count();
+  if (timeout > 0) {
+    // Start a timer in case the connection takes too long.
+    if (!writeTimeout_.scheduleTimeout(uint32_t(timeout))) {
+      throw AsyncSocketException(
+          AsyncSocketException::INTERNAL_ERROR,
+          withAddr("failed to schedule AsyncSocket connect timeout"));
+    }
+  }
+}
+
+void AsyncSocket::registerForConnectEvents() {
+  // Register for write events, so we'll
+  // be notified when the connection finishes/fails.
+  // Note that we don't register for a persistent event here.
+  assert(eventFlags_ == EventHandler::NONE);
+  eventFlags_ = EventHandler::WRITE;
+  if (!ioHandler_.registerHandler(eventFlags_)) {
+    throw AsyncSocketException(
+        AsyncSocketException::INTERNAL_ERROR,
+        withAddr("failed to register AsyncSocket connect handler"));
+  }
+}
+
 void AsyncSocket::connect(ConnectCallback* callback,
                            const string& ip, uint16_t port,
                            int timeout,
@@ -502,7 +536,7 @@ void AsyncSocket::connect(ConnectCallback* callback,
 
 void AsyncSocket::cancelConnect() {
   connectCallback_ = nullptr;
-  if (state_ == StateEnum::CONNECTING) {
+  if (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN) {
     closeNow();
   }
 }
@@ -514,7 +548,7 @@ void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
   // If we are currently pending on write requests, immediately update
   // writeTimeout_ with the new value.
   if ((eventFlags_ & EventHandler::WRITE) &&
-      (state_ != StateEnum::CONNECTING)) {
+      (state_ != StateEnum::CONNECTING && state_ != StateEnum::FAST_OPEN)) {
     assert(state_ == StateEnum::ESTABLISHED);
     assert((shutdownFlags_ & SHUT_WRITE) == 0);
     if (sendTimeout_ > 0) {
@@ -573,6 +607,7 @@ void AsyncSocket::setReadCB(ReadCallback *callback) {
 
   switch ((StateEnum)state_) {
     case StateEnum::CONNECTING:
+    case StateEnum::FAST_OPEN:
       // For convenience, we allow the read callback to be set while we are
       // still connecting.  We just store the callback for now.  Once the
       // connection completes we'll register for read events.
@@ -642,7 +677,12 @@ void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
   constexpr size_t kSmallSizeMax = 64;
   size_t count = buf->countChainElements();
   if (count <= kSmallSizeMax) {
+    // suppress "warning: variable length array 'vec' is used [-Wvla]"
+    FOLLY_PUSH_WARNING;
+    FOLLY_GCC_DISABLE_WARNING(vla);
     iovec vec[BOOST_PP_IF(FOLLY_HAVE_VLA, count, kSmallSizeMax)];
+    FOLLY_POP_WARNING;
+
     writeChainImpl(callback, vec, count, std::move(buf), flags);
   } else {
     iovec* vec = new iovec[count];
@@ -681,17 +721,18 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
 
   uint32_t countWritten = 0;
   uint32_t partialWritten = 0;
-  int bytesWritten = 0;
+  ssize_t bytesWritten = 0;
   bool mustRegister = false;
-  if (state_ == StateEnum::ESTABLISHED && !connecting()) {
+  if ((state_ == StateEnum::ESTABLISHED || state_ == StateEnum::FAST_OPEN) &&
+      !connecting()) {
     if (writeReqHead_ == nullptr) {
       // If we are established and there are no other writes pending,
       // we can attempt to perform the write immediately.
       assert(writeReqTail_ == nullptr);
       assert((eventFlags_ & EventHandler::WRITE) == 0);
 
-      auto writeResult =
-          performWrite(vec, count, flags, &countWritten, &partialWritten);
+      auto writeResult = performWrite(
+          vec, uint32_t(count), flags, &countWritten, &partialWritten);
       bytesWritten = writeResult.writeReturn;
       if (bytesWritten < 0) {
         auto errnoCopy = errno;
@@ -715,7 +756,13 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
           bufferCallback_->onEgressBuffered();
         }
       }
-      mustRegister = true;
+      if (!connecting()) {
+        // Writes might put the socket back into connecting state
+        // if TFO is enabled, and using TFO fails.
+        // This means that write timeouts would not be active, however
+        // connect timeouts would affect this stage.
+        mustRegister = true;
+      }
     }
   } else if (!connecting()) {
     // Invalid state for writing
@@ -725,14 +772,20 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
   // Create a new WriteRequest to add to the queue
   WriteRequest* req;
   try {
-    req = BytesWriteRequest::newRequest(this, callback, vec + countWritten,
-                                        count - countWritten, partialWritten,
-                                        bytesWritten, std::move(ioBuf), flags);
+    req = BytesWriteRequest::newRequest(
+        this,
+        callback,
+        vec + countWritten,
+        uint32_t(count - countWritten),
+        partialWritten,
+        uint32_t(bytesWritten),
+        std::move(ioBuf),
+        flags);
   } catch (const std::exception& ex) {
     // we mainly expect to catch std::bad_alloc here
     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
         withAddr(string("failed to append new WriteRequest: ") + ex.what()));
-    return failWrite(__func__, callback, bytesWritten, tex);
+    return failWrite(__func__, callback, size_t(bytesWritten), tex);
   }
   req->consume();
   if (writeReqTail_ == nullptr) {
@@ -839,7 +892,7 @@ void AsyncSocket::closeNow() {
   switch (state_) {
     case StateEnum::ESTABLISHED:
     case StateEnum::CONNECTING:
-    {
+    case StateEnum::FAST_OPEN: {
       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
       state_ = StateEnum::CLOSED;
 
@@ -971,7 +1024,7 @@ void AsyncSocket::shutdownWriteNow() {
       }
 
       // Shutdown writes on the file descriptor
-      ::shutdown(fd_, SHUT_WR);
+      shutdown(fd_, SHUT_WR);
 
       // Immediately fail all write requests
       failAllWrites(socketShutdownForWritesEx);
@@ -995,6 +1048,13 @@ void AsyncSocket::shutdownWriteNow() {
       // immediately shut down the write side of the socket.
       shutdownFlags_ |= SHUT_WRITE_PENDING;
       return;
+    case StateEnum::FAST_OPEN:
+      // In fast open state we haven't call connected yet, and if we shutdown
+      // the writes, we will never try to call connect, so shut everything down
+      shutdownFlags_ |= SHUT_WRITE;
+      // Immediately fail all write requests
+      failAllWrites(socketShutdownForWritesEx);
+      return;
     case StateEnum::CLOSED:
     case StateEnum::ERROR:
       // We should never get here.  SHUT_WRITE should always be set
@@ -1046,9 +1106,10 @@ bool AsyncSocket::hangup() const {
 }
 
 bool AsyncSocket::good() const {
-  return ((state_ == StateEnum::CONNECTING ||
-          state_ == StateEnum::ESTABLISHED) &&
-          (shutdownFlags_ == 0) && (eventBase_ != nullptr));
+  return (
+      (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN ||
+       state_ == StateEnum::ESTABLISHED) &&
+      (shutdownFlags_ == 0) && (eventBase_ != nullptr));
 }
 
 bool AsyncSocket::error() const {
@@ -1066,6 +1127,9 @@ void AsyncSocket::attachEventBase(EventBase* eventBase) {
   eventBase_ = eventBase;
   ioHandler_.attachEventBase(eventBase);
   writeTimeout_.attachEventBase(eventBase);
+  if (evbChangeCb_) {
+    evbChangeCb_->evbAttached(this);
+  }
 }
 
 void AsyncSocket::detachEventBase() {
@@ -1078,6 +1142,9 @@ void AsyncSocket::detachEventBase() {
   eventBase_ = nullptr;
   ioHandler_.detachEventBase();
   writeTimeout_.detachEventBase();
+  if (evbChangeCb_) {
+    evbChangeCb_->evbDetached(this);
+  }
 }
 
 bool AsyncSocket::isDetachable() const {
@@ -1101,6 +1168,10 @@ void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
   *address = addr_;
 }
 
+bool AsyncSocket::getTFOSucceded() const {
+  return detail::tfo_succeeded(fd_);
+}
+
 int AsyncSocket::setNoDelay(bool noDelay) {
   if (fd_ < 0) {
     VLOG(4) << "AsyncSocket::setNoDelay() called on non-open socket "
@@ -1134,8 +1205,12 @@ int AsyncSocket::setCongestionFlavor(const std::string &cname) {
 
   }
 
-  if (setsockopt(fd_, IPPROTO_TCP, TCP_CONGESTION, cname.c_str(),
-        cname.length() + 1) != 0) {
+  if (setsockopt(
+          fd_,
+          IPPROTO_TCP,
+          TCP_CONGESTION,
+          cname.c_str(),
+          socklen_t(cname.length() + 1)) != 0) {
     int errnoCopy = errno;
     VLOG(2) << "failed to update TCP_CONGESTION option on AsyncSocket "
             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
@@ -1283,7 +1358,7 @@ AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
   }
 }
 
-void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
+void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) {
   // no matter what, buffer should be preapared for non-ssl socket
   CHECK(readCallback_);
   readCallback_->getReadBuffer(buf, buflen);
@@ -1506,7 +1581,7 @@ void AsyncSocket::handleWrite() noexcept {
             }
           } else {
             // Reads are still enabled, so we are only doing a half-shutdown
-            ::shutdown(fd_, SHUT_WR);
+            shutdown(fd_, SHUT_WR);
           }
         }
       }
@@ -1572,7 +1647,6 @@ void AsyncSocket::handleInitialReadWrite() noexcept {
   // one here just to make sure, in case one of our calling code paths ever
   // changes.
   DestructorGuard dg(this);
-
   // If we have a readCallback_, make sure we enable read events.  We
   // may already be registered for reads if connectSuccess() set
   // the read calback.
@@ -1657,7 +1731,7 @@ void AsyncSocket::handleConnect() noexcept {
     // are still connecting we just abort the connect rather than waiting for
     // it to complete.
     assert((shutdownFlags_ & SHUT_READ) == 0);
-    ::shutdown(fd_, SHUT_WR);
+    shutdown(fd_, SHUT_WR);
     shutdownFlags_ |= SHUT_WRITE;
   }
 
@@ -1695,17 +1769,95 @@ void AsyncSocket::timeoutExpired() noexcept {
   if (state_ == StateEnum::CONNECTING) {
     // connect() timed out
     // Unregister for I/O events.
-    AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
-                           "connect timed out");
-    failConnect(__func__, ex);
+    if (connectCallback_) {
+      AsyncSocketException ex(
+          AsyncSocketException::TIMED_OUT, "connect timed out");
+      failConnect(__func__, ex);
+    } else {
+      // we faced a connect error without a connect callback, which could
+      // happen due to TFO.
+      AsyncSocketException ex(
+          AsyncSocketException::TIMED_OUT, "write timed out during connection");
+      failWrite(__func__, ex);
+    }
   } else {
     // a normal write operation timed out
-    assert(state_ == StateEnum::ESTABLISHED);
     AsyncSocketException ex(AsyncSocketException::TIMED_OUT, "write timed out");
     failWrite(__func__, ex);
   }
 }
 
+ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
+  return detail::tfo_sendmsg(fd, msg, msg_flags);
+}
+
+AsyncSocket::WriteResult
+AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
+  ssize_t totalWritten = 0;
+  if (state_ == StateEnum::FAST_OPEN) {
+    sockaddr_storage addr;
+    auto len = addr_.getAddress(&addr);
+    msg->msg_name = &addr;
+    msg->msg_namelen = len;
+    totalWritten = tfoSendMsg(fd_, msg, msg_flags);
+    if (totalWritten >= 0) {
+      tfoFinished_ = true;
+      state_ = StateEnum::ESTABLISHED;
+      // We schedule this asynchrously so that we don't end up
+      // invoking initial read or write while a write is in progress.
+      scheduleInitialReadWrite();
+    } else if (errno == EINPROGRESS) {
+      VLOG(4) << "TFO falling back to connecting";
+      // A normal sendmsg doesn't return EINPROGRESS, however
+      // TFO might fallback to connecting if there is no
+      // cookie.
+      state_ = StateEnum::CONNECTING;
+      try {
+        scheduleConnectTimeout();
+        registerForConnectEvents();
+      } catch (const AsyncSocketException& ex) {
+        return WriteResult(
+            WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
+      }
+      // Let's fake it that no bytes were written and return an errno.
+      errno = EAGAIN;
+      totalWritten = -1;
+    } else if (errno == EOPNOTSUPP) {
+      // Try falling back to connecting.
+      VLOG(4) << "TFO not supported";
+      state_ = StateEnum::CONNECTING;
+      try {
+        int ret = socketConnect((const sockaddr*)&addr, len);
+        if (ret == 0) {
+          // connect succeeded immediately
+          // Treat this like no data was written.
+          state_ = StateEnum::ESTABLISHED;
+          scheduleInitialReadWrite();
+        }
+        // If there was no exception during connections,
+        // we would return that no bytes were written.
+        errno = EAGAIN;
+        totalWritten = -1;
+      } catch (const AsyncSocketException& ex) {
+        return WriteResult(
+            WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
+      }
+    } else if (errno == EAGAIN) {
+      // Normally sendmsg would indicate that the write would block.
+      // However in the fast open case, it would indicate that sendmsg
+      // fell back to a connect. This is a return code from connect()
+      // instead, and is an error condition indicating no fds available.
+      return WriteResult(
+          WRITE_ERROR,
+          folly::make_unique<AsyncSocketException>(
+              AsyncSocketException::UNKNOWN, "No more free local ports"));
+    }
+  } else {
+    totalWritten = ::sendmsg(fd, msg, msg_flags);
+  }
+  return WriteResult(totalWritten);
+}
+
 AsyncSocket::WriteResult AsyncSocket::performWrite(
     const iovec* vec,
     uint32_t count,
@@ -1740,9 +1892,18 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
     // marks that this is the last byte of a record (response)
     msg_flags |= MSG_EOR;
   }
-  ssize_t totalWritten = ::sendmsg(fd_, &msg, msg_flags);
+  auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
+  auto totalWritten = writeResult.writeReturn;
   if (totalWritten < 0) {
-    if (errno == EAGAIN) {
+    bool tryAgain = (errno == EAGAIN);
+#ifdef __APPLE__
+    // Apple has a bug where doing a second write on a socket which we
+    // have opened with TFO causes an ENOTCONN to be thrown. However the
+    // socket is really connected, so treat ENOTCONN as a EAGAIN until
+    // this bug is fixed.
+    tryAgain |= (errno == ENOTCONN);
+#endif
+    if (!writeResult.exception && tryAgain) {
       // TCP buffer is full; we can't write any more data right now.
       *countWritten = 0;
       *partialWritten = 0;
@@ -1751,14 +1912,14 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
     // error
     *countWritten = 0;
     *partialWritten = 0;
-    return WriteResult(WRITE_ERROR);
+    return writeResult;
   }
 
   appBytesWritten_ += totalWritten;
 
   uint32_t bytesWritten;
   uint32_t n;
-  for (bytesWritten = totalWritten, n = 0; n < count; ++n) {
+  for (bytesWritten = uint32_t(totalWritten), n = 0; n < count; ++n) {
     const iovec* v = vec + n;
     if (v->iov_len > bytesWritten) {
       // Partial write finished in the middle of this iovec
@@ -1767,7 +1928,7 @@ AsyncSocket::WriteResult AsyncSocket::performWrite(
       return WriteResult(totalWritten);
     }
 
-    bytesWritten -= v->iov_len;
+    bytesWritten -= uint32_t(v->iov_len);
   }
 
   assert(bytesWritten == 0);
@@ -1841,12 +2002,7 @@ void AsyncSocket::startFail() {
   }
 }
 
-void AsyncSocket::finishFail() {
-  assert(state_ == StateEnum::ERROR);
-  assert(getDestructorGuardCount() > 0);
-
-  AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
-                         withAddr("socket closing after error"));
+void AsyncSocket::invokeAllErrors(const AsyncSocketException& ex) {
   invokeConnectErr(ex);
   failAllWrites(ex);
 
@@ -1857,6 +2013,22 @@ void AsyncSocket::finishFail() {
   }
 }
 
+void AsyncSocket::finishFail() {
+  assert(state_ == StateEnum::ERROR);
+  assert(getDestructorGuardCount() > 0);
+
+  AsyncSocketException ex(
+      AsyncSocketException::INTERNAL_ERROR,
+      withAddr("socket closing after error"));
+  invokeAllErrors(ex);
+}
+
+void AsyncSocket::finishFail(const AsyncSocketException& ex) {
+  assert(state_ == StateEnum::ERROR);
+  assert(getDestructorGuardCount() > 0);
+  invokeAllErrors(ex);
+}
+
 void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
              << state_ << " host=" << addr_.describe()
@@ -1874,7 +2046,7 @@ void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
   startFail();
 
   invokeConnectErr(ex);
-  finishFail();
+  finishFail(ex);
 }
 
 void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {