Implement LoopKeepAlive for EventBase
[folly.git] / folly / io / async / AsyncSSLSocket.h
index b203f13ad5ffcc42ebe179d6159cfb4850156102..af3fd06b8c1e6a7f9d224a534dd479320c2923eb 100644 (file)
 #include <folly/Optional.h>
 #include <folly/String.h>
 #include <folly/io/async/AsyncSocket.h>
-#include <folly/io/async/SSLContext.h>
 #include <folly/io/async/AsyncTimeout.h>
-#include <folly/io/async/OpenSSLPtrTypes.h>
+#include <folly/io/async/SSLContext.h>
 #include <folly/io/async/TimeoutManager.h>
+#include <folly/io/async/ssl/OpenSSLPtrTypes.h>
+#include <folly/io/async/ssl/OpenSSLUtils.h>
+#include <folly/io/async/ssl/SSLErrors.h>
+#include <folly/io/async/ssl/TLSDefinitions.h>
 
 #include <folly/Bits.h>
 #include <folly/io/IOBuf.h>
 
 namespace folly {
 
-class SSLException: public folly::AsyncSocketException {
- public:
-  SSLException(int sslError, int errno_copy);
-
-  int getSSLError() const { return error_; }
-
- protected:
-  int error_;
-  char msg_[256];
-};
-
 /**
  * A class for performing asynchronous I/O on an SSL connection.
  *
@@ -144,18 +136,6 @@ class AsyncSSLSocket : public virtual AsyncSocket {
     AsyncSSLSocket* sslSocket_;
   };
 
-
-  /**
-   * These are passed to the application via errno, packed in an SSL err which
-   * are outside the valid errno range.  The values are chosen to be unique
-   * against values in ssl.h
-   */
-  enum SSLError {
-    SSL_CLIENT_RENEGOTIATION_ATTEMPT = 900,
-    SSL_INVALID_RENEGOTIATION = 901,
-    SSL_EARLY_WRITE = 902
-  };
-
   /**
    * Create a client AsyncSSLSocket
    */
@@ -366,6 +346,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    */
   SSL_SESSION *getSSLSession();
 
+  /**
+   * Get a handle to the SSL struct.
+   */
+  const SSL* getSSL() const;
+
   /**
    * Set the SSL session to be used during sslConn.  AsyncSSLSocket will
    * hold a reference to the session until it is destroyed or released by the
@@ -539,7 +524,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * Get the list of supported ciphers sent by the client in the client's
    * preference order.
    */
-  void getSSLClientCiphers(std::string& clientCiphers) const {
+  void getSSLClientCiphers(
+      std::string& clientCiphers,
+      bool convertToString = true) const {
     std::stringstream ciphersStream;
     std::string cipherName;
 
@@ -551,22 +538,25 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
     for (auto originalCipherCode : clientHelloInfo_->clientHelloCipherSuites_)
     {
-      // OpenSSL expects code as a big endian char array
-      auto cipherCode = htons(originalCipherCode);
+      const SSL_CIPHER* cipher = nullptr;
+      if (convertToString) {
+        // OpenSSL expects code as a big endian char array
+        auto cipherCode = htons(originalCipherCode);
 
 #if defined(SSL_OP_NO_TLSv1_2)
-      const SSL_CIPHER* cipher =
-          TLSv1_2_method()->get_cipher_by_char((unsigned char*)&cipherCode);
+        cipher =
+            TLSv1_2_method()->get_cipher_by_char((unsigned char*)&cipherCode);
 #elif defined(SSL_OP_NO_TLSv1_1)
-      const SSL_CIPHER* cipher =
-          TLSv1_1_method()->get_cipher_by_char((unsigned char*)&cipherCode);
+        cipher =
+            TLSv1_1_method()->get_cipher_by_char((unsigned char*)&cipherCode);
 #elif defined(SSL_OP_NO_TLSv1)
-      const SSL_CIPHER* cipher =
-          TLSv1_method()->get_cipher_by_char((unsigned char*)&cipherCode);
+        cipher =
+            TLSv1_method()->get_cipher_by_char((unsigned char*)&cipherCode);
 #else
-      const SSL_CIPHER* cipher =
-          SSLv3_method()->get_cipher_by_char((unsigned char*)&cipherCode);
+        cipher =
+            SSLv3_method()->get_cipher_by_char((unsigned char*)&cipherCode);
 #endif
+      }
 
       if (cipher == nullptr) {
         ciphersStream << std::setfill('0') << std::setw(4) << std::hex
@@ -621,6 +611,19 @@ class AsyncSSLSocket : public virtual AsyncSocket {
     return sigAlgs;
   }
 
+  std::string getSSLAlertsReceived() const {
+    std::string ret;
+
+    for (const auto& alert : alertsReceived_) {
+      if (!ret.empty()) {
+        ret.append(",");
+      }
+      ret.append(folly::to<std::string>(alert.first, ": ", alert.second));
+    }
+
+    return ret;
+  }
+
   /**
    * Get the list of shared ciphers between the server and the client.
    * Works well for only SSLv2, not so good for SSLv3 or TLSv1.
@@ -655,68 +658,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
       int content_type, const void *buf, size_t len, SSL *ssl, void *arg);
   static const char* getSSLServerNameFromSSL(SSL* ssl);
 
-  // http://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml
-  enum class TLSExtension: uint16_t {
-    SERVER_NAME = 0,
-    MAX_FRAGMENT_LENGTH = 1,
-    CLIENT_CERTIFICATE_URL = 2,
-    TRUSTED_CA_KEYS = 3,
-    TRUNCATED_HMAC = 4,
-    STATUS_REQUEST = 5,
-    USER_MAPPING = 6,
-    CLIENT_AUTHZ = 7,
-    SERVER_AUTHZ = 8,
-    CERT_TYPE = 9,
-    SUPPORTED_GROUPS = 10,
-    EC_POINT_FORMATS = 11,
-    SRP = 12,
-    SIGNATURE_ALGORITHMS = 13,
-    USE_SRTP = 14,
-    HEARTBEAT = 15,
-    APPLICATION_LAYER_PROTOCOL_NEGOTIATION = 16,
-    STATUS_REQUEST_V2 = 17,
-    SIGNED_CERTIFICATE_TIMESTAMP = 18,
-    CLIENT_CERTIFICATE_TYPE = 19,
-    SERVER_CERTIFICATE_TYPE = 20,
-    PADDING = 21,
-    ENCRYPT_THEN_MAC = 22,
-    EXTENDED_MASTER_SECRET = 23,
-    SESSION_TICKET = 35,
-    RENEGOTIATION_INFO = 65281
-  };
-
-  // http://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-18
-  enum class HashAlgorithm: uint8_t {
-    NONE = 0,
-    MD5 = 1,
-    SHA1 = 2,
-    SHA224 = 3,
-    SHA256 = 4,
-    SHA384 = 5,
-    SHA512 = 6
-  };
-
-  // http://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-16
-  enum class SignatureAlgorithm: uint8_t {
-    ANONYMOUS = 0,
-    RSA = 1,
-    DSA = 2,
-    ECDSA = 3
-  };
-
-  struct ClientHelloInfo {
-    folly::IOBufQueue clientHelloBuf_;
-    uint8_t clientHelloMajorVersion_;
-    uint8_t clientHelloMinorVersion_;
-    std::vector<uint16_t> clientHelloCipherSuites_;
-    std::vector<uint8_t> clientHelloCompressionMethods_;
-    std::vector<TLSExtension> clientHelloExtensions_;
-    std::vector<
-      std::pair<HashAlgorithm, SignatureAlgorithm>> clientHelloSigAlgs_;
-  };
-
   // For unit-tests
-  ClientHelloInfo* getClientHelloInfo() const {
+  ssl::ClientHelloInfo* getClientHelloInfo() const {
     return clientHelloInfo_.get();
   }
 
@@ -737,16 +680,23 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   void setReadCB(ReadCallback* callback) override;
 
+  /**
+   * Tries to enable the buffer movable experimental feature in openssl.
+   * This is not guaranteed to succeed in case openssl does not have
+   * the experimental feature built in.
+   */
+  void setBufferMovableEnabled(bool enabled);
+
   /**
    * Returns the peer certificate, or nullptr if no peer certificate received.
    */
-  virtual X509_UniquePtr getPeerCert() const override {
+  virtual ssl::X509UniquePtr getPeerCert() const override {
     if (!ssl_) {
       return nullptr;
     }
 
     X509* cert = SSL_get_peer_certificate(ssl_);
-    return X509_UniquePtr(cert);
+    return ssl::X509UniquePtr(cert);
   }
 
   /**
@@ -758,6 +708,12 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    */
   void forceCacheAddrOnFailure(bool force) { cacheAddrOnFailure_ = force; }
 
+  const std::string& getServiceIdentity() const { return serviceIdentity_; }
+
+  void setServiceIdentity(std::string serviceIdentity) {
+    serviceIdentity_ = std::move(serviceIdentity);
+  }
+
  private:
 
   void init();
@@ -782,17 +738,22 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   void handleConnect() noexcept override;
 
   void invalidState(HandshakeCB* callback);
-  bool willBlock(int ret, int *errorOut) noexcept;
+  bool willBlock(int ret,
+                 int* sslErrorOut,
+                 unsigned long* errErrorOut) noexcept;
 
   virtual void checkForImmediateRead() noexcept override;
   // AsyncSocket calls this at the wrong time for SSL
   void handleInitialReadWrite() noexcept override {}
 
-  int interpretSSLError(int rc, int error);
-  ssize_t performRead(void** buf, size_t* buflen, size_t* offset) override;
-  ssize_t performWrite(const iovec* vec, uint32_t count, WriteFlags flags,
-                       uint32_t* countWritten, uint32_t* partialWritten)
-    override;
+  WriteResult interpretSSLError(int rc, int error);
+  ReadResult performRead(void** buf, size_t* buflen, size_t* offset) override;
+  WriteResult performWrite(
+      const iovec* vec,
+      uint32_t count,
+      WriteFlags flags,
+      uint32_t* countWritten,
+      uint32_t* partialWritten) override;
 
   ssize_t performWriteIovec(const iovec* vec, uint32_t count,
                             WriteFlags flags, uint32_t* countWritten,
@@ -834,6 +795,8 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   static void sslInfoCallback(const SSL *ssl, int type, int val);
 
+  // Whether we've applied the TCP_CORK option to the socket
+  bool corked_{false};
   // SSL related members.
   bool server_{false};
   // Used to prevent client-initiated renegotiation.  Note that AsyncSSLSocket
@@ -867,6 +830,10 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   std::shared_ptr<folly::SSLContext> handshakeCtx_;
   std::string tlsextHostname_;
 #endif
+
+  // a service identity that this socket/connection is associated with
+  std::string serviceIdentity_;
+
   folly::SSLContext::SSLVerifyPeerEnum
     verifyPeer_{folly::SSLContext::SSLVerifyPeerEnum::USE_CTX};
 
@@ -875,7 +842,9 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   bool parseClientHello_{false};
   bool cacheAddrOnFailure_{false};
-  std::unique_ptr<ClientHelloInfo> clientHelloInfo_;
+  bool bufferMovableEnabled_{false};
+  std::unique_ptr<ssl::ClientHelloInfo> clientHelloInfo_;
+  std::vector<std::pair<char, StringPiece>> alertsReceived_;
 
   // Time taken to complete the ssl handshake.
   std::chrono::steady_clock::time_point handshakeStartTime_;