Adding OpenSSLPtrTypes.h.
[folly.git] / folly / io / async / AsyncSSLSocket.h
index 6fead8468d60fb1b97dfb2f88bc392f4a54bbde9..2836544f5ad4f357f2cd6df57f497914410cab52 100644 (file)
 #include <folly/io/async/AsyncSocket.h>
 #include <folly/io/async/SSLContext.h>
 #include <folly/io/async/AsyncTimeout.h>
+#include <folly/io/async/OpenSSLPtrTypes.h>
 #include <folly/io/async/TimeoutManager.h>
 
 #include <folly/Bits.h>
 #include <folly/io/IOBuf.h>
 #include <folly/io/Cursor.h>
 
-using folly::io::Cursor;
-using std::unique_ptr;
-
 namespace folly {
 
 class SSLException: public folly::AsyncSocketException {
@@ -276,6 +274,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   virtual void shutdownWriteNow() override;
   virtual bool good() const override;
   virtual bool connecting() const override;
+  virtual std::string getApplicationProtocol() noexcept override;
+
+  virtual std::string getSecurityProtocol() const override { return "TLS"; }
 
   bool isEorTrackingEnabled() const override;
   virtual void setEorTracking(bool track) override;
@@ -323,10 +324,10 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   /**
    * Initiate an SSL connection on the socket
-   * THe callback will be invoked and uninstalled when an SSL connection
+   * The callback will be invoked and uninstalled when an SSL connection
    * has been establshed on the underlying socket.
-   * The verification option verifyPeer is applied if its passed explicitly.
-   * If its not, the options in SSLContext set on the underying SSLContext
+   * The verification option verifyPeer is applied if it's passed explicitly.
+   * If it's not, the options in SSLContext set on the underlying SSLContext
    * are applied.
    *
    * @param callback callback object to invoke on success/failure
@@ -378,7 +379,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   /**
    * Get the name of the protocol selected by the client during
-   * Next Protocol Negotiation (NPN)
+   * Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation
+   * (ALPN)
    *
    * Throw an exception if openssl does not support NPN
    *
@@ -388,13 +390,17 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    *                       Note: the AsyncSSLSocket retains ownership
    *                       of this string.
    * @param protoNameLen   Length of the name.
+   * @param protoType      Whether this was an NPN or ALPN negotiation
    */
-  virtual void getSelectedNextProtocol(const unsigned char** protoName,
-      unsigned* protoLen) const;
+  virtual void getSelectedNextProtocol(
+      const unsigned char** protoName,
+      unsigned* protoLen,
+      SSLContext::NextProtocolType* protoType = nullptr) const;
 
   /**
    * Get the name of the protocol selected by the client during
-   * Next Protocol Negotiation (NPN)
+   * Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation
+   * (ALPN)
    *
    * @param protoName      Name of the protocol (not guaranteed to be
    *                       null terminated); will be set to nullptr if
@@ -402,16 +408,19 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    *                       Note: the AsyncSSLSocket retains ownership
    *                       of this string.
    * @param protoNameLen   Length of the name.
+   * @param protoType      Whether this was an NPN or ALPN negotiation
    * @return false if openssl does not support NPN
    */
-  virtual bool getSelectedNextProtocolNoThrow(const unsigned char** protoName,
-      unsigned* protoLen) const;
+  virtual bool getSelectedNextProtocolNoThrow(
+      const unsigned char** protoName,
+      unsigned* protoLen,
+      SSLContext::NextProtocolType* protoType = nullptr) const;
 
   /**
    * Determine if the session specified during setSSLSession was reused
    * or if the server rejected it and issued a new session.
    */
-  bool getSSLSessionReused() const;
+  virtual bool getSSLSessionReused() const;
 
   /**
    * true if the session was resumed using session ID
@@ -427,7 +436,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * Returns the cipher used or the constant value "NONE" when no SSL session
    * has been established.
    */
-  const char *getNegotiatedCipherName() const;
+  virtual const char* getNegotiatedCipherName() const;
 
   /**
    * Get the server name for this SSL connection.
@@ -454,29 +463,15 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   int getSSLVersion() const;
 
   /**
-   * Get the certificate size used for this SSL connection.
-   */
-  int getSSLCertSize() const;
-
-  /* Get the number of bytes read from the wire (including protocol
-   * overhead). Returns 0 once the connection has been closed.
+   * Get the signature algorithm used in the cert that is used for this
+   * connection.
    */
-  unsigned long getBytesRead() const {
-    if (ssl_ != nullptr) {
-      return BIO_number_read(SSL_get_rbio(ssl_));
-    }
-    return 0;
-  }
+  const char *getSSLCertSigAlgName() const;
 
-  /* Get the number of bytes written to the wire (including protocol
-   * overhead).  Returns 0 once the connection has been closed.
+  /**
+   * Get the certificate size used for this SSL connection.
    */
-  unsigned long getBytesWritten() const {
-    if (ssl_ != nullptr) {
-      return BIO_number_written(SSL_get_wbio(ssl_));
-    }
-    return 0;
-  }
+  int getSSLCertSize() const;
 
   virtual void attachEventBase(EventBase* eventBase) override {
     AsyncSocket::attachEventBase(eventBase);
@@ -488,6 +483,10 @@ class AsyncSSLSocket : public virtual AsyncSocket {
     handshakeTimeout_.detachEventBase();
   }
 
+  virtual bool isDetachable() const override {
+    return AsyncSocket::isDetachable() && !handshakeTimeout_.isScheduled();
+  }
+
   virtual void attachTimeoutManager(TimeoutManager* manager) {
     handshakeTimeout_.attachTimeoutManager(manager);
   }
@@ -721,6 +720,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
     return clientHelloInfo_.get();
   }
 
+  /**
+   * Returns the time taken to complete a handshake.
+   */
+  virtual std::chrono::nanoseconds getHandshakeTime() const {
+    return handshakeEndTime_ - handshakeStartTime_;
+  }
+
   void setMinWriteSize(size_t minWriteSize) {
     minWriteSize_ = minWriteSize;
   }
@@ -734,13 +740,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   /**
    * Returns the peer certificate, or nullptr if no peer certificate received.
    */
-  std::unique_ptr<X509, X509_deleter> getPeerCert() const {
+  virtual X509_UniquePtr getPeerCert() const {
     if (!ssl_) {
       return nullptr;
     }
 
     X509* cert = SSL_get_peer_certificate(ssl_);
-    return std::unique_ptr<X509, X509_deleter>(cert);
+    return X509_UniquePtr(cert);
   }
 
  private:
@@ -812,12 +818,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   // Inherit error handling methods from AsyncSocket, plus the following.
   void failHandshake(const char* fn, const AsyncSocketException& ex);
 
+  void invokeHandshakeErr(const AsyncSocketException& ex);
   void invokeHandshakeCB();
 
   static void sslInfoCallback(const SSL *ssl, int type, int val);
 
-  // Whether we've applied the TCP_CORK option to the socket
-  bool corked_{false};
   // SSL related members.
   bool server_{false};
   // Used to prevent client-initiated renegotiation.  Note that AsyncSSLSocket
@@ -858,7 +863,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   static int sslVerifyCallback(int preverifyOk, X509_STORE_CTX* ctx);
 
   bool parseClientHello_{false};
-  unique_ptr<ClientHelloInfo> clientHelloInfo_;
+  std::unique_ptr<ClientHelloInfo> clientHelloInfo_;
+
+  // Time taken to complete the ssl handshake.
+  std::chrono::steady_clock::time_point handshakeStartTime_;
+  std::chrono::steady_clock::time_point handshakeEndTime_;
 };
 
 } // namespace