Start fixing implicit truncations
[folly.git] / folly / io / async / AsyncSocket.cpp
index bb41685dd3d3fb5a70befadf68cec62cc25d3597..987dfa5af0ee52fdfcb6f848b0b70494eca7a70c 100644 (file)
@@ -95,8 +95,10 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
     if (getNext() != nullptr) {
       writeFlags = writeFlags | WriteFlags::CORK;
     }
-    return socket_->performWrite(
+    auto writeResult = socket_->performWrite(
         getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
+    bytesWritten_ = writeResult.writeReturn > 0 ? writeResult.writeReturn : 0;
+    return writeResult;
   }
 
   bool isComplete() override {
@@ -124,7 +126,8 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
     currentOp->iov_len -= partialBytes_;
 
     // Increment the totalBytesWritten_ count by bytesWritten_;
-    totalBytesWritten_ += bytesWritten_;
+    assert(bytesWritten_ >= 0);
+    totalBytesWritten_ += uint32_t(bytesWritten_);
   }
 
  private:
@@ -418,7 +421,7 @@ void AsyncSocket::connect(ConnectCallback* callback,
 
     // Apply the additional options if any.
     for (const auto& opt: options) {
-      int rv = opt.first.apply(fd_, opt.second);
+      rv = opt.first.apply(fd_, opt.second);
       if (rv != 0) {
         auto errnoCopy = errno;
         throw AsyncSocketException(
@@ -1333,7 +1336,7 @@ AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
   }
 }
 
-void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
+void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) {
   // no matter what, buffer should be preapared for non-ssl socket
   CHECK(readCallback_);
   readCallback_->getReadBuffer(buf, buflen);
@@ -1798,8 +1801,8 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
       errno = EAGAIN;
       totalWritten = -1;
     } else if (errno == EOPNOTSUPP) {
-      VLOG(4) << "TFO not supported";
       // Try falling back to connecting.
+      VLOG(4) << "TFO not supported";
       state_ = StateEnum::CONNECTING;
       try {
         int ret = socketConnect((const sockaddr*)&addr, len);
@@ -1977,12 +1980,7 @@ void AsyncSocket::startFail() {
   }
 }
 
-void AsyncSocket::finishFail() {
-  assert(state_ == StateEnum::ERROR);
-  assert(getDestructorGuardCount() > 0);
-
-  AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
-                         withAddr("socket closing after error"));
+void AsyncSocket::invokeAllErrors(const AsyncSocketException& ex) {
   invokeConnectErr(ex);
   failAllWrites(ex);
 
@@ -1993,6 +1991,22 @@ void AsyncSocket::finishFail() {
   }
 }
 
+void AsyncSocket::finishFail() {
+  assert(state_ == StateEnum::ERROR);
+  assert(getDestructorGuardCount() > 0);
+
+  AsyncSocketException ex(
+      AsyncSocketException::INTERNAL_ERROR,
+      withAddr("socket closing after error"));
+  invokeAllErrors(ex);
+}
+
+void AsyncSocket::finishFail(const AsyncSocketException& ex) {
+  assert(state_ == StateEnum::ERROR);
+  assert(getDestructorGuardCount() > 0);
+  invokeAllErrors(ex);
+}
+
 void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
              << state_ << " host=" << addr_.describe()
@@ -2010,7 +2024,7 @@ void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
   startFail();
 
   invokeConnectErr(ex);
-  finishFail();
+  finishFail(ex);
 }
 
 void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {