Move address caching logic from AsyncSSLSocket to AsyncSocket.
[folly.git] / folly / io / async / AsyncSSLSocket.h
index 5140a82430a0aafa07abb2769ee771074b341bb7..474225c060bfd66ab0bfd86e925c26037ee01ef0 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@
 #include <folly/io/async/AsyncTimeout.h>
 #include <folly/io/async/SSLContext.h>
 #include <folly/io/async/TimeoutManager.h>
-#include <folly/io/async/ssl/OpenSSLPtrTypes.h>
+#include <folly/ssl/OpenSSLPtrTypes.h>
 #include <folly/io/async/ssl/OpenSSLUtils.h>
 #include <folly/io/async/ssl/SSLErrors.h>
 #include <folly/io/async/ssl/TLSDefinitions.h>
@@ -32,6 +32,7 @@
 #include <folly/Bits.h>
 #include <folly/io/IOBuf.h>
 #include <folly/io/Cursor.h>
+#include <folly/portability/OpenSSL.h>
 #include <folly/portability/Sockets.h>
 
 namespace folly {
@@ -122,32 +123,31 @@ class AsyncSSLSocket : public virtual AsyncSocket {
       noexcept = 0;
   };
 
-  class HandshakeTimeout : public AsyncTimeout {
+  class Timeout : public AsyncTimeout {
    public:
-    HandshakeTimeout(AsyncSSLSocket* sslSocket, EventBase* eventBase)
-      : AsyncTimeout(eventBase)
-      , sslSocket_(sslSocket) {}
+    Timeout(AsyncSSLSocket* sslSocket, EventBase* eventBase)
+        : AsyncTimeout(eventBase), sslSocket_(sslSocket) {}
 
-    virtual void timeoutExpired() noexcept {
-      sslSocket_->timeoutExpired();
+    bool scheduleTimeout(TimeoutManager::timeout_type timeout) {
+      timeout_ = timeout;
+      return AsyncTimeout::scheduleTimeout(timeout);
     }
 
-   private:
-    AsyncSSLSocket* sslSocket_;
-  };
+    bool scheduleTimeout(uint32_t timeoutMs) {
+      return scheduleTimeout(std::chrono::milliseconds{timeoutMs});
+    }
 
-  // Timer for if we fallback from SSL connects to TCP connects
-  class ConnectionTimeout : public AsyncTimeout {
-   public:
-    ConnectionTimeout(AsyncSSLSocket* sslSocket, EventBase* eventBase)
-        : AsyncTimeout(eventBase), sslSocket_(sslSocket) {}
+    TimeoutManager::timeout_type getTimeout() {
+      return timeout_;
+    }
 
-    virtual void timeoutExpired() noexcept override {
-      sslSocket_->timeoutExpired();
+    void timeoutExpired() noexcept override {
+      sslSocket_->timeoutExpired(timeout_);
     }
 
    private:
     AsyncSSLSocket* sslSocket_;
+    TimeoutManager::timeout_type timeout_;
   };
 
   /**
@@ -173,10 +173,22 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * @param deferSecurityNegotiation
    *          unencrypted data can be sent before sslConn/Accept
    */
-  AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
-                 EventBase* evb, int fd,
-                 bool server = true, bool deferSecurityNegotiation = false);
+  AsyncSSLSocket(
+      const std::shared_ptr<folly::SSLContext>& ctx,
+      EventBase* evb,
+      int fd,
+      bool server = true,
+      bool deferSecurityNegotiation = false);
 
+  /**
+   * Create a server/client AsyncSSLSocket from an already connected
+   * AsyncSocket.
+   */
+  AsyncSSLSocket(
+      const std::shared_ptr<folly::SSLContext>& ctx,
+      AsyncSocket::UniquePtr oldAsyncSocket,
+      bool server = true,
+      bool deferSecurityNegotiation = false);
 
   /**
    * Helper function to create a server/client shared_ptr<AsyncSSLSocket>.
@@ -202,7 +214,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   }
 
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
   /**
    * Create a client AsyncSSLSocket with tlsext_servername in
    * the Client Hello message.
@@ -227,11 +239,12 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * @param fd   File descriptor to take over (should be a connected socket).
    * @param serverName tlsext_hostname that will be sent in ClientHello.
    */
-  AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
-                  EventBase* evb,
-                  int fd,
-                 const std::string& serverName,
-                bool deferSecurityNegotiation = false);
+  AsyncSSLSocket(
+      const std::shared_ptr<folly::SSLContext>& ctx,
+      EventBase* evb,
+      int fd,
+      const std::string& serverName,
+      bool deferSecurityNegotiation = false);
 
   static std::shared_ptr<AsyncSSLSocket> newSocket(
     const std::shared_ptr<folly::SSLContext>& ctx,
@@ -242,7 +255,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
       new AsyncSSLSocket(ctx, evb, serverName, deferSecurityNegotiation),
       Destructor());
   }
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
 
   /**
    * TODO: implement support for SSL renegotiation.
@@ -262,19 +275,20 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   // See the documentation in TAsyncTransport.h
   // TODO: implement graceful shutdown in close()
   // TODO: implement detachSSL() that returns the SSL connection
-  virtual void closeNow() override;
-  virtual void shutdownWrite() override;
-  virtual void shutdownWriteNow() override;
-  virtual bool good() const override;
-  virtual bool connecting() const override;
-  virtual std::string getApplicationProtocol() noexcept override;
-
-  virtual std::string getSecurityProtocol() const override { return "TLS"; }
-
-  bool isEorTrackingEnabled() const override;
-  virtual void setEorTracking(bool track) override;
-  virtual size_t getRawBytesWritten() const override;
-  virtual size_t getRawBytesReceived() const override;
+  void closeNow() override;
+  void shutdownWrite() override;
+  void shutdownWriteNow() override;
+  bool good() const override;
+  bool connecting() const override;
+  std::string getApplicationProtocol() noexcept override;
+
+  std::string getSecurityProtocol() const override {
+    return "TLS";
+  }
+
+  void setEorTracking(bool track) override;
+  size_t getRawBytesWritten() const override;
+  size_t getRawBytesReceived() const override;
   void enableClientHelloParsing();
 
   /**
@@ -292,9 +306,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    *                context by default, can be set explcitly to override the
    *                method in the context
    */
-  virtual void sslAccept(HandshakeCB* callback, uint32_t timeout = 0,
+  virtual void sslAccept(
+      HandshakeCB* callback,
+      std::chrono::milliseconds timeout = std::chrono::milliseconds::zero(),
       const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer =
-            folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
+          folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
 
   /**
    * Invoke SSL accept following an asynchronous session cache lookup
@@ -313,6 +329,29 @@ class AsyncSSLSocket : public virtual AsyncSocket {
                const folly::SocketAddress& bindAddr = anyAddress())
                noexcept override;
 
+  /**
+   * A variant of connect that allows the caller to specify
+   * the timeout for the regular connect and the ssl connect
+   * separately.
+   * connectTimeout is specified as the time to establish a TCP
+   * connection.
+   * totalConnectTimeout defines the
+   * time it takes from starting the TCP connection to the time
+   * the ssl connection is established. The reason the timeout is
+   * defined this way is because user's rarely need to specify the SSL
+   * timeout independently of the connect timeout. It allows us to
+   * bound the time for a connect and SSL connection in
+   * a finer grained manner than if timeout was just defined
+   * independently for SSL.
+   */
+  virtual void connect(
+      ConnectCallback* callback,
+      const folly::SocketAddress& address,
+      std::chrono::milliseconds connectTimeout,
+      std::chrono::milliseconds totalConnectTimeout,
+      const OptionMap& options = emptyOptionMap,
+      const folly::SocketAddress& bindAddr = anyAddress()) noexcept;
+
   using AsyncSocket::connect;
 
   /**
@@ -332,9 +371,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    *                SSL_VERIFY_PEER and invokes
    *                HandshakeCB::handshakeVer().
    */
-  virtual void sslConn(HandshakeCB *callback, uint64_t timeout = 0,
-            const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer =
-                  folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
+  virtual void sslConn(
+      HandshakeCB* callback,
+      std::chrono::milliseconds timeout = std::chrono::milliseconds::zero(),
+      const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer =
+          folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
 
   enum SSLStateEnum {
     STATE_UNINIT,
@@ -474,21 +515,21 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   /**
    * Get the certificate used for this SSL connection. May be null
    */
-  virtual const X509* getSelfCert() const override;
+  const X509* getSelfCert() const override;
 
-  virtual void attachEventBase(EventBase* eventBase) override {
+  void attachEventBase(EventBase* eventBase) override {
     AsyncSocket::attachEventBase(eventBase);
     handshakeTimeout_.attachEventBase(eventBase);
     connectionTimeout_.attachEventBase(eventBase);
   }
 
-  virtual void detachEventBase() override {
+  void detachEventBase() override {
     AsyncSocket::detachEventBase();
     handshakeTimeout_.detachEventBase();
     connectionTimeout_.detachEventBase();
   }
 
-  virtual bool isDetachable() const override {
+  bool isDetachable() const override {
     return AsyncSocket::isDetachable() && !handshakeTimeout_.isScheduled();
   }
 
@@ -514,7 +555,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   void detachSSLContext();
 #endif
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
   /**
    * Switch the SSLContext to continue the SSL handshake.
    * It can only be used in server mode.
@@ -537,9 +578,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * ClientHello message.
    */
   void setServerName(std::string serverName) noexcept;
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
 
-  void timeoutExpired() noexcept;
+  void timeoutExpired(std::chrono::milliseconds timeout) noexcept;
 
   /**
    * Get the list of supported ciphers sent by the client in the client's
@@ -569,6 +610,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   std::string getSSLAlertsReceived() const;
 
+  /*
+   * Save an optional alert message generated during certificate verify
+   */
+  void setSSLCertVerificationAlert(std::string alert);
+
+  std::string getSSLCertVerificationAlert() const;
+
   /**
    * Get the list of shared ciphers between the server and the client.
    * Works well for only SSLv2, not so good for SSLv3 or TLSv1.
@@ -581,9 +629,17 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    */
   void getSSLServerCiphers(std::string& serverCiphers) const;
 
+  /**
+   * Method to check if peer verfication is set.
+   *
+   * @return true if peer verification is required.
+   */
+  bool needsPeerVerification() const;
+
   static int getSSLExDataIndex();
   static AsyncSSLSocket* getFromSSL(const SSL *ssl);
   static int bioWrite(BIO* b, const char* in, int inl);
+  static int bioRead(BIO* b, char* out, int outl);
   void resetClientHelloParsing(SSL *ssl);
   static void clientHelloParsingCallback(int write_p, int version,
       int content_type, const void *buf, size_t len, SSL *ssl, void *arg);
@@ -621,7 +677,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   /**
    * Returns the peer certificate, or nullptr if no peer certificate received.
    */
-  virtual ssl::X509UniquePtr getPeerCert() const override {
+  ssl::X509UniquePtr getPeerCert() const override {
     if (!ssl_) {
       return nullptr;
     }
@@ -657,6 +713,16 @@ class AsyncSSLSocket : public virtual AsyncSocket {
     return sessionResumptionAttempted_;
   }
 
+  /**
+   * If the SSL socket was used to connect as well
+   * as establish an SSL connection, this gives the total
+   * timeout for the connect + SSL connection that was
+   * set.
+   */
+  std::chrono::milliseconds getTotalConnectTimeout() const {
+    return totalConnectTimeout_;
+  }
+
  private:
 
   void init();
@@ -670,7 +736,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * destroy() instead.  (See the documentation in DelayedDestruction.h for
    * more details.)
    */
-  ~AsyncSSLSocket();
+  ~AsyncSSLSocket() override;
 
   // Inherit event notification methods from AsyncSocket except
   // the following.
@@ -685,7 +751,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
                  int* sslErrorOut,
                  unsigned long* errErrorOut) noexcept;
 
-  virtual void checkForImmediateRead() noexcept override;
+  void checkForImmediateRead() noexcept override;
   // AsyncSocket calls this at the wrong time for SSL
   void handleInitialReadWrite() noexcept override {}
 
@@ -745,8 +811,6 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   void invokeConnectSuccess() override;
   void scheduleConnectTimeout() override;
 
-  void cacheLocalPeerAddr();
-
   void startSSLConnect();
 
   static void sslInfoCallback(const SSL *ssl, int type, int val);
@@ -767,14 +831,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   HandshakeCB* handshakeCallback_{nullptr};
   SSL* ssl_{nullptr};
   SSL_SESSION *sslSession_{nullptr};
-  HandshakeTimeout handshakeTimeout_;
-  ConnectionTimeout connectionTimeout_;
+  Timeout handshakeTimeout_;
+  Timeout connectionTimeout_;
   // whether the SSL session was resumed using session ID or not
   bool sessionIDResumed_{false};
 
-  // Whether to track EOR or not.
-  bool trackEor_{false};
-
   // The app byte num that we are tracking for the MSG_EOR
   // Only one app EOR byte can be tracked.
   size_t appEorByteNo_{0};
@@ -786,7 +847,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   // When openssl is about to sendmsg() across the minEorRawBytesNo_,
   // it will pass MSG_EOR to sendmsg().
   size_t minEorRawByteNo_{0};
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
   std::shared_ptr<folly::SSLContext> handshakeCtx_;
   std::string tlsextHostname_;
 #endif
@@ -810,8 +871,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   // Time taken to complete the ssl handshake.
   std::chrono::steady_clock::time_point handshakeStartTime_;
   std::chrono::steady_clock::time_point handshakeEndTime_;
-  uint64_t handshakeConnectTimeout_{0};
+  std::chrono::milliseconds handshakeConnectTimeout_{0};
   bool sessionResumptionAttempted_{false};
+  std::chrono::milliseconds totalConnectTimeout_{0};
+
+  std::string sslVerificationAlert_;
 };
 
 } // namespace