Fix TFO refused case
authorSubodh Iyengar <subodh@fb.com>
Fri, 25 Nov 2016 05:18:22 +0000 (21:18 -0800)
committerFacebook Github Bot <facebook-github-bot-bot@fb.com>
Fri, 25 Nov 2016 05:53:41 +0000 (21:53 -0800)
Summary:
When TFO falls back, it's possible that
the fallback can also error out.

We handle this correctly in AsyncSocket,
however because AsyncSSLSocket is so
inter-twined with AsyncSocket, we missed
the case of error as well.

This changes it so that a connect error on
fallback will cause a handshake error

Differential Revision: D4226477

fbshipit-source-id: c6e845e4a907bfef1e6ad1b4118db47184d047e0

folly/io/async/AsyncSSLSocket.cpp
folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h
folly/io/async/test/AsyncSSLSocketTest.cpp
folly/io/async/test/AsyncSocketTest2.cpp
folly/io/async/test/SocketClient.cpp

index 8804b8d1ff64633303e23c2c8d99024618ff5498..f8e6b47104d245ed5e504ea03910efb57ebe60c4 100644 (file)
@@ -1127,11 +1127,21 @@ AsyncSSLSocket::handleConnect() noexcept {
 void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) {
   connectionTimeout_.cancelTimeout();
   AsyncSocket::invokeConnectErr(ex);
 void AsyncSSLSocket::invokeConnectErr(const AsyncSocketException& ex) {
   connectionTimeout_.cancelTimeout();
   AsyncSocket::invokeConnectErr(ex);
+  if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
+    assert(tfoAttempted_);
+    if (handshakeTimeout_.isScheduled()) {
+      handshakeTimeout_.cancelTimeout();
+    }
+    // If we fell back to connecting state during TFO and the connection
+    // failed, it would be an SSL failure as well.
+    invokeHandshakeErr(ex);
+  }
 }
 
 void AsyncSSLSocket::invokeConnectSuccess() {
   connectionTimeout_.cancelTimeout();
   if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
 }
 
 void AsyncSSLSocket::invokeConnectSuccess() {
   connectionTimeout_.cancelTimeout();
   if (sslState_ == SSLStateEnum::STATE_CONNECTING) {
+    assert(tfoAttempted_);
     // If we failed TFO, we'd fall back to trying to connect the socket,
     // to setup things like timeouts.
     startSSLConnect();
     // If we failed TFO, we'd fall back to trying to connect the socket,
     // to setup things like timeouts.
     startSSLConnect();
index 652f90f5d46cd2578199670dc71c4a5024980459..729185b89c4f35461fd5cf015a9483e0a645d849 100644 (file)
@@ -1798,8 +1798,8 @@ AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
       errno = EAGAIN;
       totalWritten = -1;
     } else if (errno == EOPNOTSUPP) {
       errno = EAGAIN;
       totalWritten = -1;
     } else if (errno == EOPNOTSUPP) {
-      VLOG(4) << "TFO not supported";
       // Try falling back to connecting.
       // Try falling back to connecting.
+      VLOG(4) << "TFO not supported";
       state_ = StateEnum::CONNECTING;
       try {
         int ret = socketConnect((const sockaddr*)&addr, len);
       state_ = StateEnum::CONNECTING;
       try {
         int ret = socketConnect((const sockaddr*)&addr, len);
@@ -1977,12 +1977,7 @@ void AsyncSocket::startFail() {
   }
 }
 
   }
 }
 
-void AsyncSocket::finishFail() {
-  assert(state_ == StateEnum::ERROR);
-  assert(getDestructorGuardCount() > 0);
-
-  AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
-                         withAddr("socket closing after error"));
+void AsyncSocket::invokeAllErrors(const AsyncSocketException& ex) {
   invokeConnectErr(ex);
   failAllWrites(ex);
 
   invokeConnectErr(ex);
   failAllWrites(ex);
 
@@ -1993,6 +1988,22 @@ void AsyncSocket::finishFail() {
   }
 }
 
   }
 }
 
