Adding OpenSSLPtrTypes.h.
[folly.git] / folly / io / async / AsyncSSLSocket.h
index c3d52150a3b21aebcbf4327092f964c3fc2d27fe..2836544f5ad4f357f2cd6df57f497914410cab52 100644 (file)
 #include <folly/io/async/AsyncSocket.h>
 #include <folly/io/async/SSLContext.h>
 #include <folly/io/async/AsyncTimeout.h>
+#include <folly/io/async/OpenSSLPtrTypes.h>
 #include <folly/io/async/TimeoutManager.h>
 
 #include <folly/Bits.h>
 #include <folly/io/IOBuf.h>
 #include <folly/io/Cursor.h>
 
-using folly::io::Cursor;
-using std::unique_ptr;
-
 namespace folly {
 
 class SSLException: public folly::AsyncSocketException {
@@ -80,10 +78,11 @@ class SSLException: public folly::AsyncSocketException {
 class AsyncSSLSocket : public virtual AsyncSocket {
  public:
   typedef std::unique_ptr<AsyncSSLSocket, Destructor> UniquePtr;
+  using X509_deleter = folly::static_function_deleter<X509, &X509_free>;
 
   class HandshakeCB {
    public:
-    virtual ~HandshakeCB() {}
+    virtual ~HandshakeCB() = default;
 
     /**
      * handshakeVer() is invoked during handshaking to give the
@@ -99,9 +98,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
      * See the passages on verify_callback in SSL_CTX_set_verify(3)
      * for more details.
      */
-    virtual bool handshakeVer(AsyncSSLSocket* sock,
+    virtual bool handshakeVer(AsyncSSLSocket* /*sock*/,
                                  bool preverifyOk,
-                                 X509_STORE_CTX* ctx) noexcept {
+                                 X509_STORE_CTX* /*ctx*/) noexcept {
       return preverifyOk;
     }
 
@@ -162,7 +161,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * Create a client AsyncSSLSocket
    */
   AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
-                  EventBase* evb);
+                 EventBase* evb, bool deferSecurityNegotiation = false);
 
   /**
    * Create a server/client AsyncSSLSocket from an already connected
@@ -178,9 +177,12 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * @param evb EventBase that will manage this socket.
    * @param fd  File descriptor to take over (should be a connected socket).
    * @param server Is socket in server mode?
+   * @param deferSecurityNegotiation
+   *          unencrypted data can be sent before sslConn/Accept
    */
   AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
-                  EventBase* evb, int fd, bool server = true);
+                 EventBase* evb, int fd,
+                 bool server = true, bool deferSecurityNegotiation = false);
 
 
   /**
@@ -188,9 +190,10 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    */
   static std::shared_ptr<AsyncSSLSocket> newSocket(
     const std::shared_ptr<folly::SSLContext>& ctx,
-    EventBase* evb, int fd, bool server=true) {
+    EventBase* evb, int fd, bool server=true,
+    bool deferSecurityNegotiation = false) {
     return std::shared_ptr<AsyncSSLSocket>(
-      new AsyncSSLSocket(ctx, evb, fd, server),
+      new AsyncSSLSocket(ctx, evb, fd, server, deferSecurityNegotiation),
       Destructor());
   }
 
@@ -199,9 +202,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    */
   static std::shared_ptr<AsyncSSLSocket> newSocket(
     const std::shared_ptr<folly::SSLContext>& ctx,
-    EventBase* evb) {
+    EventBase* evb, bool deferSecurityNegotiation = false) {
     return std::shared_ptr<AsyncSSLSocket>(
-      new AsyncSSLSocket(ctx, evb),
+      new AsyncSSLSocket(ctx, evb, deferSecurityNegotiation),
       Destructor());
   }
 
@@ -213,7 +216,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    */
   AsyncSSLSocket(const std::shared_ptr<folly::SSLContext> &ctx,
                   EventBase* evb,
-                  const std::string& serverName);
+                 const std::string& serverName,
+                bool deferSecurityNegotiation = false);
 
   /**
    * Create a client AsyncSSLSocket from an already connected
@@ -233,14 +237,16 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   AsyncSSLSocket(const std::shared_ptr<folly::SSLContext>& ctx,
                   EventBase* evb,
                   int fd,
-                  const std::string& serverName);
+                 const std::string& serverName,
+                bool deferSecurityNegotiation = false);
 
   static std::shared_ptr<AsyncSSLSocket> newSocket(
     const std::shared_ptr<folly::SSLContext>& ctx,
     EventBase* evb,
-    const std::string& serverName) {
+    const std::string& serverName,
+    bool deferSecurityNegotiation = false) {
     return std::shared_ptr<AsyncSSLSocket>(
-      new AsyncSSLSocket(ctx, evb, serverName),
+      new AsyncSSLSocket(ctx, evb, serverName, deferSecurityNegotiation),
       Destructor());
   }
 #endif
@@ -268,6 +274,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   virtual void shutdownWriteNow() override;
   virtual bool good() const override;
   virtual bool connecting() const override;
+  virtual std::string getApplicationProtocol() noexcept override;
+
+  virtual std::string getSecurityProtocol() const override { return "TLS"; }
 
   bool isEorTrackingEnabled() const override;
   virtual void setEorTracking(bool track) override;
@@ -315,10 +324,10 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   /**
    * Initiate an SSL connection on the socket
-   * THe callback will be invoked and uninstalled when an SSL connection
+   * The callback will be invoked and uninstalled when an SSL connection
    * has been establshed on the underlying socket.
-   * The verification option verifyPeer is applied if its passed explicitly.
-   * If its not, the options in SSLContext set on the underying SSLContext
+   * The verification option verifyPeer is applied if it's passed explicitly.
+   * If it's not, the options in SSLContext set on the underlying SSLContext
    * are applied.
    *
    * @param callback callback object to invoke on success/failure
@@ -336,6 +345,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   enum SSLStateEnum {
     STATE_UNINIT,
+    STATE_UNENCRYPTED,
     STATE_ACCEPTING,
     STATE_CACHE_LOOKUP,
     STATE_RSA_ASYNC_PENDING,
@@ -369,7 +379,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   /**
    * Get the name of the protocol selected by the client during
-   * Next Protocol Negotiation (NPN)
+   * Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation
+   * (ALPN)
    *
    * Throw an exception if openssl does not support NPN
    *
@@ -379,13 +390,17 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    *                       Note: the AsyncSSLSocket retains ownership
    *                       of this string.
    * @param protoNameLen   Length of the name.
+   * @param protoType      Whether this was an NPN or ALPN negotiation
    */
-  virtual void getSelectedNextProtocol(const unsigned char** protoName,
-      unsigned* protoLen) const;
+  virtual void getSelectedNextProtocol(
+      const unsigned char** protoName,
+      unsigned* protoLen,
+      SSLContext::NextProtocolType* protoType = nullptr) const;
 
   /**
    * Get the name of the protocol selected by the client during
-   * Next Protocol Negotiation (NPN)
+   * Next Protocol Negotiation (NPN) or Application Layer Protocol Negotiation
+   * (ALPN)
    *
    * @param protoName      Name of the protocol (not guaranteed to be
    *                       null terminated); will be set to nullptr if
@@ -393,16 +408,19 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    *                       Note: the AsyncSSLSocket retains ownership
    *                       of this string.
    * @param protoNameLen   Length of the name.
+   * @param protoType      Whether this was an NPN or ALPN negotiation
    * @return false if openssl does not support NPN
    */
-  virtual bool getSelectedNextProtocolNoThrow(const unsigned char** protoName,
-      unsigned* protoLen) const;
+  virtual bool getSelectedNextProtocolNoThrow(
+      const unsigned char** protoName,
+      unsigned* protoLen,
+      SSLContext::NextProtocolType* protoType = nullptr) const;
 
   /**
    * Determine if the session specified during setSSLSession was reused
    * or if the server rejected it and issued a new session.
    */
-  bool getSSLSessionReused() const;
+  virtual bool getSSLSessionReused() const;
 
   /**
    * true if the session was resumed using session ID
@@ -418,7 +436,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * Returns the cipher used or the constant value "NONE" when no SSL session
    * has been established.
    */
-  const char *getNegotiatedCipherName() const;
+  virtual const char* getNegotiatedCipherName() const;
 
   /**
    * Get the server name for this SSL connection.
@@ -445,29 +463,15 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   int getSSLVersion() const;
 
   /**
-   * Get the certificate size used for this SSL connection.
+   * Get the signature algorithm used in the cert that is used for this
+   * connection.
    */
-  int getSSLCertSize() const;
+  const char *getSSLCertSigAlgName() const;
 
-  /* Get the number of bytes read from the wire (including protocol
-   * overhead). Returns 0 once the connection has been closed.
-   */
-  unsigned long getBytesRead() const {
-    if (ssl_ != nullptr) {
-      return BIO_number_read(SSL_get_rbio(ssl_));
-    }
-    return 0;
-  }
-
-  /* Get the number of bytes written to the wire (including protocol
-   * overhead).  Returns 0 once the connection has been closed.
+  /**
+   * Get the certificate size used for this SSL connection.
    */
-  unsigned long getBytesWritten() const {
-    if (ssl_ != nullptr) {
-      return BIO_number_written(SSL_get_wbio(ssl_));
-    }
-    return 0;
-  }
+  int getSSLCertSize() const;
 
   virtual void attachEventBase(EventBase* eventBase) override {
     AsyncSocket::attachEventBase(eventBase);
@@ -479,6 +483,10 @@ class AsyncSSLSocket : public virtual AsyncSocket {
     handshakeTimeout_.detachEventBase();
   }
 
+  virtual bool isDetachable() const override {
+    return AsyncSocket::isDetachable() && !handshakeTimeout_.isScheduled();
+  }
+
   virtual void attachTimeoutManager(TimeoutManager* manager) {
     handshakeTimeout_.attachTimeoutManager(manager);
   }
@@ -532,7 +540,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * Get the list of supported ciphers sent by the client in the client's
    * preference order.
    */
-  void getSSLClientCiphers(std::string& clientCiphers) {
+  void getSSLClientCiphers(std::string& clientCiphers) const {
     std::stringstream ciphersStream;
     std::string cipherName;
 
@@ -576,7 +584,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   /**
    * Get the list of compression methods sent by the client in TLS Hello.
    */
-  std::string getSSLClientComprMethods() {
+  std::string getSSLClientComprMethods() const {
     if (!parseClientHello_) {
       return "";
     }
@@ -586,18 +594,39 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   /**
    * Get the list of TLS extensions sent by the client in the TLS Hello.
    */
-  std::string getSSLClientExts() {
+  std::string getSSLClientExts() const {
     if (!parseClientHello_) {
       return "";
     }
     return folly::join(":", clientHelloInfo_->clientHelloExtensions_);
   }
 
+  std::string getSSLClientSigAlgs() const {
+    if (!parseClientHello_) {
+      return "";
+    }
+
+    std::string sigAlgs;
+    sigAlgs.reserve(clientHelloInfo_->clientHelloSigAlgs_.size() * 4);
+    for (size_t i = 0; i < clientHelloInfo_->clientHelloSigAlgs_.size(); i++) {
+      if (i) {
+        sigAlgs.push_back(':');
+      }
+      sigAlgs.append(folly::to<std::string>(
+          clientHelloInfo_->clientHelloSigAlgs_[i].first));
+      sigAlgs.push_back(',');
+      sigAlgs.append(folly::to<std::string>(
+          clientHelloInfo_->clientHelloSigAlgs_[i].second));
+    }
+
+    return sigAlgs;
+  }
+
   /**
    * Get the list of shared ciphers between the server and the client.
    * Works well for only SSLv2, not so good for SSLv3 or TLSv1.
    */
-  void getSSLSharedCiphers(std::string& sharedCiphers) {
+  void getSSLSharedCiphers(std::string& sharedCiphers) const {
     char ciphersBuffer[1024];
     ciphersBuffer[0] = '\0';
     SSL_get_shared_ciphers(ssl_, ciphersBuffer, sizeof(ciphersBuffer) - 1);
@@ -608,7 +637,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * Get the list of ciphers supported by the server in the server's
    * preference order.
    */
-  void getSSLServerCiphers(std::string& serverCiphers) {
+  void getSSLServerCiphers(std::string& serverCiphers) const {
     serverCiphers = SSL_get_cipher_list(ssl_, 0);
     int i = 1;
     const char *cipher;
@@ -626,34 +655,118 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   static void clientHelloParsingCallback(int write_p, int version,
       int content_type, const void *buf, size_t len, SSL *ssl, void *arg);
 
+  // http://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml
+  enum class TLSExtension: uint16_t {
+    SERVER_NAME = 0,
+    MAX_FRAGMENT_LENGTH = 1,
+    CLIENT_CERTIFICATE_URL = 2,
+    TRUSTED_CA_KEYS = 3,
+    TRUNCATED_HMAC = 4,
+    STATUS_REQUEST = 5,
+    USER_MAPPING = 6,
+    CLIENT_AUTHZ = 7,
+    SERVER_AUTHZ = 8,
+    CERT_TYPE = 9,
+    SUPPORTED_GROUPS = 10,
+    EC_POINT_FORMATS = 11,
+    SRP = 12,
+    SIGNATURE_ALGORITHMS = 13,
+    USE_SRTP = 14,
+    HEARTBEAT = 15,
+    APPLICATION_LAYER_PROTOCOL_NEGOTIATION = 16,
+    STATUS_REQUEST_V2 = 17,
+    SIGNED_CERTIFICATE_TIMESTAMP = 18,
+    CLIENT_CERTIFICATE_TYPE = 19,
+    SERVER_CERTIFICATE_TYPE = 20,
+    PADDING = 21,
+    ENCRYPT_THEN_MAC = 22,
+    EXTENDED_MASTER_SECRET = 23,
+    SESSION_TICKET = 35,
+    RENEGOTIATION_INFO = 65281
+  };
+
+  // http://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-18
+  enum class HashAlgorithm: uint8_t {
+    NONE = 0,
+    MD5 = 1,
+    SHA1 = 2,
+    SHA224 = 3,
+    SHA256 = 4,
+    SHA384 = 5,
+    SHA512 = 6
+  };
+
+  // http://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-16
+  enum class SignatureAlgorithm: uint8_t {
+    ANONYMOUS = 0,
+    RSA = 1,
+    DSA = 2,
+    ECDSA = 3
+  };
+
   struct ClientHelloInfo {
     folly::IOBufQueue clientHelloBuf_;
     uint8_t clientHelloMajorVersion_;
     uint8_t clientHelloMinorVersion_;
     std::vector<uint16_t> clientHelloCipherSuites_;
     std::vector<uint8_t> clientHelloCompressionMethods_;
-    std::vector<uint16_t> clientHelloExtensions_;
+    std::vector<TLSExtension> clientHelloExtensions_;
+    std::vector<
+      std::pair<HashAlgorithm, SignatureAlgorithm>> clientHelloSigAlgs_;
   };
 
   // For unit-tests
-  ClientHelloInfo* getClientHelloInfo() {
+  ClientHelloInfo* getClientHelloInfo() const {
     return clientHelloInfo_.get();
   }
 
+  /**
+   * Returns the time taken to complete a handshake.
+   */
+  virtual std::chrono::nanoseconds getHandshakeTime() const {
+    return handshakeEndTime_ - handshakeStartTime_;
+  }
+
+  void setMinWriteSize(size_t minWriteSize) {
+    minWriteSize_ = minWriteSize;
+  }
+
+  size_t getMinWriteSize() const {
+    return minWriteSize_;
+  }
+
+  void setReadCB(ReadCallback* callback) override;
+
+  /**
+   * Returns the peer certificate, or nullptr if no peer certificate received.
+   */
+  virtual X509_UniquePtr getPeerCert() const {
+    if (!ssl_) {
+      return nullptr;
+    }
+
+    X509* cert = SSL_get_peer_certificate(ssl_);
+    return X509_UniquePtr(cert);
+  }
+
+ private:
+
+  void init();
+
  protected:
 
   /**
    * Protected destructor.
    *
    * Users of AsyncSSLSocket must never delete it directly.  Instead, invoke
-   * destroy() instead.  (See the documentation in TDelayedDestruction.h for
+   * destroy() instead.  (See the documentation in DelayedDestruction.h for
    * more details.)
    */
   ~AsyncSSLSocket();
 
   // Inherit event notification methods from AsyncSocket except
   // the following.
-
+  void prepareReadBuffer(void** buf, size_t* buflen) noexcept override;
   void handleRead() noexcept override;
   void handleWrite() noexcept override;
   void handleAccept() noexcept;
@@ -666,11 +779,16 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   // AsyncSocket calls this at the wrong time for SSL
   void handleInitialReadWrite() noexcept override {}
 
-  ssize_t performRead(void* buf, size_t buflen) override;
+  int interpretSSLError(int rc, int error);
+  ssize_t performRead(void** buf, size_t* buflen, size_t* offset) override;
   ssize_t performWrite(const iovec* vec, uint32_t count, WriteFlags flags,
                        uint32_t* countWritten, uint32_t* partialWritten)
     override;
 
+  ssize_t performWriteIovec(const iovec* vec, uint32_t count,
+                            WriteFlags flags, uint32_t* countWritten,
+                            uint32_t* partialWritten);
+
   // This virtual wrapper around SSL_write exists solely for testing/mockability
   virtual int sslWriteImpl(SSL *ssl, const void *buf, int n) {
     return SSL_write(ssl, buf, n);
@@ -700,14 +818,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   // Inherit error handling methods from AsyncSocket, plus the following.
   void failHandshake(const char* fn, const AsyncSocketException& ex);
 
+  void invokeHandshakeErr(const AsyncSocketException& ex);
   void invokeHandshakeCB();
 
   static void sslInfoCallback(const SSL *ssl, int type, int val);
 
-  static std::mutex mutex_;
-  static int sslExDataIndex_;
-  // Whether we've applied the TCP_CORK option to the socket
-  bool corked_{false};
   // SSL related members.
   bool server_{false};
   // Used to prevent client-initiated renegotiation.  Note that AsyncSSLSocket
@@ -730,6 +845,10 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   // Only one app EOR byte can be tracked.
   size_t appEorByteNo_{0};
 
+  // Try to avoid calling SSL_write() for buffers smaller than this.
+  // It doesn't take effect when it is 0.
+  size_t minWriteSize_{1500};
+
   // When openssl is about to sendmsg() across the minEorRawBytesNo_,
   // it will pass MSG_EOR to sendmsg().
   size_t minEorRawByteNo_{0};
@@ -744,7 +863,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   static int sslVerifyCallback(int preverifyOk, X509_STORE_CTX* ctx);
 
   bool parseClientHello_{false};
-  unique_ptr<ClientHelloInfo> clientHelloInfo_;
+  std::unique_ptr<ClientHelloInfo> clientHelloInfo_;
+
+  // Time taken to complete the ssl handshake.
+  std::chrono::steady_clock::time_point handshakeStartTime_;
+  std::chrono::steady_clock::time_point handshakeEndTime_;
 };
 
 } // namespace