Enable EOR flag configuration for folly::AsyncSocket.
[folly.git] / folly / io / async / AsyncSocket.h
index 36949725c3558639de74e124f423430ee951e788..5aeb159c65fb635b8bc872688560ed6712348832 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -64,6 +64,10 @@ namespace folly {
  * responding and no further progress can be made sending the data.
  */
 
+#if defined __linux__ && !defined SO_NO_TRANSPARENT_TLS
+#define SO_NO_TRANSPARENT_TLS 200
+#endif
+
 #ifdef _MSC_VER
 // We do a dynamic_cast on this, in
 // AsyncTransportWrapper::getUnderlyingTransport so be safe and
@@ -94,6 +98,19 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
       noexcept = 0;
   };
 
+  class EvbChangeCallback {
+   public:
+    virtual ~EvbChangeCallback() = default;
+
+    // Called when the socket has been attached to a new EVB
+    // and is called from within that EVB thread
+    virtual void evbAttached(AsyncSocket* socket) = 0;
+
+    // Called when the socket is detached from an EVB and
+    // is called from the EVB thread being detached
+    virtual void evbDetached(AsyncSocket* socket) = 0;
+  };
+
   explicit AsyncSocket();
   /**
    * Create a new unconnected AsyncSocket.
@@ -375,9 +392,13 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   void getPeerAddress(
     folly::SocketAddress* address) const override;
 
-  bool isEorTrackingEnabled() const override { return false; }
+  bool isEorTrackingEnabled() const override {
+    return trackEor_;
+  }
 
-  void setEorTracking(bool /*track*/) override {}
+  void setEorTracking(bool track) override {
+    trackEor_ = track;
+  }
 
   bool connecting() const override {
     return (state_ == StateEnum::CONNECTING);
@@ -431,6 +452,20 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     return tfoFinished_;
   }
 
+  /**
+   * Returns whether or not TFO attempt succeded on this
+   * connection.
+   * For servers this is pretty straightforward API and can
+   * be invoked right after the connection is accepted. This API
+   * will perform one syscall.
+   * This API is a bit tricky to use for clients, since clients
+   * only know this for sure after the SYN-ACK is returned. So it's
+   * appropriate to call this only after the first application
+   * data is read from the socket when the caller knows that
+   * the SYN has been ACKed by the server.
+   */
+  bool getTFOSucceded() const;
+
   // Methods controlling socket options
 
   /**
@@ -535,6 +570,10 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
 #endif
   }
 
+  void disableTransparentTls() {
+    noTransparentTls_ = true;
+  }
+
   enum class StateEnum : uint8_t {
     UNINIT,
     CONNECTING,
@@ -546,6 +585,12 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
 
   void setBufferCallback(BufferCallback* cb);
 
+  // Callers should set this prior to connecting the socket for the safest
+  // behavior.
+  void setEvbChangedCallback(std::unique_ptr<EvbChangeCallback> cb) {
+    evbChangeCb_ = std::move(cb);
+  }
+
   /**
    * writeReturn is the total number of bytes written, or WRITE_ERROR on error.
    * If no data has been written, 0 is returned.
@@ -589,7 +634,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     WriteRequest(AsyncSocket* socket, WriteCallback* callback) :
       socket_(socket), callback_(callback) {}
 
-    virtual void start() {};
+    virtual void start() {}
 
     virtual void destroy() = 0;
 
@@ -621,7 +666,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     }
 
     void bytesWritten(size_t count) {
-      totalBytesWritten_ += count;
+      totalBytesWritten_ += uint32_t(count);
       socket_->appBytesWritten_ += count;
     }
 
@@ -735,11 +780,25 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
     }
   }
 
+  /**
+   * Schedule handleInitalReadWrite to run in the next iteration.
+   */
+  void scheduleInitialReadWrite() noexcept {
+    if (good()) {
+      DestructorGuard dg(this);
+      eventBase_->runInLoop([this, dg] {
+        if (good()) {
+          handleInitialReadWrite();
+        }
+      });
+    }
+  }
+
   // event notification methods
   void ioReady(uint16_t events) noexcept;
   virtual void checkForImmediateRead() noexcept;
   virtual void handleInitialReadWrite() noexcept;
-  virtual void prepareReadBuffer(void** buf, size_t* buflen) noexcept;
+  virtual void prepareReadBuffer(void** buf, size_t* buflen);
   virtual void handleRead() noexcept;
   virtual void handleWrite() noexcept;
   virtual void handleConnect() noexcept;
@@ -824,7 +883,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
 
   int socketConnect(const struct sockaddr* addr, socklen_t len);
 
-  void scheduleConnectTimeoutAndRegisterForEvents();
+  virtual void scheduleConnectTimeout();
+  void registerForConnectEvents();
 
   bool updateEventRegistration();
 
@@ -848,6 +908,8 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   // error handling methods
   void startFail();
   void finishFail();
+  void finishFail(const AsyncSocketException& ex);
+  void invokeAllErrors(const AsyncSocketException& ex);
   void fail(const char* fn, const AsyncSocketException& ex);
   void failConnect(const char* fn, const AsyncSocketException& ex);
   void failRead(const char* fn, const AsyncSocketException& ex);
@@ -855,7 +917,7 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
                  const AsyncSocketException& ex);
   void failWrite(const char* fn, const AsyncSocketException& ex);
   void failAllWrites(const AsyncSocketException& ex);
-  void invokeConnectErr(const AsyncSocketException& ex);
+  virtual void invokeConnectErr(const AsyncSocketException& ex);
   virtual void invokeConnectSuccess();
   void invalidState(ConnectCallback* callback);
   void invalidState(ReadCallback* callback);
@@ -899,6 +961,11 @@ class AsyncSocket : virtual public AsyncTransportWrapper {
   bool tfoEnabled_{false};
   bool tfoAttempted_{false};
   bool tfoFinished_{false};
+  bool noTransparentTls_{false};
+  // Whether to track EOR or not.
+  bool trackEor_{false};
+
+  std::unique_ptr<EvbChangeCallback> evbChangeCb_{nullptr};
 };
 #ifdef _MSC_VER
 #pragma vtordisp(pop)