+void AsyncSocket::finishFail() {
+  assert(state_ == StateEnum::ERROR);
+  assert(getDestructorGuardCount() > 0);
+
+  AsyncSocketException ex(
+      AsyncSocketException::INTERNAL_ERROR,
+      withAddr("socket closing after error"));
+  invokeAllErrors(ex);
+}
+
+void AsyncSocket::finishFail(const AsyncSocketException& ex) {
+  assert(state_ == StateEnum::ERROR);
+  assert(getDestructorGuardCount() > 0);
+  invokeAllErrors(ex);
+}
+
 void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
              << state_ << " host=" << addr_.describe()
 void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
              << state_ << " host=" << addr_.describe()
@@ -2010,7 +2021,7 @@ void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
   startFail();
 
   invokeConnectErr(ex);
   startFail();
 
   invokeConnectErr(ex);
-  finishFail();
+  finishFail(ex);
 }
 
 void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
 }
 
 void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
index 3f5d715d60952690b5b158fdf5d3f98a9099b4fc..2fed39a070a28b602328417585e86b26fd656a31 100644 (file)
@@ -877,6 +877,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   // error handling methods
   void startFail();
   void finishFail();
   // error handling methods
   void startFail();
   void finishFail();
+  void finishFail(const AsyncSocketException& ex);
+  void invokeAllErrors(const AsyncSocketException& ex);
   void fail(const char* fn, const AsyncSocketException& ex);
   void failConnect(const char* fn, const AsyncSocketException& ex);
   void failRead(const char* fn, const AsyncSocketException& ex);
   void fail(const char* fn, const AsyncSocketException& ex);
   void failConnect(const char* fn, const AsyncSocketException& ex);
   void failRead(const char* fn, const AsyncSocketException& ex);
index 65c056f6e8a6d5d7ab1baa2385ff22bfc0ba4d25..bc7c4f4896db1e174a5967854a49ae320def8036 100644 (file)
@@ -1918,6 +1918,21 @@ TEST(AsyncSSLSocketTest, HandshakeTFOFallbackTimeout) {
   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
 }
 
   EXPECT_THAT(ccb.error, testing::HasSubstr("SSL connect timed out"));
 }
 
+TEST(AsyncSSLSocketTest, HandshakeTFORefused) {
+  // Start listening on a local port
+  EventBase evb;
+
+  // Hopefully nothing is listening on this address
+  SocketAddress addr("127.0.0.1", 65535);
+  auto socket = setupSocketWithFallback(&evb, addr, AtMost(1));
+  ConnCallback ccb;
+  socket->connect(&ccb, addr, 100);
+
+  evb.loop();
+  EXPECT_EQ(ConnCallback::State::ERROR, ccb.state);
+  EXPECT_THAT(ccb.error, testing::HasSubstr("refused"));
+}
+
 #endif
 
 } // namespace
 #endif
 
 } // namespace
