Stop abusing errno
authorSubodh Iyengar <subodh@fb.com>
Thu, 28 Apr 2016 19:10:01 +0000 (12:10 -0700)
committerFacebook Github Bot 0 <facebook-github-bot-0-bot@fb.com>
Thu, 28 Apr 2016 19:20:29 +0000 (12:20 -0700)
Summary:
We abuse errno to propagate exceptions from AsyncSSLSocket.
Stop doing this and propagate exceptions correctly.

This also formats the exception messages better.

Reviewed By: anirudhvr

Differential Revision: D3226808

fb-gh-sync-id: 15a5e67b0332136857e5fb85b1765757e548e040
fbshipit-source-id: 15a5e67b0332136857e5fb85b1765757e548e040

folly/Makefile.am
folly/io/async/AsyncSSLSocket.cpp
folly/io/async/AsyncSSLSocket.h
folly/io/async/AsyncSocket.cpp
folly/io/async/AsyncSocket.h
folly/io/async/ssl/SSLErrors.cpp [new file with mode: 0644]
folly/io/async/ssl/SSLErrors.h [new file with mode: 0644]
folly/io/async/test/AsyncSSLSocketTest.cpp
folly/io/async/test/AsyncSSLSocketTest.h
folly/io/async/test/AsyncSSLSocketWriteTest.cpp
folly/io/async/test/BlockingSocket.h

index f82d4898586a0f8dc6d0f5c051c6bb0d31d80b96..173937948da87ad6a30badcee1940ea0acc2956d 100644 (file)
@@ -234,6 +234,7 @@ nobase_follyinclude_HEADERS = \
        io/async/HHWheelTimer.h \
        io/async/ssl/OpenSSLPtrTypes.h \
        io/async/ssl/OpenSSLUtils.h \
+       io/async/ssl/SSLErrors.h \
        io/async/ssl/TLSDefinitions.h \
        io/async/Request.h \
        io/async/SSLContext.h \
@@ -417,6 +418,7 @@ libfolly_la_SOURCES = \
        io/async/test/SocketPair.cpp \
        io/async/test/TimeUtil.cpp \
        io/async/ssl/OpenSSLUtils.cpp \
+       io/async/ssl/SSLErrors.cpp \
        json.cpp \
        detail/MemoryIdler.cpp \
        MacAddress.cpp \
index 0decbbc8072c7f75def9a3150da930d5c9118a5a..2146d6b9b6a19d8af3104c19350a8d434836cec4 100644 (file)
@@ -62,10 +62,6 @@ using folly::SSLContext;
 static SSLContext *dummyCtx = nullptr;
 static SpinLock dummyCtxLock;
 
-// Numbers chosen as to not collide with functions in ssl.h
-const uint8_t TASYNCSSLSOCKET_F_PERFORM_READ = 90;
-const uint8_t TASYNCSSLSOCKET_F_PERFORM_WRITE = 91;
-
 // If given min write size is less than this, buffer will be allocated on
 // stack, otherwise it is allocated on heap
 const size_t MAX_STACK_BUF_SIZE = 2048;
@@ -246,39 +242,10 @@ void* initEorBioMethod(void) {
   return nullptr;
 }
 
