Adds writer test case for RCU
[folly.git] / folly / io / async / test / AsyncSSLSocketTest.h
index e4a73c19b0b491c30cbc7ed072c6cad1f2d83e59..7a05fbbc8ea5d995dd0751add4d059643fccdaa4 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2017 Facebook, Inc.
+ * Copyright 2012-present Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,7 +16,6 @@
 #pragma once
 
 #include <signal.h>
-#include <pthread.h>
 
 #include <folly/ExceptionWrapper.h>
 #include <folly/SocketAddress.h>
@@ -30,6 +29,7 @@
 #include <folly/io/async/ssl/SSLErrors.h>
 #include <folly/io/async/test/TestSSLServer.h>
 #include <folly/portability/GTest.h>
+#include <folly/portability/PThread.h>
 #include <folly/portability/Sockets.h>
 #include <folly/portability/Unistd.h>
 
@@ -38,6 +38,7 @@
 #include <condition_variable>
 #include <iostream>
 #include <list>
+#include <memory>
 
 namespace folly {
 
@@ -60,7 +61,7 @@ class SendMsgParamsCallbackBase :
 
   int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
                                                                      override {
-    return oldCallback_->getFlags(flags);
+    return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
   }
 
   void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
@@ -88,7 +89,7 @@ class SendMsgFlagsCallback : public SendMsgParamsCallbackBase {
     if (flags_) {
       return flags_;
     } else {
-      return oldCallback_->getFlags(flags);
+      return oldCallback_->getFlags(flags, false /*zeroCopyEnabled*/);
     }
   }
 
@@ -126,14 +127,14 @@ class SendMsgDataCallback : public SendMsgFlagsCallback {
 
 class WriteCallbackBase :
 public AsyncTransportWrapper::WriteCallback {
-public:
+ public:
   explicit WriteCallbackBase(SendMsgParamsCallbackBase* mcb = nullptr)
       : state(STATE_WAITING)
       , bytesWritten(0)
       , exception(AsyncSocketException::UNKNOWN, "none")
       , mcb_(mcb) {}
 
-  ~WriteCallbackBase() {
+  ~WriteCallbackBase() override {
     EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
@@ -145,7 +146,7 @@ public:
     }
   }
 
-  virtual void writeSuccess() noexcept override {
+  void writeSuccess() noexcept override {
     std::cerr << "writeSuccess" << std::endl;
     state = STATE_SUCCEEDED;
   }
@@ -171,11 +172,11 @@ public:
 
 class ExpectWriteErrorCallback :
 public WriteCallbackBase {
-public:
+ public:
   explicit ExpectWriteErrorCallback(SendMsgParamsCallbackBase* mcb = nullptr)
       : WriteCallbackBase(mcb) {}
 
-  ~ExpectWriteErrorCallback() {
+  ~ExpectWriteErrorCallback() override {
     EXPECT_EQ(STATE_FAILED, state);
     EXPECT_EQ(exception.type_,
              AsyncSocketException::AsyncSocketExceptionType::NETWORK_ERROR);
@@ -185,7 +186,7 @@ public:
   }
 };
 
-#ifdef MSG_ERRQUEUE
+#ifdef FOLLY_HAVE_MSG_ERRQUEUE
 /* copied from include/uapi/linux/net_tstamp.h */
 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
 enum SOF_TIMESTAMPING {
@@ -198,12 +199,12 @@ enum SOF_TIMESTAMPING {
 };
 
 class WriteCheckTimestampCallback :
- public WriteCallbackBase {
-public:
 public WriteCallbackBase {
+ public:
   explicit WriteCheckTimestampCallback(SendMsgParamsCallbackBase* mcb = nullptr)
     : WriteCallbackBase(mcb) {}
 
-  ~WriteCheckTimestampCallback() {
+  ~WriteCheckTimestampCallback() override {
     EXPECT_EQ(STATE_SUCCEEDED, state);
     EXPECT_TRUE(gotTimestamp_);
     EXPECT_TRUE(gotByteSeq_);
@@ -275,7 +276,7 @@ public:
   bool gotTimestamp_{false};
   bool gotByteSeq_{false};
 };
-#endif // MSG_ERRQUEUE
+#endif // FOLLY_HAVE_MSG_ERRQUEUE
 
 class ReadCallbackBase :
 public AsyncTransportWrapper::ReadCallback {
@@ -283,7 +284,7 @@ public AsyncTransportWrapper::ReadCallback {
   explicit ReadCallbackBase(WriteCallbackBase* wcb)
       : wcb_(wcb), state(STATE_WAITING) {}
 
-  ~ReadCallbackBase() {
+  ~ReadCallbackBase() override {
     EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
@@ -318,12 +319,12 @@ public AsyncTransportWrapper::ReadCallback {
 };
 
 class ReadCallback : public ReadCallbackBase {
-public:
+ public:
   explicit ReadCallback(WriteCallbackBase *wcb)
       : ReadCallbackBase(wcb)
       , buffers() {}
 
-  ~ReadCallback() {
+  ~ReadCallback() override {
     for (std::vector<Buffer>::iterator it = buffers.begin();
          it != buffers.end();
          ++it) {
@@ -356,7 +357,7 @@ public:
   }
 
   class Buffer {
-  public:
+   public:
     Buffer() : buffer(nullptr), length(0) {}
     Buffer(char* buf, size_t len) : buffer(buf), length(len) {}
 
@@ -383,7 +384,7 @@ public:
 };
 
 class ReadErrorCallback : public ReadCallbackBase {
-public:
+ public:
   explicit ReadErrorCallback(WriteCallbackBase *wcb)
       : ReadCallbackBase(wcb) {}
 
@@ -428,7 +429,7 @@ class ReadEOFCallback : public ReadCallbackBase {
 };
 
 class WriteErrorCallback : public ReadCallback {
-public:
+ public:
   explicit WriteErrorCallback(WriteCallbackBase *wcb)
       : ReadCallback(wcb) {}
 
@@ -464,7 +465,7 @@ public:
 };
 
 class EmptyReadCallback : public ReadCallback {
-public:
+ public:
   explicit EmptyReadCallback()
       : ReadCallback(nullptr) {}
 
@@ -489,7 +490,7 @@ public:
 
 class HandshakeCallback :
 public AsyncSSLSocket::HandshakeCB {
-public:
+ public:
   enum ExpectType {
     EXPECT_SUCCESS,
     EXPECT_ERROR
@@ -539,7 +540,7 @@ public:
     cv_.wait(lock, [this] { return state != STATE_WAITING; });
   }
 
-  ~HandshakeCallback() {
+  ~HandshakeCallback() override {
     EXPECT_EQ(STATE_SUCCEEDED, state);
   }
 
@@ -562,7 +563,7 @@ public:
 };
 
 class SSLServerAcceptCallback: public SSLServerAcceptCallbackBase {
-public:
+ public:
   uint32_t timeout_;
 
   explicit SSLServerAcceptCallback(HandshakeCallback *hcb,
@@ -570,7 +571,7 @@ public:
       SSLServerAcceptCallbackBase(hcb),
       timeout_(timeout) {}
 
-  virtual ~SSLServerAcceptCallback() {
+  ~SSLServerAcceptCallback() override {
     if (timeout_ > 0) {
       // if we set a timeout, we expect failure
       EXPECT_EQ(hcb_->state, STATE_FAILED);
@@ -578,7 +579,6 @@ public:
     }
   }
 
-  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
   void connAccepted(
     const std::shared_ptr<folly::AsyncSSLSocket> &s)
     noexcept override {
@@ -595,11 +595,10 @@ public:
 };
 
 class SSLServerAcceptCallbackDelay: public SSLServerAcceptCallback {
-public:
+ public:
   explicit SSLServerAcceptCallbackDelay(HandshakeCallback *hcb):
       SSLServerAcceptCallback(hcb) {}
 
-  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
   void connAccepted(
     const std::shared_ptr<folly::AsyncSSLSocket> &s)
     noexcept override {
@@ -636,12 +635,11 @@ public:
 };
 
 class SSLServerAsyncCacheAcceptCallback: public SSLServerAcceptCallback {
-public:
+ public:
   explicit SSLServerAsyncCacheAcceptCallback(HandshakeCallback *hcb,
                                              uint32_t timeout = 0):
     SSLServerAcceptCallback(hcb, timeout) {}
 
-  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
   void connAccepted(
     const std::shared_ptr<folly::AsyncSSLSocket> &s)
     noexcept override {
@@ -662,11 +660,10 @@ public:
 
 
 class HandshakeErrorCallback: public SSLServerAcceptCallbackBase {
-public:
+ public:
   explicit HandshakeErrorCallback(HandshakeCallback *hcb):
   SSLServerAcceptCallbackBase(hcb)  {}
 
-  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
   void connAccepted(
     const std::shared_ptr<folly::AsyncSSLSocket> &s)
     noexcept override {
@@ -698,11 +695,10 @@ public:
 };
 
 class HandshakeTimeoutCallback: public SSLServerAcceptCallbackBase {
-public:
+ public:
   explicit HandshakeTimeoutCallback(HandshakeCallback *hcb):
   SSLServerAcceptCallbackBase(hcb)  {}
 
-  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
   void connAccepted(
     const std::shared_ptr<folly::AsyncSSLSocket> &s)
     noexcept override {
@@ -738,7 +734,6 @@ class ConnectTimeoutCallback : public SSLServerAcceptCallbackBase {
     state = STATE_SUCCEEDED;
   }
 
-  // Functions inherited from TAsyncSSLServerSocket::SSLAcceptCallback
   void connAccepted(
       const std::shared_ptr<folly::AsyncSSLSocket>& s) noexcept override {
     std::cerr << "ConnectTimeoutCallback::connAccepted" << std::endl;
@@ -830,13 +825,13 @@ class BlockingWriteClient :
       bufLen_(2500),
       iovCount_(2000) {
     // Fill buf_
-    buf_.reset(new uint8_t[bufLen_]);
+    buf_ = std::make_unique<uint8_t[]>(bufLen_);
     for (uint32_t n = 0; n < sizeof(buf_); ++n) {
       buf_[n] = n % 0xff;
     }
 
     // Initialize iov_
-    iov_.reset(new struct iovec[iovCount_]);
+    iov_ = std::make_unique<struct iovec[]>(iovCount_);
     for (uint32_t n = 0; n < iovCount_; ++n) {
       iov_[n].iov_base = buf_.get() + n;
       if (n & 0x1) {
@@ -891,7 +886,7 @@ class BlockingWriteServer :
     : socket_(std::move(socket)),
       bufSize_(2500 * 2000),
       bytesRead_(0) {
-    buf_.reset(new uint8_t[bufSize_]);
+    buf_ = std::make_unique<uint8_t[]>(bufSize_);
     socket_->sslAccept(this, std::chrono::milliseconds(100));
   }
 
@@ -968,6 +963,7 @@ class NpnClient :
   const unsigned char* nextProto;
   unsigned nextProtoLength;
   SSLContext::NextProtocolType protocolType;
+  folly::Optional<AsyncSocketException> except;
 
  private:
   void handshakeSuc(AsyncSSLSocket*) noexcept override {
@@ -977,7 +973,7 @@ class NpnClient :
   void handshakeErr(
     AsyncSSLSocket*,
     const AsyncSocketException& ex) noexcept override {
-    ADD_FAILURE() << "client handshake error: " << ex.what();
+    except = ex;
   }
   void writeSuccess() noexcept override {
     socket_->close();
@@ -1004,6 +1000,7 @@ class NpnServer :
   const unsigned char* nextProto;
   unsigned nextProtoLength;
   SSLContext::NextProtocolType protocolType;
+  folly::Optional<AsyncSocketException> except;
 
  private:
   void handshakeSuc(AsyncSSLSocket*) noexcept override {
@@ -1013,7 +1010,7 @@ class NpnServer :
   void handshakeErr(
     AsyncSSLSocket*,
     const AsyncSocketException& ex) noexcept override {
-    ADD_FAILURE() << "server handshake error: " << ex.what();
+    except = ex;
   }
   void getReadBuffer(void** /* bufReturn */, size_t* lenReturn) override {
     *lenReturn = 0;
@@ -1038,7 +1035,7 @@ class RenegotiatingServer : public AsyncSSLSocket::HandshakeCB,
     socket_->sslAccept(this);
   }
 
-  ~RenegotiatingServer() {
+  ~RenegotiatingServer() override {
     socket_->setReadCB(nullptr);
   }
 
@@ -1216,7 +1213,7 @@ class SSLClient : public AsyncSocket::ConnectCallback,
     memset(buf_, 'a', sizeof(buf_));
   }
 
-  ~SSLClient() {
+  ~SSLClient() override {
     if (session_) {
       SSL_SESSION_free(session_);
     }
@@ -1505,4 +1502,4 @@ class EventBaseAborter : public AsyncTimeout {
   EventBase* eventBase_;
 };
 
-}
+} // namespace folly