Enable EOR flag configuration for folly::AsyncSocket.
[folly.git] / folly / io / async / AsyncSSLSocket.h
index b4e4ca47b134a4099ff54e11c42d1cbe476b459e..e9aca826c379cb788578460c6179f709b827aa6f 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -32,6 +32,7 @@
 #include <folly/Bits.h>
 #include <folly/io/IOBuf.h>
 #include <folly/io/Cursor.h>
+#include <folly/portability/OpenSSL.h>
 #include <folly/portability/Sockets.h>
 
 namespace folly {
@@ -136,6 +137,20 @@ class AsyncSSLSocket : public virtual AsyncSocket {
     AsyncSSLSocket* sslSocket_;
   };
 
+  // Timer for if we fallback from SSL connects to TCP connects
+  class ConnectionTimeout : public AsyncTimeout {
+   public:
+    ConnectionTimeout(AsyncSSLSocket* sslSocket, EventBase* eventBase)
+        : AsyncTimeout(eventBase), sslSocket_(sslSocket) {}
+
+    virtual void timeoutExpired() noexcept override {
+      sslSocket_->timeoutExpired();
+    }
+
+   private:
+    AsyncSSLSocket* sslSocket_;
+  };
+
   /**
    * Create a client AsyncSSLSocket
    */
@@ -188,7 +203,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   }
 
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
   /**
    * Create a client AsyncSSLSocket with tlsext_servername in
    * the Client Hello message.
@@ -228,7 +243,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
       new AsyncSSLSocket(ctx, evb, serverName, deferSecurityNegotiation),
       Destructor());
   }
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
 
   /**
    * TODO: implement support for SSL renegotiation.
@@ -257,12 +272,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   virtual std::string getSecurityProtocol() const override { return "TLS"; }
 
-  bool isEorTrackingEnabled() const override;
   virtual void setEorTracking(bool track) override;
   virtual size_t getRawBytesWritten() const override;
   virtual size_t getRawBytesReceived() const override;
   void enableClientHelloParsing();
 
+  void setPreReceivedData(std::unique_ptr<IOBuf> data);
+
   /**
    * Accept an SSL connection on the socket.
    *
@@ -278,9 +294,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    *                context by default, can be set explcitly to override the
    *                method in the context
    */
