Enable EOR flag configuration for folly::AsyncSocket.
[folly.git] / folly / io / async / AsyncSSLSocket.h
index 6f8b14644bba8fa7875d021bfafa01a20917e436..e9aca826c379cb788578460c6179f709b827aa6f 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -32,6 +32,7 @@
 #include <folly/Bits.h>
 #include <folly/io/IOBuf.h>
 #include <folly/io/Cursor.h>
+#include <folly/portability/OpenSSL.h>
 #include <folly/portability/Sockets.h>
 
 namespace folly {
@@ -202,7 +203,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   }
 
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
   /**
    * Create a client AsyncSSLSocket with tlsext_servername in
    * the Client Hello message.
@@ -242,7 +243,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
       new AsyncSSLSocket(ctx, evb, serverName, deferSecurityNegotiation),
       Destructor());
   }
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
 
   /**
    * TODO: implement support for SSL renegotiation.
@@ -271,12 +272,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   virtual std::string getSecurityProtocol() const override { return "TLS"; }
 
-  bool isEorTrackingEnabled() const override;
   virtual void setEorTracking(bool track) override;
   virtual size_t getRawBytesWritten() const override;
   virtual size_t getRawBytesReceived() const override;
   void enableClientHelloParsing();
 
+  void setPreReceivedData(std::unique_ptr<IOBuf> data);
+
   /**
    * Accept an SSL connection on the socket.
    *
@@ -292,9 +294,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    *                context by default, can be set explcitly to override the
    *                method in the context
    */
-  virtual void sslAccept(HandshakeCB* callback, uint32_t timeout = 0,
+  virtual void sslAccept(
+      HandshakeCB* callback,
+      std::chrono::milliseconds timeout = std::chrono::milliseconds::zero(),
       const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer =
-            folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
+          folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
 
   /**
    * Invoke SSL accept following an asynchronous session cache lookup
@@ -332,9 +336,11 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    *                SSL_VERIFY_PEER and invokes
    *                HandshakeCB::handshakeVer().
    */
-  virtual void sslConn(HandshakeCB *callback, uint64_t timeout = 0,
-            const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer =
-                  folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
+  virtual void sslConn(
+      HandshakeCB* callback,
+      std::chrono::milliseconds timeout = std::chrono::milliseconds::zero(),
+      const folly::SSLContext::SSLVerifyPeerEnum& verifyPeer =
+          folly::SSLContext::SSLVerifyPeerEnum::USE_CTX);
 
   enum SSLStateEnum {
     STATE_UNINIT,
@@ -514,7 +520,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   void detachSSLContext();
 #endif
 
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
   /**
    * Switch the SSLContext to continue the SSL handshake.
    * It can only be used in server mode.
@@ -537,7 +543,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
    * ClientHello message.
    */
   void setServerName(std::string serverName) noexcept;
-#endif
+#endif // FOLLY_OPENSSL_HAS_SNI
 
   void timeoutExpired() noexcept;
 
@@ -561,6 +567,12 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   std::string getSSLClientSigAlgs() const;
 
+  /**
+   * Get the list of versions in the supported versions extension (used to
+   * negotiate TLS 1.3).
+   */
+  std::string getSSLClientSupportedVersions() const;
+
   std::string getSSLAlertsReceived() const;
 
   /**
@@ -578,6 +590,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   static int getSSLExDataIndex();
   static AsyncSSLSocket* getFromSSL(const SSL *ssl);
   static int bioWrite(BIO* b, const char* in, int inl);
+  static int bioRead(BIO* b, char* out, int outl);
   void resetClientHelloParsing(SSL *ssl);
   static void clientHelloParsingCallback(int write_p, int version,
       int content_type, const void *buf, size_t len, SSL *ssl, void *arg);
@@ -651,9 +664,20 @@ class AsyncSSLSocket : public virtual AsyncSocket {
     return sessionResumptionAttempted_;
   }
 
+  /**
+   * Clears the ERR stack before invoking SSL methods.
+   * This is useful if unrelated code that runs in the same thread
+   * does not properly handle SSL error conditions, in which case
+   * it could cause SSL_* methods to fail with incorrect error codes.
+   */
+  void setClearOpenSSLErrors(bool clearErr) {
+    clearOpenSSLErrors_ = clearErr;
+  }
+
  private:
 
   void init();
+  void clearOpenSSLErrors();
 
  protected:
 
@@ -698,6 +722,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
 
   // This virtual wrapper around SSL_write exists solely for testing/mockability
   virtual int sslWriteImpl(SSL *ssl, const void *buf, int n) {
+    clearOpenSSLErrors();
     return SSL_write(ssl, buf, n);
   }
 
@@ -766,9 +791,6 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   // whether the SSL session was resumed using session ID or not
   bool sessionIDResumed_{false};
 
-  // Whether to track EOR or not.
-  bool trackEor_{false};
-
   // The app byte num that we are tracking for the MSG_EOR
   // Only one app EOR byte can be tracked.
   size_t appEorByteNo_{0};
@@ -780,7 +802,7 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   // When openssl is about to sendmsg() across the minEorRawBytesNo_,
   // it will pass MSG_EOR to sendmsg().
   size_t minEorRawByteNo_{0};
-#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && !defined(OPENSSL_NO_TLSEXT)
+#if FOLLY_OPENSSL_HAS_SNI
   std::shared_ptr<folly::SSLContext> handshakeCtx_;
   std::string tlsextHostname_;
 #endif
@@ -804,8 +826,13 @@ class AsyncSSLSocket : public virtual AsyncSocket {
   // Time taken to complete the ssl handshake.
   std::chrono::steady_clock::time_point handshakeStartTime_;
   std::chrono::steady_clock::time_point handshakeEndTime_;
-  uint64_t handshakeConnectTimeout_{0};
+  std::chrono::milliseconds handshakeConnectTimeout_{0};
   bool sessionResumptionAttempted_{false};
+
+  std::unique_ptr<IOBuf> preReceivedData_;
+  // Whether or not to clear the err stack before invocation of another
+  // SSL method
+  bool clearOpenSSLErrors_{false};
 };
 
 } // namespace