index 36ccc305928ba185ba86d108edb4ef0a06060c9f..afe23fa134eb37f37e037373df6f4e627f6fc1e9 100644 (file)
@@ -2524,7 +2524,7 @@ TEST(AsyncSocketTest, ConnectTFOSupplyEarlyReadCB) {
 /**
  * Test connecting to a server that isn't listening
  */
 /**
  * Test connecting to a server that isn't listening
  */
-TEST(AsyncSocketTest, ConnectRefusedTFO) {
+TEST(AsyncSocketTest, ConnectRefusedImmediatelyTFO) {
   EventBase evb;
 
   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
   EventBase evb;
 
   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
@@ -2541,7 +2541,6 @@ TEST(AsyncSocketTest, ConnectRefusedTFO) {
   WriteCallback write1;
   // Trigger the connect if TFO attempt is supported.
   socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
   WriteCallback write1;
   // Trigger the connect if TFO attempt is supported.
   socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
-  evb.loop();
   WriteCallback write2;
   socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
   evb.loop();
   WriteCallback write2;
   socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
   evb.loop();
@@ -2675,6 +2674,51 @@ TEST(AsyncSocketTest, TestTFOUnsupported) {
   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
 }
 
   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
 }
 
+TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) {
+  EventBase evb;
+
+  auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
+  socket->enableTFO();
+
+  // Hopefully this fails
+  folly::SocketAddress fakeAddr("127.0.0.1", 65535);
+  EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
+      .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
+        sockaddr_storage addr;
+        auto len = fakeAddr.getAddress(&addr);
+        int ret = connect(fd, (const struct sockaddr*)&addr, len);
+        LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : "
+                  << errno;
+        return ret;
+      }));
+
+  // Hopefully nothing is actually listening on this address
+  ConnCallback cb;
+  socket->connect(&cb, fakeAddr, 30);
+
+  WriteCallback write1;
+  // Trigger the connect if TFO attempt is supported.
+  socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
+
+  if (socket->getTFOFinished()) {
+    // This test is useless now.
+    return;
+  }
+  WriteCallback write2;
+  // Trigger the connect if TFO attempt is supported.
+  socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
+  evb.loop();
+
+  EXPECT_EQ(STATE_FAILED, write1.state);
+  EXPECT_EQ(STATE_FAILED, write2.state);
+  EXPECT_FALSE(socket->getTFOSucceded());
+
+  EXPECT_EQ(STATE_SUCCEEDED, cb.state);
+  EXPECT_LE(0, socket->getConnectTime().count());
+  EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
+  EXPECT_TRUE(socket->getTFOAttempted());
+}
+
 TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) {
   // Try connecting to server that won't respond.
   //
 TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) {
   // Try connecting to server that won't respond.
   //
index 1a93ab95bcddc2e7c844065a602d3355ff032185..2fbb546bd32d5778fc487a850738925a23d43fc6 100644 (file)
@@ -27,6 +27,9 @@ DEFINE_int32(port, 0, "port");
 DEFINE_bool(tfo, false, "enable tfo");
 DEFINE_string(msg, "", "Message to send");
 DEFINE_bool(ssl, false, "use ssl");
 DEFINE_bool(tfo, false, "enable tfo");
 DEFINE_string(msg, "", "Message to send");
 DEFINE_bool(ssl, false, "use ssl");
+DEFINE_int32(timeout_ms, 0, "timeout");
+DEFINE_int32(sendtimeout_ms, 0, "send timeout");
+DEFINE_int32(num_writes, 1, "number of writes");
 
 int main(int argc, char** argv) {
   gflags::ParseCommandLineFlags(&argc, &argv, true);
 
 int main(int argc, char** argv) {
   gflags::ParseCommandLineFlags(&argc, &argv, true);
@@ -53,6 +56,10 @@ int main(int argc, char** argv) {
 #endif
   }
 
 #endif
   }
 
+  if (FLAGS_sendtimeout_ms != 0) {
+    socket->setSendTimeout(FLAGS_sendtimeout_ms);
+  }
+
   // Keep this around
   auto sockAddr = socket.get();
 
   // Keep this around
   auto sockAddr = socket.get();
 
@@ -60,10 +67,13 @@ int main(int argc, char** argv) {
   SocketAddress addr;
   addr.setFromHostPort(FLAGS_host, FLAGS_port);
   sock.setAddress(addr);
   SocketAddress addr;
   addr.setFromHostPort(FLAGS_host, FLAGS_port);
   sock.setAddress(addr);
-  sock.open();
+  std::chrono::milliseconds timeout(FLAGS_timeout_ms);
+  sock.open(timeout);
   LOG(INFO) << "connected to " << addr.getAddressStr();
 
   LOG(INFO) << "connected to " << addr.getAddressStr();
 
-  sock.write((const uint8_t*)FLAGS_msg.data(), FLAGS_msg.size());
+  for (int32_t i = 0; i < FLAGS_num_writes; ++i) {
+    sock.write((const uint8_t*)FLAGS_msg.data(), FLAGS_msg.size());
+  }
 
   LOG(INFO) << "TFO attempted: " << sockAddr->getTFOAttempted();
   LOG(INFO) << "TFO finished: " << sockAddr->getTFOFinished();
 
   LOG(INFO) << "TFO attempted: " << sockAddr->getTFOAttempted();
   LOG(INFO) << "TFO finished: " << sockAddr->getTFOFinished();