-std::string decodeOpenSSLError(int sslError,
-                               unsigned long errError,
-                               int sslOperationReturnValue) {
-  if (sslError == SSL_ERROR_SYSCALL && errError == 0) {
-    if (sslOperationReturnValue == 0) {
-      return "SSL_ERROR_SYSCALL: EOF";
-    } else {
-      // In this case errno is set, AsyncSocketException will add it.
-      return "SSL_ERROR_SYSCALL";
-    }
-  } else if (sslError == SSL_ERROR_ZERO_RETURN) {
-    // This signifies a TLS closure alert.
-    return "SSL_ERROR_ZERO_RETURN";
-  } else {
-    char buf[256];
-    std::string msg(ERR_error_string(errError, buf));
-    return msg;
-  }
-}
-
 } // anonymous namespace
 
 namespace folly {
 
-SSLException::SSLException(int sslError,
-                           unsigned long errError,
-                           int sslOperationReturnValue,
-                           int errno_copy)
-    : AsyncSocketException(
-          AsyncSocketException::SSL_ERROR,
-          decodeOpenSSLError(sslError, errError, sslOperationReturnValue),
-          sslError == SSL_ERROR_SYSCALL ? errno_copy : 0) {}
-
 /**
  * Create a client AsyncSSLSocket
  */
@@ -807,6 +774,10 @@ SSL_SESSION *AsyncSSLSocket::getSSLSession() {
   return sslSession_;
 }
 
+const SSL* AsyncSSLSocket::getSSL() const {
+  return ssl_;
+}
+
 void AsyncSSLSocket::setSSLSession(SSL_SESSION *session, bool takeOwnership) {
   sslSession_ = session;
   if (!takeOwnership && session != nullptr) {
@@ -967,8 +938,6 @@ bool AsyncSSLSocket::willBlock(int ret,
     // The timeout (if set) keeps running here
     return true;
   } else {
-    // SSL_ERROR_ZERO_RETURN is processed here so we can get some detail
-    // in the log
     unsigned long lastError = *errErrorOut = ERR_get_error();
     VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
             << "state=" << state_ << ", "
@@ -981,7 +950,6 @@ bool AsyncSSLSocket::willBlock(int ret,
             << "written: " << BIO_number_written(SSL_get_wbio(ssl_)) << ", "
             << "func: " << ERR_func_error_string(lastError) << ", "
             << "reason: " << ERR_reason_error_string(lastError);
-    ERR_clear_error();
     return false;
   }
 }
@@ -1055,7 +1023,6 @@ AsyncSSLSocket::handleAccept() noexcept {
     SSL_set_msg_callback_arg(ssl_, this);
   }
 
-  errno = 0;
   int ret = SSL_accept(ssl_);
   if (ret <= 0) {
     int sslError;
@@ -1119,7 +1086,6 @@ AsyncSSLSocket::handleConnect() noexcept {
          sslState_ == STATE_CONNECTING);
   assert(ssl_);
 
-  errno = 0;
   int ret = SSL_connect(ssl_);
   if (ret <= 0) {
     int sslError;
@@ -1223,16 +1189,15 @@ AsyncSSLSocket::handleRead() noexcept {
   AsyncSocket::handleRead();
 }
 
-ssize_t
+AsyncSocket::ReadResult
 AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
-  VLOG(4) << "AsyncSSLSocket::performRead() this=" << this
-          << ", buf=" << *buf << ", buflen=" << *buflen;
+  VLOG(4) << "AsyncSSLSocket::performRead() this=" << this << ", buf=" << *buf
+          << ", buflen=" << *buflen;
 
   if (sslState_ == STATE_UNENCRYPTED) {
     return AsyncSocket::performRead(buf, buflen, offset);
   }
 
-  errno = 0;
   ssize_t bytes = 0;
   if (!isBufferMovable_) {
     bytes = SSL_read(ssl_, *buf, *buflen);
@@ -1247,20 +1212,18 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
                << ", sslstate=" << sslState_ << ", events=" << eventFlags_
                << "): client intitiated SSL renegotiation not permitted";
-    // We pack our own SSLerr here with a dummy function
-    errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
-                     SSL_CLIENT_RENEGOTIATION_ATTEMPT);
-    ERR_clear_error();
-    return READ_ERROR;
+    return ReadResult(
+        READ_ERROR,
+        folly::make_unique<SSLException>(SSLError::CLIENT_RENEGOTIATION));
   }
   if (bytes <= 0) {
     int error = SSL_get_error(ssl_, bytes);
     if (error == SSL_ERROR_WANT_READ) {
       // The caller will register for read event if not already.
       if (errno == EWOULDBLOCK || errno == EAGAIN) {
-        return READ_BLOCKING;
+        return ReadResult(READ_BLOCKING);
       } else {
-        return READ_ERROR;
+        return ReadResult(READ_ERROR);
       }
     } else if (error == SSL_ERROR_WANT_WRITE) {
       // TODO: Even though we are attempting to read data, SSL_read() may
@@ -1268,17 +1231,15 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
       // don't support this and just fail the read.
       LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
                  << ", sslState=" << sslState_ << ", events=" << eventFlags_
-                 << "): unsupported SSL renegotiation during read",
-      errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_READ,
-                       SSL_INVALID_RENEGOTIATION);
-      ERR_clear_error();
-      return READ_ERROR;
+                 << "): unsupported SSL renegotiation during read";
+      return ReadResult(
+          READ_ERROR,
+          folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
     } else {
-      // TODO: Fix this code so that it can return a proper error message
-      // to the callback, rather than relying on AsyncSocket code which
-      // can't handle SSL errors.
-      long lastError = ERR_get_error();
-
+      if (zero_return(error, bytes)) {
+        return ReadResult(bytes);
+      }
+      long errError = ERR_get_error();
       VLOG(6) << "AsyncSSLSocket(fd=" << fd_ << ", "
               << "state=" << state_ << ", "
               << "sslState=" << sslState_ << ", "
@@ -1286,24 +1247,15 @@ AsyncSSLSocket::performRead(void** buf, size_t* buflen, size_t* offset) {
               << "bytes: " << bytes << ", "
               << "error: " << error << ", "
               << "errno: " << errno << ", "
-              << "func: " << ERR_func_error_string(lastError) << ", "
-              << "reason: " << ERR_reason_error_string(lastError);
-      ERR_clear_error();
-      if (zero_return(error, bytes)) {
-        return bytes;
-      }
-      if (error != SSL_ERROR_SYSCALL) {
-        if ((unsigned long)lastError < 0x8000) {
-          errno = ENOSYS;
-        } else {
-          errno = lastError;
-        }
-      }
-      return READ_ERROR;
+              << "func: " << ERR_func_error_string(errError) << ", "
+              << "reason: " << ERR_reason_error_string(errError);
+      return ReadResult(
+          READ_ERROR,
+          folly::make_unique<SSLException>(error, errError, bytes, errno));
     }
   } else {
     appBytesReceived_ += bytes;
-    return bytes;
+    return ReadResult(bytes);
   }
 }
 
@@ -1331,49 +1283,40 @@ void AsyncSSLSocket::handleWrite() noexcept {
   AsyncSocket::handleWrite();
 }
 
-int AsyncSSLSocket::interpretSSLError(int rc, int error) {
+AsyncSocket::WriteResult AsyncSSLSocket::interpretSSLError(int rc, int error) {
   if (error == SSL_ERROR_WANT_READ) {
-    // TODO: Even though we are attempting to write data, SSL_write() may
+    // Even though we are attempting to write data, SSL_write() may
     // need to read data if renegotiation is being performed.  We currently
     // don't support this and just fail the write.
     LOG(ERROR) << "AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
                << ", sslState=" << sslState_ << ", events=" << eventFlags_
-               << "): " << "unsupported SSL renegotiation during write",
-      errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
-                       SSL_INVALID_RENEGOTIATION);
-    ERR_clear_error();
-    return -1;
+               << "): "
+               << "unsupported SSL renegotiation during write";
+    return WriteResult(
+        WRITE_ERROR,
+        folly::make_unique<SSLException>(SSLError::INVALID_RENEGOTIATION));
   } else {
-    // TODO: Fix this code so that it can return a proper error message
-    // to the callback, rather than relying on AsyncSocket code which
-    // can't handle SSL errors.
-    long lastError = ERR_get_error();
+    if (zero_return(error, rc)) {
+      return WriteResult(0);
+    }
+    auto errError = ERR_get_error();
     VLOG(3) << "ERROR: AsyncSSLSocket(fd=" << fd_ << ", state=" << int(state_)
             << ", sslState=" << sslState_ << ", events=" << eventFlags_ << "): "
             << "SSL error: " << error << ", errno: " << errno
-            << ", func: " << ERR_func_error_string(lastError)
-            << ", reason: " << ERR_reason_error_string(lastError);
-    if (error != SSL_ERROR_SYSCALL) {
-      if ((unsigned long)lastError < 0x8000) {
-        errno = ENOSYS;
-      } else {
-        errno = lastError;
-      }
-    }
-    ERR_clear_error();
-    if (!zero_return(error, rc)) {
-      return -1;
-    } else {
-      return 0;
-    }
+            << ", func: " << ERR_func_error_string(errError)
+            << ", reason: " << ERR_reason_error_string(errError);
+    return WriteResult(
+        WRITE_ERROR,
+        folly::make_unique<SSLException>(error, errError, rc, errno));
   }
 }
 
-ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
-                                      uint32_t count,
-                                      WriteFlags flags,
-                                      uint32_t* countWritten,
-                                      uint32_t* partialWritten) {
+AsyncSocket::WriteResult AsyncSSLSocket::performWrite(
+    const iovec* vec,
+    uint32_t count,
+    WriteFlags flags,
+    uint32_t* countWritten,
+    uint32_t* partialWritten) {
   if (sslState_ == STATE_UNENCRYPTED) {
     return AsyncSocket::performWrite(
       vec, count, flags, countWritten, partialWritten);
@@ -1384,9 +1327,8 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
                << ", events=" << eventFlags_ << "): "
                << "TODO: AsyncSSLSocket currently does not support calling "
                << "write() before the handshake has fully completed";
-      errno = ERR_PACK(ERR_LIB_USER, TASYNCSSLSOCKET_F_PERFORM_WRITE,
-                       SSL_EARLY_WRITE);
-      return -1;
+    return WriteResult(
+        WRITE_ERROR, folly::make_unique<SSLException>(SSLError::EARLY_WRITE));
   }
 
   bool cork = isSet(flags, WriteFlags::CORK);
@@ -1420,7 +1362,6 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
     buf = ((const char*)v->iov_base) + offset;
 
     ssize_t bytes;
-    errno = 0;
     uint32_t buffersStolen = 0;
     if ((len < minWriteSize_) && ((i + 1) < count)) {
       // Combine this buffer with part or all of the next buffers in
@@ -1474,11 +1415,11 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
       if (error == SSL_ERROR_WANT_WRITE) {
         // The caller will register for write event if not already.
         *partialWritten = offset;
-        return totalWritten;
+        return WriteResult(totalWritten);
       }
-      int rc = interpretSSLError(bytes, error);
-      if (rc < 0) {
-        return rc;
+      auto writeResult = interpretSSLError(bytes, error);
+      if (writeResult.writeReturn < 0) {
+        return writeResult;
       } // else fall through to below to correctly record totalWritten
     }
 
@@ -1500,11 +1441,11 @@ ssize_t AsyncSSLSocket::performWrite(const iovec* vec,
         v = &(vec[++i]);
       }
       *partialWritten = bytes;
-      return totalWritten;
+      return WriteResult(totalWritten);
     }
   }
 
-  return totalWritten;
+  return WriteResult(totalWritten);
 }
 
 int AsyncSSLSocket::eorAwareSSLWrite(SSL *ssl, const void *buf, int n,
@@ -1575,7 +1516,6 @@ int AsyncSSLSocket::eorAwareBioWrite(BIO *b, const char *in, int inl) {
     flags = MSG_EOR;
   }
 
-  errno = 0;
   ret = sendmsg(b->num, &msg, flags);
   BIO_clear_retry_flags(b);
   if (ret <= 0) {
index 1bb3fc170b79a8028839d5705dd2911018d1b8cd..af3fd06b8c1e6a7f9d224a534dd479320c2923eb 100644 (file)
@@ -27,6 +27,7 @@
 #include <folly/io/async/TimeoutManager.h>
 #include <folly/io/async/ssl/OpenSSLPtrTypes.h>
 #include <folly/io/async/ssl/OpenSSLUtils.h>
+#include <folly/io/async/ssl/SSLErrors.h>
 #include <folly/io/async/ssl/TLSDefinitions.h>
 
 #include <folly/Bits.h>
 
 namespace folly {
 
-class SSLException: public folly::AsyncSocketException {
- public:
-  SSLException(int sslError,
-               unsigned long errError,
-               int sslOperationReturnValue,
-               int errno_copy);
-};
-
 /**
  * A class for performing asynchronous I/O on an SSL connection.
  *
@@ -143,18 +136,6 @@ class AsyncSSLSocket : public virtual AsyncSocket {
     AsyncSSLSocket* sslSocket_;
   };
 
-
-  /**
-   * These are passed to the application via errno, packed in an SSL err which
-   * are outside the valid errno range.  The values are chosen to be unique
-   * against values in ssl.h
-   */
-  enum SSLError {
-    SSL_CLIENT_RENEGOTIATION_ATTEMPT = 900,
-    SSL_INVALID_RENEGOTIATION = 901,
-    SSL_EARLY_WRITE = 902
-  };
-
   /**
    * Create a client AsyncSSLSocket
    */
@@ -365,6 +346,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    */
   SSL_SESSION *getSSLSession();
 
+  /**
+   * Get a handle to the SSL struct.
+   */
+  const SSL* getSSL() const;
+
   /**
    * Set the SSL session to be used during sslConn.  AsyncSSLSocket will
    * hold a reference to the session until it is destroyed or released by the
@@ -760,11 +746,14 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   // AsyncSocket calls this at the wrong time for SSL
   void handleInitialReadWrite() noexcept override {}
 
-  int interpretSSLError(int rc, int error);
-  ssize_t performRead(void** buf, size_t* buflen, size_t* offset) override;
-  ssize_t performWrite(const iovec* vec, uint32_t count, WriteFlags flags,
-                       uint32_t* countWritten, uint32_t* partialWritten)
-    override;
+  WriteResult interpretSSLError(int rc, int error);
+  ReadResult performRead(void** buf, size_t* buflen, size_t* offset) override;
+  WriteResult performWrite(
+      const iovec* vec,
+      uint32_t count,
+      WriteFlags flags,
+      uint32_t* countWritten,
+      uint32_t* partialWritten) override;
 
   ssize_t performWriteIovec(const iovec* vec, uint32_t count,
                             WriteFlags flags, uint32_t* countWritten,
index 6fd3c355918c326d586447ab571372d4e233ebae..848cf5c47a94f0567eb55754109790fde4b48120 100644 (file)
@@ -91,14 +91,13 @@ class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
     free(this);
   }
 
-  bool performWrite() override {
+  WriteResult performWrite() override {
     WriteFlags writeFlags = flags_;
     if (getNext() != nullptr) {
       writeFlags = writeFlags | WriteFlags::CORK;
     }
-    bytesWritten_ = socket_->performWrite(getOps(), getOpCount(), writeFlags,
-                                          &opsWritten_, &partialBytes_);
-    return bytesWritten_ >= 0;
+    return socket_->performWrite(
+        getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
   }
 
   bool isComplete() override {
@@ -694,10 +693,14 @@ void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
       assert(writeReqTail_ == nullptr);
       assert((eventFlags_ & EventHandler::WRITE) == 0);
 
-      bytesWritten = performWrite(vec, count, flags,
-                                  &countWritten, &partialWritten);
+      auto writeResult =
+          performWrite(vec, count, flags, &countWritten, &partialWritten);
+      bytesWritten = writeResult.writeReturn;
       if (bytesWritten < 0) {
         auto errnoCopy = errno;
+        if (writeResult.exception) {
+          return failWrite(__func__, callback, 0, *writeResult.exception);
+        }
         AsyncSocketException ex(
             AsyncSocketException::INTERNAL_ERROR,
             withAddr("writev failed"),
@@ -1259,11 +1262,10 @@ void AsyncSocket::ioReady(uint16_t events) noexcept {
   }
 }
 
-ssize_t AsyncSocket::performRead(void** buf,
-                                 size_t* buflen,
-                                 size_t* /* offset */) {
-  VLOG(5) << "AsyncSocket::performRead() this=" << this
-          << ", buf=" << *buf << ", buflen=" << *buflen;
+AsyncSocket::ReadResult
+AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
+  VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf
+          << ", buflen=" << *buflen;
 
   int recvFlags = 0;
   if (peek_) {
@@ -1274,13 +1276,13 @@ ssize_t AsyncSocket::performRead(void** buf,
   if (bytes < 0) {
     if (errno == EAGAIN || errno == EWOULDBLOCK) {
       // No more data to read right now.
-      return READ_BLOCKING;
+      return ReadResult(READ_BLOCKING);
     } else {
-      return READ_ERROR;
+      return ReadResult(READ_ERROR);
     }
   } else {
     appBytesReceived_ += bytes;
-    return bytes;
+    return ReadResult(bytes);
   }
 }
 
@@ -1347,7 +1349,8 @@ void AsyncSocket::handleRead() noexcept {
     }
 
     // Perform the read
-    ssize_t bytesRead = performRead(&buf, &buflen, &offset);
+    auto readResult = performRead(&buf, &buflen, &offset);
+    auto bytesRead = readResult.readReturn;
     VLOG(4) << "this=" << this << ", AsyncSocket::handleRead() got "
             << bytesRead << " bytes";
     if (bytesRead > 0) {
@@ -1376,6 +1379,9 @@ void AsyncSocket::handleRead() noexcept {
         return;
     } else if (bytesRead == READ_ERROR) {
       readErr_ = READ_ERROR;
+      if (readResult.exception) {
+        return failRead(__func__, *readResult.exception);
+      }
       auto errnoCopy = errno;
       AsyncSocketException ex(
           AsyncSocketException::INTERNAL_ERROR,
@@ -1439,7 +1445,11 @@ void AsyncSocket::handleWrite() noexcept {
   // (See the comment in handleRead() explaining how this can happen.)
   EventBase* originalEventBase = eventBase_;
   while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
-    if (!writeReqHead_->performWrite()) {
+    auto writeResult = writeReqHead_->performWrite();
+    if (writeResult.writeReturn < 0) {
+      if (writeResult.exception) {
+        return failWrite(__func__, *writeResult.exception);
+      }
       auto errnoCopy = errno;
       AsyncSocketException ex(
           AsyncSocketException::INTERNAL_ERROR,
@@ -1697,11 +1707,12 @@ void AsyncSocket::timeoutExpired() noexcept {
   }
 }
 
-ssize_t AsyncSocket::performWrite(const iovec* vec,
-                                   uint32_t count,
-                                   WriteFlags flags,
-                                   uint32_t* countWritten,
-                                   uint32_t* partialWritten) {
+AsyncSocket::WriteResult AsyncSocket::performWrite(
+    const iovec* vec,
+    uint32_t count,
+    WriteFlags flags,
+    uint32_t* countWritten,
+    uint32_t* partialWritten) {
   // We use sendmsg() instead of writev() so that we can pass in MSG_NOSIGNAL
   // We correctly handle EPIPE errors, so we never want to receive SIGPIPE
   // (since it may terminate the program if the main program doesn't explicitly
@@ -1736,12 +1747,12 @@ ssize_t AsyncSocket::performWrite(const iovec* vec,
       // TCP buffer is full; we can't write any more data right now.
       *countWritten = 0;
       *partialWritten = 0;
-      return 0;
+      return WriteResult(0);
     }
     // error
     *countWritten = 0;
     *partialWritten = 0;
-    return -1;
+    return WriteResult(WRITE_ERROR);
   }
 
   appBytesWritten_ += totalWritten;
@@ -1754,7 +1765,7 @@ ssize_t AsyncSocket::performWrite(const iovec* vec,
       // Partial write finished in the middle of this iovec
       *countWritten = n;
       *partialWritten = bytesWritten;
-      return totalWritten;
+      return WriteResult(totalWritten);
     }
 
     bytesWritten -= v->iov_len;
@@ -1763,7 +1774,7 @@ ssize_t AsyncSocket::performWrite(const iovec* vec,
   assert(bytesWritten == 0);
   *countWritten = n;
   *partialWritten = 0;
-  return totalWritten;
+  return WriteResult(totalWritten);
 }
 
 /**
index ba706747c57a8a46e32950eebdb98aefbe7af20c..37fdf08e89318c8703e176f061e57326bde9c4b6 100644 (file)
 
 #pragma once
 
-#include <sys/types.h>
-#include <sys/socket.h>
+#include <folly/Optional.h>
 #include <folly/SocketAddress.h>
-#include <folly/io/ShutdownSocketSet.h>
 #include <folly/io/IOBuf.h>
-#include <folly/io/async/AsyncTimeout.h>
+#include <folly/io/ShutdownSocketSet.h>
 #include <folly/io/async/AsyncSocketException.h>
+#include <folly/io/async/AsyncTimeout.h>
 #include <folly/io/async/AsyncTransport.h>
-#include <folly/io/async/EventHandler.h>
 #include <folly/io/async/DelayedDestruction.h>
+#include <folly/io/async/EventHandler.h>
+#include <sys/socket.h>
+#include <sys/types.h>
 
 #include <chrono>
 #include <memory>
@@ -517,6 +518,41 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
 
   void setBufferCallback(BufferCallback* cb);
 
+  /**
+   * writeReturn is the total number of bytes written, or WRITE_ERROR on error.
+   * If no data has been written, 0 is returned.
+   * exception is a more specific exception that cause a write error.
+   * Not all writes have exceptions associated with them thus writeReturn
+   * should be checked to determine whether the operation resulted in an error.
+   */
+  struct WriteResult {
+    explicit WriteResult(ssize_t ret) : writeReturn(ret) {}
+
+    WriteResult(ssize_t ret, std::unique_ptr<const AsyncSocketException> e)
+        : writeReturn(ret), exception(std::move(e)) {}
+
+    ssize_t writeReturn;
+    std::unique_ptr<const AsyncSocketException> exception;
+  };
+
+  /**
+   * readReturn is the number of bytes read, or READ_EOF on EOF, or
+   * READ_ERROR on error, or READ_BLOCKING if the operation will
+   * block.
+   * exception is a more specific exception that may have caused a read error.
+   * Not all read errors have exceptions associated with them thus readReturn
+   * should be checked to determine whether the operation resulted in an error.
+   */
+  struct ReadResult {
+    explicit ReadResult(ssize_t ret) : readReturn(ret) {}
+
+    ReadResult(ssize_t ret, std::unique_ptr<const AsyncSocketException> e)
+        : readReturn(ret), exception(std::move(e)) {}
+
+    ssize_t readReturn;
+    std::unique_ptr<const AsyncSocketException> exception;
+  };
+
   /**
    * A WriteRequest object tracks information about a pending write operation.
    */
@@ -529,7 +565,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
 
     virtual void destroy() = 0;
 
-    virtual bool performWrite() = 0;
+    virtual WriteResult performWrite() = 0;
 
     virtual void consume() = 0;
 
@@ -579,6 +615,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     READ_NO_ERROR = -3,
   };
 
+  enum WriteResultEnum {
+    WRITE_ERROR = -1,
+  };
+
   /**
    * Protected destructor.
    *
@@ -683,11 +723,9 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    * @param buf      The buffer to read data into.
    * @param buflen   The length of the buffer.
    *
-   * @return Returns the number of bytes read, or READ_EOF on EOF, or
-   * READ_ERROR on error, or READ_BLOCKING if the operation will
-   * block.
+   * @return Returns a read result. See read result for details.
    */
-  virtual ssize_t performRead(void** buf, size_t* buflen, size_t* offset);
+  virtual ReadResult performRead(void** buf, size_t* buflen, size_t* offset);
 
   /**
    * Populate an iovec array from an IOBuf and attempt to write it.
@@ -736,12 +774,14 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
    *                          will contain the number of bytes written in the
    *                          partially written iovec entry.
    *
-   * @return Returns the total number of bytes written, or -1 on error.  If no
-   *     data can be written immediately, 0 is returned.
+   * @return Returns a WriteResult. See WriteResult for more details.
    */
-  virtual ssize_t performWrite(const iovec* vec, uint32_t count,
-                               WriteFlags flags, uint32_t* countWritten,
-                               uint32_t* partialWritten);
+  virtual WriteResult performWrite(
+      const iovec* vec,
+      uint32_t count,
+      WriteFlags flags,
+      uint32_t* countWritten,
+      uint32_t* partialWritten);
 
   bool updateEventRegistration();
 
diff --git a/folly/io/async/ssl/SSLErrors.cpp b/folly/io/async/ssl/SSLErrors.cpp
new file mode 100644 (file)
index 0000000..94550f5
--- /dev/null
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2016 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <folly/io/async/ssl/SSLErrors.h>
+
+#include <folly/Range.h>
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+
+using namespace folly;
+
+namespace {
+
+std::string decodeOpenSSLError(
+    int sslError,
+    unsigned long errError,
+    int sslOperationReturnValue) {
+  if (sslError == SSL_ERROR_SYSCALL && errError == 0) {
+    if (sslOperationReturnValue == 0) {
+      return "SSL_ERROR_SYSCALL: EOF";
+    } else {
+      // In this case errno is set, AsyncSocketException will add it.
+      return "SSL_ERROR_SYSCALL";
+    }
+  } else if (sslError == SSL_ERROR_ZERO_RETURN) {
+    // This signifies a TLS closure alert.
+    return "SSL_ERROR_ZERO_RETURN";
+  } else {
+    std::array<char, 256> buf;
+    std::string msg(ERR_error_string(errError, buf.data()));
+    return msg;
+  }
+}
+
+const StringPiece getSSLErrorString(SSLError error) {
+  StringPiece ret;
+  switch (error) {
+    case SSLError::CLIENT_RENEGOTIATION:
+      ret = "Client tried to renegotiate with server";
+      break;
+    case SSLError::INVALID_RENEGOTIATION:
+      ret = "Attempt to start renegotiation, but unsupported";
+      break;
+    case SSLError::EARLY_WRITE:
+      ret = "Attempt to write before SSL connection established";
+      break;
+    case SSLError::OPENSSL_ERR:
+      // decodeOpenSSLError should be used for this type.
+      ret = "OPENSSL error";
+      break;
+  }
+  return ret;
+}
+}
+
+namespace folly {
+
+SSLException::SSLException(
+    int sslError,
+    unsigned long errError,
+    int sslOperationReturnValue,
+    int errno_copy)
+    : AsyncSocketException(
+          AsyncSocketException::SSL_ERROR,
+          decodeOpenSSLError(sslError, errError, sslOperationReturnValue),
+          sslError == SSL_ERROR_SYSCALL ? errno_copy : 0),
+      sslError(SSLError::OPENSSL_ERR),
+      opensslSSLError(sslError),
+      opensslErr(errError) {}
+
+SSLException::SSLException(SSLError error)
+    : AsyncSocketException(
+          AsyncSocketException::SSL_ERROR,
+          getSSLErrorString(error).str(),
+          0),
+      sslError(error) {}
+}
diff --git a/folly/io/async/ssl/SSLErrors.h b/folly/io/async/ssl/SSLErrors.h
new file mode 100644 (file)
index 0000000..ad7a5de
--- /dev/null
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2016 Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include <folly/Optional.h>
+#include <folly/io/async/AsyncSocketException.h>
+
+namespace folly {
+
+enum class SSLError {
+  CLIENT_RENEGOTIATION, // A client tried to renegotiate with this server
+  INVALID_RENEGOTIATION, // We attempted to start a renegotiation.
+  EARLY_WRITE, // Wrote before SSL connection established.
+  // An openssl error type. The openssl specific methods should be used
+  // to find the real error type.
+  // This exists for compatibility until all error types can be move to proper
+  // errors.
+  OPENSSL_ERR,
+};
+
+class SSLException : public folly::AsyncSocketException {
+ public:
+  SSLException(
+      int sslError,
+      unsigned long errError,
+      int sslOperationReturnValue,
+      int errno_copy);
+
+  explicit SSLException(SSLError error);
+
+  SSLError getType() const {
+    return sslError;
+  }
+
+  // These methods exist for compatibility until there are proper exceptions
+  // for all ssl error types.
+  int getOpensslSSLError() const {
+    return opensslSSLError;
+  }
+
+  unsigned long getOpensslErr() const {
+    return opensslErr;
+  }
+
+ private:
+  SSLError sslError;
+  int opensslSSLError;
+  unsigned long opensslErr;
+};
+}
index 038f4f312e8d43b360e26d07297bdd1e63ca1a7c..bd0242aa9c10ee6ffa61cd614730af07ce43927c 100644 (file)
@@ -201,13 +201,89 @@ TEST(AsyncSSLSocketTest, ConnectWriteReadClose) {
   cerr << "ConnectWriteReadClose test completed" << endl;
 }
 
+/**
+ * Test reading after server close.
+ */
+TEST(AsyncSSLSocketTest, ReadAfterClose) {
+  // Start listening on a local port
+  WriteCallbackBase writeCallback;
+  ReadEOFCallback readCallback(&writeCallback);
+  HandshakeCallback handshakeCallback(&readCallback);
+  SSLServerAcceptCallback acceptCallback(&handshakeCallback);
+  auto server = folly::make_unique<TestSSLServer>(&acceptCallback);
+
+  // Set up SSL context.
+  auto sslContext = std::make_shared<SSLContext>();
+  sslContext->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
+
+  auto socket =
+      std::make_shared<BlockingSocket>(server->getAddress(), sslContext);
+  socket->open();
+
+  // This should trigger an EOF on the client.
+  auto evb = handshakeCallback.getSocket()->getEventBase();
+  evb->runInEventBaseThreadAndWait([&]() { handshakeCallback.closeSocket(); });
+  std::array<uint8_t, 128> readbuf;
+  auto bytesRead = socket->read(readbuf.data(), readbuf.size());
+  EXPECT_EQ(0, bytesRead);
+}
+
+/**
+ * Test bad renegotiation
+ */
+TEST(AsyncSSLSocketTest, Renegotiate) {
+  EventBase eventBase;
+  auto clientCtx = std::make_shared<SSLContext>();
+  auto dfServerCtx = std::make_shared<SSLContext>();
+  std::array<int, 2> fds;
+  getfds(fds.data());
+  getctx(clientCtx, dfServerCtx);
+
+  AsyncSSLSocket::UniquePtr clientSock(
+      new AsyncSSLSocket(clientCtx, &eventBase, fds[0], false));
+  AsyncSSLSocket::UniquePtr serverSock(
+      new AsyncSSLSocket(dfServerCtx, &eventBase, fds[1], true));
+  SSLHandshakeClient client(std::move(clientSock), true, true);
+  RenegotiatingServer server(std::move(serverSock));
+
+  while (!client.handshakeSuccess_ && !client.handshakeError_) {
+    eventBase.loopOnce();
+  }
+
+  ASSERT_TRUE(client.handshakeSuccess_);
+
+  auto sslSock = std::move(client).moveSocket();
+  sslSock->detachEventBase();
+  // This is nasty, however we don't want to add support for
+  // renegotiation in AsyncSSLSocket.
+  SSL_renegotiate(const_cast<SSL*>(sslSock->getSSL()));
+
+  auto socket = std::make_shared<BlockingSocket>(std::move(sslSock));
+
+  std::thread t([&]() { eventBase.loopForever(); });
+
+  // Trigger the renegotiation.
+  std::array<uint8_t, 128> buf;
+  memset(buf.data(), 'a', buf.size());
+  try {
+    socket->write(buf.data(), buf.size());
+  } catch (AsyncSocketException& e) {
+    LOG(INFO) << "client got error " << e.what();
+  }
+  eventBase.terminateLoopSoon();
+  t.join();
+
+  eventBase.loop();
+  ASSERT_TRUE(server.renegotiationError_);
+}
+
 /**
  * Negative test for handshakeError().
  */
 TEST(AsyncSSLSocketTest, HandshakeError) {
   // Start listening on a local port
   WriteCallbackBase writeCallback;
-  ReadCallback readCallback(&writeCallback);
+  WriteErrorCallback readCallback(&writeCallback);
   HandshakeCallback handshakeCallback(&readCallback);
   HandshakeErrorCallback acceptCallback(&handshakeCallback);
   TestSSLServer server(&acceptCallback);
index a4e18aaa8447e271e72c2c25d930731f479c3099..69966d6743955e125237fc4106b997869241c1a8 100644 (file)
 #include <signal.h>
 #include <pthread.h>
 
-#include <folly/io/async/AsyncServerSocket.h>
+#include <folly/ExceptionWrapper.h>
+#include <folly/SocketAddress.h>
 #include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncServerSocket.h>
 #include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/AsyncTimeout.h>
 #include <folly/io/async/AsyncTransport.h>
 #include <folly/io/async/EventBase.h>
-#include <folly/io/async/AsyncTimeout.h>
-#include <folly/SocketAddress.h>
+#include <folly/io/async/ssl/SSLErrors.h>
 
 #include <gtest/gtest.h>
 #include <iostream>
@@ -58,7 +60,7 @@ public:
       , exception(AsyncSocketException::UNKNOWN, "none") {}
 
   ~WriteCallbackBase() {
-    EXPECT_EQ(state, STATE_SUCCEEDED);
+    EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
   void setSocket(
@@ -92,10 +94,9 @@ public:
 
 class ReadCallbackBase :
 public AsyncTransportWrapper::ReadCallback {
-public:
-  explicit ReadCallbackBase(WriteCallbackBase *wcb)
-      : wcb_(wcb)
-      , state(STATE_WAITING) {}
+ public:
+  explicit ReadCallbackBase(WriteCallbackBase* wcb)
+      : wcb_(wcb), state(STATE_WAITING) {}
 
   ~ReadCallbackBase() {
     EXPECT_EQ(state, STATE_SUCCEEDED);
@@ -222,6 +223,27 @@ public:
   }
 };
 
+class ReadEOFCallback : public ReadCallbackBase {
+ public:
+  explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
+
+  // Return nullptr buffer to trigger readError()
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *bufReturn = nullptr;
+    *lenReturn = 0;
+  }
+
+  void readDataAvailable(size_t /* len */) noexcept override {
+    // This should never to called.
+    FAIL();
+  }
+
+  void readEOF() noexcept override {
+    ReadCallbackBase::readEOF();
+    setState(STATE_SUCCEEDED);
+  }
+};
+
 class WriteErrorCallback : public ReadCallback {
 public:
   explicit WriteErrorCallback(WriteCallbackBase *wcb)
@@ -340,6 +362,10 @@ public:
     state = STATE_SUCCEEDED;
   }
 
+  std::shared_ptr<AsyncSSLSocket> getSocket() {
+    return socket_;
+  }
+
   StateEnum state;
   std::shared_ptr<AsyncSSLSocket> socket_;
   ReadCallbackBase *rcb_;
@@ -879,6 +905,48 @@ class NpnServer :
   AsyncSSLSocket::UniquePtr socket_;
 };
 
+class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
+                            public AsyncTransportWrapper::ReadCallback {
+ public:
+  explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
+      : socket_(std::move(socket)) {
+    socket_->sslAccept(this);
+  }
+
+  ~RenegotiatingServer() {
+    socket_->setReadCB(nullptr);
+  }
+
+  void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
+    LOG(INFO) << "Renegotiating server handshake success";
+    socket_->setReadCB(this);
+  }
+  void handshakeErr(
+      AsyncSSLSocket*,
+      const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
+  }
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *lenReturn = sizeof(buf);
+    *bufReturn = buf;
+  }
+  void readDataAvailable(size_t /* len */) noexcept override {}
+  void readEOF() noexcept override {}
+  void readErr(const AsyncSocketException& ex) noexcept override {
+    LOG(INFO) << "server got read error " << ex.what();
+    auto exPtr = dynamic_cast<const SSLException*>(&ex);
+    ASSERT_NE(nullptr, exPtr);
+    std::string exStr(ex.what());
+    SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
+    ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
+    renegotiationError_ = true;
+  }
+
+  AsyncSSLSocket::UniquePtr socket_;
+  unsigned char buf[128];
+  bool renegotiationError_{false};
+};
+
 #ifndef OPENSSL_NO_TLSEXT
 class SNIClient :
   private AsyncSSLSocket::HandshakeCB,
@@ -1139,6 +1207,10 @@ class SSLHandshakeBase :
     verifyResult_(verifyResult) {
   }
 
+  AsyncSSLSocket::UniquePtr moveSocket() && {
+    return std::move(socket_);
+  }
+
   bool handshakeVerify_;
   bool handshakeSuccess_;
   bool handshakeError_;
@@ -1160,12 +1232,15 @@ class SSLHandshakeBase :
   }
 
   void handshakeSuc(AsyncSSLSocket*) noexcept override {
+    LOG(INFO) << "Handshake success";
     handshakeSuccess_ = true;
     handshakeTime = socket_->getHandshakeTime();
   }
 
-  void handshakeErr(AsyncSSLSocket*,
-                    const AsyncSocketException& /* ex */) noexcept override {
+  void handshakeErr(
+      AsyncSSLSocket*,
+      const AsyncSocketException& ex) noexcept override {
+    LOG(INFO) << "Handshake error " << ex.what();
     handshakeError_ = true;
     handshakeTime = socket_->getHandshakeTime();
   }
index 4497e65ffa22f1f29db69ae0be4b7e686c59ce4e..4235517d38790401ab5b7ad5c3585dcac8a85f9c 100644 (file)
@@ -58,8 +58,12 @@ class MockAsyncSSLSocket : public AsyncSSLSocket{
   MOCK_CONST_METHOD0(getRawBytesWritten, size_t());
 
   // public wrapper for protected interface
-  ssize_t testPerformWrite(const iovec* vec, uint32_t count, WriteFlags flags,
-                           uint32_t* countWritten, uint32_t* partialWritten) {
+  WriteResult testPerformWrite(
+      const iovec* vec,
+      uint32_t count,
+      WriteFlags flags,
+      uint32_t* countWritten,
+      uint32_t* partialWritten) {
     return performWrite(vec, count, flags, countWritten, partialWritten);
   }
 
index 7cfb870c7b19b0112833119df2a41295ccb06e34..360cfcb18cfa342004ea7934a21a2cb610fde79f 100644 (file)
@@ -35,6 +35,11 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
             new folly::AsyncSocket(&eventBase_)),
     address_(address) {}
 
+  explicit BlockingSocket(folly::AsyncSocket::UniquePtr socket)
+      : sock_(std::move(socket)) {
+    sock_->attachEventBase(&eventBase_);
+  }
+
   void open() {
     sock_->connect(this, address_);
     eventBase_.loop();