Adding OpenSSLPtrTypes.h.
[folly.git] / folly / io / async / AsyncSSLSocket.h
index 4b8e4773ec491fe39adf4f0f54b5db6bf91ca494..2836544f5ad4f357f2cd6df57f497914410cab52 100644 (file)
@@ -25,6 +25,7 @@
 #include <folly/io/async/AsyncSocket.h>
 #include <folly/io/async/SSLContext.h>
 #include <folly/io/async/AsyncTimeout.h>
+#include <folly/io/async/OpenSSLPtrTypes.h>
 #include <folly/io/async/TimeoutManager.h>
 
 #include <folly/Bits.h>
@@ -273,6 +274,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   virtual void shutdownWriteNow() override;
   virtual bool good() const override;
   virtual bool connecting() const override;
+  virtual std::string getApplicationProtocol() noexcept override;
+
+  virtual std::string getSecurityProtocol() const override { return "TLS"; }
 
   bool isEorTrackingEnabled() const override;
   virtual void setEorTracking(bool track) override;
@@ -375,7 +379,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   /**
    * Get the name of the protocol selected by the client during
-   * Next Protocol Negotiation (NPN)
+   * Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation
+   * (ALPN)
    *
    * Throw an exception if openssl does not support NPN
    *
@@ -385,13 +390,17 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    *                       Note: the AsyncSSLSocket retains ownership
    *                       of this string.
    * @param protoNameLen   Length of the name.
+   * @param protoType      Whether this was an NPN or ALPN negotiation
    */
-  virtual void getSelectedNextProtocol(const unsigned char** protoName,
-      unsigned* protoLen) const;
+  virtual void getSelectedNextProtocol(
+      const unsigned char** protoName,
+      unsigned* protoLen,
+      SSLContext::NextProtocolType* protoType = nullptr) const;
 
   /**
    * Get the name of the protocol selected by the client during
-   * Next Protocol Negotiation (NPN)
+   * Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation
+   * (ALPN)
    *
    * @param protoName      Name of the protocol (not guaranteed to be
    *                       null terminated); will be set to nullptr if
@@ -399,16 +408,19 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    *                       Note: the AsyncSSLSocket retains ownership
    *                       of this string.
    * @param protoNameLen   Length of the name.
+   * @param protoType      Whether this was an NPN or ALPN negotiation
    * @return false if openssl does not support NPN
    */
-  virtual bool getSelectedNextProtocolNoThrow(const unsigned char** protoName,
-      unsigned* protoLen) const;
+  virtual bool getSelectedNextProtocolNoThrow(
+      const unsigned char** protoName,
+      unsigned* protoLen,
+      SSLContext::NextProtocolType* protoType = nullptr) const;
 
   /**
    * Determine if the session specified during setSSLSession was reused
    * or if the server rejected it and issued a new session.
    */
-  bool getSSLSessionReused() const;
+  virtual bool getSSLSessionReused() const;
 
   /**
    * true if the session was resumed using session ID
@@ -424,7 +436,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * Returns the cipher used or the constant value "NONE" when no SSL session
    * has been established.
    */
-  const char *getNegotiatedCipherName() const;
+  virtual const char* getNegotiatedCipherName() const;
 
   /**
    * Get the server name for this SSL connection.
@@ -451,29 +463,15 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   int getSSLVersion() const;
 
   /**
-   * Get the certificate size used for this SSL connection.
+   * Get the signature algorithm used in the cert that is used for this
+   * connection.
    */
-  int getSSLCertSize() const;
+  const char *getSSLCertSigAlgName() const;
 
-  /* Get the number of bytes read from the wire (including protocol
-   * overhead). Returns 0 once the connection has been closed.
-   */
-  unsigned long getBytesRead() const {
-    if (ssl_ != nullptr) {
-      return BIO_number_read(SSL_get_rbio(ssl_));
-    }
-    return 0;
-  }
-
-  /* Get the number of bytes written to the wire (including protocol
-   * overhead).  Returns 0 once the connection has been closed.
+  /**
+   * Get the certificate size used for this SSL connection.
    */
-  unsigned long getBytesWritten() const {
-    if (ssl_ != nullptr) {
-      return BIO_number_written(SSL_get_wbio(ssl_));
-    }
-    return 0;
-  }
+  int getSSLCertSize() const;
 
   virtual void attachEventBase(EventBase* eventBase) override {
     AsyncSocket::attachEventBase(eventBase);
@@ -722,6 +720,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
     return clientHelloInfo_.get();
   }
 
+  /**
+   * Returns the time taken to complete a handshake.
+   */
+  virtual std::chrono::nanoseconds getHandshakeTime() const {
+    return handshakeEndTime_ - handshakeStartTime_;
+  }
+
   void setMinWriteSize(size_t minWriteSize) {
     minWriteSize_ = minWriteSize;
   }
@@ -735,13 +740,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   /**
    * Returns the peer certificate, or nullptr if no peer certificate received.
    */
-  std::unique_ptr<X509, X509_deleter> getPeerCert() const {
+  virtual X509_UniquePtr getPeerCert() const {
     if (!ssl_) {
       return nullptr;
     }
 
     X509* cert = SSL_get_peer_certificate(ssl_);
-    return std::unique_ptr<X509, X509_deleter>(cert);
+    return X509_UniquePtr(cert);
   }
 
  private:
@@ -813,12 +818,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   // Inherit error handling methods from AsyncSocket, plus the following.
   void failHandshake(const char* fn, const AsyncSocketException& ex);
 
+  void invokeHandshakeErr(const AsyncSocketException& ex);
   void invokeHandshakeCB();
 
   static void sslInfoCallback(const SSL *ssl, int type, int val);
 
-  // Whether we've applied the TCP_CORK option to the socket
-  bool corked_{false};
   // SSL related members.
   bool server_{false};
   // Used to prevent client-initiated renegotiation.  Note that AsyncSSLSocket
@@ -860,6 +864,10 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   bool parseClientHello_{false};
   std::unique_ptr<ClientHelloInfo> clientHelloInfo_;
+
+  // Time taken to complete the ssl handshake.
+  std::chrono::steady_clock::time_point handshakeStartTime_;
+  std::chrono::steady_clock::time_point handshakeEndTime_;
 };
 
 } // namespace