2017
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
index d9618098f3c4c54d417360809908cad76634505f..1aba12ed0580c9589dc22a723ffe724975e794b8 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@
 
 #include <folly/ExceptionWrapper.h>
 #include <folly/SocketAddress.h>
+#include <folly/experimental/TestUtil.h>
 #include <folly/io/async/AsyncSSLSocket.h>
 #include <folly/io/async/AsyncServerSocket.h>
 #include <folly/io/async/AsyncSocket.h>
 #include <folly/io/async/AsyncTransport.h>
 #include <folly/io/async/EventBase.h>
 #include <folly/io/async/ssl/SSLErrors.h>
+#include <folly/portability/GTest.h>
 #include <folly/portability/Sockets.h>
 #include <folly/portability/Unistd.h>
 
 #include <fcntl.h>
-#include <gtest/gtest.h>
 #include <sys/types.h>
 #include <condition_variable>
 #include <iostream>
@@ -73,16 +74,15 @@ public:
   }
 
   void writeErr(
-    size_t bytesWritten,
+    size_t nBytesWritten,
     const AsyncSocketException& ex) noexcept override {
-    std::cerr << "writeError: bytesWritten " << bytesWritten
+    std::cerr << "writeError: bytesWritten " << nBytesWritten
          << ", exception " << ex.what() << std::endl;
 
     state = STATE_FAILED;
-    this->bytesWritten = bytesWritten;
+    this->bytesWritten = nBytesWritten;
     exception = ex;
     socket_->close();
-    socket_->detachEventBase();
   }
 
   std::shared_ptr<AsyncSSLSocket> socket_;
@@ -118,14 +118,12 @@ public AsyncTransportWrapper::ReadCallback {
     std::cerr << "readError " << ex.what() << std::endl;
     state = STATE_FAILED;
     socket_->close();
-    socket_->detachEventBase();
   }
 
   void readEOF() noexcept override {
     std::cerr << "readEOF" << std::endl;
 
     socket_->close();
-    socket_->detachEventBase();
   }
 
   std::shared_ptr<AsyncSSLSocket> socket_;
@@ -180,10 +178,10 @@ public:
       buffer = nullptr;
       length = 0;
     }
