Stop abusing errno
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
index a4e18aaa8447e271e72c2c25d930731f479c3099..69966d6743955e125237fc4106b997869241c1a8 100644 (file)
 #include <signal.h>
 #include <pthread.h>
 
-#include <folly/io/async/AsyncServerSocket.h>
+#include <folly/ExceptionWrapper.h>
+#include <folly/SocketAddress.h>
 #include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncServerSocket.h>
 #include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/AsyncTimeout.h>
 #include <folly/io/async/AsyncTransport.h>
 #include <folly/io/async/EventBase.h>
-#include <folly/io/async/AsyncTimeout.h>
-#include <folly/SocketAddress.h>
+#include <folly/io/async/ssl/SSLErrors.h>
 
 #include <gtest/gtest.h>
 #include <iostream>
@@ -58,7 +60,7 @@ public:
       , exception(AsyncSocketException::UNKNOWN, "none") {}
 
   ~WriteCallbackBase() {
-    EXPECT_EQ(state, STATE_SUCCEEDED);
+    EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
   void setSocket(
@@ -92,10 +94,9 @@ public:
 
 class ReadCallbackBase :
 public AsyncTransportWrapper::ReadCallback {
-public:
-  explicit ReadCallbackBase(WriteCallbackBase *wcb)
-      : wcb_(wcb)
-      , state(STATE_WAITING) {}
+ public:
+  explicit ReadCallbackBase(WriteCallbackBase* wcb)
+      : wcb_(wcb), state(STATE_WAITING) {}
 
   ~ReadCallbackBase() {
     EXPECT_EQ(state, STATE_SUCCEEDED);
@@ -222,6 +223,27 @@ public:
   }
 };
 
+class ReadEOFCallback : public ReadCallbackBase {
+ public:
+  explicit ReadEOFCallback(WriteCallbackBase* wcb) : ReadCallbackBase(wcb) {}
+
+  // Return nullptr buffer to trigger readError()
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *bufReturn = nullptr;
+    *lenReturn = 0;
+  }
+
+  void readDataAvailable(size_t /* len */) noexcept override {
+    // This should never to called.
+    FAIL();
+  }
+
+  void readEOF() noexcept override {
+    ReadCallbackBase::readEOF();
+    setState(STATE_SUCCEEDED);
+  }
+};
+
 class WriteErrorCallback : public ReadCallback {
 public:
   explicit WriteErrorCallback(WriteCallbackBase *wcb)
@@ -340,6 +362,10 @@ public:
     state = STATE_SUCCEEDED;
   }
 
+  std::shared_ptr<AsyncSSLSocket> getSocket() {
+    return socket_;
+  }
+
   StateEnum state;
   std::shared_ptr<AsyncSSLSocket> socket_;
   ReadCallbackBase *rcb_;
@@ -879,6 +905,48 @@ class NpnServer :
   AsyncSSLSocket::UniquePtr socket_;
 };
 
+class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
+                            public AsyncTransportWrapper::ReadCallback {
+ public:
+  explicit RenegotiatingServer(AsyncSSLSocket::UniquePtr socket)
+      : socket_(std::move(socket)) {
+    socket_->sslAccept(this);
+  }
+
+  ~RenegotiatingServer() {
+    socket_->setReadCB(nullptr);
+  }
+
+  void handshakeSuc(AsyncSSLSocket* /* socket */) noexcept override {
+    LOG(INFO) << "Renegotiating server handshake success";
+    socket_->setReadCB(this);
+  }
+  void handshakeErr(
+      AsyncSSLSocket*,
+      const AsyncSocketException& ex) noexcept override {
+    ADD_FAILURE() << "Renegotiating server handshake error: " << ex.what();
+  }
+  void getReadBuffer(void** bufReturn, size_t* lenReturn) override {
+    *lenReturn = sizeof(buf);
+    *bufReturn = buf;
+  }
+  void readDataAvailable(size_t /* len */) noexcept override {}
+  void readEOF() noexcept override {}
+  void readErr(const AsyncSocketException& ex) noexcept override {
+    LOG(INFO) << "server got read error " << ex.what();
+    auto exPtr = dynamic_cast<const SSLException*>(&ex);
+    ASSERT_NE(nullptr, exPtr);
+    std::string exStr(ex.what());
+    SSLException sslEx(SSLError::CLIENT_RENEGOTIATION);
+    ASSERT_NE(std::string::npos, exStr.find(sslEx.what()));
+    renegotiationError_ = true;
+  }
+
+  AsyncSSLSocket::UniquePtr socket_;
+  unsigned char buf[128];
+  bool renegotiationError_{false};
+};
+
 #ifndef OPENSSL_NO_TLSEXT
 class SNIClient :
   private AsyncSSLSocket::HandshakeCB,
@@ -1139,6 +1207,10 @@ class SSLHandshakeBase :
     verifyResult_(verifyResult) {
   }
 
+  AsyncSSLSocket::UniquePtr moveSocket() && {
+    return std::move(socket_);
+  }
+
   bool handshakeVerify_;
   bool handshakeSuccess_;
   bool handshakeError_;
@@ -1160,12 +1232,15 @@ class SSLHandshakeBase :
   }
 
   void handshakeSuc(AsyncSSLSocket*) noexcept override {
+    LOG(INFO) << "Handshake success";
     handshakeSuccess_ = true;
     handshakeTime = socket_->getHandshakeTime();
   }
 
-  void handshakeErr(AsyncSSLSocket*,
-                    const AsyncSocketException& /* ex */) noexcept override {
+  void handshakeErr(
+      AsyncSSLSocket*,
+      const AsyncSocketException& ex) noexcept override {
+    LOG(INFO) << "Handshake error " << ex.what();
     handshakeError_ = true;
     handshakeTime = socket_->getHandshakeTime();
   }