Fix copyright lines
[folly.git] / folly / io / async / test / AsyncSocketTest.h
index 5d52ad204f68faa2859d91ec4c4c8bb42b58126d..fe69a4e4b3a35c57c9ab6a5f0303d39a32d55e59 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2015 Facebook, Inc.
+ * Copyright 2015-present Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 
 #include <folly/io/async/AsyncSocket.h>
 #include <folly/io/async/test/BlockingSocket.h>
+#include <folly/portability/Sockets.h>
 
 #include <boost/scoped_array.hpp>
-#include <poll.h>
-
-// This is a test-only header
-/* using override */
-using namespace folly;
+#include <memory>
 
 enum StateEnum {
   STATE_WAITING,
@@ -33,11 +30,11 @@ enum StateEnum {
 
 typedef std::function<void()> VoidCallback;
 
-class ConnCallback : public AsyncSocket::ConnectCallback {
+class ConnCallback : public folly::AsyncSocket::ConnectCallback {
  public:
   ConnCallback()
-    : state(STATE_WAITING)
-    , exception(AsyncSocketException::UNKNOWN, "none") {}
+      : state(STATE_WAITING),
+        exception(folly::AsyncSocketException::UNKNOWN, "none") {}
 
   void connectSuccess() noexcept override {
     state = STATE_SUCCEEDED;
@@ -46,7 +43,7 @@ class ConnCallback : public AsyncSocket::ConnectCallback {
     }
   }
 
-  void connectErr(const AsyncSocketException& ex) noexcept override {
+  void connectErr(const folly::AsyncSocketException& ex) noexcept override {
     state = STATE_FAILED;
     exception = ex;
     if (errorCallback) {
@@ -55,34 +52,17 @@ class ConnCallback : public AsyncSocket::ConnectCallback {
   }
 
   StateEnum state;
-  AsyncSocketException exception;
+  folly::AsyncSocketException exception;
   VoidCallback successCallback;
   VoidCallback errorCallback;
 };
 
-class BufferCallback : public AsyncTransportWrapper::BufferCallback {
- public:
-  BufferCallback()
-    : buffered_(false) {}
-
-  void onEgressBuffered() override {
-    buffered_ = true;
-  }
-
-  bool hasBuffered() const {
-    return buffered_;
-  }
-
- private:
-  bool buffered_{false};
-};
-
-class WriteCallback : public AsyncTransportWrapper::WriteCallback {
+class WriteCallback : public folly::AsyncTransportWrapper::WriteCallback {
  public:
   WriteCallback()
-    : state(STATE_WAITING)
-    , bytesWritten(0)
-    , exception(AsyncSocketException::UNKNOWN, "none") {}
+      : state(STATE_WAITING),
+        bytesWritten(0),
+        exception(folly::AsyncSocketException::UNKNOWN, "none") {}
 
   void writeSuccess() noexcept override {
     state = STATE_SUCCEEDED;
@@ -91,10 +71,11 @@ class WriteCallback : public AsyncTransportWrapper::WriteCallback {
     }
   }
 
-  void writeErr(size_t bytesWritten,
-                const AsyncSocketException& ex) noexcept override {
+  void writeErr(size_t nBytesWritten,
+                const folly::AsyncSocketException& ex) noexcept override {
+    LOG(ERROR) << ex.what();
     state = STATE_FAILED;
-    this->bytesWritten = bytesWritten;
+    this->bytesWritten = nBytesWritten;
     exception = ex;
     if (errorCallback) {
       errorCallback();
@@ -102,21 +83,21 @@ class WriteCallback : public AsyncTransportWrapper::WriteCallback {
   }
 
   StateEnum state;
-  size_t bytesWritten;
-  AsyncSocketException exception;
+  std::atomic<size_t> bytesWritten;
+  folly::AsyncSocketException exception;
   VoidCallback successCallback;
   VoidCallback errorCallback;
 };
 
-class ReadCallback : public AsyncTransportWrapper::ReadCallback {
+class ReadCallback : public folly::AsyncTransportWrapper::ReadCallback {
  public:
   explicit ReadCallback(size_t _maxBufferSz = 4096)
-    : state(STATE_WAITING)
-    , exception(AsyncSocketException::UNKNOWN, "none")
-    , buffers()
-    , maxBufferSz(_maxBufferSz) {}
+      : state(STATE_WAITING),
+        exception(folly::AsyncSocketException::UNKNOWN, "none"),
+        buffers(),
+        maxBufferSz(_maxBufferSz) {}
 
-  ~ReadCallback() {
+  ~ReadCallback() override {
     for (std::vector<Buffer>::iterator it = buffers.begin();
          it != buffers.end();
          ++it) {
@@ -146,7 +127,7 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback {
     state = STATE_SUCCEEDED;
   }
 
-  void readErr(const AsyncSocketException& ex) noexcept override {
+  void readErr(const folly::AsyncSocketException& ex) noexcept override {
     state = STATE_FAILED;
     exception = ex;
   }
@@ -180,10 +161,10 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback {
       buffer = nullptr;
       length = 0;
     }
-    void allocate(size_t length) {
+    void allocate(size_t len) {
       assert(buffer == nullptr);
-      this->buffer = static_cast<char*>(malloc(length));
-      this->length = length;
+      this->buffer = static_cast<char*>(malloc(len));
+      this->length = len;
     }
     void free() {
       ::free(buffer);
@@ -195,36 +176,152 @@ class ReadCallback : public AsyncTransportWrapper::ReadCallback {
   };
 
   StateEnum state;
-  AsyncSocketException exception;
+  folly::AsyncSocketException exception;
   std::vector<Buffer> buffers;
   Buffer currentBuffer;
   VoidCallback dataAvailableCallback;
   const size_t maxBufferSz;
 };
 
+class BufferCallback : public folly::AsyncTransport::BufferCallback {
+ public:
+  BufferCallback() : buffered_(false), bufferCleared_(false) {}
+
+  void onEgressBuffered() override { buffered_ = true; }
+
+  void onEgressBufferCleared() override { bufferCleared_ = true; }
+
+  bool hasBuffered() const { return buffered_; }
+
+  bool hasBufferCleared() const { return bufferCleared_; }
+
+ private:
+  bool buffered_{false};
+  bool bufferCleared_{false};
+};
+
 class ReadVerifier {
 };
 
+class TestSendMsgParamsCallback :
+    public folly::AsyncSocket::SendMsgParamsCallback {
+ public:
+  TestSendMsgParamsCallback(int flags, uint32_t dataSize, void* data)
+  : flags_(flags),
+    writeFlags_(folly::WriteFlags::NONE),
+    dataSize_(dataSize),
+    data_(data),
+    queriedFlags_(false),
+    queriedData_(false)
+  {}
+
+  void reset(int flags) {
+    flags_ = flags;
+    writeFlags_ = folly::WriteFlags::NONE;
+    queriedFlags_ = false;
+    queriedData_ = false;
+  }
+
+  int getFlagsImpl(folly::WriteFlags flags, int /*defaultFlags*/) noexcept
+                                                                  override {
+    queriedFlags_ = true;
+    if (writeFlags_ == folly::WriteFlags::NONE) {
+      writeFlags_ = flags;
+    } else {
+      assert(flags == writeFlags_);
+    }
+    return flags_;
+  }
+
+  void getAncillaryData(folly::WriteFlags flags, void* data) noexcept override {
+    queriedData_ = true;
+    if (writeFlags_ == folly::WriteFlags::NONE) {
+      writeFlags_ = flags;
+    } else {
+      assert(flags == writeFlags_);
+    }
+    assert(data != nullptr);
+    memcpy(data, data_, dataSize_);
+  }
+
+  uint32_t getAncillaryDataSize(folly::WriteFlags flags) noexcept override {
+    if (writeFlags_ == folly::WriteFlags::NONE) {
+      writeFlags_ = flags;
+    } else {
+      assert(flags == writeFlags_);
+    }
+    return dataSize_;
+  }
+
+  int flags_;
+  folly::WriteFlags writeFlags_;
+  uint32_t dataSize_;
+  void* data_;
+  bool queriedFlags_;
+  bool queriedData_;
+};
+
 class TestServer {
  public:
   // Create a TestServer.
   // This immediately starts listening on an ephemeral port.
-  TestServer()
-    : fd_(-1) {
-    fd_ = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
+  explicit TestServer(bool enableTFO = false, int bufSize = -1) : fd_(-1) {
+    namespace fsp = folly::portability::sockets;
+    fd_ = fsp::socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
     if (fd_ < 0) {
-      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-                                "failed to create test server socket", errno);
+      throw folly::AsyncSocketException(
+          folly::AsyncSocketException::INTERNAL_ERROR,
+          "failed to create test server socket",
+          errno);
     }
     if (fcntl(fd_, F_SETFL, O_NONBLOCK) != 0) {
-      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-                                "failed to put test server socket in "
-                                "non-blocking mode", errno);
+      throw folly::AsyncSocketException(
+          folly::AsyncSocketException::INTERNAL_ERROR,
+          "failed to put test server socket in "
+          "non-blocking mode",
+          errno);
+    }
+    if (enableTFO) {
+#if FOLLY_ALLOW_TFO
+      folly::detail::tfo_enable(fd_, 100);
+#endif
     }
+
+    struct addrinfo hints, *res;
+    memset(&hints, 0, sizeof(hints));
+    hints.ai_family = AF_INET;
+    hints.ai_socktype = SOCK_STREAM;
+    hints.ai_flags = AI_PASSIVE;
+
+    if (getaddrinfo(nullptr, "0", &hints, &res)) {
+      throw folly::AsyncSocketException(
+          folly::AsyncSocketException::INTERNAL_ERROR,
+          "Attempted to bind address to socket with "
+          "bad getaddrinfo",
+          errno);
+    }
+
+    SCOPE_EXIT {
+      freeaddrinfo(res);
+    };
+
+    if (bufSize > 0) {
+      setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufSize, sizeof(bufSize));
+      setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufSize, sizeof(bufSize));
+    }
+
+    if (bind(fd_, res->ai_addr, res->ai_addrlen)) {
+      throw folly::AsyncSocketException(
+          folly::AsyncSocketException::INTERNAL_ERROR,
+          "failed to bind to async server socket for port 10",
+          errno);
+    }
+
     if (listen(fd_, 10) != 0) {
-      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-                                "failed to listen on test server socket",
-                                errno);
+      throw folly::AsyncSocketException(
+          folly::AsyncSocketException::INTERNAL_ERROR,
+          "failed to listen on test server socket",
+          errno);
     }
 
     address_.setFromLocalAddress(fd_);
@@ -233,28 +330,40 @@ class TestServer {
     address_.setFromIpPort("127.0.0.1", address_.getPort());
   }
 
+  ~TestServer() {
+    if (fd_ != -1) {
+      close(fd_);
+    }
+  }
+
   // Get the address for connecting to the server
   const folly::SocketAddress& getAddress() const {
     return address_;
   }
 
   int acceptFD(int timeout=50) {
+    namespace fsp = folly::portability::sockets;
     struct pollfd pfd;
     pfd.fd = fd_;
     pfd.events = POLLIN;
     int ret = poll(&pfd, 1, timeout);
     if (ret == 0) {
-      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-                                "test server accept() timed out");
+      throw folly::AsyncSocketException(
+          folly::AsyncSocketException::INTERNAL_ERROR,
+          "test server accept() timed out");
     } else if (ret < 0) {
-      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-                                "test server accept() poll failed", errno);
+      throw folly::AsyncSocketException(
+          folly::AsyncSocketException::INTERNAL_ERROR,
+          "test server accept() poll failed",
+          errno);
     }
 
-    int acceptedFd = ::accept(fd_, nullptr, nullptr);
+    int acceptedFd = fsp::accept(fd_, nullptr, nullptr);
     if (acceptedFd < 0) {
-      throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
-                                "test server accept() failed", errno);
+      throw folly::AsyncSocketException(
+          folly::AsyncSocketException::INTERNAL_ERROR,
+          "test server accept() failed",
+          errno);
     }
 
     return acceptedFd;
@@ -262,12 +371,13 @@ class TestServer {
 
   std::shared_ptr<BlockingSocket> accept(int timeout=50) {
     int fd = acceptFD(timeout);
-    return std::shared_ptr<BlockingSocket>(new BlockingSocket(fd));
+    return std::make_shared<BlockingSocket>(fd);
   }
 
-  std::shared_ptr<AsyncSocket> acceptAsync(EventBase* evb, int timeout=50) {
+  std::shared_ptr<folly::AsyncSocket> acceptAsync(folly::EventBase* evb,
+                                                  int timeout = 50) {
     int fd = acceptFD(timeout);
-    return AsyncSocket::newSocket(evb, fd);
+    return folly::AsyncSocket::newSocket(evb, fd);
   }
 
   /**