Fix AsyncSSLSocket handshake error reporting.
authorKyle Nekritz <knekritz@fb.com>
Thu, 3 Mar 2016 16:51:53 +0000 (08:51 -0800)
committerFacebook Github Bot 9 <facebook-github-bot-9-bot@fb.com>
Thu, 3 Mar 2016 17:05:29 +0000 (09:05 -0800)
Summary:https://www.openssl.org/docs/manmaster/ssl/SSL_get_error.html
OpenSSL errors are a pain to deal with and we were handling several cases
incorrectly, resulting in a ton of "DH lib" errors when none were likely
actually DH lib errors.

Reviewed By: siyengar

Differential Revision: D2999084

fb-gh-sync-id: b3182be2c199f79ed341af7dbf7524197a838584
shipit-source-id: b3182be2c199f79ed341af7dbf7524197a838584

folly/io/async/AsyncSSLSocket.cpp
folly/io/async/AsyncSSLSocket.h
folly/io/async/test/AsyncSSLSocketTest.cpp
folly/io/async/test/AsyncSSLSocketTest.h
folly/io/async/test/BlockingSocket.h

index 476b13b..0b71f84 100644 (file)
@@ -246,15 +246,38 @@ void* initEorBioMethod(void) {
   return nullptr;
 }
 