-    void allocate(size_t length) {
+    void allocate(size_t len) {
       assert(buffer == nullptr);
-      this->buffer = static_cast<char*>(malloc(length));
-      this->length = length;
+      this->buffer = static_cast<char*>(malloc(len));
+      this->length = len;
     }
     void free() {
       ::free(buffer);
@@ -259,7 +257,9 @@ public:
     wcb_->setSocket(socket_);
 
     // Write back the same data.
-    socket_->write(wcb_, currentBuffer.buffer, len);
+    folly::test::msvcSuppressAbortOnInvalidParams([&] {
+      socket_->write(wcb_, currentBuffer.buffer, len);
+    });
 
     if (wcb_->state == STATE_FAILED) {
       setState(STATE_SUCCEEDED);
@@ -285,15 +285,16 @@ public:
   void readErr(const AsyncSocketException& ex) noexcept override {
     std::cerr << "readError " << ex.what() << std::endl;
     state = STATE_FAILED;
-    tcpSocket_->close();
-    tcpSocket_->detachEventBase();
+    if (tcpSocket_) {
+      tcpSocket_->close();
+    }
   }
 
   void readEOF() noexcept override {
     std::cerr << "readEOF" << std::endl;
-
-    tcpSocket_->close();
-    tcpSocket_->detachEventBase();
+    if (tcpSocket_) {
+      tcpSocket_->close();
+    }
     state = STATE_SUCCEEDED;
   }
 
@@ -392,12 +393,14 @@ public:
 
   void connectionAccepted(
       int fd, const folly::SocketAddress& /* clientAddr */) noexcept override {
+    if (socket_) {
+      socket_->detachEventBase();
+    }
     printf("Connection accepted\n");
-    std::shared_ptr<AsyncSSLSocket> sslSock;
     try {
       // Create a AsyncSSLSocket object with the fd. The socket should be
       // added to the event base and in the state of accepting SSL connection.
-      sslSock = AsyncSSLSocket::newSocket(ctx_, base_, fd);
+      socket_ = AsyncSSLSocket::newSocket(ctx_, base_, fd);
     } catch (const std::exception &e) {
       LOG(ERROR) << "Exception %s caught while creating a AsyncSSLSocket "
         "object with socket " << e.what() << fd;
@@ -406,15 +409,20 @@ public:
       return;
     }
 
-    connAccepted(sslSock);
+    connAccepted(socket_);
   }
 
   virtual void connAccepted(
     const std::shared_ptr<folly::AsyncSSLSocket> &s) = 0;
 
+  void detach() {
+    socket_->detachEventBase();
+  }
+
   StateEnum state;
   HandshakeCallback *hcb_;
   std::shared_ptr<folly::SSLContext> ctx_;
+  std::shared_ptr<AsyncSSLSocket> socket_;
   folly::EventBase* base_;
 };
 
@@ -443,7 +451,7 @@ public:
     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
 
     hcb_->setSocket(sock);
-    sock->sslAccept(hcb_, timeout_);
+    sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
     EXPECT_EQ(sock->getSSLState(),
                       AsyncSSLSocket::STATE_ACCEPTING);
 
@@ -507,7 +515,7 @@ public:
     std::cerr << "SSLServerAcceptCallback::connAccepted" << std::endl;
 
     hcb_->setSocket(sock);
-    sock->sslAccept(hcb_, timeout_);
+    sock->sslAccept(hcb_, std::chrono::milliseconds(timeout_));
     ASSERT_TRUE((sock->getSSLState() ==
                  AsyncSSLSocket::STATE_ACCEPTING) ||
                 (sock->getSSLState() ==
@@ -548,8 +556,6 @@ public:
     EXPECT_EQ(hcb_->state, STATE_FAILED);
     EXPECT_EQ(callback2.state, STATE_FAILED);
 
-    sock->detachEventBase();
-
     state = STATE_SUCCEEDED;
     hcb_->setState(STATE_SUCCEEDED);
     callback2.setState(STATE_SUCCEEDED);
@@ -620,6 +626,7 @@ class TestSSLServer {
   static void *Main(void *ctx) {
     TestSSLServer *self = static_cast<TestSSLServer*>(ctx);
     self->evb_.loop();
+    self->acb_->detach();
     std::cerr << "Server thread exited event loop" << std::endl;
     return nullptr;
   }
@@ -676,6 +683,7 @@ class TestSSLAsyncCacheServer : public TestSSLServer {
                                          int* copyflag) {
     *copyflag = 0;
     asyncCallbacks_++;
+    (void)ssl;
 #ifdef SSL_ERROR_WANT_SESS_CACHE_LOOKUP
     if (!SSL_want_sess_cache_lookup(ssl)) {
       // libssl.so mismatch
@@ -741,7 +749,7 @@ class BlockingWriteClient :
       }
     }
 
-    socket_->sslConn(this, 100);
+    socket_->sslConn(this, std::chrono::milliseconds(100));
   }
 
   struct iovec* getIovec() const {
@@ -787,7 +795,7 @@ class BlockingWriteServer :
       bufSize_(2500 * 2000),
       bytesRead_(0) {
     buf_.reset(new uint8_t[bufSize_]);
-    socket_->sslAccept(this, 100);
+    socket_->sslAccept(this, std::chrono::milliseconds(100));
   }
 
   void checkBuffer(struct iovec* iov, uint32_t count) const {
@@ -1286,7 +1294,7 @@ class SSLHandshakeClient : public SSLHandshakeBase {
    bool preverifyResult,
    bool verifyResult) :
     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
-    socket_->sslConn(this, 0);
+    socket_->sslConn(this, std::chrono::milliseconds::zero());
   }
 };
 
@@ -1297,8 +1305,10 @@ class SSLHandshakeClientNoVerify : public SSLHandshakeBase {
    bool preverifyResult,
    bool verifyResult) :
     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
-    socket_->sslConn(this, 0,
-      folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+    socket_->sslConn(
+        this,
+        std::chrono::milliseconds::zero(),
+        folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
   }
 };
 
@@ -1309,8 +1319,10 @@ class SSLHandshakeClientDoVerify : public SSLHandshakeBase {
    bool preverifyResult,
    bool verifyResult) :
     SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
-    socket_->sslConn(this, 0,
-      folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
+    socket_->sslConn(
+        this,
+        std::chrono::milliseconds::zero(),
+        folly::SSLContext::SSLVerifyPeerEnum::VERIFY);
   }
 };
 
@@ -1321,7 +1333,7 @@ class SSLHandshakeServer : public SSLHandshakeBase {
       bool preverifyResult,
       bool verifyResult)
     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
-    socket_->sslAccept(this, 0);
+    socket_->sslAccept(this, std::chrono::milliseconds::zero());
   }
 };
 
@@ -1333,7 +1345,7 @@ class SSLHandshakeServerParseClientHello : public SSLHandshakeBase {
       bool verifyResult)
       : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
     socket_->enableClientHelloParsing();
-    socket_->sslAccept(this, 0);
+    socket_->sslAccept(this, std::chrono::milliseconds::zero());
   }
 
   std::string clientCiphers_, sharedCiphers_, serverCiphers_, chosenCipher_;
@@ -1356,8 +1368,10 @@ class SSLHandshakeServerNoVerify : public SSLHandshakeBase {
       bool preverifyResult,
       bool verifyResult)
     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
-    socket_->sslAccept(this, 0,
-      folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
+    socket_->sslAccept(
+        this,
+        std::chrono::milliseconds::zero(),
+        folly::SSLContext::SSLVerifyPeerEnum::NO_VERIFY);
   }
 };
 
@@ -1368,8 +1382,10 @@ class SSLHandshakeServerDoVerify : public SSLHandshakeBase {
       bool preverifyResult,
       bool verifyResult)
     : SSLHandshakeBase(std::move(socket), preverifyResult, verifyResult) {
-    socket_->sslAccept(this, 0,
-      folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
+    socket_->sslAccept(
+        this,
+        std::chrono::milliseconds::zero(),
+        folly::SSLContext::SSLVerifyPeerEnum::VERIFY_REQ_CLIENT_CERT);
   }
 };