Add getTotalConnectTimeout method
[folly.git] / folly / io / async / test / BlockingSocket.h
index 360cfcb18cfa342004ea7934a21a2cb610fde79f..0aeb4d0e49065816243e51c904d54a44eb3ff17b 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 #pragma once
 
 #include <folly/Optional.h>
-#include <folly/io/async/SSLContext.h>
-#include <folly/io/async/AsyncSocket.h>
 #include <folly/io/async/AsyncSSLSocket.h>
+#include <folly/io/async/AsyncSocket.h>
+#include <folly/io/async/SSLContext.h>
 
 class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
                        public folly::AsyncTransportWrapper::ReadCallback,
-                       public folly::AsyncTransportWrapper::WriteCallback
-{
+                       public folly::AsyncTransportWrapper::WriteCallback {
  public:
   explicit BlockingSocket(int fd)
-    : sock_(new folly::AsyncSocket(&eventBase_, fd)) {
-  }
+      : sock_(new folly::AsyncSocket(&eventBase_, fd)) {}
 
-  BlockingSocket(folly::SocketAddress address,
-                 std::shared_ptr<folly::SSLContext> sslContext)
-    : sock_(sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_) :
-            new folly::AsyncSocket(&eventBase_)),
-    address_(address) {}
+  BlockingSocket(
+      folly::SocketAddress address,
+      std::shared_ptr<folly::SSLContext> sslContext)
+      : sock_(
+            sslContext ? new folly::AsyncSSLSocket(sslContext, &eventBase_)
+                       : new folly::AsyncSocket(&eventBase_)),
+        address_(address) {}
 
   explicit BlockingSocket(folly::AsyncSocket::UniquePtr socket)
       : sock_(std::move(socket)) {
     sock_->attachEventBase(&eventBase_);
   }
 
-  void open() {
-    sock_->connect(this, address_);
+  void enableTFO() {
+    sock_->enableTFO();
+  }
+
+  void setAddress(folly::SocketAddress address) {
+    address_ = address;
+  }
+
+  void open(
+      std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) {
+    sock_->connect(this, address_, timeout.count());
     eventBase_.loop();
     if (err_.hasValue()) {
       throw err_.value();
     }
   }
+
   void close() {
     sock_->close();
   }
-  void closeWithReset() { sock_->closeWithReset(); }
+  void closeWithReset() {
+    sock_->closeWithReset();
+  }
 
   int32_t write(uint8_t const* buf, size_t len) {
     sock_->write(this, buf, len);
@@ -63,11 +75,11 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
 
   void flush() {}
 
-  int32_t readAll(uint8_t *buf, size_t len) {
+  int32_t readAll(uint8_tbuf, size_t len) {
     return readHelper(buf, len, true);
   }
 
-  int32_t read(uint8_t *buf, size_t len) {
+  int32_t read(uint8_tbuf, size_t len) {
     return readHelper(buf, len, false);
   }
 
@@ -75,11 +87,19 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
     return sock_->getFd();
   }
 
+  folly::AsyncSocket* getSocket() {
+    return sock_.get();
+  }
+
+  folly::AsyncSSLSocket* getSSLSocket() {
+    return dynamic_cast<folly::AsyncSSLSocket*>(sock_.get());
+  }
+
  private:
   folly::EventBase eventBase_;
   folly::AsyncSocket::UniquePtr sock_;
   folly::Optional<folly::AsyncSocketException> err_;
-  uint8_t *readBuf_{nullptr};
+  uint8_treadBuf_{nullptr};
   size_t readLen_{0};
   folly::SocketAddress address_;
 
@@ -98,23 +118,27 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
       sock_->setReadCB(nullptr);
     }
   }
-  void readEOF() noexcept override {
-  }
+  void readEOF() noexcept override {}
   void readErr(const folly::AsyncSocketException& ex) noexcept override {
     err_ = ex;
   }
   void writeSuccess() noexcept override {}
-  void writeErr(size_t /* bytesWritten */,
-                const folly::AsyncSocketException& ex) noexcept override {
+  void writeErr(
+      size_t /* bytesWritten */,
+      const folly::AsyncSocketException& ex) noexcept override {
     err_ = ex;
   }
 
-  int32_t readHelper(uint8_t *buf, size_t len, bool all) {
+  int32_t readHelper(uint8_t* buf, size_t len, bool all) {
+    if (!sock_->good()) {
+      return 0;
+    }
+
     readBuf_ = buf;
     readLen_ = len;
     sock_->setReadCB(this);
     while (!err_ && sock_->good() && readLen_ > 0) {
-      eventBase_.loop();
+      eventBase_.loopOnce();
       if (!all) {
         break;
       }
@@ -124,8 +148,8 @@ class BlockingSocket : public folly::AsyncSocket::ConnectCallback,
       throw err_.value();
     }
     if (all && readLen_ > 0) {
-      throw folly::AsyncSocketException(folly::AsyncSocketException::UNKNOWN,
-                                        "eof");
+      throw folly::AsyncSocketException(
+          folly::AsyncSocketException::UNKNOWN, "eof");
     }
     return len - readLen_;
   }