+std::string decodeOpenSSLError(int sslError,
+                               unsigned long errError,
+                               int sslOperationReturnValue) {
+  if (sslError == SSL_ERROR_SYSCALL && errError == 0) {
+    if (sslOperationReturnValue == 0) {
+      return "SSL_ERROR_SYSCALL: EOF";
+    } else {
+      // In this case errno is set, AsyncSocketException will add it.
+      return "SSL_ERROR_SYSCALL";
+    }
+  } else if (sslError == SSL_ERROR_ZERO_RETURN) {
+    // This signifies a TLS closure alert.
+    return "SSL_ERROR_ZERO_RETURN";
+  } else {
+    char buf[256];
+    std::string msg(ERR_error_string(errError, buf));
+    return msg;
+  }
+}
+
 } // anonymous namespace
 
 namespace folly {
 
-SSLException::SSLException(int sslError, int errno_copy):
-    AsyncSocketException(
-      AsyncSocketException::SSL_ERROR,
-      ERR_error_string(sslError, msg_),
-      sslError == SSL_ERROR_SYSCALL ? errno_copy : 0), error_(sslError) {}
+SSLException::SSLException(int sslError,
+                           unsigned long errError,
+                           int sslOperationReturnValue,
+                           int errno_copy)
+    : AsyncSocketException(
+          AsyncSocketException::SSL_ERROR,
+          decodeOpenSSLError(sslError, errError, sslOperationReturnValue),
+          sslError == SSL_ERROR_SYSCALL ? errno_copy : 0) {}
 
 /**
  * Create a client AsyncSSLSocket
@@ -889,8 +912,11 @@ int AsyncSSLSocket::getSSLCertSize() const {
   return certSize;
 }
 
-bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
-  int error = *errorOut = SSL_get_error(ssl_, ret);
+bool AsyncSSLSocket::willBlock(int ret,
+                               int* sslErrorOut,
+                               unsigned long* errErrorOut) noexcept {
+  *errErrorOut = 0;
+  int error = *sslErrorOut = SSL_get_error(ssl_, ret);
   if (error == SSL_ERROR_WANT_READ) {
     // Register for read event if not already.
     updateEventRegistration(EventHandler::READ, EventHandler::WRITE);
@@ -943,7 +969,7 @@ bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
   } else {
     // SSL_ERROR_ZERO_RETURN is processed here so we can get some detail
     // in the log
-    long lastError = ERR_get_error();
+    unsigned long lastError = *errErrorOut = ERR_get_error();
     VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
             << "state=" << state_ << ", "
             << "sslState=" << sslState_ << ", "
@@ -955,16 +981,6 @@ bool AsyncSSLSocket::willBlock(int ret, int *errorOut) noexcept {
             << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
             << "func: " << ERR_func_error_string(lastError) << ", "
             << "reason: " << ERR_reason_error_string(lastError);
-    if (error != SSL_ERROR_SYSCALL) {
-      if (error == SSL_ERROR_SSL) {
-        *errorOut = lastError;
-      }
-      if ((unsigned long)lastError < 0x8000) {
-        errno = ENOSYS;
-      } else {
-        errno = lastError;
-      }
-    }
     ERR_clear_error();
     return false;
   }
@@ -1042,12 +1058,14 @@ AsyncSSLSocket::handleAccept() noexcept {
   errno = 0;
   int ret = SSL_accept(ssl_);
   if (ret <= 0) {
-    int error;
-    if (willBlock(ret, &error)) {
+    int sslError;
+    unsigned long errError;
+    int errnoCopy = errno;
+    if (willBlock(ret, &sslError, &errError)) {
       return;
     } else {
       sslState_ = STATE_ERROR;
-      SSLException ex(error, errno);
+      SSLException ex(sslError, errError, ret, errnoCopy);
       return failHandshake(__func__, ex);
     }
   }
@@ -1104,12 +1122,14 @@ AsyncSSLSocket::handleConnect() noexcept {
   errno = 0;
   int ret = SSL_connect(ssl_);
   if (ret <= 0) {
-    int error;
-    if (willBlock(ret, &error)) {
+    int sslError;
+    unsigned long errError;
+    int errnoCopy = errno;
+    if (willBlock(ret, &sslError, &errError)) {
       return;
     } else {
       sslState_ = STATE_ERROR;
-      SSLException ex(error, errno);
+      SSLException ex(sslError, errError, ret, errnoCopy);
       return failHandshake(__func__, ex);
     }
   }
index b203f13..732d486 100644 (file)
@@ -35,13 +35,10 @@ namespace folly {
 
 class SSLException: public folly::AsyncSocketException {
  public:
-  SSLException(int sslError, int errno_copy);
-
-  int getSSLError() const { return error_; }
-
- protected:
-  int error_;
-  char msg_[256];
+  SSLException(int sslError,
+               unsigned long errError,
+               int sslOperationReturnValue,
+               int errno_copy);
 };
 
 /**
@@ -782,7 +779,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   void handleConnect() noexcept override;
 
   void invalidState(HandshakeCB* callback);
-  bool willBlock(int ret, int *errorOut) noexcept;
+  bool willBlock(int ret,
+                 int* sslErrorOut,
+                 unsigned long* errErrorOut) noexcept;
 
   virtual void checkForImmediateRead() noexcept override;
   // AsyncSocket calls this at the wrong time for SSL
index 5bb0a82..87909f0 100644 (file)
@@ -879,7 +879,7 @@ TEST(AsyncSSLSocketTest, SSLServerCacheCloseTest) {
   server.getEventBase().runInEventBaseThread([&handshakeCallback]{
       handshakeCallback.closeSocket();});
   // give time for the cache lookup to come back and find it closed
-  usleep(500000);
+  handshakeCallback.waitForHandshake();
 
   EXPECT_EQ(server.getAsyncCallbacks(), 1);
   EXPECT_EQ(server.getAsyncLookups(), 1);
@@ -1520,6 +1520,71 @@ TEST(AsyncSSLSocketTest, UnencryptedTest) {
   EXPECT_EQ(AsyncSSLSocket::STATE_ESTABLISHED, client->getSSLState());
 }
 
+TEST(AsyncSSLSocketTest, ConnResetErrorString) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  WriteErrorCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback,
+                                      HandshakeCallback::EXPECT_ERROR);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+  socket->open();
+  uint8_t buf[3] = {0x16, 0x03, 0x01};
+  socket->write(buf, sizeof(buf));
+  socket->closeWithReset();
+
+  handshakeCallback.waitForHandshake();
+  EXPECT_NE(handshakeCallback.errorString_.find("SSL_ERROR_SYSCALL"),
+            std::string::npos);
+  EXPECT_NE(handshakeCallback.errorString_.find("104"), std::string::npos);
+}
+
+TEST(AsyncSSLSocketTest, ConnEOFErrorString) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  WriteErrorCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback,
+                                      HandshakeCallback::EXPECT_ERROR);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+  socket->open();
+  uint8_t buf[3] = {0x16, 0x03, 0x01};
+  socket->write(buf, sizeof(buf));
+  socket->close();
+
+  handshakeCallback.waitForHandshake();
+  EXPECT_NE(handshakeCallback.errorString_.find("SSL_ERROR_SYSCALL"),
+            std::string::npos);
+  EXPECT_NE(handshakeCallback.errorString_.find("EOF"), std::string::npos);
+}
+
+TEST(AsyncSSLSocketTest, ConnOpenSSLErrorString) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  WriteErrorCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback,
+                                      HandshakeCallback::EXPECT_ERROR);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  TestSSLServer server(&acceptCallback);
+
+  auto socket = std::make_shared<BlockingSocket>(server.getAddress(), nullptr);
+  socket->open();
+  uint8_t buf[256] = {0x16, 0x03};
+  memset(buf + 2, 'a', sizeof(buf) - 2);
+  socket->write(buf, sizeof(buf));
+  socket->close();
+
+  handshakeCallback.waitForHandshake();
+  EXPECT_NE(handshakeCallback.errorString_.find("SSL routines"),
+            std::string::npos);
+  EXPECT_NE(handshakeCallback.errorString_.find("unknown protocol"),
+            std::string::npos);
+}
+
 } // namespace
 
 ///////////////////////////////////////////////////////////////////////////
index 612d0e1..a4e18aa 100644 (file)
@@ -305,6 +305,8 @@ public:
 
   // Functions inherited from AsyncSSLSocketHandshakeCallback
   void handshakeSuc(AsyncSSLSocket *sock) noexcept override {
+    std::lock_guard<std::mutex> g(mutex_);
+    cv_.notify_all();
     EXPECT_EQ(sock, socket_.get());
     std::cerr << "HandshakeCallback::connectionAccepted" << std::endl;
     rcb_->setSocket(socket_);
@@ -313,12 +315,20 @@ public:
   }
   void handshakeErr(AsyncSSLSocket* /* sock */,
                     const AsyncSocketException& ex) noexcept override {
+    std::lock_guard<std::mutex> g(mutex_);
+    cv_.notify_all();
     std::cerr << "HandshakeCallback::handshakeError " << ex.what() << std::endl;
     state = (expect_ == EXPECT_ERROR) ? STATE_SUCCEEDED : STATE_FAILED;
     if (expect_ == EXPECT_ERROR) {
       // rcb will never be invoked
       rcb_->setState(STATE_SUCCEEDED);
     }
+    errorString_ = ex.what();
+  }
+
+  void waitForHandshake() {
+    std::unique_lock<std::mutex> lock(mutex_);
+    cv_.wait(lock, [this] { return state != STATE_WAITING; });
   }
 
   ~HandshakeCallback() {
@@ -334,6 +344,9 @@ public:
   std::shared_ptr<AsyncSSLSocket> socket_;
   ReadCallbackBase *rcb_;
   ExpectType expect_;
+  std::mutex mutex_;
+  std::condition_variable cv_;
+  std::string errorString_;
 };
 
 class SSLServerAcceptCallbackBase:
index 7858145..7cfb870 100644 (file)
@@ -45,6 +45,7 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
   void close() {
     sock_->close();
   }
+  void closeWithReset() { sock_->closeWithReset(); }
 
   int32_t write(uint8_t const* buf, size_t len) {
     sock_->write(this, buf, len);