/*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
noexcept = 0;
};
- class HandshakeTimeout : public AsyncTimeout {
+ class Timeout : public AsyncTimeout {
public:
- HandshakeTimeout(AsyncSSLSocket* sslSocket, EventBase* eventBase)
- : AsyncTimeout(eventBase)
- , sslSocket_(sslSocket) {}
+ Timeout(AsyncSSLSocket* sslSocket, EventBase* eventBase)
+ : AsyncTimeout(eventBase), sslSocket_(sslSocket) {}
- virtual void timeoutExpired() noexcept {
- sslSocket_->timeoutExpired();
+ bool scheduleTimeout(TimeoutManager::timeout_type timeout) {
+ timeout_ = timeout;
+ return AsyncTimeout::scheduleTimeout(timeout);
}
- private:
- AsyncSSLSocket* sslSocket_;
- };
+ bool scheduleTimeout(uint32_t timeoutMs) {
+ return scheduleTimeout(std::chrono::milliseconds{timeoutMs});
+ }
- // Timer for if we fallback from SSL connects to TCP connects
- class ConnectionTimeout : public AsyncTimeout {
- public:
- ConnectionTimeout(AsyncSSLSocket* sslSocket, EventBase* eventBase)
- : AsyncTimeout(eventBase), sslSocket_(sslSocket) {}
+ TimeoutManager::timeout_type getTimeout() {
+ return timeout_;
+ }
virtual void timeoutExpired() noexcept override {
- sslSocket_->timeoutExpired();
+ sslSocket_->timeoutExpired(timeout_);
}
private:
AsyncSSLSocket* sslSocket_;
+ TimeoutManager::timeout_type timeout_;
};
/**
virtual std::string getSecurityProtocol() const override { return "TLS"; }
- bool isEorTrackingEnabled() const override;
virtual void setEorTracking(bool track) override;
virtual size_t getRawBytesWritten() const override;
virtual size_t getRawBytesReceived() const override;
void enableClientHelloParsing();
+ void setPreReceivedData(std::unique_ptr<IOBuf> data);
+
/**
* Accept an SSL connection on the socket.
*
const folly::SocketAddress& bindAddr = anyAddress())
noexcept override;
+ /**
+ * A variant of connect that allows the caller to specify
+ * the timeout for the regular connect and the ssl connect
+ * separately.
+ * connectTimeout is specified as the time to establish a TCP
+ * connection.
+ * totalConnectTimeout defines the
+ * time it takes from starting the TCP connection to the time
+ * the ssl connection is established. The reason the timeout is
+ * defined this way is because user's rarely need to specify the SSL
+ * timeout independently of the connect timeout. It allows us to
+ * bound the time for a connect and SSL connection in
+ * a finer grained manner than if timeout was just defined
+ * independently for SSL.
+ */
+ virtual void connect(
+ ConnectCallback* callback,
+ const folly::SocketAddress& address,
+ std::chrono::milliseconds connectTimeout,
+ std::chrono::milliseconds totalConnectTimeout,
+ const OptionMap& options = emptyOptionMap,
+ const folly::SocketAddress& bindAddr = anyAddress()) noexcept;
+
using AsyncSocket::connect;
/**
void setServerName(std::string serverName) noexcept;
#endif // FOLLY_OPENSSL_HAS_SNI
- void timeoutExpired() noexcept;
+ void timeoutExpired(std::chrono::milliseconds timeout) noexcept;
/**
* Get the list of supported ciphers sent by the client in the client's
std::string getSSLAlertsReceived() const;
+ /*
+ * Save an optional alert message generated during certificate verify
+ */
+ void setSSLCertVerificationAlert(std::string alert);
+
+ std::string getSSLCertVerificationAlert() const;
+
/**
* Get the list of shared ciphers between the server and the client.
* Works well for only SSLv2, not so good for SSLv3 or TLSv1.
*/
void getSSLServerCiphers(std::string& serverCiphers) const;
+ /**
+ * Method to check if peer verfication is set.
+ *
+ * @return true if peer verification is required.
+ */
+ bool needsPeerVerification() const;
+
static int getSSLExDataIndex();
static AsyncSSLSocket* getFromSSL(const SSL *ssl);
static int bioWrite(BIO* b, const char* in, int inl);
+ static int bioRead(BIO* b, char* out, int outl);
void resetClientHelloParsing(SSL *ssl);
static void clientHelloParsingCallback(int write_p, int version,
int content_type, const void *buf, size_t len, SSL *ssl, void *arg);
return sessionResumptionAttempted_;
}
+ /**
+ * If the SSL socket was used to connect as well
+ * as establish an SSL connection, this gives the total
+ * timeout for the connect + SSL connection that was
+ * set.
+ */
+ std::chrono::milliseconds getTotalConnectTimeout() const {
+ return totalConnectTimeout_;
+ }
+
private:
void init();
HandshakeCB* handshakeCallback_{nullptr};
SSL* ssl_{nullptr};
SSL_SESSION *sslSession_{nullptr};
- HandshakeTimeout handshakeTimeout_;
- ConnectionTimeout connectionTimeout_;
+ Timeout handshakeTimeout_;
+ Timeout connectionTimeout_;
// whether the SSL session was resumed using session ID or not
bool sessionIDResumed_{false};
- // Whether to track EOR or not.
- bool trackEor_{false};
-
// The app byte num that we are tracking for the MSG_EOR
// Only one app EOR byte can be tracked.
size_t appEorByteNo_{0};
std::chrono::steady_clock::time_point handshakeEndTime_;
std::chrono::milliseconds handshakeConnectTimeout_{0};
bool sessionResumptionAttempted_{false};
+ std::chrono::milliseconds totalConnectTimeout_{0};
+
+ std::unique_ptr<IOBuf> preReceivedData_;
+ std::string sslVerificationAlert_;
};
} // namespace