-  virtual void sslAccept(HandshakeCB* callback, uint32_t timeout = 0,
+  virtual void sslAccept(
+      HandshakeCB* callback,
+      std::chrono::milliseconds timeout = std::chrono::milliseconds::zero(),
       const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer =
-            folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
+          folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
 
   /**
    * Invoke SSL accept following an asynchronous session cache lookup
@@ -318,9 +336,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    *                SSL_VERIFY_PEER and invokes
    *                HandshakeCB::handshakeVer().
    */
-  virtual void sslConn(HandshakeCB *callback, uint64_t timeout = 0,
-            const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer =
-                  folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
+  virtual void sslConn(
+      HandshakeCB* callback,
+      std::chrono::milliseconds timeout = std::chrono::milliseconds::zero(),
+      const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer =
+          folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
 
   enum SSLStateEnum {
     STATE_UNINIT,
@@ -457,14 +477,21 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    */
   int getSSLCertSize() const;
 
+  /**
+   * Get the certificate used for this SSL connection. May be null
+   */
+  virtual const X509* getSelfCert() const override;
+
   virtual void attachEventBase(EventBase* eventBase) override {
     AsyncSocket::attachEventBase(eventBase);
     handshakeTimeout_.attachEventBase(eventBase);
+    connectionTimeout_.attachEventBase(eventBase);
   }
 
   virtual void detachEventBase() override {
     AsyncSocket::detachEventBase();
     handshakeTimeout_.detachEventBase();
+    connectionTimeout_.detachEventBase();
   }
 
   virtual bool isDetachable() const override {
@@ -493,7 +520,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   void detachSSLContext();
 #endif
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
   /**
    * Switch the SSLContext to continue the SSL handshake.
    * It can only be used in server mode.
@@ -516,7 +543,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * ClientHello message.
    */
   void setServerName(std::string serverName) noexcept;
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
 
   void timeoutExpired() noexcept;
 
@@ -526,133 +553,44 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    */
   void getSSLClientCiphers(
       std::string& clientCiphers,
-      bool convertToString = true) const {
-    std::stringstream ciphersStream;
-    std::string cipherName;
-
-    if (parseClientHello_ == false
-        || clientHelloInfo_->clientHelloCipherSuites_.empty()) {
-      clientCiphers = "";
-      return;
-    }
-
-    for (auto originalCipherCode : clientHelloInfo_->clientHelloCipherSuites_)
-    {
-      const SSL_CIPHER* cipher = nullptr;
-      if (convertToString) {
-        // OpenSSL expects code as a big endian char array
-        auto cipherCode = htons(originalCipherCode);
-
-#if defined(SSL_OP_NO_TLSv1_2)
-        cipher =
-            TLSv1_2_method()->get_cipher_by_char((unsigned char*)&cipherCode);
-#elif defined(SSL_OP_NO_TLSv1_1)
-        cipher =
-            TLSv1_1_method()->get_cipher_by_char((unsigned char*)&cipherCode);
-#elif defined(SSL_OP_NO_TLSv1)
-        cipher =
-            TLSv1_method()->get_cipher_by_char((unsigned char*)&cipherCode);
-#else
-        cipher =
-            SSLv3_method()->get_cipher_by_char((unsigned char*)&cipherCode);
-#endif
-      }
-
-      if (cipher == nullptr) {
-        ciphersStream << std::setfill('0') << std::setw(4) << std::hex
-                      << originalCipherCode << ":";
-      } else {
-        ciphersStream << SSL_CIPHER_get_name(cipher) << ":";
-      }
-    }
-
-    clientCiphers = ciphersStream.str();
-    clientCiphers.erase(clientCiphers.end() - 1);
-  }
+      bool convertToString = true) const;
 
   /**
    * Get the list of compression methods sent by the client in TLS Hello.
    */
-  std::string getSSLClientComprMethods() const {
-    if (!parseClientHello_) {
-      return "";
-    }
-    return folly::join(":", clientHelloInfo_->clientHelloCompressionMethods_);
-  }
+  std::string getSSLClientComprMethods() const;
 
   /**
    * Get the list of TLS extensions sent by the client in the TLS Hello.
    */
-  std::string getSSLClientExts() const {
-    if (!parseClientHello_) {
-      return "";
-    }
-    return folly::join(":", clientHelloInfo_->clientHelloExtensions_);
-  }
-
-  std::string getSSLClientSigAlgs() const {
-    if (!parseClientHello_) {
-      return "";
-    }
+  std::string getSSLClientExts() const;
 
-    std::string sigAlgs;
-    sigAlgs.reserve(clientHelloInfo_->clientHelloSigAlgs_.size() * 4);
-    for (size_t i = 0; i < clientHelloInfo_->clientHelloSigAlgs_.size(); i++) {
-      if (i) {
-        sigAlgs.push_back(':');
-      }
-      sigAlgs.append(folly::to<std::string>(
-          clientHelloInfo_->clientHelloSigAlgs_[i].first));
-      sigAlgs.push_back(',');
-      sigAlgs.append(folly::to<std::string>(
-          clientHelloInfo_->clientHelloSigAlgs_[i].second));
-    }
-
-    return sigAlgs;
-  }
-
-  std::string getSSLAlertsReceived() const {
-    std::string ret;
+  std::string getSSLClientSigAlgs() const;
 
-    for (const auto& alert : alertsReceived_) {
-      if (!ret.empty()) {
-        ret.append(",");
-      }
-      ret.append(folly::to<std::string>(alert.first, ": ", alert.second));
-    }
+  /**
+   * Get the list of versions in the supported versions extension (used to
+   * negotiate TLS 1.3).
+   */
+  std::string getSSLClientSupportedVersions() const;
 
-    return ret;
-  }
+  std::string getSSLAlertsReceived() const;
 
   /**
    * Get the list of shared ciphers between the server and the client.
    * Works well for only SSLv2, not so good for SSLv3 or TLSv1.
    */
-  void getSSLSharedCiphers(std::string& sharedCiphers) const {
-    char ciphersBuffer[1024];
-    ciphersBuffer[0] = '\0';
-    SSL_get_shared_ciphers(ssl_, ciphersBuffer, sizeof(ciphersBuffer) - 1);
-    sharedCiphers = ciphersBuffer;
-  }
+  void getSSLSharedCiphers(std::string& sharedCiphers) const;
 
   /**
    * Get the list of ciphers supported by the server in the server's
    * preference order.
    */
-  void getSSLServerCiphers(std::string& serverCiphers) const {
-    serverCiphers = SSL_get_cipher_list(ssl_, 0);
-    int i = 1;
-    const char *cipher;
-    while ((cipher = SSL_get_cipher_list(ssl_, i)) != nullptr) {
-      serverCiphers.append(":");
-      serverCiphers.append(cipher);
-      i++;
-    }
-  }
+  void getSSLServerCiphers(std::string& serverCiphers) const;
 
   static int getSSLExDataIndex();
   static AsyncSSLSocket* getFromSSL(const SSL *ssl);
-  static int eorAwareBioWrite(BIO *b, const char *in, int inl);
+  static int bioWrite(BIO* b, const char* in, int inl);
+  static int bioRead(BIO* b, char* out, int outl);
   void resetClientHelloParsing(SSL *ssl);
   static void clientHelloParsingCallback(int write_p, int version,
       int content_type, const void *buf, size_t len, SSL *ssl, void *arg);
@@ -714,9 +652,32 @@ class AsyncSSLSocket : public virtual AsyncSocket {
     serviceIdentity_ = std::move(serviceIdentity);
   }
 
+  void setCertCacheHit(bool hit) {
+    certCacheHit_ = hit;
+  }
+
+  bool getCertCacheHit() const {
+    return certCacheHit_;
+  }
+
+  bool sessionResumptionAttempted() const {
+    return sessionResumptionAttempted_;
+  }
+
+  /**
+   * Clears the ERR stack before invoking SSL methods.
+   * This is useful if unrelated code that runs in the same thread
+   * does not properly handle SSL error conditions, in which case
+   * it could cause SSL_* methods to fail with incorrect error codes.
+   */
+  void setClearOpenSSLErrors(bool clearErr) {
+    clearOpenSSLErrors_ = clearErr;
+  }
+
  private:
 
   void init();
+  void clearOpenSSLErrors();
 
  protected:
 
@@ -731,7 +692,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   // Inherit event notification methods from AsyncSocket except
   // the following.
-  void prepareReadBuffer(void** buf, size_t* buflen) noexcept override;
+  void prepareReadBuffer(void** buf, size_t* buflen) override;
   void handleRead() noexcept override;
   void handleWrite() noexcept override;
   void handleAccept() noexcept;
@@ -761,6 +722,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   // This virtual wrapper around SSL_write exists solely for testing/mockability
   virtual int sslWriteImpl(SSL *ssl, const void *buf, int n) {
+    clearOpenSSLErrors();
     return SSL_write(ssl, buf, n);
   }
 
@@ -774,6 +736,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    */
   void applyVerificationOptions(SSL * ssl);
 
+  /**
+   * Sets up SSL with a custom write bio which intercepts all writes.
+   *
+   * @return true, if succeeds and false if there is an error creating the bio.
+   */
+  bool setupSSLBio();
+
   /**
    * A SSL_write wrapper that understand EOR
    *
@@ -791,12 +760,18 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   void invokeHandshakeErr(const AsyncSocketException& ex);
   void invokeHandshakeCB();
 
+  void invokeConnectErr(const AsyncSocketException& ex) override;
+  void invokeConnectSuccess() override;
+  void scheduleConnectTimeout() override;
+
   void cacheLocalPeerAddr();
 
+  void startSSLConnect();
+
   static void sslInfoCallback(const SSL *ssl, int type, int val);
 
-  // Whether we've applied the TCP_CORK option to the socket
-  bool corked_{false};
+  // Whether the current write to the socket should use MSG_MORE.
+  bool corkCurrentWrite_{false};
   // SSL related members.
   bool server_{false};
   // Used to prevent client-initiated renegotiation.  Note that AsyncSSLSocket
@@ -812,6 +787,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   SSL* ssl_{nullptr};
   SSL_SESSION *sslSession_{nullptr};
   HandshakeTimeout handshakeTimeout_;
+  ConnectionTimeout connectionTimeout_;
   // whether the SSL session was resumed using session ID or not
   bool sessionIDResumed_{false};
 
@@ -826,7 +802,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   // When openssl is about to sendmsg() across the minEorRawBytesNo_,
   // it will pass MSG_EOR to sendmsg().
   size_t minEorRawByteNo_{0};
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
   std::shared_ptr<folly::SSLContext> handshakeCtx_;
   std::string tlsextHostname_;
 #endif
@@ -843,12 +819,20 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   bool parseClientHello_{false};
   bool cacheAddrOnFailure_{false};
   bool bufferMovableEnabled_{false};
+  bool certCacheHit_{false};
   std::unique_ptr<ssl::ClientHelloInfo> clientHelloInfo_;
   std::vector<std::pair<char, StringPiece>> alertsReceived_;
 
   // Time taken to complete the ssl handshake.
   std::chrono::steady_clock::time_point handshakeStartTime_;
   std::chrono::steady_clock::time_point handshakeEndTime_;
+  std::chrono::milliseconds handshakeConnectTimeout_{0};
+  bool sessionResumptionAttempted_{false};
+
+  std::unique_ptr<IOBuf> preReceivedData_;
+  // Whether or not to clear the err stack before invocation of another
+  // SSL method
+  bool clearOpenSSLErrors_{false};
 };
 
 } // namespace