From 5575da1ffdf860588f8a1da9557f1fbc6d86b649 Mon Sep 17 00:00:00 2001 From: James Sedgwick Date: Tue, 26 May 2015 15:14:41 -0700 Subject: [PATCH] copy wangle back into folly Summary: copy everything but example/ Test Plan: fbconfig -r folly/wangle && fbmake runtests Reviewed By: hans@fb.com Subscribers: fugalh, ps, bmatheny, folly-diffs@, jsedgwick, yfeldblum, markdrayton, chalfant FB internal diff: D2100811 Tasks: 5802833 Signature: t1:2100811:1432678173:6c336fe53aa223993f6f82de4ac91b3c19beacf1 --- folly/wangle/acceptor/Acceptor.cpp | 456 ++++++++++++ folly/wangle/acceptor/Acceptor.h | 356 ++++++++++ folly/wangle/acceptor/ConnectionCounter.h | 54 ++ folly/wangle/acceptor/ConnectionManager.cpp | 233 +++++++ folly/wangle/acceptor/ConnectionManager.h | 245 +++++++ folly/wangle/acceptor/DomainNameMisc.h | 68 ++ .../wangle/acceptor/LoadShedConfiguration.cpp | 43 ++ folly/wangle/acceptor/LoadShedConfiguration.h | 118 ++++ folly/wangle/acceptor/ManagedConnection.cpp | 64 ++ folly/wangle/acceptor/ManagedConnection.h | 140 ++++ folly/wangle/acceptor/NetworkAddress.h | 60 ++ folly/wangle/acceptor/ServerSocketConfig.h | 128 ++++ folly/wangle/acceptor/SocketOptions.cpp | 38 + folly/wangle/acceptor/SocketOptions.h | 24 + folly/wangle/acceptor/TransportInfo.cpp | 65 ++ folly/wangle/acceptor/TransportInfo.h | 298 ++++++++ folly/wangle/bootstrap/BootstrapTest.cpp | 365 ++++++++++ folly/wangle/bootstrap/ClientBootstrap.h | 109 +++ folly/wangle/bootstrap/ServerBootstrap-inl.h | 198 ++++++ folly/wangle/bootstrap/ServerBootstrap.cpp | 62 ++ folly/wangle/bootstrap/ServerBootstrap.h | 351 ++++++++++ folly/wangle/bootstrap/ServerSocketFactory.h | 122 ++++ folly/wangle/channel/AsyncSocketHandler.h | 164 +++++ folly/wangle/channel/EventBaseHandler.h | 45 ++ folly/wangle/channel/Handler.h | 173 +++++ folly/wangle/channel/HandlerContext-inl.h | 447 ++++++++++++ folly/wangle/channel/HandlerContext.h | 108 +++ folly/wangle/channel/OutputBufferingHandler.h | 84 +++ folly/wangle/channel/Pipeline-inl.h | 267 +++++++ folly/wangle/channel/Pipeline.h | 182 +++++ folly/wangle/channel/StaticPipeline.h | 137 ++++ folly/wangle/channel/test/MockHandler.h | 75 ++ .../test/OutputBufferingHandlerTest.cpp | 65 ++ folly/wangle/channel/test/PipelineTest.cpp | 306 ++++++++ folly/wangle/codec/ByteToMessageCodec.cpp | 33 + folly/wangle/codec/ByteToMessageCodec.h | 52 ++ folly/wangle/codec/CodecTest.cpp | 637 +++++++++++++++++ folly/wangle/codec/FixedLengthFrameDecoder.h | 59 ++ .../codec/LengthFieldBasedFrameDecoder.cpp | 127 ++++ .../codec/LengthFieldBasedFrameDecoder.h | 209 ++++++ folly/wangle/codec/LengthFieldPrepender.cpp | 99 +++ folly/wangle/codec/LengthFieldPrepender.h | 67 ++ folly/wangle/codec/LineBasedFrameDecoder.cpp | 103 +++ folly/wangle/codec/LineBasedFrameDecoder.h | 59 ++ folly/wangle/codec/README.md | 5 + folly/wangle/codec/StringCodec.h | 46 ++ folly/wangle/concurrent/BlockingQueue.h | 38 + .../concurrent/CPUThreadPoolExecutor.cpp | 152 ++++ .../wangle/concurrent/CPUThreadPoolExecutor.h | 99 +++ folly/wangle/concurrent/Codel.cpp | 91 +++ folly/wangle/concurrent/Codel.h | 66 ++ folly/wangle/concurrent/FiberIOExecutor.h | 49 ++ folly/wangle/concurrent/FutureExecutor.h | 79 +++ folly/wangle/concurrent/GlobalExecutor.cpp | 120 ++++ folly/wangle/concurrent/GlobalExecutor.h | 46 ++ folly/wangle/concurrent/IOExecutor.h | 47 ++ .../concurrent/IOThreadPoolExecutor.cpp | 188 +++++ .../wangle/concurrent/IOThreadPoolExecutor.h | 71 ++ folly/wangle/concurrent/LifoSemMPMCQueue.h | 57 ++ folly/wangle/concurrent/NamedThreadFactory.h | 56 ++ .../concurrent/PriorityLifoSemMPMCQueue.h | 80 +++ folly/wangle/concurrent/ThreadFactory.h | 30 + .../wangle/concurrent/ThreadPoolExecutor.cpp | 202 ++++++ folly/wangle/concurrent/ThreadPoolExecutor.h | 234 +++++++ folly/wangle/concurrent/test/CodelTest.cpp | 38 + .../concurrent/test/GlobalExecutorTest.cpp | 85 +++ .../test/ThreadPoolExecutorTest.cpp | 395 +++++++++++ folly/wangle/rx/Dummy.cpp | 19 + folly/wangle/rx/Observable.h | 285 ++++++++ folly/wangle/rx/Observer.h | 113 +++ folly/wangle/rx/README.md | 36 + folly/wangle/rx/Subject.h | 47 ++ folly/wangle/rx/Subscription.h | 70 ++ folly/wangle/rx/test/RxBenchmark.cpp | 155 +++++ folly/wangle/rx/test/RxTest.cpp | 195 ++++++ folly/wangle/rx/types.h | 35 + folly/wangle/service/ClientDispatcher.h | 69 ++ folly/wangle/service/ServerDispatcher.h | 46 ++ folly/wangle/service/Service.h | 154 +++++ folly/wangle/service/ServiceTest.cpp | 258 +++++++ folly/wangle/ssl/ClientHelloExtStats.h | 24 + folly/wangle/ssl/DHParam.h | 53 ++ folly/wangle/ssl/PasswordInFile.cpp | 31 + folly/wangle/ssl/PasswordInFile.h | 38 + folly/wangle/ssl/SSLCacheOptions.h | 23 + folly/wangle/ssl/SSLCacheProvider.h | 69 ++ folly/wangle/ssl/SSLContextConfig.h | 95 +++ folly/wangle/ssl/SSLContextManager.cpp | 651 ++++++++++++++++++ folly/wangle/ssl/SSLContextManager.h | 182 +++++ folly/wangle/ssl/SSLSessionCacheManager.cpp | 354 ++++++++++ folly/wangle/ssl/SSLSessionCacheManager.h | 292 ++++++++ folly/wangle/ssl/SSLStats.h | 42 ++ folly/wangle/ssl/SSLUtil.cpp | 76 ++ folly/wangle/ssl/SSLUtil.h | 102 +++ folly/wangle/ssl/TLSTicketKeyManager.cpp | 305 ++++++++ folly/wangle/ssl/TLSTicketKeyManager.h | 198 ++++++ folly/wangle/ssl/TLSTicketKeySeeds.h | 20 + folly/wangle/ssl/test/SSLCacheTest.cpp | 272 ++++++++ .../wangle/ssl/test/SSLContextManagerTest.cpp | 87 +++ 99 files changed, 13765 insertions(+) create mode 100644 folly/wangle/acceptor/Acceptor.cpp create mode 100644 folly/wangle/acceptor/Acceptor.h create mode 100644 folly/wangle/acceptor/ConnectionCounter.h create mode 100644 folly/wangle/acceptor/ConnectionManager.cpp create mode 100644 folly/wangle/acceptor/ConnectionManager.h create mode 100644 folly/wangle/acceptor/DomainNameMisc.h create mode 100644 folly/wangle/acceptor/LoadShedConfiguration.cpp create mode 100644 folly/wangle/acceptor/LoadShedConfiguration.h create mode 100644 folly/wangle/acceptor/ManagedConnection.cpp create mode 100644 folly/wangle/acceptor/ManagedConnection.h create mode 100644 folly/wangle/acceptor/NetworkAddress.h create mode 100644 folly/wangle/acceptor/ServerSocketConfig.h create mode 100644 folly/wangle/acceptor/SocketOptions.cpp create mode 100644 folly/wangle/acceptor/SocketOptions.h create mode 100644 folly/wangle/acceptor/TransportInfo.cpp create mode 100644 folly/wangle/acceptor/TransportInfo.h create mode 100644 folly/wangle/bootstrap/BootstrapTest.cpp create mode 100644 folly/wangle/bootstrap/ClientBootstrap.h create mode 100644 folly/wangle/bootstrap/ServerBootstrap-inl.h create mode 100644 folly/wangle/bootstrap/ServerBootstrap.cpp create mode 100644 folly/wangle/bootstrap/ServerBootstrap.h create mode 100644 folly/wangle/bootstrap/ServerSocketFactory.h create mode 100644 folly/wangle/channel/AsyncSocketHandler.h create mode 100644 folly/wangle/channel/EventBaseHandler.h create mode 100644 folly/wangle/channel/Handler.h create mode 100644 folly/wangle/channel/HandlerContext-inl.h create mode 100644 folly/wangle/channel/HandlerContext.h create mode 100644 folly/wangle/channel/OutputBufferingHandler.h create mode 100644 folly/wangle/channel/Pipeline-inl.h create mode 100644 folly/wangle/channel/Pipeline.h create mode 100644 folly/wangle/channel/StaticPipeline.h create mode 100644 folly/wangle/channel/test/MockHandler.h create mode 100644 folly/wangle/channel/test/OutputBufferingHandlerTest.cpp create mode 100644 folly/wangle/channel/test/PipelineTest.cpp create mode 100644 folly/wangle/codec/ByteToMessageCodec.cpp create mode 100644 folly/wangle/codec/ByteToMessageCodec.h create mode 100644 folly/wangle/codec/CodecTest.cpp create mode 100644 folly/wangle/codec/FixedLengthFrameDecoder.h create mode 100644 folly/wangle/codec/LengthFieldBasedFrameDecoder.cpp create mode 100644 folly/wangle/codec/LengthFieldBasedFrameDecoder.h create mode 100644 folly/wangle/codec/LengthFieldPrepender.cpp create mode 100644 folly/wangle/codec/LengthFieldPrepender.h create mode 100644 folly/wangle/codec/LineBasedFrameDecoder.cpp create mode 100644 folly/wangle/codec/LineBasedFrameDecoder.h create mode 100644 folly/wangle/codec/README.md create mode 100644 folly/wangle/codec/StringCodec.h create mode 100644 folly/wangle/concurrent/BlockingQueue.h create mode 100644 folly/wangle/concurrent/CPUThreadPoolExecutor.cpp create mode 100644 folly/wangle/concurrent/CPUThreadPoolExecutor.h create mode 100644 folly/wangle/concurrent/Codel.cpp create mode 100644 folly/wangle/concurrent/Codel.h create mode 100644 folly/wangle/concurrent/FiberIOExecutor.h create mode 100644 folly/wangle/concurrent/FutureExecutor.h create mode 100644 folly/wangle/concurrent/GlobalExecutor.cpp create mode 100644 folly/wangle/concurrent/GlobalExecutor.h create mode 100644 folly/wangle/concurrent/IOExecutor.h create mode 100644 folly/wangle/concurrent/IOThreadPoolExecutor.cpp create mode 100644 folly/wangle/concurrent/IOThreadPoolExecutor.h create mode 100644 folly/wangle/concurrent/LifoSemMPMCQueue.h create mode 100644 folly/wangle/concurrent/NamedThreadFactory.h create mode 100644 folly/wangle/concurrent/PriorityLifoSemMPMCQueue.h create mode 100644 folly/wangle/concurrent/ThreadFactory.h create mode 100644 folly/wangle/concurrent/ThreadPoolExecutor.cpp create mode 100644 folly/wangle/concurrent/ThreadPoolExecutor.h create mode 100644 folly/wangle/concurrent/test/CodelTest.cpp create mode 100644 folly/wangle/concurrent/test/GlobalExecutorTest.cpp create mode 100644 folly/wangle/concurrent/test/ThreadPoolExecutorTest.cpp create mode 100644 folly/wangle/rx/Dummy.cpp create mode 100644 folly/wangle/rx/Observable.h create mode 100644 folly/wangle/rx/Observer.h create mode 100644 folly/wangle/rx/README.md create mode 100644 folly/wangle/rx/Subject.h create mode 100644 folly/wangle/rx/Subscription.h create mode 100644 folly/wangle/rx/test/RxBenchmark.cpp create mode 100644 folly/wangle/rx/test/RxTest.cpp create mode 100644 folly/wangle/rx/types.h create mode 100644 folly/wangle/service/ClientDispatcher.h create mode 100644 folly/wangle/service/ServerDispatcher.h create mode 100644 folly/wangle/service/Service.h create mode 100644 folly/wangle/service/ServiceTest.cpp create mode 100644 folly/wangle/ssl/ClientHelloExtStats.h create mode 100644 folly/wangle/ssl/DHParam.h create mode 100644 folly/wangle/ssl/PasswordInFile.cpp create mode 100644 folly/wangle/ssl/PasswordInFile.h create mode 100644 folly/wangle/ssl/SSLCacheOptions.h create mode 100644 folly/wangle/ssl/SSLCacheProvider.h create mode 100644 folly/wangle/ssl/SSLContextConfig.h create mode 100644 folly/wangle/ssl/SSLContextManager.cpp create mode 100644 folly/wangle/ssl/SSLContextManager.h create mode 100644 folly/wangle/ssl/SSLSessionCacheManager.cpp create mode 100644 folly/wangle/ssl/SSLSessionCacheManager.h create mode 100644 folly/wangle/ssl/SSLStats.h create mode 100644 folly/wangle/ssl/SSLUtil.cpp create mode 100644 folly/wangle/ssl/SSLUtil.h create mode 100644 folly/wangle/ssl/TLSTicketKeyManager.cpp create mode 100644 folly/wangle/ssl/TLSTicketKeyManager.h create mode 100644 folly/wangle/ssl/TLSTicketKeySeeds.h create mode 100644 folly/wangle/ssl/test/SSLCacheTest.cpp create mode 100644 folly/wangle/ssl/test/SSLContextManagerTest.cpp diff --git a/folly/wangle/acceptor/Acceptor.cpp b/folly/wangle/acceptor/Acceptor.cpp new file mode 100644 index 00000000..8ef0d18b --- /dev/null +++ b/folly/wangle/acceptor/Acceptor.cpp @@ -0,0 +1,456 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using folly::wangle::ConnectionManager; +using folly::wangle::ManagedConnection; +using std::chrono::microseconds; +using std::chrono::milliseconds; +using std::filebuf; +using std::ifstream; +using std::ios; +using std::shared_ptr; +using std::string; + +namespace folly { + +#ifndef NO_LIB_GFLAGS +DEFINE_int32(shutdown_idle_grace_ms, 5000, "milliseconds to wait before " + "closing idle conns"); +#else +const int32_t FLAGS_shutdown_idle_grace_ms = 5000; +#endif + +static const std::string empty_string; +std::atomic Acceptor::totalNumPendingSSLConns_{0}; + +/** + * Lightweight wrapper class to keep track of a newly + * accepted connection during SSL handshaking. + */ +class AcceptorHandshakeHelper : + public AsyncSSLSocket::HandshakeCB, + public ManagedConnection { + public: + AcceptorHandshakeHelper(AsyncSSLSocket::UniquePtr socket, + Acceptor* acceptor, + const SocketAddress& clientAddr, + std::chrono::steady_clock::time_point acceptTime, + TransportInfo& tinfo) + : socket_(std::move(socket)), acceptor_(acceptor), + acceptTime_(acceptTime), clientAddr_(clientAddr), + tinfo_(tinfo) { + acceptor_->downstreamConnectionManager_->addConnection(this, true); + if(acceptor_->parseClientHello_) { + socket_->enableClientHelloParsing(); + } + socket_->sslAccept(this); + } + + virtual void timeoutExpired() noexcept override { + VLOG(4) << "SSL handshake timeout expired"; + sslError_ = SSLErrorEnum::TIMEOUT; + dropConnection(); + } + virtual void describe(std::ostream& os) const override { + os << "pending handshake on " << clientAddr_; + } + virtual bool isBusy() const override { + return true; + } + virtual void notifyPendingShutdown() override {} + virtual void closeWhenIdle() override {} + + virtual void dropConnection() override { + VLOG(10) << "Dropping in progress handshake for " << clientAddr_; + socket_->closeNow(); + } + virtual void dumpConnectionState(uint8_t loglevel) override { + } + + private: + // AsyncSSLSocket::HandshakeCallback API + virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept override { + + const unsigned char* nextProto = nullptr; + unsigned nextProtoLength = 0; + sock->getSelectedNextProtocol(&nextProto, &nextProtoLength); + if (VLOG_IS_ON(3)) { + if (nextProto) { + VLOG(3) << "Client selected next protocol " << + string((const char*)nextProto, nextProtoLength); + } else { + VLOG(3) << "Client did not select a next protocol"; + } + } + + // fill in SSL-related fields from TransportInfo + // the other fields like RTT are filled in the Acceptor + tinfo_.ssl = true; + tinfo_.acceptTime = acceptTime_; + tinfo_.sslSetupTime = std::chrono::duration_cast( + std::chrono::steady_clock::now() - acceptTime_ + ); + tinfo_.sslSetupBytesRead = sock->getRawBytesReceived(); + tinfo_.sslSetupBytesWritten = sock->getRawBytesWritten(); + tinfo_.sslServerName = sock->getSSLServerName() ? + std::make_shared(sock->getSSLServerName()) : nullptr; + tinfo_.sslCipher = sock->getNegotiatedCipherName() ? + std::make_shared(sock->getNegotiatedCipherName()) : nullptr; + tinfo_.sslVersion = sock->getSSLVersion(); + tinfo_.sslCertSize = sock->getSSLCertSize(); + tinfo_.sslResume = SSLUtil::getResumeState(sock); + tinfo_.sslClientCiphers = std::make_shared(); + sock->getSSLClientCiphers(*tinfo_.sslClientCiphers); + tinfo_.sslServerCiphers = std::make_shared(); + sock->getSSLServerCiphers(*tinfo_.sslServerCiphers); + tinfo_.sslClientComprMethods = + std::make_shared(sock->getSSLClientComprMethods()); + tinfo_.sslClientExts = + std::make_shared(sock->getSSLClientExts()); + tinfo_.sslNextProtocol = std::make_shared(); + tinfo_.sslNextProtocol->assign(reinterpret_cast(nextProto), + nextProtoLength); + + acceptor_->updateSSLStats( + sock, + tinfo_.sslSetupTime, + SSLErrorEnum::NO_ERROR + ); + acceptor_->downstreamConnectionManager_->removeConnection(this); + acceptor_->sslConnectionReady(std::move(socket_), clientAddr_, + nextProto ? string((const char*)nextProto, nextProtoLength) : + empty_string, tinfo_); + delete this; + } + + virtual void handshakeErr(AsyncSSLSocket* sock, + const AsyncSocketException& ex) noexcept override { + auto elapsedTime = std::chrono::duration_cast(std::chrono::steady_clock::now() - acceptTime_); + VLOG(3) << "SSL handshake error after " << elapsedTime.count() << + " ms; " << sock->getRawBytesReceived() << " bytes received & " << + sock->getRawBytesWritten() << " bytes sent: " << + ex.what(); + acceptor_->updateSSLStats(sock, elapsedTime, sslError_); + acceptor_->sslConnectionError(); + delete this; + } + + AsyncSSLSocket::UniquePtr socket_; + Acceptor* acceptor_; + std::chrono::steady_clock::time_point acceptTime_; + SocketAddress clientAddr_; + TransportInfo tinfo_; + SSLErrorEnum sslError_{SSLErrorEnum::NO_ERROR}; +}; + +Acceptor::Acceptor(const ServerSocketConfig& accConfig) : + accConfig_(accConfig), + socketOptions_(accConfig.getSocketOptions()) { +} + +void +Acceptor::init(AsyncServerSocket* serverSocket, + EventBase* eventBase) { + CHECK(nullptr == this->base_); + + if (accConfig_.isSSL()) { + if (!sslCtxManager_) { + sslCtxManager_ = folly::make_unique( + eventBase, + "vip_" + getName(), + accConfig_.strictSSL, nullptr); + } + for (const auto& sslCtxConfig : accConfig_.sslContextConfigs) { + sslCtxManager_->addSSLContextConfig( + sslCtxConfig, + accConfig_.sslCacheOptions, + &accConfig_.initialTicketSeeds, + accConfig_.bindAddress, + cacheProvider_); + parseClientHello_ |= sslCtxConfig.clientHelloParsingEnabled; + } + + CHECK(sslCtxManager_->getDefaultSSLCtx()); + } + + base_ = eventBase; + state_ = State::kRunning; + downstreamConnectionManager_ = ConnectionManager::makeUnique( + eventBase, accConfig_.connectionIdleTimeout, this); + + if (serverSocket) { + serverSocket->addAcceptCallback(this, eventBase); + + for (auto& fd : serverSocket->getSockets()) { + if (fd < 0) { + continue; + } + for (const auto& opt: socketOptions_) { + opt.first.apply(fd, opt.second); + } + } + } +} + +Acceptor::~Acceptor(void) { +} + +void Acceptor::addSSLContextConfig(const SSLContextConfig& sslCtxConfig) { + sslCtxManager_->addSSLContextConfig(sslCtxConfig, + accConfig_.sslCacheOptions, + &accConfig_.initialTicketSeeds, + accConfig_.bindAddress, + cacheProvider_); +} + +void +Acceptor::drainAllConnections() { + if (downstreamConnectionManager_) { + downstreamConnectionManager_->initiateGracefulShutdown( + std::chrono::milliseconds(FLAGS_shutdown_idle_grace_ms)); + } +} + +void Acceptor::setLoadShedConfig(const LoadShedConfiguration& from, + IConnectionCounter* counter) { + loadShedConfig_ = from; + connectionCounter_ = counter; +} + +bool Acceptor::canAccept(const SocketAddress& address) { + if (!connectionCounter_) { + return true; + } + + uint64_t maxConnections = connectionCounter_->getMaxConnections(); + if (maxConnections == 0) { + return true; + } + + uint64_t currentConnections = connectionCounter_->getNumConnections(); + if (currentConnections < maxConnections) { + return true; + } + + if (loadShedConfig_.isWhitelisted(address)) { + return true; + } + + // Take care of comparing connection count against max connections across + // all acceptors. Expensive since a lock must be taken to get the counter. + auto connectionCountForLoadShedding = getConnectionCountForLoadShedding(); + if (connectionCountForLoadShedding < loadShedConfig_.getMaxConnections()) { + return true; + } + + VLOG(4) << address.describe() << " not whitelisted"; + return false; +} + +void +Acceptor::connectionAccepted( + int fd, const SocketAddress& clientAddr) noexcept { + if (!canAccept(clientAddr)) { + close(fd); + return; + } + auto acceptTime = std::chrono::steady_clock::now(); + for (const auto& opt: socketOptions_) { + opt.first.apply(fd, opt.second); + } + + onDoneAcceptingConnection(fd, clientAddr, acceptTime); +} + +void Acceptor::onDoneAcceptingConnection( + int fd, + const SocketAddress& clientAddr, + std::chrono::steady_clock::time_point acceptTime) noexcept { + TransportInfo tinfo; + processEstablishedConnection(fd, clientAddr, acceptTime, tinfo); +} + +void +Acceptor::processEstablishedConnection( + int fd, + const SocketAddress& clientAddr, + std::chrono::steady_clock::time_point acceptTime, + TransportInfo& tinfo) noexcept { + if (accConfig_.isSSL()) { + CHECK(sslCtxManager_); + AsyncSSLSocket::UniquePtr sslSock( + makeNewAsyncSSLSocket( + sslCtxManager_->getDefaultSSLCtx(), base_, fd)); + ++numPendingSSLConns_; + ++totalNumPendingSSLConns_; + if (totalNumPendingSSLConns_ > accConfig_.maxConcurrentSSLHandshakes) { + VLOG(2) << "dropped SSL handshake on " << accConfig_.name << + " too many handshakes in progress"; + updateSSLStats(sslSock.get(), std::chrono::milliseconds(0), + SSLErrorEnum::DROPPED); + sslConnectionError(); + return; + } + new AcceptorHandshakeHelper( + std::move(sslSock), + this, + clientAddr, + acceptTime, + tinfo + ); + } else { + tinfo.ssl = false; + tinfo.acceptTime = acceptTime; + AsyncSocket::UniquePtr sock(makeNewAsyncSocket(base_, fd)); + connectionReady(std::move(sock), clientAddr, empty_string, tinfo); + } +} + +void +Acceptor::connectionReady( + AsyncSocket::UniquePtr sock, + const SocketAddress& clientAddr, + const string& nextProtocolName, + TransportInfo& tinfo) { + // Limit the number of reads from the socket per poll loop iteration, + // both to keep memory usage under control and to prevent one fast- + // writing client from starving other connections. + sock->setMaxReadsPerEvent(16); + tinfo.initWithSocket(sock.get()); + onNewConnection(std::move(sock), &clientAddr, nextProtocolName, tinfo); +} + +void +Acceptor::sslConnectionReady(AsyncSocket::UniquePtr sock, + const SocketAddress& clientAddr, + const string& nextProtocol, + TransportInfo& tinfo) { + CHECK(numPendingSSLConns_ > 0); + connectionReady(std::move(sock), clientAddr, nextProtocol, tinfo); + --numPendingSSLConns_; + --totalNumPendingSSLConns_; + if (state_ == State::kDraining) { + checkDrained(); + } +} + +void +Acceptor::sslConnectionError() { + CHECK(numPendingSSLConns_ > 0); + --numPendingSSLConns_; + --totalNumPendingSSLConns_; + if (state_ == State::kDraining) { + checkDrained(); + } +} + +void +Acceptor::acceptError(const std::exception& ex) noexcept { + // An error occurred. + // The most likely error is out of FDs. AsyncServerSocket will back off + // briefly if we are out of FDs, then continue accepting later. + // Just log a message here. + LOG(ERROR) << "error accepting on acceptor socket: " << ex.what(); +} + +void +Acceptor::acceptStopped() noexcept { + VLOG(3) << "Acceptor " << this << " acceptStopped()"; + // Drain the open client connections + drainAllConnections(); + + // If we haven't yet finished draining, begin doing so by marking ourselves + // as in the draining state. We must be sure to hit checkDrained() here, as + // if we're completely idle, we can should consider ourself drained + // immediately (as there is no outstanding work to complete to cause us to + // re-evaluate this). + if (state_ != State::kDone) { + state_ = State::kDraining; + checkDrained(); + } +} + +void +Acceptor::onEmpty(const ConnectionManager& cm) { + VLOG(3) << "Acceptor=" << this << " onEmpty()"; + if (state_ == State::kDraining) { + checkDrained(); + } +} + +void +Acceptor::checkDrained() { + CHECK(state_ == State::kDraining); + if (forceShutdownInProgress_ || + (downstreamConnectionManager_->getNumConnections() != 0) || + (numPendingSSLConns_ != 0)) { + return; + } + + VLOG(2) << "All connections drained from Acceptor=" << this << " in thread " + << base_; + + downstreamConnectionManager_.reset(); + + state_ = State::kDone; + + onConnectionsDrained(); +} + +milliseconds +Acceptor::getConnTimeout() const { + return accConfig_.connectionIdleTimeout; +} + +void Acceptor::addConnection(ManagedConnection* conn) { + // Add the socket to the timeout manager so that it can be cleaned + // up after being left idle for a long time. + downstreamConnectionManager_->addConnection(conn, true); +} + +void +Acceptor::forceStop() { + base_->runInEventBaseThread([&] { dropAllConnections(); }); +} + +void +Acceptor::dropAllConnections() { + if (downstreamConnectionManager_) { + VLOG(3) << "Dropping all connections from Acceptor=" << this << + " in thread " << base_; + assert(base_->isInEventBaseThread()); + forceShutdownInProgress_ = true; + downstreamConnectionManager_->dropAllConnections(); + CHECK(downstreamConnectionManager_->getNumConnections() == 0); + downstreamConnectionManager_.reset(); + } + CHECK(numPendingSSLConns_ == 0); + + state_ = State::kDone; + onConnectionsDrained(); +} + +} // namespace diff --git a/folly/wangle/acceptor/Acceptor.h b/folly/wangle/acceptor/Acceptor.h new file mode 100644 index 00000000..bf2f567a --- /dev/null +++ b/folly/wangle/acceptor/Acceptor.h @@ -0,0 +1,356 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace folly { namespace wangle { +class ManagedConnection; +}} + +namespace folly { + +class SocketAddress; +class SSLContext; +class AsyncTransport; +class SSLContextManager; + +/** + * An abstract acceptor for TCP-based network services. + * + * There is one acceptor object per thread for each listening socket. When a + * new connection arrives on the listening socket, it is accepted by one of the + * acceptor objects. From that point on the connection will be processed by + * that acceptor's thread. + * + * The acceptor will call the abstract onNewConnection() method to create + * a new ManagedConnection object for each accepted socket. The acceptor + * also tracks all outstanding connections that it has accepted. + */ +class Acceptor : + public folly::AsyncServerSocket::AcceptCallback, + public folly::wangle::ConnectionManager::Callback, + public AsyncUDPServerSocket::Callback { + public: + + enum class State : uint32_t { + kInit, // not yet started + kRunning, // processing requests normally + kDraining, // processing outstanding conns, but not accepting new ones + kDone, // no longer accepting, and all connections finished + }; + + explicit Acceptor(const ServerSocketConfig& accConfig); + virtual ~Acceptor(); + + /** + * Supply an SSL cache provider + * @note Call this before init() + */ + virtual void setSSLCacheProvider( + const std::shared_ptr& cacheProvider) { + cacheProvider_ = cacheProvider; + } + + /** + * Initialize the Acceptor to run in the specified EventBase + * thread, receiving connections from the specified AsyncServerSocket. + * + * This method will be called from the AsyncServerSocket's primary thread, + * not the specified EventBase thread. + */ + virtual void init(AsyncServerSocket* serverSocket, + EventBase* eventBase); + + /** + * Dynamically add a new SSLContextConfig + */ + void addSSLContextConfig(const SSLContextConfig& sslCtxConfig); + + SSLContextManager* getSSLContextManager() const { + return sslCtxManager_.get(); + } + + /** + * Return the number of outstanding connections in this service instance. + */ + uint32_t getNumConnections() const { + return downstreamConnectionManager_ ? + (uint32_t)downstreamConnectionManager_->getNumConnections() : 0; + } + + /** + * Access the Acceptor's event base. + */ + virtual EventBase* getEventBase() const { return base_; } + + /** + * Access the Acceptor's downstream (client-side) ConnectionManager + */ + virtual folly::wangle::ConnectionManager* getConnectionManager() { + return downstreamConnectionManager_.get(); + } + + /** + * Invoked when a new ManagedConnection is created. + * + * This allows the Acceptor to track the outstanding connections, + * for tracking timeouts and for ensuring that all connections have been + * drained on shutdown. + */ + void addConnection(folly::wangle::ManagedConnection* connection); + + /** + * Get this acceptor's current state. + */ + State getState() const { + return state_; + } + + /** + * Get the current connection timeout. + */ + std::chrono::milliseconds getConnTimeout() const; + + /** + * Returns the name of this VIP. + * + * Will return an empty string if no name has been configured. + */ + const std::string& getName() const { + return accConfig_.name; + } + + /** + * Force the acceptor to drop all connections and stop processing. + * + * This function may be called from any thread. The acceptor will not + * necessarily stop before this function returns: the stop will be scheduled + * to run in the acceptor's thread. + */ + virtual void forceStop(); + + bool isSSL() const { return accConfig_.isSSL(); } + + const ServerSocketConfig& getConfig() const { return accConfig_; } + + static uint64_t getTotalNumPendingSSLConns() { + return totalNumPendingSSLConns_.load(); + } + + /** + * Called right when the TCP connection has been accepted, before processing + * the first HTTP bytes (HTTP) or the SSL handshake (HTTPS) + */ + virtual void onDoneAcceptingConnection( + int fd, + const SocketAddress& clientAddr, + std::chrono::steady_clock::time_point acceptTime + ) noexcept; + + /** + * Begins either processing HTTP bytes (HTTP) or the SSL handshake (HTTPS) + */ + void processEstablishedConnection( + int fd, + const SocketAddress& clientAddr, + std::chrono::steady_clock::time_point acceptTime, + TransportInfo& tinfo + ) noexcept; + + /** + * Drains all open connections of their outstanding transactions. When + * a connection's transaction count reaches zero, the connection closes. + */ + void drainAllConnections(); + + /** + * Drop all connections. + * + * forceStop() schedules dropAllConnections() to be called in the acceptor's + * thread. + */ + void dropAllConnections(); + + protected: + friend class AcceptorHandshakeHelper; + + /** + * Our event loop. + * + * Probably needs to be used to pass to a ManagedConnection + * implementation. Also visible in case a subclass wishes to do additional + * things w/ the event loop (e.g. in attach()). + */ + EventBase* base_{nullptr}; + + virtual uint64_t getConnectionCountForLoadShedding(void) const { return 0; } + + /** + * Hook for subclasses to drop newly accepted connections prior + * to handshaking. + */ + virtual bool canAccept(const folly::SocketAddress&); + + /** + * Invoked when a new connection is created. This is where application starts + * processing a new downstream connection. + * + * NOTE: Application should add the new connection to + * downstreamConnectionManager so that it can be garbage collected after + * certain period of idleness. + * + * @param sock the socket connected to the client + * @param address the address of the client + * @param nextProtocolName the name of the L6 or L7 protocol to be + * spoken on the connection, if known (e.g., + * from TLS NPN during secure connection setup), + * or an empty string if unknown + */ + virtual void onNewConnection( + AsyncSocket::UniquePtr sock, + const folly::SocketAddress* address, + const std::string& nextProtocolName, + const TransportInfo& tinfo) {} + + void onListenStarted() noexcept {} + void onListenStopped() noexcept {} + void onDataAvailable( + std::shared_ptr socket, + const SocketAddress&, + std::unique_ptr, bool) noexcept {} + + virtual AsyncSocket::UniquePtr makeNewAsyncSocket(EventBase* base, int fd) { + return AsyncSocket::UniquePtr(new AsyncSocket(base, fd)); + } + + virtual AsyncSSLSocket::UniquePtr makeNewAsyncSSLSocket( + const std::shared_ptr& ctx, EventBase* base, int fd) { + return AsyncSSLSocket::UniquePtr(new AsyncSSLSocket(ctx, base, fd)); + } + + /** + * Hook for subclasses to record stats about SSL connection establishment. + */ + virtual void updateSSLStats( + const AsyncSSLSocket* sock, + std::chrono::milliseconds acceptLatency, + SSLErrorEnum error) noexcept {} + + protected: + + /** + * onConnectionsDrained() will be called once all connections have been + * drained while the acceptor is stopping. + * + * Subclasses can override this method to perform any subclass-specific + * cleanup. + */ + virtual void onConnectionsDrained() {} + + // AsyncServerSocket::AcceptCallback methods + void connectionAccepted(int fd, + const folly::SocketAddress& clientAddr) + noexcept; + void acceptError(const std::exception& ex) noexcept; + void acceptStopped() noexcept; + + // ConnectionManager::Callback methods + void onEmpty(const folly::wangle::ConnectionManager& cm); + void onConnectionAdded(const folly::wangle::ConnectionManager& cm) {} + void onConnectionRemoved(const folly::wangle::ConnectionManager& cm) {} + + /** + * Process a connection that is to ready to receive L7 traffic. + * This method is called immediately upon accept for plaintext + * connections and upon completion of SSL handshaking or resumption + * for SSL connections. + */ + void connectionReady( + AsyncSocket::UniquePtr sock, + const folly::SocketAddress& clientAddr, + const std::string& nextProtocolName, + TransportInfo& tinfo); + + const LoadShedConfiguration& getLoadShedConfiguration() const { + return loadShedConfig_; + } + + protected: + const ServerSocketConfig accConfig_; + void setLoadShedConfig(const LoadShedConfiguration& from, + IConnectionCounter* counter); + + /** + * Socket options to apply to the client socket + */ + AsyncSocket::OptionMap socketOptions_; + + std::unique_ptr sslCtxManager_; + + /** + * Whether we want to enable client hello parsing in the handshake helper + * to get list of supported client ciphers. + */ + bool parseClientHello_{false}; + + folly::wangle::ConnectionManager::UniquePtr downstreamConnectionManager_; + + private: + + // Forbidden copy constructor and assignment opererator + Acceptor(Acceptor const &) = delete; + Acceptor& operator=(Acceptor const &) = delete; + + /** + * Wrapper for connectionReady() that decrements the count of + * pending SSL connections. + */ + void sslConnectionReady(AsyncSocket::UniquePtr sock, + const folly::SocketAddress& clientAddr, + const std::string& nextProtocol, + TransportInfo& tinfo); + + /** + * Notification callback for SSL handshake failures. + */ + void sslConnectionError(); + + void checkDrained(); + + State state_{State::kInit}; + uint64_t numPendingSSLConns_{0}; + + static std::atomic totalNumPendingSSLConns_; + + bool forceShutdownInProgress_{false}; + LoadShedConfiguration loadShedConfig_; + IConnectionCounter* connectionCounter_{nullptr}; + std::shared_ptr cacheProvider_; +}; + +class AcceptorFactory { + public: + virtual std::shared_ptr newAcceptor(folly::EventBase*) = 0; + virtual ~AcceptorFactory() = default; +}; + +} // namespace diff --git a/folly/wangle/acceptor/ConnectionCounter.h b/folly/wangle/acceptor/ConnectionCounter.h new file mode 100644 index 00000000..3aaba647 --- /dev/null +++ b/folly/wangle/acceptor/ConnectionCounter.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +namespace folly { + +class IConnectionCounter { + public: + virtual uint64_t getNumConnections() const = 0; + + /** + * Get the maximum number of non-whitelisted client-side connections + * across all Acceptors managed by this. A value + * of zero means "unlimited." + */ + virtual uint64_t getMaxConnections() const = 0; + + /** + * Increment the count of client-side connections. + */ + virtual void onConnectionAdded() = 0; + + /** + * Decrement the count of client-side connections. + */ + virtual void onConnectionRemoved() = 0; + virtual ~IConnectionCounter() {} +}; + +class SimpleConnectionCounter: public IConnectionCounter { + public: + uint64_t getNumConnections() const override { return numConnections_; } + uint64_t getMaxConnections() const override { return maxConnections_; } + void setMaxConnections(uint64_t maxConnections) { + maxConnections_ = maxConnections; + } + + void onConnectionAdded() override { numConnections_++; } + void onConnectionRemoved() override { numConnections_--; } + virtual ~SimpleConnectionCounter() {} + + protected: + uint64_t maxConnections_{0}; + uint64_t numConnections_{0}; +}; + +} diff --git a/folly/wangle/acceptor/ConnectionManager.cpp b/folly/wangle/acceptor/ConnectionManager.cpp new file mode 100644 index 00000000..bb75c74a --- /dev/null +++ b/folly/wangle/acceptor/ConnectionManager.cpp @@ -0,0 +1,233 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +using folly::HHWheelTimer; +using std::chrono::milliseconds; + +namespace folly { namespace wangle { + +ConnectionManager::ConnectionManager(EventBase* eventBase, + milliseconds timeout, Callback* callback) + : connTimeouts_(new HHWheelTimer(eventBase)), + callback_(callback), + eventBase_(eventBase), + idleIterator_(conns_.end()), + idleLoopCallback_(this), + timeout_(timeout), + idleConnEarlyDropThreshold_(timeout_ / 2) { + +} + +void +ConnectionManager::addConnection(ManagedConnection* connection, + bool timeout) { + CHECK_NOTNULL(connection); + ConnectionManager* oldMgr = connection->getConnectionManager(); + if (oldMgr != this) { + if (oldMgr) { + // 'connection' was being previously managed in a different thread. + // We must remove it from that manager before adding it to this one. + oldMgr->removeConnection(connection); + } + + // put the connection into busy part first. This should not matter at all + // because the last callback for an idle connection must be onDeactivated(), + // so the connection must be moved to idle part then. + conns_.push_front(*connection); + + connection->setConnectionManager(this); + if (callback_) { + callback_->onConnectionAdded(*this); + } + } + if (timeout) { + scheduleTimeout(connection, timeout_); + } +} + +void +ConnectionManager::scheduleTimeout(ManagedConnection* const connection, + std::chrono::milliseconds timeout) { + if (timeout > std::chrono::milliseconds(0)) { + connTimeouts_->scheduleTimeout(connection, timeout); + } +} + +void ConnectionManager::scheduleTimeout( + folly::HHWheelTimer::Callback* callback, + std::chrono::milliseconds timeout) { + connTimeouts_->scheduleTimeout(callback, timeout); +} + +void +ConnectionManager::removeConnection(ManagedConnection* connection) { + if (connection->getConnectionManager() == this) { + connection->cancelTimeout(); + connection->setConnectionManager(nullptr); + + // Un-link the connection from our list, being careful to keep the iterator + // that we're using for idle shedding valid + auto it = conns_.iterator_to(*connection); + if (it == idleIterator_) { + ++idleIterator_; + } + conns_.erase(it); + + if (callback_) { + callback_->onConnectionRemoved(*this); + if (getNumConnections() == 0) { + callback_->onEmpty(*this); + } + } + } +} + +void +ConnectionManager::initiateGracefulShutdown( + std::chrono::milliseconds idleGrace) { + if (idleGrace.count() > 0) { + idleLoopCallback_.scheduleTimeout(idleGrace); + VLOG(3) << "Scheduling idle grace period of " << idleGrace.count() << "ms"; + } else { + action_ = ShutdownAction::DRAIN2; + VLOG(3) << "proceeding directly to closing idle connections"; + } + drainAllConnections(); +} + +void +ConnectionManager::drainAllConnections() { + DestructorGuard g(this); + size_t numCleared = 0; + size_t numKept = 0; + + auto it = idleIterator_ == conns_.end() ? + conns_.begin() : idleIterator_; + + while (it != conns_.end() && (numKept + numCleared) < 64) { + ManagedConnection& conn = *it++; + if (action_ == ShutdownAction::DRAIN1) { + conn.notifyPendingShutdown(); + } else { + // Second time around: close idle sessions. If they aren't idle yet, + // have them close when they are idle + if (conn.isBusy()) { + numKept++; + } else { + numCleared++; + } + conn.closeWhenIdle(); + } + } + + if (action_ == ShutdownAction::DRAIN2) { + VLOG(2) << "Idle connections cleared: " << numCleared << + ", busy conns kept: " << numKept; + } + if (it != conns_.end()) { + idleIterator_ = it; + eventBase_->runInLoop(&idleLoopCallback_); + } else { + action_ = ShutdownAction::DRAIN2; + } +} + +void +ConnectionManager::dropAllConnections() { + DestructorGuard g(this); + + // Iterate through our connection list, and drop each connection. + VLOG(3) << "connections to drop: " << conns_.size(); + idleLoopCallback_.cancelTimeout(); + unsigned i = 0; + while (!conns_.empty()) { + ManagedConnection& conn = conns_.front(); + conns_.pop_front(); + conn.cancelTimeout(); + conn.setConnectionManager(nullptr); + // For debugging purposes, dump information about the first few + // connections. + static const unsigned MAX_CONNS_TO_DUMP = 2; + if (++i <= MAX_CONNS_TO_DUMP) { + conn.dumpConnectionState(3); + } + conn.dropConnection(); + } + idleIterator_ = conns_.end(); + idleLoopCallback_.cancelLoopCallback(); + + if (callback_) { + callback_->onEmpty(*this); + } +} + +void +ConnectionManager::onActivated(ManagedConnection& conn) { + auto it = conns_.iterator_to(conn); + if (it == idleIterator_) { + idleIterator_++; + } + conns_.erase(it); + conns_.push_front(conn); +} + +void +ConnectionManager::onDeactivated(ManagedConnection& conn) { + auto it = conns_.iterator_to(conn); + conns_.erase(it); + conns_.push_back(conn); + if (idleIterator_ == conns_.end()) { + idleIterator_--; + } +} + +size_t +ConnectionManager::dropIdleConnections(size_t num) { + VLOG(4) << "attempt to drop " << num << " idle connections"; + if (idleConnEarlyDropThreshold_ >= timeout_) { + return 0; + } + + size_t count = 0; + while(count < num) { + auto it = idleIterator_; + if (it == conns_.end()) { + return count; // no more idle session + } + auto idleTime = it->getIdleTime(); + if (idleTime == std::chrono::milliseconds(0) || + idleTime <= idleConnEarlyDropThreshold_) { + VLOG(4) << "conn's idletime: " << idleTime.count() + << ", earlyDropThreshold: " << idleConnEarlyDropThreshold_.count() + << ", attempt to drop " << count << "/" << num; + return count; // idleTime cannot be further reduced + } + ManagedConnection& conn = *it; + idleIterator_++; + conn.timeoutExpired(); + count++; + } + + return count; +} + + +}} // folly::wangle diff --git a/folly/wangle/acceptor/ConnectionManager.h b/folly/wangle/acceptor/ConnectionManager.h new file mode 100644 index 00000000..45400a6a --- /dev/null +++ b/folly/wangle/acceptor/ConnectionManager.h @@ -0,0 +1,245 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace folly { namespace wangle { + +/** + * A ConnectionManager keeps track of ManagedConnections. + */ +class ConnectionManager: public folly::DelayedDestruction, + private ManagedConnection::Callback { + public: + + /** + * Interface for an optional observer that's notified about + * various events in a ConnectionManager + */ + class Callback { + public: + virtual ~Callback() {} + + /** + * Invoked when the number of connections managed by the + * ConnectionManager changes from nonzero to zero. + */ + virtual void onEmpty(const ConnectionManager& cm) = 0; + + /** + * Invoked when a connection is added to the ConnectionManager. + */ + virtual void onConnectionAdded(const ConnectionManager& cm) = 0; + + /** + * Invoked when a connection is removed from the ConnectionManager. + */ + virtual void onConnectionRemoved(const ConnectionManager& cm) = 0; + }; + + typedef std::unique_ptr UniquePtr; + + /** + * Returns a new instance of ConnectionManager wrapped in a unique_ptr + */ + template + static UniquePtr makeUnique(Args&&... args) { + return folly::make_unique( + std::forward(args)...); + } + + /** + * Constructor not to be used by itself. + */ + ConnectionManager(folly::EventBase* eventBase, + std::chrono::milliseconds timeout, + Callback* callback = nullptr); + + /** + * Add a connection to the set of connections managed by this + * ConnectionManager. + * + * @param connection The connection to add. + * @param timeout Whether to immediately register this connection + * for an idle timeout callback. + */ + void addConnection(ManagedConnection* connection, + bool timeout = false); + + /** + * Schedule a timeout callback for a connection. + */ + void scheduleTimeout(ManagedConnection* const connection, + std::chrono::milliseconds timeout); + + /* + * Schedule a callback on the wheel timer + */ + void scheduleTimeout(folly::HHWheelTimer::Callback* callback, + std::chrono::milliseconds timeout); + + /** + * Remove a connection from this ConnectionManager and, if + * applicable, cancel the pending timeout callback that the + * ConnectionManager has scheduled for the connection. + * + * @note This method does NOT destroy the connection. + */ + void removeConnection(ManagedConnection* connection); + + /* Begin gracefully shutting down connections in this ConnectionManager. + * Notify all connections of pending shutdown, and after idleGrace, + * begin closing idle connections. + */ + void initiateGracefulShutdown(std::chrono::milliseconds idleGrace); + + /** + * Destroy all connections Managed by this ConnectionManager, even + * the ones that are busy. + */ + void dropAllConnections(); + + size_t getNumConnections() const { return conns_.size(); } + + template + void iterateConns(F func) { + auto it = conns_.begin(); + while ( it != conns_.end()) { + func(&(*it)); + it++; + } + } + + std::chrono::milliseconds getDefaultTimeout() const { + return timeout_; + } + + void setLoweredIdleTimeout(std::chrono::milliseconds timeout) { + CHECK(timeout >= std::chrono::milliseconds(0)); + CHECK(timeout <= timeout_); + idleConnEarlyDropThreshold_ = timeout; + } + + /** + * try to drop num idle connections to release system resources. Return the + * actual number of dropped idle connections + */ + size_t dropIdleConnections(size_t num); + + /** + * ManagedConnection::Callbacks + */ + void onActivated(ManagedConnection& conn); + + void onDeactivated(ManagedConnection& conn); + + private: + class CloseIdleConnsCallback : + public folly::EventBase::LoopCallback, + public folly::AsyncTimeout { + public: + explicit CloseIdleConnsCallback(ConnectionManager* manager) + : folly::AsyncTimeout(manager->eventBase_), + manager_(manager) {} + + void runLoopCallback() noexcept override { + VLOG(3) << "Draining more conns from loop callback"; + manager_->drainAllConnections(); + } + + void timeoutExpired() noexcept override { + VLOG(3) << "Idle grace expired"; + manager_->drainAllConnections(); + } + + private: + ConnectionManager* manager_; + }; + + enum class ShutdownAction : uint8_t { + /** + * Drain part 1: inform remote that you will soon reject new requests. + */ + DRAIN1 = 0, + /** + * Drain part 2: start rejecting new requests. + */ + DRAIN2 = 1, + }; + + ~ConnectionManager() {} + + ConnectionManager(const ConnectionManager&) = delete; + ConnectionManager& operator=(ConnectionManager&) = delete; + + /** + * Destroy all connections managed by this ConnectionManager that + * are currently idle, as determined by a call to each ManagedConnection's + * isBusy() method. + */ + void drainAllConnections(); + + /** + * All the managed connections. idleIterator_ seperates them into two parts: + * idle and busy ones. [conns_.begin(), idleIterator_) are the busy ones, + * while [idleIterator_, conns_.end()) are the idle one. Moreover, the idle + * ones are organized in the decreasing idle time order. */ + folly::CountedIntrusiveList< + ManagedConnection,&ManagedConnection::listHook_> conns_; + + /** Connections that currently are registered for timeouts */ + folly::HHWheelTimer::UniquePtr connTimeouts_; + + /** Optional callback to notify of state changes */ + Callback* callback_; + + /** Event base in which we run */ + folly::EventBase* eventBase_; + + /** Iterator to the next connection to shed; used by drainAllConnections() */ + folly::CountedIntrusiveList< + ManagedConnection,&ManagedConnection::listHook_>::iterator idleIterator_; + CloseIdleConnsCallback idleLoopCallback_; + ShutdownAction action_{ShutdownAction::DRAIN1}; + + /** + * the default idle timeout for downstream sessions when no system resource + * limit is reached + */ + std::chrono::milliseconds timeout_; + + /** + * The idle connections can be closed earlier that their idle timeout when any + * system resource limit is reached. This feature can be considerred as a pre + * load shedding stage for the system, and can be easily disabled by setting + * idleConnEarlyDropThreshold_ to defaultIdleTimeout_. Also, + * idleConnEarlyDropThreshold_ can be used to bottom the idle timeout. That + * is, connection manager will not early drop the idle connections whose idle + * time is less than idleConnEarlyDropThreshold_. + */ + std::chrono::milliseconds idleConnEarlyDropThreshold_; +}; + +}} // folly::wangle diff --git a/folly/wangle/acceptor/DomainNameMisc.h b/folly/wangle/acceptor/DomainNameMisc.h new file mode 100644 index 00000000..bba27a32 --- /dev/null +++ b/folly/wangle/acceptor/DomainNameMisc.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include + +namespace folly { + +struct dn_char_traits : public std::char_traits { + static bool eq(char c1, char c2) { + return ::tolower(c1) == ::tolower(c2); + } + + static bool ne(char c1, char c2) { + return ::tolower(c1) != ::tolower(c2); + } + + static bool lt(char c1, char c2) { + return ::tolower(c1) < ::tolower(c2); + } + + static int compare(const char* s1, const char* s2, size_t n) { + while (n--) { + if(::tolower(*s1) < ::tolower(*s2) ) { + return -1; + } + if(::tolower(*s1) > ::tolower(*s2) ) { + return 1; + } + ++s1; + ++s2; + } + return 0; + } + + static const char* find(const char* s, size_t n, char a) { + char la = ::tolower(a); + while (n--) { + if(::tolower(*s) == la) { + return s; + } else { + ++s; + } + } + return nullptr; + } +}; + +// Case insensitive string +typedef std::basic_string DNString; + +struct DNStringHash : public std::hash { + size_t operator()(const DNString& s1) const noexcept { + std::string s2(s1.data(), s1.size()); + for (char& c : s2) + c = ::tolower(c); + return std::hash()(s2); + } +}; + +} // namespace diff --git a/folly/wangle/acceptor/LoadShedConfiguration.cpp b/folly/wangle/acceptor/LoadShedConfiguration.cpp new file mode 100644 index 00000000..44381e41 --- /dev/null +++ b/folly/wangle/acceptor/LoadShedConfiguration.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include +#include + +using std::string; + +namespace folly { + +void LoadShedConfiguration::addWhitelistAddr(folly::StringPiece input) { + auto addr = input.str(); + size_t separator = addr.find_first_of('/'); + if (separator == string::npos) { + whitelistAddrs_.insert(SocketAddress(addr, 0)); + } else { + unsigned prefixLen = folly::to(addr.substr(separator + 1)); + addr.erase(separator); + whitelistNetworks_.insert(NetworkAddress(SocketAddress(addr, 0), prefixLen)); + } +} + +bool LoadShedConfiguration::isWhitelisted(const SocketAddress& address) const { + if (whitelistAddrs_.find(address) != whitelistAddrs_.end()) { + return true; + } + for (auto& network : whitelistNetworks_) { + if (network.contains(address)) { + return true; + } + } + return false; +} + +} diff --git a/folly/wangle/acceptor/LoadShedConfiguration.h b/folly/wangle/acceptor/LoadShedConfiguration.h new file mode 100644 index 00000000..57c51b97 --- /dev/null +++ b/folly/wangle/acceptor/LoadShedConfiguration.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace folly { + +/** + * Class that holds an LoadShed configuration for a service + */ +class LoadShedConfiguration { + public: + + // Comparison function for SocketAddress that disregards the port + struct AddressOnlyCompare { + bool operator()( + const SocketAddress& addr1, + const SocketAddress& addr2) const { + return addr1.getIPAddress() < addr2.getIPAddress(); + } + }; + + typedef std::set AddressSet; + typedef std::set NetworkSet; + + LoadShedConfiguration() {} + + ~LoadShedConfiguration() {} + + void addWhitelistAddr(folly::StringPiece); + + /** + * Set/get the set of IPs that should be whitelisted through even when we're + * trying to shed load. + */ + void setWhitelistAddrs(const AddressSet& addrs) { whitelistAddrs_ = addrs; } + const AddressSet& getWhitelistAddrs() const { return whitelistAddrs_; } + + /** + * Set/get the set of networks that should be whitelisted through even + * when we're trying to shed load. + */ + void setWhitelistNetworks(const NetworkSet& networks) { + whitelistNetworks_ = networks; + } + const NetworkSet& getWhitelistNetworks() const { return whitelistNetworks_; } + + /** + * Set/get the maximum number of downstream connections across all VIPs. + */ + void setMaxConnections(uint64_t maxConns) { maxConnections_ = maxConns; } + uint64_t getMaxConnections() const { return maxConnections_; } + + /** + * Set/get the maximum cpu usage. + */ + void setMaxMemUsage(double max) { + CHECK(max >= 0); + CHECK(max <= 1); + maxMemUsage_ = max; + } + double getMaxMemUsage() const { return maxMemUsage_; } + + /** + * Set/get the maximum memory usage. + */ + void setMaxCpuUsage(double max) { + CHECK(max >= 0); + CHECK(max <= 1); + maxCpuUsage_ = max; + } + double getMaxCpuUsage() const { return maxCpuUsage_; } + + /** + * Set/get the minium actual free memory on the system. + */ + void setMinFreeMem(uint64_t min) { + minFreeMem_ = min; + } + uint64_t getMinFreeMem() const { + return minFreeMem_; + } + + void setLoadUpdatePeriod(std::chrono::milliseconds period) { + period_ = period; + } + std::chrono::milliseconds getLoadUpdatePeriod() const { return period_; } + + bool isWhitelisted(const SocketAddress& addr) const; + + private: + + AddressSet whitelistAddrs_; + NetworkSet whitelistNetworks_; + uint64_t maxConnections_{0}; + uint64_t minFreeMem_{0}; + double maxMemUsage_; + double maxCpuUsage_; + std::chrono::milliseconds period_; +}; + +} diff --git a/folly/wangle/acceptor/ManagedConnection.cpp b/folly/wangle/acceptor/ManagedConnection.cpp new file mode 100644 index 00000000..3ddd0d3c --- /dev/null +++ b/folly/wangle/acceptor/ManagedConnection.cpp @@ -0,0 +1,64 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +namespace folly { namespace wangle { + +ManagedConnection::ManagedConnection() + : connectionManager_(nullptr) { +} + +ManagedConnection::~ManagedConnection() { + if (connectionManager_) { + connectionManager_->removeConnection(this); + } +} + +void +ManagedConnection::resetTimeout() { + if (connectionManager_) { + resetTimeoutTo(connectionManager_->getDefaultTimeout()); + } +} + +void +ManagedConnection::resetTimeoutTo(std::chrono::milliseconds timeout) { + if (connectionManager_) { + connectionManager_->scheduleTimeout(this, timeout); + } +} + +void +ManagedConnection::scheduleTimeout( + folly::HHWheelTimer::Callback* callback, + std::chrono::milliseconds timeout) { + if (connectionManager_) { + connectionManager_->scheduleTimeout(callback, timeout); + } +} + +////////////////////// Globals ///////////////////// + +std::ostream& +operator<<(std::ostream& os, const ManagedConnection& conn) { + conn.describe(os); + return os; +} + +}} // folly::wangle diff --git a/folly/wangle/acceptor/ManagedConnection.h b/folly/wangle/acceptor/ManagedConnection.h new file mode 100644 index 00000000..abee324e --- /dev/null +++ b/folly/wangle/acceptor/ManagedConnection.h @@ -0,0 +1,140 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace folly { namespace wangle { + +class ConnectionManager; + +/** + * Interface describing a connection that can be managed by a + * container such as an Acceptor. + */ +class ManagedConnection: + public folly::HHWheelTimer::Callback, + public folly::DelayedDestruction { + public: + + ManagedConnection(); + + class Callback { + public: + virtual ~Callback() {} + + /* Invoked when this connection becomes busy */ + virtual void onActivated(ManagedConnection& conn) = 0; + + /* Invoked when a connection becomes idle */ + virtual void onDeactivated(ManagedConnection& conn) = 0; + }; + + // HHWheelTimer::Callback API (left for subclasses to implement). + virtual void timeoutExpired() noexcept = 0; + + /** + * Print a human-readable description of the connection. + * @param os Destination stream. + */ + virtual void describe(std::ostream& os) const = 0; + + /** + * Check whether the connection has any requests outstanding. + */ + virtual bool isBusy() const = 0; + + /** + * Get the idle time of the connection. If it returning 0, that means the idle + * connections will never be dropped during pre load shedding stage. + */ + virtual std::chrono::milliseconds getIdleTime() const { + return std::chrono::milliseconds(0); + } + + /** + * Notify the connection that a shutdown is pending. This method will be + * called at the beginning of graceful shutdown. + */ + virtual void notifyPendingShutdown() = 0; + + /** + * Instruct the connection that it should shutdown as soon as it is + * safe. This is called after notifyPendingShutdown(). + */ + virtual void closeWhenIdle() = 0; + + /** + * Forcibly drop a connection. + * + * If a request is in progress, this should cause the connection to be + * closed with a reset. + */ + virtual void dropConnection() = 0; + + /** + * Dump the state of the connection to the log + */ + virtual void dumpConnectionState(uint8_t loglevel) = 0; + + /** + * If the connection has a connection manager, reset the timeout countdown to + * connection manager's default timeout. + * @note If the connection manager doesn't have the connection scheduled + * for a timeout already, this method will schedule one. If the + * connection manager does have the connection connection scheduled + * for a timeout, this method will push back the timeout to N msec + * from now, where N is the connection manager's timer interval. + */ + virtual void resetTimeout(); + + /** + * If the connection has a connection manager, reset the timeout countdown to + * user specified timeout. + */ + void resetTimeoutTo(std::chrono::milliseconds); + + // Schedule an arbitrary timeout on the HHWheelTimer + virtual void scheduleTimeout( + folly::HHWheelTimer::Callback* callback, + std::chrono::milliseconds timeout); + + ConnectionManager* getConnectionManager() { + return connectionManager_; + } + + protected: + virtual ~ManagedConnection(); + + private: + friend class ConnectionManager; + + void setConnectionManager(ConnectionManager* mgr) { + connectionManager_ = mgr; + } + + ConnectionManager* connectionManager_; + + folly::SafeIntrusiveListHook listHook_; +}; + +std::ostream& operator<<(std::ostream& os, const ManagedConnection& conn); + +}} // folly::wangle diff --git a/folly/wangle/acceptor/NetworkAddress.h b/folly/wangle/acceptor/NetworkAddress.h new file mode 100644 index 00000000..4b444adb --- /dev/null +++ b/folly/wangle/acceptor/NetworkAddress.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include + +namespace folly { + +/** + * A simple wrapper around SocketAddress that represents + * a network in CIDR notation + */ +class NetworkAddress { +public: + /** + * Create a NetworkAddress for an addr/prefixLen + * @param addr IPv4 or IPv6 address of the network + * @param prefixLen Prefix length, in bits + */ + NetworkAddress(const folly::SocketAddress& addr, + unsigned prefixLen): + addr_(addr), prefixLen_(prefixLen) {} + + /** Get the network address */ + const folly::SocketAddress& getAddress() const { + return addr_; + } + + /** Get the prefix length in bits */ + unsigned getPrefixLength() const { return prefixLen_; } + + /** Check whether a given address lies within the network */ + bool contains(const folly::SocketAddress& addr) const { + return addr_.prefixMatch(addr, prefixLen_); + } + + /** Comparison operator to enable use in ordered collections */ + bool operator<(const NetworkAddress& other) const { + if (addr_ < other.addr_) { + return true; + } else if (other.addr_ < addr_) { + return false; + } else { + return (prefixLen_ < other.prefixLen_); + } + } + +private: + folly::SocketAddress addr_; + unsigned prefixLen_; +}; + +} // namespace diff --git a/folly/wangle/acceptor/ServerSocketConfig.h b/folly/wangle/acceptor/ServerSocketConfig.h new file mode 100644 index 00000000..1e776ad1 --- /dev/null +++ b/folly/wangle/acceptor/ServerSocketConfig.h @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace folly { + +/** + * Configuration for a single Acceptor. + * + * This configures not only accept behavior, but also some types of SSL + * behavior that may make sense to configure on a per-VIP basis (e.g. which + * cert(s) we use, etc). + */ +struct ServerSocketConfig { + ServerSocketConfig() { + // generate a single random current seed + uint8_t seed[32]; + folly::Random::secureRandom(seed, sizeof(seed)); + initialTicketSeeds.currentSeeds.push_back( + SSLUtil::hexlify(std::string((char *)seed, sizeof(seed)))); + } + + bool isSSL() const { return !(sslContextConfigs.empty()); } + + /** + * Set/get the socket options to apply on all downstream connections. + */ + void setSocketOptions( + const AsyncSocket::OptionMap& opts) { + socketOptions_ = filterIPSocketOptions(opts, bindAddress.getFamily()); + } + AsyncSocket::OptionMap& + getSocketOptions() { + return socketOptions_; + } + const AsyncSocket::OptionMap& + getSocketOptions() const { + return socketOptions_; + } + + bool hasExternalPrivateKey() const { + for (const auto& cfg : sslContextConfigs) { + if (!cfg.isLocalPrivateKey) { + return true; + } + } + return false; + } + + /** + * The name of this acceptor; used for stats/reporting purposes. + */ + std::string name; + + /** + * The depth of the accept queue backlog. + */ + uint32_t acceptBacklog{1024}; + + /** + * The number of milliseconds a connection can be idle before we close it. + */ + std::chrono::milliseconds connectionIdleTimeout{600000}; + + /** + * The address to bind to. + */ + SocketAddress bindAddress; + + /** + * Options for controlling the SSL cache. + */ + SSLCacheOptions sslCacheOptions{std::chrono::seconds(600), 20480, 200}; + + /** + * The initial TLS ticket seeds. + */ + TLSTicketKeySeeds initialTicketSeeds; + + /** + * The configs for all the SSL_CTX for use by this Acceptor. + */ + std::vector sslContextConfigs; + + /** + * Determines if the Acceptor does strict checking when loading the SSL + * contexts. + */ + bool strictSSL{true}; + + /** + * Maximum number of concurrent pending SSL handshakes + */ + uint32_t maxConcurrentSSLHandshakes{30720}; + + private: + AsyncSocket::OptionMap socketOptions_; +}; + +} // folly diff --git a/folly/wangle/acceptor/SocketOptions.cpp b/folly/wangle/acceptor/SocketOptions.cpp new file mode 100644 index 00000000..b33e103a --- /dev/null +++ b/folly/wangle/acceptor/SocketOptions.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include +#include + +namespace folly { + +AsyncSocket::OptionMap filterIPSocketOptions( + const AsyncSocket::OptionMap& allOptions, + const int addrFamily) { + AsyncSocket::OptionMap opts; + int exclude; + if (addrFamily == AF_INET) { + exclude = IPPROTO_IPV6; + } else if (addrFamily == AF_INET6) { + exclude = IPPROTO_IP; + } else { + LOG(FATAL) << "Address family " << addrFamily << " was not IPv4 or IPv6"; + return opts; + } + for (const auto& opt: allOptions) { + if (opt.first.level != exclude) { + opts[opt.first] = opt.second; + } + } + return opts; +} + +} diff --git a/folly/wangle/acceptor/SocketOptions.h b/folly/wangle/acceptor/SocketOptions.h new file mode 100644 index 00000000..dadd22b1 --- /dev/null +++ b/folly/wangle/acceptor/SocketOptions.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include + +namespace folly { + +/** + * Returns a copy of the socket options excluding options with the given + * level. + */ +AsyncSocket::OptionMap filterIPSocketOptions( + const AsyncSocket::OptionMap& allOptions, + const int addrFamily); + +} diff --git a/folly/wangle/acceptor/TransportInfo.cpp b/folly/wangle/acceptor/TransportInfo.cpp new file mode 100644 index 00000000..4f735b4f --- /dev/null +++ b/folly/wangle/acceptor/TransportInfo.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include +#include +#include + +using std::chrono::microseconds; +using std::map; +using std::string; + +namespace folly { + +bool TransportInfo::initWithSocket(const AsyncSocket* sock) { +#if defined(__linux__) || defined(__FreeBSD__) + if (!TransportInfo::readTcpInfo(&tcpinfo, sock)) { + tcpinfoErrno = errno; + return false; + } + rtt = microseconds(tcpinfo.tcpi_rtt); + validTcpinfo = true; +#else + tcpinfoErrno = EINVAL; + rtt = microseconds(-1); +#endif + return true; +} + +int64_t TransportInfo::readRTT(const AsyncSocket* sock) { +#if defined(__linux__) || defined(__FreeBSD__) + struct tcp_info tcpinfo; + if (!TransportInfo::readTcpInfo(&tcpinfo, sock)) { + return -1; + } + return tcpinfo.tcpi_rtt; +#else + return -1; +#endif +} + +#if defined(__linux__) || defined(__FreeBSD__) +bool TransportInfo::readTcpInfo(struct tcp_info* tcpinfo, + const AsyncSocket* sock) { + socklen_t len = sizeof(struct tcp_info); + if (!sock) { + return false; + } + if (getsockopt(sock->getFd(), IPPROTO_TCP, + TCP_INFO, (void*) tcpinfo, &len) < 0) { + VLOG(4) << "Error calling getsockopt(): " << strerror(errno); + return false; + } + return true; +} +#endif + +} // folly diff --git a/folly/wangle/acceptor/TransportInfo.h b/folly/wangle/acceptor/TransportInfo.h new file mode 100644 index 00000000..67203c7e --- /dev/null +++ b/folly/wangle/acceptor/TransportInfo.h @@ -0,0 +1,298 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include + +#include +#include +#include + +namespace folly { +class AsyncSocket; + +/** + * A structure that encapsulates byte counters related to the HTTP headers. + */ +struct HTTPHeaderSize { + /** + * The number of bytes used to represent the header after compression or + * before decompression. If header compression is not supported, the value + * is set to 0. + */ + size_t compressed{0}; + + /** + * The number of bytes used to represent the serialized header before + * compression or after decompression, in plain-text format. + */ + size_t uncompressed{0}; +}; + +struct TransportInfo { + /* + * timestamp of when the connection handshake was completed + */ + std::chrono::steady_clock::time_point acceptTime{}; + + /* + * connection RTT (Round-Trip Time) + */ + std::chrono::microseconds rtt{0}; + +#if defined(__linux__) || defined(__FreeBSD__) + /* + * TCP information as fetched from getsockopt(2) + */ + tcp_info tcpinfo { +#if __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 17 + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 // 32 +#else + 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 // 29 +#endif // __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 17 + }; +#endif // defined(__linux__) || defined(__FreeBSD__) + + /* + * time for setting the connection, from the moment in was accepted until it + * is established. + */ + std::chrono::milliseconds setupTime{0}; + + /* + * time for setting up the SSL connection or SSL handshake + */ + std::chrono::milliseconds sslSetupTime{0}; + + /* + * The name of the SSL ciphersuite used by the transaction's + * transport. Returns null if the transport is not SSL. + */ + std::shared_ptr sslCipher{nullptr}; + + /* + * The SSL server name used by the transaction's + * transport. Returns null if the transport is not SSL. + */ + std::shared_ptr sslServerName{nullptr}; + + /* + * list of ciphers sent by the client + */ + std::shared_ptr sslClientCiphers{nullptr}; + + /* + * list of compression methods sent by the client + */ + std::shared_ptr sslClientComprMethods{nullptr}; + + /* + * list of TLS extensions sent by the client + */ + std::shared_ptr sslClientExts{nullptr}; + + /* + * hash of all the SSL parameters sent by the client + */ + std::shared_ptr sslSignature{nullptr}; + + /* + * list of ciphers supported by the server + */ + std::shared_ptr sslServerCiphers{nullptr}; + + /* + * guessed "(os) (browser)" based on SSL Signature + */ + std::shared_ptr guessedUserAgent{nullptr}; + + /** + * The result of SSL NPN negotiation. + */ + std::shared_ptr sslNextProtocol{nullptr}; + + /* + * total number of bytes sent over the connection + */ + int64_t totalBytes{0}; + + /** + * If the client passed through one of our L4 proxies (using PROXY Protocol), + * then this will contain the IP address of the proxy host. + */ + std::shared_ptr clientAddrOriginal; + + /** + * header bytes read + */ + HTTPHeaderSize ingressHeader; + + /* + * header bytes written + */ + HTTPHeaderSize egressHeader; + + /* + * Here is how the timeToXXXByte variables are planned out: + * 1. All timeToXXXByte variables are measuring the ByteEvent from reqStart_ + * 2. You can get the timing between two ByteEvents by calculating their + * differences. For example: + * timeToLastBodyByteAck - timeToFirstByte + * => Total time to deliver the body + * 3. The calculation in point (2) is typically done outside acceptor + * + * Future plan: + * We should log the timestamps (TimePoints) and allow + * the consumer to calculate the latency whatever it + * wants instead of calculating them in wangle, for the sake of flexibility. + * For example: + * 1. TimePoint reqStartTimestamp; + * 2. TimePoint firstHeaderByteSentTimestamp; + * 3. TimePoint firstBodyByteTimestamp; + * 3. TimePoint lastBodyByteTimestamp; + * 4. TimePoint lastBodyByteAckTimestamp; + */ + + /* + * time to first header byte written to the kernel send buffer + * NOTE: It is not 100% accurate since TAsyncSocket does not do + * do callback on partial write. + */ + int32_t timeToFirstHeaderByte{-1}; + + /* + * time to first body byte written to the kernel send buffer + */ + int32_t timeToFirstByte{-1}; + + /* + * time to last body byte written to the kernel send buffer + */ + int32_t timeToLastByte{-1}; + + /* + * time to TCP Ack received for the last written body byte + */ + int32_t timeToLastBodyByteAck{-1}; + + /* + * time it took the client to ACK the last byte, from the moment when the + * kernel sent the last byte to the client and until it received the ACK + * for that byte + */ + int32_t lastByteAckLatency{-1}; + + /* + * time spent inside wangle + */ + int32_t proxyLatency{-1}; + + /* + * time between connection accepted and client message headers completed + */ + int32_t clientLatency{-1}; + + /* + * latency for communication with the server + */ + int32_t serverLatency{-1}; + + /* + * time used to get a usable connection. + */ + int32_t connectLatency{-1}; + + /* + * body bytes written + */ + uint32_t egressBodySize{0}; + + /* + * value of errno in case of getsockopt() error + */ + int tcpinfoErrno{0}; + + /* + * bytes read & written during SSL Setup + */ + uint32_t sslSetupBytesWritten{0}; + uint32_t sslSetupBytesRead{0}; + + /** + * SSL error detail + */ + uint32_t sslError{0}; + + /** + * body bytes read + */ + uint32_t ingressBodySize{0}; + + /* + * The SSL version used by the transaction's transport, in + * OpenSSL's format: 4 bits for the major version, followed by 4 bits + * for the minor version. Returns zero for non-SSL. + */ + uint16_t sslVersion{0}; + + /* + * The SSL certificate size. + */ + uint16_t sslCertSize{0}; + + /** + * response status code + */ + uint16_t statusCode{0}; + + /* + * The SSL mode for the transaction's transport: new session, + * resumed session, or neither (non-SSL). + */ + SSLResumeEnum sslResume{SSLResumeEnum::NA}; + + /* + * true if the tcpinfo was successfully read from the kernel + */ + bool validTcpinfo{false}; + + /* + * true if the connection is SSL, false otherwise + */ + bool ssl{false}; + + /* + * get the RTT value in milliseconds + */ + std::chrono::milliseconds getRttMs() const { + return std::chrono::duration_cast(rtt); + } + + /* + * initialize the fields related with tcp_info + */ + bool initWithSocket(const AsyncSocket* sock); + + /* + * Get the kernel's estimate of round-trip time (RTT) to the transport's peer + * in microseconds. Returns -1 on error. + */ + static int64_t readRTT(const AsyncSocket* sock); + +#if defined(__linux__) || defined(__FreeBSD__) + /* + * perform the getsockopt(2) syscall to fetch TCP info for a given socket + */ + static bool readTcpInfo(struct tcp_info* tcpinfo, + const AsyncSocket* sock); +#endif +}; + +} // folly diff --git a/folly/wangle/bootstrap/BootstrapTest.cpp b/folly/wangle/bootstrap/BootstrapTest.cpp new file mode 100644 index 00000000..724afec9 --- /dev/null +++ b/folly/wangle/bootstrap/BootstrapTest.cpp @@ -0,0 +1,365 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "folly/wangle/bootstrap/ServerBootstrap.h" +#include "folly/wangle/bootstrap/ClientBootstrap.h" +#include "folly/wangle/channel/Handler.h" + +#include +#include +#include + +using namespace folly::wangle; +using namespace folly; + +typedef Pipeline> BytesPipeline; + +typedef ServerBootstrap TestServer; +typedef ClientBootstrap TestClient; + +class TestClientPipelineFactory : public PipelineFactory { + public: + std::unique_ptr + newPipeline(std::shared_ptr sock) { + // We probably aren't connected immedately, check after a small delay + EventBaseManager::get()->getEventBase()->tryRunAfterDelay([sock](){ + CHECK(sock->good()); + CHECK(sock->readable()); + }, 100); + return nullptr; + } +}; + +class TestPipelineFactory : public PipelineFactory { + public: + std::unique_ptr newPipeline( + std::shared_ptr sock) { + + pipelines++; + return std::unique_ptr( + new BytesPipeline()); + } + std::atomic pipelines{0}; +}; + +class TestAcceptor : public Acceptor { +EventBase base_; + public: + TestAcceptor() : Acceptor(ServerSocketConfig()) { + Acceptor::init(nullptr, &base_); + } + void onNewConnection( + AsyncSocket::UniquePtr sock, + const folly::SocketAddress* address, + const std::string& nextProtocolName, + const TransportInfo& tinfo) { + } +}; + +class TestAcceptorFactory : public AcceptorFactory { + public: + std::shared_ptr newAcceptor(EventBase* base) { + return std::make_shared(); + } +}; + +TEST(Bootstrap, Basic) { + TestServer server; + TestClient client; +} + +TEST(Bootstrap, ServerWithPipeline) { + TestServer server; + server.childPipeline(std::make_shared()); + server.bind(0); + server.stop(); +} + +TEST(Bootstrap, ServerWithChildHandler) { + TestServer server; + server.childHandler(std::make_shared()); + server.bind(0); + server.stop(); +} + +TEST(Bootstrap, ClientServerTest) { + TestServer server; + auto factory = std::make_shared(); + server.childPipeline(factory); + server.bind(0); + auto base = EventBaseManager::get()->getEventBase(); + + SocketAddress address; + server.getSockets()[0]->getAddress(&address); + + TestClient client; + client.pipelineFactory(std::make_shared()); + client.connect(address); + base->loop(); + server.stop(); + + CHECK(factory->pipelines == 1); +} + +TEST(Bootstrap, ClientConnectionManagerTest) { + // Create a single IO thread, and verify that + // client connections are pooled properly + + TestServer server; + auto factory = std::make_shared(); + server.childPipeline(factory); + server.group(std::make_shared(1)); + server.bind(0); + auto base = EventBaseManager::get()->getEventBase(); + + SocketAddress address; + server.getSockets()[0]->getAddress(&address); + + TestClient client; + client.pipelineFactory(std::make_shared()); + + client.connect(address); + + TestClient client2; + client2.pipelineFactory(std::make_shared()); + client2.connect(address); + + base->loop(); + server.stop(); + + CHECK(factory->pipelines == 2); +} + +TEST(Bootstrap, ServerAcceptGroupTest) { + // Verify that server is using the accept IO group + + TestServer server; + auto factory = std::make_shared(); + server.childPipeline(factory); + server.group(std::make_shared(1), nullptr); + server.bind(0); + + SocketAddress address; + server.getSockets()[0]->getAddress(&address); + + boost::barrier barrier(2); + auto thread = std::thread([&](){ + TestClient client; + client.pipelineFactory(std::make_shared()); + client.connect(address); + EventBaseManager::get()->getEventBase()->loop(); + barrier.wait(); + }); + barrier.wait(); + server.stop(); + thread.join(); + + CHECK(factory->pipelines == 1); +} + +TEST(Bootstrap, ServerAcceptGroup2Test) { + // Verify that server is using the accept IO group + + // Check if reuse port is supported, if not, don't run this test + try { + EventBase base; + auto serverSocket = AsyncServerSocket::newSocket(&base); + serverSocket->bind(0); + serverSocket->listen(0); + serverSocket->startAccepting(); + serverSocket->setReusePortEnabled(true); + serverSocket->stopAccepting(); + } catch(...) { + LOG(INFO) << "Reuse port probably not supported"; + return; + } + + TestServer server; + auto factory = std::make_shared(); + server.childPipeline(factory); + server.group(std::make_shared(4), nullptr); + server.bind(0); + + SocketAddress address; + server.getSockets()[0]->getAddress(&address); + + TestClient client; + client.pipelineFactory(std::make_shared()); + + client.connect(address); + EventBaseManager::get()->getEventBase()->loop(); + + server.stop(); + + CHECK(factory->pipelines == 1); +} + +TEST(Bootstrap, SharedThreadPool) { + // Check if reuse port is supported, if not, don't run this test + try { + EventBase base; + auto serverSocket = AsyncServerSocket::newSocket(&base); + serverSocket->bind(0); + serverSocket->listen(0); + serverSocket->startAccepting(); + serverSocket->setReusePortEnabled(true); + serverSocket->stopAccepting(); + } catch(...) { + LOG(INFO) << "Reuse port probably not supported"; + return; + } + + auto pool = std::make_shared(2); + + TestServer server; + auto factory = std::make_shared(); + server.childPipeline(factory); + server.group(pool, pool); + + server.bind(0); + + SocketAddress address; + server.getSockets()[0]->getAddress(&address); + + TestClient client; + client.pipelineFactory(std::make_shared()); + client.connect(address); + + TestClient client2; + client2.pipelineFactory(std::make_shared()); + client2.connect(address); + + TestClient client3; + client3.pipelineFactory(std::make_shared()); + client3.connect(address); + + TestClient client4; + client4.pipelineFactory(std::make_shared()); + client4.connect(address); + + TestClient client5; + client5.pipelineFactory(std::make_shared()); + client5.connect(address); + + EventBaseManager::get()->getEventBase()->loop(); + + server.stop(); + CHECK(factory->pipelines == 5); +} + +TEST(Bootstrap, ExistingSocket) { + TestServer server; + auto factory = std::make_shared(); + server.childPipeline(factory); + folly::AsyncServerSocket::UniquePtr socket(new AsyncServerSocket); + server.bind(std::move(socket)); +} + +std::atomic connections{0}; + +class TestHandlerPipeline : public InboundHandler { + public: + void read(Context* ctx, void* conn) { + connections++; + return ctx->fireRead(conn); + } +}; + +template +class TestHandlerPipelineFactory + : public PipelineFactory::AcceptPipeline> { + public: + std::unique_ptr::AcceptPipeline, + folly::DelayedDestruction::Destructor> + newPipeline(std::shared_ptr) { + + std::unique_ptr::AcceptPipeline, + folly::DelayedDestruction::Destructor> pipeline( + new ServerBootstrap::AcceptPipeline); + pipeline->addBack(HandlerPipeline()); + return pipeline; + } +}; + +TEST(Bootstrap, LoadBalanceHandler) { + TestServer server; + auto factory = std::make_shared(); + server.childPipeline(factory); + + auto pipelinefactory = + std::make_shared>(); + server.pipeline(pipelinefactory); + server.bind(0); + auto base = EventBaseManager::get()->getEventBase(); + + SocketAddress address; + server.getSockets()[0]->getAddress(&address); + + TestClient client; + client.pipelineFactory(std::make_shared()); + client.connect(address); + base->loop(); + server.stop(); + + CHECK(factory->pipelines == 1); + CHECK(connections == 1); +} + +class TestUDPPipeline : public InboundHandler { + public: + void read(Context* ctx, void* conn) { + connections++; + } +}; + +TEST(Bootstrap, UDP) { + TestServer server; + auto factory = std::make_shared(); + auto pipelinefactory = + std::make_shared>(); + server.pipeline(pipelinefactory); + server.channelFactory(std::make_shared()); + server.bind(0); +} + +TEST(Bootstrap, UDPClientServerTest) { + connections = 0; + + TestServer server; + auto factory = std::make_shared(); + auto pipelinefactory = + std::make_shared>(); + server.pipeline(pipelinefactory); + server.channelFactory(std::make_shared()); + server.bind(0); + + auto base = EventBaseManager::get()->getEventBase(); + + SocketAddress address; + server.getSockets()[0]->getAddress(&address); + + SocketAddress localhost("::1", 0); + AsyncUDPSocket client(base); + client.bind(localhost); + auto data = IOBuf::create(1); + data->append(1); + *(data->writableData()) = 'a'; + client.write(address, std::move(data)); + base->loop(); + server.stop(); + + CHECK(connections == 1); +} diff --git a/folly/wangle/bootstrap/ClientBootstrap.h b/folly/wangle/bootstrap/ClientBootstrap.h new file mode 100644 index 00000000..ecd0c3b1 --- /dev/null +++ b/folly/wangle/bootstrap/ClientBootstrap.h @@ -0,0 +1,109 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace folly { + +/* + * A thin wrapper around Pipeline and AsyncSocket to match + * ServerBootstrap. On connect() a new pipeline is created. + */ +template +class ClientBootstrap { + + class ConnectCallback : public AsyncSocket::ConnectCallback { + public: + ConnectCallback(Promise promise, ClientBootstrap* bootstrap) + : promise_(std::move(promise)) + , bootstrap_(bootstrap) {} + + void connectSuccess() noexcept override { + if (bootstrap_->getPipeline()) { + bootstrap_->getPipeline()->transportActive(); + } + promise_.setValue(bootstrap_->getPipeline()); + delete this; + } + + void connectErr(const AsyncSocketException& ex) noexcept override { + promise_.setException( + folly::make_exception_wrapper(ex)); + delete this; + } + private: + Promise promise_; + ClientBootstrap* bootstrap_; + }; + + public: + ClientBootstrap() { + } + + ClientBootstrap* group( + std::shared_ptr group) { + group_ = group; + return this; + } + ClientBootstrap* bind(int port) { + port_ = port; + return this; + } + Future connect(SocketAddress address) { + DCHECK(pipelineFactory_); + auto base = EventBaseManager::get()->getEventBase(); + if (group_) { + base = group_->getEventBase(); + } + Future retval((Pipeline*)nullptr); + base->runImmediatelyOrRunInEventBaseThreadAndWait([&](){ + auto socket = AsyncSocket::newSocket(base); + Promise promise; + retval = promise.getFuture(); + socket->connect( + new ConnectCallback(std::move(promise), this), address); + pipeline_ = pipelineFactory_->newPipeline(socket); + }); + return retval; + } + + ClientBootstrap* pipelineFactory( + std::shared_ptr> factory) { + pipelineFactory_ = factory; + return this; + } + + Pipeline* getPipeline() { + return pipeline_.get(); + } + + virtual ~ClientBootstrap() {} + + protected: + std::unique_ptr pipeline_; + + int port_; + + std::shared_ptr> pipelineFactory_; + std::shared_ptr group_; +}; + +} // namespace diff --git a/folly/wangle/bootstrap/ServerBootstrap-inl.h b/folly/wangle/bootstrap/ServerBootstrap-inl.h new file mode 100644 index 00000000..4b56de86 --- /dev/null +++ b/folly/wangle/bootstrap/ServerBootstrap-inl.h @@ -0,0 +1,198 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace folly { + +template +class ServerAcceptor + : public Acceptor + , public folly::wangle::InboundHandler { + typedef std::unique_ptr PipelinePtr; + + class ServerConnection : public wangle::ManagedConnection, + public wangle::PipelineManager { + public: + explicit ServerConnection(PipelinePtr pipeline) + : pipeline_(std::move(pipeline)) { + pipeline_->setPipelineManager(this); + } + + ~ServerConnection() {} + + void timeoutExpired() noexcept override { + } + + void describe(std::ostream& os) const override {} + bool isBusy() const override { + return false; + } + void notifyPendingShutdown() override {} + void closeWhenIdle() override {} + void dropConnection() override { + delete this; + } + void dumpConnectionState(uint8_t loglevel) override {} + + void deletePipeline(wangle::PipelineBase* p) override { + CHECK(p == pipeline_.get()); + delete this; + } + + private: + PipelinePtr pipeline_; + }; + + public: + explicit ServerAcceptor( + std::shared_ptr> pipelineFactory, + std::shared_ptr> acceptorPipeline, + EventBase* base) + : Acceptor(ServerSocketConfig()) + , base_(base) + , childPipelineFactory_(pipelineFactory) + , acceptorPipeline_(acceptorPipeline) { + Acceptor::init(nullptr, base_); + CHECK(acceptorPipeline_); + + acceptorPipeline_->addBack(this); + acceptorPipeline_->finalize(); + } + + void read(Context* ctx, void* conn) { + AsyncSocket::UniquePtr transport((AsyncSocket*)conn); + std::unique_ptr + pipeline(childPipelineFactory_->newPipeline( + std::shared_ptr( + transport.release(), + folly::DelayedDestruction::Destructor()))); + pipeline->transportActive(); + auto connection = new ServerConnection(std::move(pipeline)); + Acceptor::addConnection(connection); + } + + /* See Acceptor::onNewConnection for details */ + void onNewConnection( + AsyncSocket::UniquePtr transport, const SocketAddress* address, + const std::string& nextProtocolName, const TransportInfo& tinfo) { + acceptorPipeline_->read(transport.release()); + } + + // UDP thunk + void onDataAvailable(std::shared_ptr socket, + const folly::SocketAddress& addr, + std::unique_ptr buf, + bool truncated) noexcept { + acceptorPipeline_->read(buf.release()); + } + + private: + EventBase* base_; + + std::shared_ptr> childPipelineFactory_; + std::shared_ptr> acceptorPipeline_; +}; + +template +class ServerAcceptorFactory : public AcceptorFactory { + public: + explicit ServerAcceptorFactory( + std::shared_ptr> factory, + std::shared_ptr>> pipeline) + : factory_(factory) + , pipeline_(pipeline) {} + + std::shared_ptr newAcceptor(EventBase* base) { + std::shared_ptr> pipeline( + pipeline_->newPipeline(nullptr)); + return std::make_shared>(factory_, pipeline, base); + } + private: + std::shared_ptr> factory_; + std::shared_ptr>> pipeline_; +}; + +class ServerWorkerPool : public folly::wangle::ThreadPoolExecutor::Observer { + public: + explicit ServerWorkerPool( + std::shared_ptr acceptorFactory, + folly::wangle::IOThreadPoolExecutor* exec, + std::shared_ptr>> sockets, + std::shared_ptr socketFactory) + : acceptorFactory_(acceptorFactory) + , exec_(exec) + , sockets_(sockets) + , socketFactory_(socketFactory) { + CHECK(exec); + } + + template + void forEachWorker(F&& f) const; + + void threadStarted( + folly::wangle::ThreadPoolExecutor::ThreadHandle*); + void threadStopped( + folly::wangle::ThreadPoolExecutor::ThreadHandle*); + void threadPreviouslyStarted( + folly::wangle::ThreadPoolExecutor::ThreadHandle* thread) { + threadStarted(thread); + } + void threadNotYetStopped( + folly::wangle::ThreadPoolExecutor::ThreadHandle* thread) { + threadStopped(thread); + } + + private: + std::map> workers_; + std::shared_ptr acceptorFactory_; + folly::wangle::IOThreadPoolExecutor* exec_{nullptr}; + std::shared_ptr>> sockets_; + std::shared_ptr socketFactory_; +}; + +template +void ServerWorkerPool::forEachWorker(F&& f) const { + for (const auto& kv : workers_) { + f(kv.second.get()); + } +} + +class DefaultAcceptPipelineFactory + : public PipelineFactory> { + typedef wangle::Pipeline AcceptPipeline; + + public: + std::unique_ptr + newPipeline(std::shared_ptr) { + + return std::unique_ptr + (new AcceptPipeline); + } +}; + +} // namespace diff --git a/folly/wangle/bootstrap/ServerBootstrap.cpp b/folly/wangle/bootstrap/ServerBootstrap.cpp new file mode 100644 index 00000000..6b7a4101 --- /dev/null +++ b/folly/wangle/bootstrap/ServerBootstrap.cpp @@ -0,0 +1,62 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include + +namespace folly { + +void ServerWorkerPool::threadStarted( + folly::wangle::ThreadPoolExecutor::ThreadHandle* h) { + auto worker = acceptorFactory_->newAcceptor(exec_->getEventBase(h)); + workers_.insert({h, worker}); + + for(auto socket : *sockets_) { + socket->getEventBase()->runImmediatelyOrRunInEventBaseThreadAndWait( + [this, worker, socket](){ + socketFactory_->addAcceptCB( + socket, worker.get(), worker->getEventBase()); + }); + } +} + +void ServerWorkerPool::threadStopped( + folly::wangle::ThreadPoolExecutor::ThreadHandle* h) { + auto worker = workers_.find(h); + CHECK(worker != workers_.end()); + + for (auto socket : *sockets_) { + socket->getEventBase()->runImmediatelyOrRunInEventBaseThreadAndWait( + [&]() { + socketFactory_->removeAcceptCB( + socket, worker->second.get(), nullptr); + }); + } + + if (!worker->second->getEventBase()->isInEventBaseThread()) { + worker->second->getEventBase()->runImmediatelyOrRunInEventBaseThreadAndWait( + [=]() { + worker->second->dropAllConnections(); + }); + } else { + worker->second->dropAllConnections(); + } + + workers_.erase(worker); +} + +} // namespace diff --git a/folly/wangle/bootstrap/ServerBootstrap.h b/folly/wangle/bootstrap/ServerBootstrap.h new file mode 100644 index 00000000..4940e67b --- /dev/null +++ b/folly/wangle/bootstrap/ServerBootstrap.h @@ -0,0 +1,351 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace folly { + +typedef folly::wangle::Pipeline< + folly::IOBufQueue&, std::unique_ptr> DefaultPipeline; + +/* + * ServerBootstrap is a parent class intended to set up a + * high-performance TCP accepting server. It will manage a pool of + * accepting threads, any number of accepting sockets, a pool of + * IO-worker threads, and connection pool for each IO thread for you. + * + * The output is given as a Pipeline template: given a + * PipelineFactory, it will create a new pipeline for each connection, + * and your server can handle the incoming bytes. + * + * BACKWARDS COMPATIBLITY: for servers already taking a pool of + * Acceptor objects, an AcceptorFactory can be given directly instead + * of a pipeline factory. + */ +template +class ServerBootstrap { + public: + + ServerBootstrap(const ServerBootstrap& that) = delete; + ServerBootstrap(ServerBootstrap&& that) = default; + + ServerBootstrap() {} + + ~ServerBootstrap() { + stop(); + join(); + } + + typedef wangle::Pipeline AcceptPipeline; + /* + * Pipeline used to add connections to event bases. + * This is used for UDP or for load balancing + * TCP connections to IO threads explicitly + */ + ServerBootstrap* pipeline( + std::shared_ptr> factory) { + pipeline_ = factory; + return this; + } + + ServerBootstrap* channelFactory( + std::shared_ptr factory) { + socketFactory_ = factory; + return this; + } + + /* + * BACKWARDS COMPATIBILITY - an acceptor factory can be set. Your + * Acceptor is responsible for managing the connection pool. + * + * @param childHandler - acceptor factory to call for each IO thread + */ + ServerBootstrap* childHandler(std::shared_ptr h) { + acceptorFactory_ = h; + return this; + } + + /* + * Set a pipeline factory that will be called for each new connection + * + * @param factory pipeline factory to use for each new connection + */ + ServerBootstrap* childPipeline( + std::shared_ptr> factory) { + childPipelineFactory_ = factory; + return this; + } + + /* + * Set the IO executor. If not set, a default one will be created + * with one thread per core. + * + * @param io_group - io executor to use for IO threads. + */ + ServerBootstrap* group( + std::shared_ptr io_group) { + return group(nullptr, io_group); + } + + /* + * Set the acceptor executor, and IO executor. + * + * If no acceptor executor is set, a single thread will be created for accepts + * If no IO executor is set, a default of one thread per core will be created + * + * @param group - acceptor executor to use for acceptor threads. + * @param io_group - io executor to use for IO threads. + */ + ServerBootstrap* group( + std::shared_ptr accept_group, + std::shared_ptr io_group) { + if (!accept_group) { + accept_group = std::make_shared( + 1, std::make_shared("Acceptor Thread")); + } + if (!io_group) { + io_group = std::make_shared( + 32, std::make_shared("IO Thread")); + } + + // TODO better config checking + // CHECK(acceptorFactory_ || childPipelineFactory_); + CHECK(!(acceptorFactory_ && childPipelineFactory_)); + + if (acceptorFactory_) { + workerFactory_ = std::make_shared( + acceptorFactory_, io_group.get(), sockets_, socketFactory_); + } else { + workerFactory_ = std::make_shared( + std::make_shared>( + childPipelineFactory_, + pipeline_), + io_group.get(), sockets_, socketFactory_); + } + + io_group->addObserver(workerFactory_); + + acceptor_group_ = accept_group; + io_group_ = io_group; + + return this; + } + + /* + * Bind to an existing socket + * + * @param sock Existing socket to use for accepting + */ + void bind(folly::AsyncServerSocket::UniquePtr s) { + if (!workerFactory_) { + group(nullptr); + } + + // Since only a single socket is given, + // we can only accept on a single thread + CHECK(acceptor_group_->numThreads() == 1); + + std::shared_ptr socket( + s.release(), DelayedDestruction::Destructor()); + + folly::Baton<> barrier; + acceptor_group_->add([&](){ + socket->attachEventBase(EventBaseManager::get()->getEventBase()); + socket->listen(socketConfig.acceptBacklog); + socket->startAccepting(); + barrier.post(); + }); + barrier.wait(); + + // Startup all the threads + workerFactory_->forEachWorker([this, socket](Acceptor* worker){ + socket->getEventBase()->runImmediatelyOrRunInEventBaseThreadAndWait( + [this, worker, socket](){ + socketFactory_->addAcceptCB(socket, worker, worker->getEventBase()); + }); + }); + + sockets_->push_back(socket); + } + + void bind(folly::SocketAddress& address) { + bindImpl(-1, address); + } + + /* + * Bind to a port and start listening. + * One of childPipeline or childHandler must be called before bind + * + * @param port Port to listen on + */ + void bind(int port) { + CHECK(port >= 0); + folly::SocketAddress address; + bindImpl(port, address); + } + + void bindImpl(int port, folly::SocketAddress& address) { + if (!workerFactory_) { + group(nullptr); + } + + bool reusePort = false; + if (acceptor_group_->numThreads() > 1) { + reusePort = true; + } + + std::mutex sock_lock; + std::vector> new_sockets; + + + std::exception_ptr exn; + + auto startupFunc = [&](std::shared_ptr> barrier){ + + try { + auto socket = socketFactory_->newSocket( + port, address, socketConfig.acceptBacklog, reusePort, socketConfig); + + sock_lock.lock(); + new_sockets.push_back(socket); + sock_lock.unlock(); + + if (port <= 0) { + socket->getAddress(&address); + port = address.getPort(); + } + + barrier->post(); + } catch (...) { + exn = std::current_exception(); + barrier->post(); + + return; + } + + + + }; + + auto wait0 = std::make_shared>(); + acceptor_group_->add(std::bind(startupFunc, wait0)); + wait0->wait(); + + for (size_t i = 1; i < acceptor_group_->numThreads(); i++) { + auto barrier = std::make_shared>(); + acceptor_group_->add(std::bind(startupFunc, barrier)); + barrier->wait(); + } + + if (exn) { + std::rethrow_exception(exn); + } + + for (auto& socket : new_sockets) { + // Startup all the threads + workerFactory_->forEachWorker([this, socket](Acceptor* worker){ + socket->getEventBase()->runImmediatelyOrRunInEventBaseThreadAndWait( + [this, worker, socket](){ + socketFactory_->addAcceptCB(socket, worker, worker->getEventBase()); + }); + }); + + sockets_->push_back(socket); + } + } + + /* + * Stop listening on all sockets. + */ + void stop() { + // sockets_ may be null if ServerBootstrap has been std::move'd + if (sockets_) { + for (auto socket : *sockets_) { + socket->getEventBase()->runImmediatelyOrRunInEventBaseThreadAndWait( + [&]() mutable { + socketFactory_->stopSocket(socket); + }); + } + sockets_->clear(); + } + if (!stopped_) { + stopped_ = true; + // stopBaton_ may be null if ServerBootstrap has been std::move'd + if (stopBaton_) { + stopBaton_->post(); + } + } + } + + void join() { + if (acceptor_group_) { + acceptor_group_->join(); + } + if (io_group_) { + io_group_->join(); + } + } + + void waitForStop() { + if (!stopped_) { + CHECK(stopBaton_); + stopBaton_->wait(); + } + } + + /* + * Get the list of listening sockets + */ + const std::vector>& + getSockets() const { + return *sockets_; + } + + std::shared_ptr getIOGroup() const { + return io_group_; + } + + template + void forEachWorker(F&& f) const { + workerFactory_->forEachWorker(f); + } + + ServerSocketConfig socketConfig; + + private: + std::shared_ptr acceptor_group_; + std::shared_ptr io_group_; + + std::shared_ptr workerFactory_; + std::shared_ptr>> sockets_{ + std::make_shared>>()}; + + std::shared_ptr acceptorFactory_; + std::shared_ptr> childPipelineFactory_; + std::shared_ptr> pipeline_{ + std::make_shared()}; + std::shared_ptr socketFactory_{ + std::make_shared()}; + + std::unique_ptr> stopBaton_{ + folly::make_unique>()}; + bool stopped_{false}; +}; + +} // namespace diff --git a/folly/wangle/bootstrap/ServerSocketFactory.h b/folly/wangle/bootstrap/ServerSocketFactory.h new file mode 100644 index 00000000..fad602db --- /dev/null +++ b/folly/wangle/bootstrap/ServerSocketFactory.h @@ -0,0 +1,122 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace folly { + +class ServerSocketFactory { + public: + virtual std::shared_ptr newSocket( + int port, SocketAddress address, int backlog, + bool reuse, ServerSocketConfig& config) = 0; + + virtual void stopSocket( + std::shared_ptr& socket) = 0; + + virtual void removeAcceptCB(std::shared_ptr sock, Acceptor *callback, EventBase* base) = 0; + virtual void addAcceptCB(std::shared_ptr sock, Acceptor* callback, EventBase* base) = 0 ; + virtual ~ServerSocketFactory() = default; +}; + +class AsyncServerSocketFactory : public ServerSocketFactory { + public: + std::shared_ptr newSocket( + int port, SocketAddress address, int backlog, bool reuse, + ServerSocketConfig& config) { + + auto socket = folly::AsyncServerSocket::newSocket(); + socket->setReusePortEnabled(reuse); + socket->attachEventBase(EventBaseManager::get()->getEventBase()); + if (port >= 0) { + socket->bind(port); + } else { + socket->bind(address); + } + + socket->listen(config.acceptBacklog); + socket->startAccepting(); + + return socket; + } + + virtual void stopSocket( + std::shared_ptr& s) { + auto socket = std::dynamic_pointer_cast(s); + DCHECK(socket); + socket->stopAccepting(); + socket->detachEventBase(); + } + + virtual void removeAcceptCB(std::shared_ptr s, + Acceptor *callback, EventBase* base) { + auto socket = std::dynamic_pointer_cast(s); + CHECK(socket); + socket->removeAcceptCallback(callback, base); + } + + virtual void addAcceptCB(std::shared_ptr s, + Acceptor* callback, EventBase* base) { + auto socket = std::dynamic_pointer_cast(s); + CHECK(socket); + socket->addAcceptCallback(callback, base); + } +}; + +class AsyncUDPServerSocketFactory : public ServerSocketFactory { + public: + std::shared_ptr newSocket( + int port, SocketAddress address, int backlog, bool reuse, + ServerSocketConfig& config) { + + auto socket = std::make_shared( + EventBaseManager::get()->getEventBase()); + socket->setReusePort(reuse); + if (port >= 0) { + SocketAddress addressr("::1", port); + socket->bind(addressr); + } else { + socket->bind(address); + } + socket->listen(); + + return socket; + } + + virtual void stopSocket( + std::shared_ptr& s) { + auto socket = std::dynamic_pointer_cast(s); + DCHECK(socket); + socket->close(); + } + + virtual void removeAcceptCB(std::shared_ptr s, + Acceptor *callback, EventBase* base) { + } + + virtual void addAcceptCB(std::shared_ptr s, + Acceptor* callback, EventBase* base) { + auto socket = std::dynamic_pointer_cast(s); + DCHECK(socket); + socket->addListener(base, callback); + } +}; + +} // namespace diff --git a/folly/wangle/channel/AsyncSocketHandler.h b/folly/wangle/channel/AsyncSocketHandler.h new file mode 100644 index 00000000..26728494 --- /dev/null +++ b/folly/wangle/channel/AsyncSocketHandler.h @@ -0,0 +1,164 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace folly { namespace wangle { + +// This handler may only be used in a single Pipeline +class AsyncSocketHandler + : public folly::wangle::BytesToBytesHandler, + public AsyncSocket::ReadCallback { + public: + explicit AsyncSocketHandler( + std::shared_ptr socket) + : socket_(std::move(socket)) {} + + AsyncSocketHandler(AsyncSocketHandler&&) = default; + + ~AsyncSocketHandler() { + detachReadCallback(); + } + + void attachReadCallback() { + socket_->setReadCB(socket_->good() ? this : nullptr); + } + + void detachReadCallback() { + if (socket_ && socket_->getReadCallback() == this) { + socket_->setReadCB(nullptr); + } + auto ctx = getContext(); + if (ctx && !firedInactive_) { + firedInactive_ = true; + ctx->fireTransportInactive(); + } + } + + void attachEventBase(folly::EventBase* eventBase) { + if (eventBase && !socket_->getEventBase()) { + socket_->attachEventBase(eventBase); + } + } + + void detachEventBase() { + detachReadCallback(); + if (socket_->getEventBase()) { + socket_->detachEventBase(); + } + } + + void transportActive(Context* ctx) override { + ctx->getPipeline()->setTransport(socket_); + attachReadCallback(); + ctx->fireTransportActive(); + } + + void detachPipeline(Context* ctx) override { + detachReadCallback(); + } + + folly::Future write( + Context* ctx, + std::unique_ptr buf) override { + if (UNLIKELY(!buf)) { + return folly::makeFuture(); + } + + if (!socket_->good()) { + VLOG(5) << "socket is closed in write()"; + return folly::makeFuture(AsyncSocketException( + AsyncSocketException::AsyncSocketExceptionType::NOT_OPEN, + "socket is closed in write()")); + } + + auto cb = new WriteCallback(); + auto future = cb->promise_.getFuture(); + socket_->writeChain(cb, std::move(buf), ctx->getWriteFlags()); + return future; + }; + + folly::Future close(Context* ctx) override { + if (socket_) { + detachReadCallback(); + socket_->closeNow(); + } + ctx->getPipeline()->deletePipeline(); + return folly::makeFuture(); + } + + // Must override to avoid warnings about hidden overloaded virtual due to + // AsyncSocket::ReadCallback::readEOF() + void readEOF(Context* ctx) override { + ctx->fireReadEOF(); + } + + void getReadBuffer(void** bufReturn, size_t* lenReturn) override { + const auto readBufferSettings = getContext()->getReadBufferSettings(); + const auto ret = bufQueue_.preallocate( + readBufferSettings.first, + readBufferSettings.second); + *bufReturn = ret.first; + *lenReturn = ret.second; + } + + void readDataAvailable(size_t len) noexcept override { + bufQueue_.postallocate(len); + getContext()->fireRead(bufQueue_); + } + + void readEOF() noexcept override { + getContext()->fireReadEOF(); + } + + void readErr(const AsyncSocketException& ex) + noexcept override { + getContext()->fireReadException( + make_exception_wrapper(ex)); + } + + private: + class WriteCallback : private AsyncSocket::WriteCallback { + void writeSuccess() noexcept override { + promise_.setValue(); + delete this; + } + + void writeErr(size_t bytesWritten, + const AsyncSocketException& ex) + noexcept override { + promise_.setException(ex); + delete this; + } + + private: + friend class AsyncSocketHandler; + folly::Promise promise_; + }; + + folly::IOBufQueue bufQueue_{folly::IOBufQueue::cacheChainLength()}; + std::shared_ptr socket_{nullptr}; + bool firedInactive_{false}; +}; + +}} diff --git a/folly/wangle/channel/EventBaseHandler.h b/folly/wangle/channel/EventBaseHandler.h new file mode 100644 index 00000000..55290ded --- /dev/null +++ b/folly/wangle/channel/EventBaseHandler.h @@ -0,0 +1,45 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +namespace folly { namespace wangle { + +class EventBaseHandler : public OutboundBytesToBytesHandler { + public: + folly::Future write( + Context* ctx, + std::unique_ptr buf) override { + folly::Future retval; + DCHECK(ctx->getTransport()); + DCHECK(ctx->getTransport()->getEventBase()); + ctx->getTransport()->getEventBase()->runImmediatelyOrRunInEventBaseThreadAndWait([&](){ + retval = ctx->fireWrite(std::move(buf)); + }); + return retval; + } + + Future close(Context* ctx) override { + DCHECK(ctx->getTransport()); + DCHECK(ctx->getTransport()->getEventBase()); + Future retval; + ctx->getTransport()->getEventBase()->runImmediatelyOrRunInEventBaseThreadAndWait([&](){ + retval = ctx->fireClose(); + }); + return retval; + } +}; + +}} // namespace diff --git a/folly/wangle/channel/Handler.h b/folly/wangle/channel/Handler.h new file mode 100644 index 00000000..2080c2c2 --- /dev/null +++ b/folly/wangle/channel/Handler.h @@ -0,0 +1,173 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace folly { namespace wangle { + +template +class HandlerBase { + public: + virtual ~HandlerBase() {} + + virtual void attachPipeline(Context* ctx) {} + virtual void detachPipeline(Context* ctx) {} + + Context* getContext() { + if (attachCount_ != 1) { + return nullptr; + } + CHECK(ctx_); + return ctx_; + } + + private: + friend PipelineContext; + uint64_t attachCount_{0}; + Context* ctx_{nullptr}; +}; + +template +class Handler : public HandlerBase> { + public: + static const HandlerDir dir = HandlerDir::BOTH; + + typedef Rin rin; + typedef Rout rout; + typedef Win win; + typedef Wout wout; + typedef HandlerContext Context; + virtual ~Handler() {} + + virtual void read(Context* ctx, Rin msg) = 0; + virtual void readEOF(Context* ctx) { + ctx->fireReadEOF(); + } + virtual void readException(Context* ctx, exception_wrapper e) { + ctx->fireReadException(std::move(e)); + } + virtual void transportActive(Context* ctx) { + ctx->fireTransportActive(); + } + virtual void transportInactive(Context* ctx) { + ctx->fireTransportInactive(); + } + + virtual Future write(Context* ctx, Win msg) = 0; + virtual Future close(Context* ctx) { + return ctx->fireClose(); + } + + /* + // Other sorts of things we might want, all shamelessly stolen from Netty + // inbound + virtual void exceptionCaught( + HandlerContext* ctx, + exception_wrapper e) {} + virtual void channelRegistered(HandlerContext* ctx) {} + virtual void channelUnregistered(HandlerContext* ctx) {} + virtual void channelReadComplete(HandlerContext* ctx) {} + virtual void userEventTriggered(HandlerContext* ctx, void* evt) {} + virtual void channelWritabilityChanged(HandlerContext* ctx) {} + + // outbound + virtual Future bind( + HandlerContext* ctx, + SocketAddress localAddress) {} + virtual Future connect( + HandlerContext* ctx, + SocketAddress remoteAddress, SocketAddress localAddress) {} + virtual Future disconnect(HandlerContext* ctx) {} + virtual Future deregister(HandlerContext* ctx) {} + virtual Future read(HandlerContext* ctx) {} + virtual void flush(HandlerContext* ctx) {} + */ +}; + +template +class InboundHandler : public HandlerBase> { + public: + static const HandlerDir dir = HandlerDir::IN; + + typedef Rin rin; + typedef Rout rout; + typedef Nothing win; + typedef Nothing wout; + typedef InboundHandlerContext Context; + virtual ~InboundHandler() {} + + virtual void read(Context* ctx, Rin msg) = 0; + virtual void readEOF(Context* ctx) { + ctx->fireReadEOF(); + } + virtual void readException(Context* ctx, exception_wrapper e) { + ctx->fireReadException(std::move(e)); + } + virtual void transportActive(Context* ctx) { + ctx->fireTransportActive(); + } + virtual void transportInactive(Context* ctx) { + ctx->fireTransportInactive(); + } +}; + +template +class OutboundHandler : public HandlerBase> { + public: + static const HandlerDir dir = HandlerDir::OUT; + + typedef Nothing rin; + typedef Nothing rout; + typedef Win win; + typedef Wout wout; + typedef OutboundHandlerContext Context; + virtual ~OutboundHandler() {} + + virtual Future write(Context* ctx, Win msg) = 0; + virtual Future close(Context* ctx) { + return ctx->fireClose(); + } +}; + +template +class HandlerAdapter : public Handler { + public: + typedef typename Handler::Context Context; + + void read(Context* ctx, R msg) override { + ctx->fireRead(std::forward(msg)); + } + + Future write(Context* ctx, W msg) override { + return ctx->fireWrite(std::forward(msg)); + } +}; + +typedef HandlerAdapter> +BytesToBytesHandler; + +typedef InboundHandler> +InboundBytesToBytesHandler; + +typedef OutboundHandler> +OutboundBytesToBytesHandler; + +}} diff --git a/folly/wangle/channel/HandlerContext-inl.h b/folly/wangle/channel/HandlerContext-inl.h new file mode 100644 index 00000000..9a220bcd --- /dev/null +++ b/folly/wangle/channel/HandlerContext-inl.h @@ -0,0 +1,447 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace folly { namespace wangle { + +class PipelineContext { + public: + virtual ~PipelineContext() {} + + virtual void attachPipeline() = 0; + virtual void detachPipeline() = 0; + + template + void attachContext(H* handler, HandlerContext* ctx) { + if (++handler->attachCount_ == 1) { + handler->ctx_ = ctx; + } else { + handler->ctx_ = nullptr; + } + } + + virtual void setNextIn(PipelineContext* ctx) = 0; + virtual void setNextOut(PipelineContext* ctx) = 0; +}; + +template +class InboundLink { + public: + virtual ~InboundLink() {} + virtual void read(In msg) = 0; + virtual void readEOF() = 0; + virtual void readException(exception_wrapper e) = 0; + virtual void transportActive() = 0; + virtual void transportInactive() = 0; +}; + +template +class OutboundLink { + public: + virtual ~OutboundLink() {} + virtual Future write(Out msg) = 0; + virtual Future close() = 0; +}; + +template +class ContextImplBase : public PipelineContext { + public: + ~ContextImplBase() {} + + H* getHandler() { + return handler_.get(); + } + + void initialize(P* pipeline, std::shared_ptr handler) { + pipeline_ = pipeline; + handler_ = std::move(handler); + } + + // PipelineContext overrides + void attachPipeline() override { + if (!attached_) { + this->attachContext(handler_.get(), impl_); + handler_->attachPipeline(impl_); + attached_ = true; + } + } + + void detachPipeline() override { + handler_->detachPipeline(impl_); + attached_ = false; + } + + void setNextIn(PipelineContext* ctx) override { + auto nextIn = dynamic_cast*>(ctx); + if (nextIn) { + nextIn_ = nextIn; + } else { + throw std::invalid_argument("inbound type mismatch"); + } + } + + void setNextOut(PipelineContext* ctx) override { + auto nextOut = dynamic_cast*>(ctx); + if (nextOut) { + nextOut_ = nextOut; + } else { + throw std::invalid_argument("outbound type mismatch"); + } + } + + protected: + Context* impl_; + P* pipeline_; + std::shared_ptr handler_; + InboundLink* nextIn_{nullptr}; + OutboundLink* nextOut_{nullptr}; + + private: + bool attached_{false}; + using DestructorGuard = typename P::DestructorGuard; +}; + +template +class ContextImpl + : public HandlerContext, + public InboundLink, + public OutboundLink, + public ContextImplBase> { + public: + typedef typename H::rin Rin; + typedef typename H::rout Rout; + typedef typename H::win Win; + typedef typename H::wout Wout; + static const HandlerDir dir = HandlerDir::BOTH; + + explicit ContextImpl(P* pipeline, std::shared_ptr handler) { + this->impl_ = this; + this->initialize(pipeline, std::move(handler)); + } + + // For StaticPipeline + ContextImpl() { + this->impl_ = this; + } + + ~ContextImpl() {} + + // HandlerContext overrides + void fireRead(Rout msg) override { + DestructorGuard dg(this->pipeline_); + if (this->nextIn_) { + this->nextIn_->read(std::forward(msg)); + } else { + LOG(WARNING) << "read reached end of pipeline"; + } + } + + void fireReadEOF() override { + DestructorGuard dg(this->pipeline_); + if (this->nextIn_) { + this->nextIn_->readEOF(); + } else { + LOG(WARNING) << "readEOF reached end of pipeline"; + } + } + + void fireReadException(exception_wrapper e) override { + DestructorGuard dg(this->pipeline_); + if (this->nextIn_) { + this->nextIn_->readException(std::move(e)); + } else { + LOG(WARNING) << "readException reached end of pipeline"; + } + } + + void fireTransportActive() override { + DestructorGuard dg(this->pipeline_); + if (this->nextIn_) { + this->nextIn_->transportActive(); + } + } + + void fireTransportInactive() override { + DestructorGuard dg(this->pipeline_); + if (this->nextIn_) { + this->nextIn_->transportInactive(); + } + } + + Future fireWrite(Wout msg) override { + DestructorGuard dg(this->pipeline_); + if (this->nextOut_) { + return this->nextOut_->write(std::forward(msg)); + } else { + LOG(WARNING) << "write reached end of pipeline"; + return makeFuture(); + } + } + + Future fireClose() override { + DestructorGuard dg(this->pipeline_); + if (this->nextOut_) { + return this->nextOut_->close(); + } else { + LOG(WARNING) << "close reached end of pipeline"; + return makeFuture(); + } + } + + PipelineBase* getPipeline() override { + return this->pipeline_; + } + + void setWriteFlags(WriteFlags flags) override { + this->pipeline_->setWriteFlags(flags); + } + + WriteFlags getWriteFlags() override { + return this->pipeline_->getWriteFlags(); + } + + void setReadBufferSettings( + uint64_t minAvailable, + uint64_t allocationSize) override { + this->pipeline_->setReadBufferSettings(minAvailable, allocationSize); + } + + std::pair getReadBufferSettings() override { + return this->pipeline_->getReadBufferSettings(); + } + + // InboundLink overrides + void read(Rin msg) override { + DestructorGuard dg(this->pipeline_); + this->handler_->read(this, std::forward(msg)); + } + + void readEOF() override { + DestructorGuard dg(this->pipeline_); + this->handler_->readEOF(this); + } + + void readException(exception_wrapper e) override { + DestructorGuard dg(this->pipeline_); + this->handler_->readException(this, std::move(e)); + } + + void transportActive() override { + DestructorGuard dg(this->pipeline_); + this->handler_->transportActive(this); + } + + void transportInactive() override { + DestructorGuard dg(this->pipeline_); + this->handler_->transportInactive(this); + } + + // OutboundLink overrides + Future write(Win msg) override { + DestructorGuard dg(this->pipeline_); + return this->handler_->write(this, std::forward(msg)); + } + + Future close() override { + DestructorGuard dg(this->pipeline_); + return this->handler_->close(this); + } + + private: + using DestructorGuard = typename P::DestructorGuard; +}; + +template +class InboundContextImpl + : public InboundHandlerContext, + public InboundLink, + public ContextImplBase> { + public: + typedef typename H::rin Rin; + typedef typename H::rout Rout; + typedef typename H::win Win; + typedef typename H::wout Wout; + static const HandlerDir dir = HandlerDir::IN; + + explicit InboundContextImpl(P* pipeline, std::shared_ptr handler) { + this->impl_ = this; + this->initialize(pipeline, std::move(handler)); + } + + // For StaticPipeline + InboundContextImpl() { + this->impl_ = this; + } + + ~InboundContextImpl() {} + + // InboundHandlerContext overrides + void fireRead(Rout msg) override { + DestructorGuard dg(this->pipeline_); + if (this->nextIn_) { + this->nextIn_->read(std::forward(msg)); + } else { + LOG(WARNING) << "read reached end of pipeline"; + } + } + + void fireReadEOF() override { + DestructorGuard dg(this->pipeline_); + if (this->nextIn_) { + this->nextIn_->readEOF(); + } else { + LOG(WARNING) << "readEOF reached end of pipeline"; + } + } + + void fireReadException(exception_wrapper e) override { + DestructorGuard dg(this->pipeline_); + if (this->nextIn_) { + this->nextIn_->readException(std::move(e)); + } else { + LOG(WARNING) << "readException reached end of pipeline"; + } + } + + void fireTransportActive() override { + DestructorGuard dg(this->pipeline_); + if (this->nextIn_) { + this->nextIn_->transportActive(); + } + } + + void fireTransportInactive() override { + DestructorGuard dg(this->pipeline_); + if (this->nextIn_) { + this->nextIn_->transportInactive(); + } + } + + PipelineBase* getPipeline() override { + return this->pipeline_; + } + + // InboundLink overrides + void read(Rin msg) override { + DestructorGuard dg(this->pipeline_); + this->handler_->read(this, std::forward(msg)); + } + + void readEOF() override { + DestructorGuard dg(this->pipeline_); + this->handler_->readEOF(this); + } + + void readException(exception_wrapper e) override { + DestructorGuard dg(this->pipeline_); + this->handler_->readException(this, std::move(e)); + } + + void transportActive() override { + DestructorGuard dg(this->pipeline_); + this->handler_->transportActive(this); + } + + void transportInactive() override { + DestructorGuard dg(this->pipeline_); + this->handler_->transportInactive(this); + } + + private: + using DestructorGuard = typename P::DestructorGuard; +}; + +template +class OutboundContextImpl + : public OutboundHandlerContext, + public OutboundLink, + public ContextImplBase> { + public: + typedef typename H::rin Rin; + typedef typename H::rout Rout; + typedef typename H::win Win; + typedef typename H::wout Wout; + static const HandlerDir dir = HandlerDir::OUT; + + explicit OutboundContextImpl(P* pipeline, std::shared_ptr handler) { + this->impl_ = this; + this->initialize(pipeline, std::move(handler)); + } + + // For StaticPipeline + OutboundContextImpl() { + this->impl_ = this; + } + + ~OutboundContextImpl() {} + + // OutboundHandlerContext overrides + Future fireWrite(Wout msg) override { + DestructorGuard dg(this->pipeline_); + if (this->nextOut_) { + return this->nextOut_->write(std::forward(msg)); + } else { + LOG(WARNING) << "write reached end of pipeline"; + return makeFuture(); + } + } + + Future fireClose() override { + DestructorGuard dg(this->pipeline_); + if (this->nextOut_) { + return this->nextOut_->close(); + } else { + LOG(WARNING) << "close reached end of pipeline"; + return makeFuture(); + } + } + + PipelineBase* getPipeline() override { + return this->pipeline_; + } + + // OutboundLink overrides + Future write(Win msg) override { + DestructorGuard dg(this->pipeline_); + return this->handler_->write(this, std::forward(msg)); + } + + Future close() override { + DestructorGuard dg(this->pipeline_); + return this->handler_->close(this); + } + + private: + using DestructorGuard = typename P::DestructorGuard; +}; + +template +struct ContextType { + typedef typename std::conditional< + Handler::dir == HandlerDir::BOTH, + ContextImpl, + typename std::conditional< + Handler::dir == HandlerDir::IN, + InboundContextImpl, + OutboundContextImpl + >::type>::type + type; +}; + +}} // folly::wangle diff --git a/folly/wangle/channel/HandlerContext.h b/folly/wangle/channel/HandlerContext.h new file mode 100644 index 00000000..ddd9a576 --- /dev/null +++ b/folly/wangle/channel/HandlerContext.h @@ -0,0 +1,108 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace folly { namespace wangle { + +class PipelineBase; + +template +class HandlerContext { + public: + virtual ~HandlerContext() {} + + virtual void fireRead(In msg) = 0; + virtual void fireReadEOF() = 0; + virtual void fireReadException(exception_wrapper e) = 0; + virtual void fireTransportActive() = 0; + virtual void fireTransportInactive() = 0; + + virtual Future fireWrite(Out msg) = 0; + virtual Future fireClose() = 0; + + virtual PipelineBase* getPipeline() = 0; + std::shared_ptr getTransport() { + return getPipeline()->getTransport(); + } + + virtual void setWriteFlags(WriteFlags flags) = 0; + virtual WriteFlags getWriteFlags() = 0; + + virtual void setReadBufferSettings( + uint64_t minAvailable, + uint64_t allocationSize) = 0; + virtual std::pair getReadBufferSettings() = 0; + + /* TODO + template + virtual void addHandlerBefore(H&&) {} + template + virtual void addHandlerAfter(H&&) {} + template + virtual void replaceHandler(H&&) {} + virtual void removeHandler() {} + */ +}; + +template +class InboundHandlerContext { + public: + virtual ~InboundHandlerContext() {} + + virtual void fireRead(In msg) = 0; + virtual void fireReadEOF() = 0; + virtual void fireReadException(exception_wrapper e) = 0; + virtual void fireTransportActive() = 0; + virtual void fireTransportInactive() = 0; + + virtual PipelineBase* getPipeline() = 0; + std::shared_ptr getTransport() { + return getPipeline()->getTransport(); + } + + // TODO Need get/set writeFlags, readBufferSettings? Probably not. + // Do we even really need them stored in the pipeline at all? + // Could just always delegate to the socket impl +}; + +template +class OutboundHandlerContext { + public: + virtual ~OutboundHandlerContext() {} + + virtual Future fireWrite(Out msg) = 0; + virtual Future fireClose() = 0; + + virtual PipelineBase* getPipeline() = 0; + std::shared_ptr getTransport() { + return getPipeline()->getTransport(); + } +}; + +enum class HandlerDir { + IN, + OUT, + BOTH +}; + +}} // folly::wangle + +#include diff --git a/folly/wangle/channel/OutputBufferingHandler.h b/folly/wangle/channel/OutputBufferingHandler.h new file mode 100644 index 00000000..d712b8a0 --- /dev/null +++ b/folly/wangle/channel/OutputBufferingHandler.h @@ -0,0 +1,84 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace folly { namespace wangle { + +/* + * OutputBufferingHandler buffers writes in order to minimize syscalls. The + * transport will be written to once per event loop instead of on every write. + * + * This handler may only be used in a single Pipeline. + */ +class OutputBufferingHandler : public OutboundBytesToBytesHandler, + protected EventBase::LoopCallback { + public: + Future write(Context* ctx, std::unique_ptr buf) override { + CHECK(buf); + if (!queueSends_) { + return ctx->fireWrite(std::move(buf)); + } else { + // Delay sends to optimize for fewer syscalls + if (!sends_) { + DCHECK(!isLoopCallbackScheduled()); + // Buffer all the sends, and call writev once per event loop. + sends_ = std::move(buf); + ctx->getTransport()->getEventBase()->runInLoop(this); + } else { + DCHECK(isLoopCallbackScheduled()); + sends_->prependChain(std::move(buf)); + } + return sharedPromise_.getFuture(); + } + } + + void runLoopCallback() noexcept override { + MoveWrapper> sharedPromise; + std::swap(*sharedPromise, sharedPromise_); + getContext()->fireWrite(std::move(sends_)) + .then([sharedPromise](Try t) mutable { + sharedPromise->setTry(std::move(t)); + }); + } + + Future close(Context* ctx) override { + if (isLoopCallbackScheduled()) { + cancelLoopCallback(); + } + + // If there are sends queued, cancel them + sharedPromise_.setException( + folly::make_exception_wrapper( + "close() called while sends still pending")); + sends_.reset(); + sharedPromise_ = SharedPromise(); + return ctx->fireClose(); + } + + SharedPromise sharedPromise_; + std::unique_ptr sends_{nullptr}; + bool queueSends_{true}; +}; + +}} diff --git a/folly/wangle/channel/Pipeline-inl.h b/folly/wangle/channel/Pipeline-inl.h new file mode 100644 index 00000000..7c1d46bc --- /dev/null +++ b/folly/wangle/channel/Pipeline-inl.h @@ -0,0 +1,267 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace folly { namespace wangle { + +template +Pipeline::Pipeline() : isStatic_(false) {} + +template +Pipeline::Pipeline(bool isStatic) : isStatic_(isStatic) { + CHECK(isStatic_); +} + +template +Pipeline::~Pipeline() { + if (!isStatic_) { + detachHandlers(); + } +} + +template +void Pipeline::setWriteFlags(WriteFlags flags) { + writeFlags_ = flags; +} + +template +WriteFlags Pipeline::getWriteFlags() { + return writeFlags_; +} + +template +void Pipeline::setReadBufferSettings( + uint64_t minAvailable, + uint64_t allocationSize) { + readBufferSettings_ = std::make_pair(minAvailable, allocationSize); +} + +template +std::pair Pipeline::getReadBufferSettings() { + return readBufferSettings_; +} + +template +template +typename std::enable_if::value>::type +Pipeline::read(R msg) { + if (!front_) { + throw std::invalid_argument("read(): no inbound handler in Pipeline"); + } + front_->read(std::forward(msg)); +} + +template +template +typename std::enable_if::value>::type +Pipeline::readEOF() { + if (!front_) { + throw std::invalid_argument("readEOF(): no inbound handler in Pipeline"); + } + front_->readEOF(); +} + +template +template +typename std::enable_if::value>::type +Pipeline::transportActive() { + if (front_) { + front_->transportActive(); + } +} + +template +template +typename std::enable_if::value>::type +Pipeline::transportInactive() { + if (front_) { + front_->transportInactive(); + } +} + +template +template +typename std::enable_if::value>::type +Pipeline::readException(exception_wrapper e) { + if (!front_) { + throw std::invalid_argument( + "readException(): no inbound handler in Pipeline"); + } + front_->readException(std::move(e)); +} + +template +template +typename std::enable_if::value, Future>::type +Pipeline::write(W msg) { + if (!back_) { + throw std::invalid_argument("write(): no outbound handler in Pipeline"); + } + return back_->write(std::forward(msg)); +} + +template +template +typename std::enable_if::value, Future>::type +Pipeline::close() { + if (!back_) { + throw std::invalid_argument("close(): no outbound handler in Pipeline"); + } + return back_->close(); +} + +template +template +Pipeline& Pipeline::addBack(std::shared_ptr handler) { + typedef typename ContextType>::type Context; + return addHelper(std::make_shared(this, std::move(handler)), false); +} + +template +template +Pipeline& Pipeline::addBack(H&& handler) { + return addBack(std::make_shared(std::forward(handler))); +} + +template +template +Pipeline& Pipeline::addBack(H* handler) { + return addBack(std::shared_ptr(handler, [](H*){})); +} + +template +template +Pipeline& Pipeline::addFront(std::shared_ptr handler) { + typedef typename ContextType>::type Context; + return addHelper(std::make_shared(this, std::move(handler)), true); +} + +template +template +Pipeline& Pipeline::addFront(H&& handler) { + return addFront(std::make_shared(std::forward(handler))); +} + +template +template +Pipeline& Pipeline::addFront(H* handler) { + return addFront(std::shared_ptr(handler, [](H*){})); +} + +template +template +H* Pipeline::getHandler(int i) { + typedef typename ContextType>::type Context; + auto ctx = dynamic_cast(ctxs_[i].get()); + CHECK(ctx); + return ctx->getHandler(); +} + +namespace detail { + +template +inline void logWarningIfNotNothing(const std::string& warning) { + LOG(WARNING) << warning; +} + +template <> +inline void logWarningIfNotNothing(const std::string& warning) { + // do nothing +} + +} // detail + +// TODO Have read/write/etc check that pipeline has been finalized +template +void Pipeline::finalize() { + if (!inCtxs_.empty()) { + front_ = dynamic_cast*>(inCtxs_.front()); + for (size_t i = 0; i < inCtxs_.size() - 1; i++) { + inCtxs_[i]->setNextIn(inCtxs_[i+1]); + } + } + + if (!outCtxs_.empty()) { + back_ = dynamic_cast*>(outCtxs_.back()); + for (size_t i = outCtxs_.size() - 1; i > 0; i--) { + outCtxs_[i]->setNextOut(outCtxs_[i-1]); + } + } + + if (!front_) { + detail::logWarningIfNotNothing( + "No inbound handler in Pipeline, inbound operations will throw " + "std::invalid_argument"); + } + if (!back_) { + detail::logWarningIfNotNothing( + "No outbound handler in Pipeline, outbound operations will throw " + "std::invalid_argument"); + } + + for (auto it = ctxs_.rbegin(); it != ctxs_.rend(); it++) { + (*it)->attachPipeline(); + } +} + +template +template +bool Pipeline::setOwner(H* handler) { + typedef typename ContextType>::type Context; + for (auto& ctx : ctxs_) { + auto ctxImpl = dynamic_cast(ctx.get()); + if (ctxImpl && ctxImpl->getHandler() == handler) { + owner_ = ctx; + return true; + } + } + return false; +} + +template +template +void Pipeline::addContextFront(Context* ctx) { + addHelper(std::shared_ptr(ctx, [](Context*){}), true); +} + +template +void Pipeline::detachHandlers() { + for (auto& ctx : ctxs_) { + if (ctx != owner_) { + ctx->detachPipeline(); + } + } +} + +template +template +Pipeline& Pipeline::addHelper( + std::shared_ptr&& ctx, + bool front) { + ctxs_.insert(front ? ctxs_.begin() : ctxs_.end(), ctx); + if (Context::dir == HandlerDir::BOTH || Context::dir == HandlerDir::IN) { + inCtxs_.insert(front ? inCtxs_.begin() : inCtxs_.end(), ctx.get()); + } + if (Context::dir == HandlerDir::BOTH || Context::dir == HandlerDir::OUT) { + outCtxs_.insert(front ? outCtxs_.begin() : outCtxs_.end(), ctx.get()); + } + return *this; +} + +}} // folly::wangle diff --git a/folly/wangle/channel/Pipeline.h b/folly/wangle/channel/Pipeline.h new file mode 100644 index 00000000..81b68023 --- /dev/null +++ b/folly/wangle/channel/Pipeline.h @@ -0,0 +1,182 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace folly { namespace wangle { + +class PipelineManager { + public: + virtual ~PipelineManager() {} + virtual void deletePipeline(PipelineBase* pipeline) = 0; +}; + +class PipelineBase { + public: + virtual ~PipelineBase() {} + + void setPipelineManager(PipelineManager* manager) { + manager_ = manager; + } + + void deletePipeline() { + if (manager_) { + manager_->deletePipeline(this); + } + } + + void setTransport(std::shared_ptr transport) { + transport_ = transport; + } + + std::shared_ptr getTransport() { + return transport_; + } + + private: + PipelineManager* manager_{nullptr}; + std::shared_ptr transport_; +}; + +struct Nothing{}; + +/* + * R is the inbound type, i.e. inbound calls start with pipeline.read(R) + * W is the outbound type, i.e. outbound calls start with pipeline.write(W) + * + * Use Nothing for one of the types if your pipeline is unidirectional. + * If R is Nothing, read(), readEOF(), and readException() will be disabled. + * If W is Nothing, write() and close() will be disabled. + */ +template +class Pipeline : public PipelineBase, public DelayedDestruction { + public: + Pipeline(); + ~Pipeline(); + + void setWriteFlags(WriteFlags flags); + WriteFlags getWriteFlags(); + + void setReadBufferSettings(uint64_t minAvailable, uint64_t allocationSize); + std::pair getReadBufferSettings(); + + template + typename std::enable_if::value>::type + read(R msg); + + template + typename std::enable_if::value>::type + readEOF(); + + template + typename std::enable_if::value>::type + readException(exception_wrapper e); + + template + typename std::enable_if::value>::type + transportActive(); + + template + typename std::enable_if::value>::type + transportInactive(); + + template + typename std::enable_if::value, Future>::type + write(W msg); + + template + typename std::enable_if::value, Future>::type + close(); + + template + Pipeline& addBack(std::shared_ptr handler); + + template + Pipeline& addBack(H&& handler); + + template + Pipeline& addBack(H* handler); + + template + Pipeline& addFront(std::shared_ptr handler); + + template + Pipeline& addFront(H&& handler); + + template + Pipeline& addFront(H* handler); + + template + H* getHandler(int i); + + void finalize(); + + // If one of the handlers owns the pipeline itself, use setOwner to ensure + // that the pipeline doesn't try to detach the handler during destruction, + // lest destruction ordering issues occur. + // See thrift/lib/cpp2/async/Cpp2Channel.cpp for an example + template + bool setOwner(H* handler); + + protected: + explicit Pipeline(bool isStatic); + + template + void addContextFront(Context* ctx); + + void detachHandlers(); + + private: + template + Pipeline& addHelper(std::shared_ptr&& ctx, bool front); + + WriteFlags writeFlags_{WriteFlags::NONE}; + std::pair readBufferSettings_{2048, 2048}; + + bool isStatic_{false}; + std::shared_ptr owner_; + std::vector> ctxs_; + std::vector inCtxs_; + std::vector outCtxs_; + InboundLink* front_{nullptr}; + OutboundLink* back_{nullptr}; +}; + +}} + +namespace folly { + +class AsyncSocket; + +template +class PipelineFactory { + public: + virtual std::unique_ptr + newPipeline(std::shared_ptr) = 0; + + virtual ~PipelineFactory() {} +}; + +} + +#include diff --git a/folly/wangle/channel/StaticPipeline.h b/folly/wangle/channel/StaticPipeline.h new file mode 100644 index 00000000..a5d2e893 --- /dev/null +++ b/folly/wangle/channel/StaticPipeline.h @@ -0,0 +1,137 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +namespace folly { namespace wangle { + +/* + * StaticPipeline allows you to create a Pipeline with minimal allocations. + * Specify your handlers after the input/output types of your Pipeline in order + * from front to back, and construct with either H&&, H*, or std::shared_ptr + * for each handler. The pipeline will be finalized for you at the end of + * construction. For example: + * + * StringToStringHandler stringHandler1; + * auto stringHandler2 = std::make_shared(); + * + * StaticPipeline( + * IntToStringHandler(), // H&& + * &stringHandler1, // H* + * stringHandler2) // std::shared_ptr + * pipeline; + * + * You can then use pipeline just like any Pipeline. See Pipeline.h. + */ +template +class StaticPipeline; + +template +class StaticPipeline : public Pipeline { + protected: + explicit StaticPipeline(bool) : Pipeline(true) {} +}; + +template +class BaseWithOptional { + protected: + folly::Optional handler_; +}; + +template +class BaseWithoutOptional { +}; + +template +class StaticPipeline + : public StaticPipeline + , public std::conditional::value, + BaseWithoutOptional, + BaseWithOptional>::type { + public: + template + explicit StaticPipeline(HandlerArgs&&... handlers) + : StaticPipeline(true, std::forward(handlers)...) { + isFirst_ = true; + } + + ~StaticPipeline() { + if (isFirst_) { + Pipeline::detachHandlers(); + } + } + + protected: + template + StaticPipeline( + bool isFirst, + HandlerArg&& handler, + HandlerArgs&&... handlers) + : StaticPipeline( + false, + std::forward(handlers)...) { + isFirst_ = isFirst; + setHandler(std::forward(handler)); + CHECK(handlerPtr_); + ctx_.initialize(this, handlerPtr_); + Pipeline::addContextFront(&ctx_); + if (isFirst_) { + Pipeline::finalize(); + } + } + + private: + template + typename std::enable_if::type, + Handler + >::value>::type + setHandler(HandlerArg&& arg) { + BaseWithOptional::handler_.emplace(std::forward(arg)); + handlerPtr_ = std::shared_ptr(&(*BaseWithOptional::handler_), [](Handler*){}); + } + + template + typename std::enable_if::type, + std::shared_ptr + >::value>::type + setHandler(HandlerArg&& arg) { + handlerPtr_ = std::forward(arg); + } + + template + typename std::enable_if::type, + Handler* + >::value>::type + setHandler(HandlerArg&& arg) { + handlerPtr_ = std::shared_ptr(arg, [](Handler*){}); + } + + bool isFirst_; + std::shared_ptr handlerPtr_; + typename ContextType>::type ctx_; +}; + +}} // folly::wangle diff --git a/folly/wangle/channel/test/MockHandler.h b/folly/wangle/channel/test/MockHandler.h new file mode 100644 index 00000000..5a476646 --- /dev/null +++ b/folly/wangle/channel/test/MockHandler.h @@ -0,0 +1,75 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace folly { namespace wangle { + +template +class MockHandler : public Handler { + public: + typedef typename Handler::Context Context; + + MockHandler() = default; + MockHandler(MockHandler&&) = default; + +#ifdef __clang__ +# pragma clang diagnostic push +# if __clang_major__ > 3 || __clang_minor__ >= 6 +# pragma clang diagnostic ignored "-Winconsistent-missing-override" +# endif +#endif + + MOCK_METHOD2_T(read_, void(Context*, Rin&)); + MOCK_METHOD1_T(readEOF, void(Context*)); + MOCK_METHOD2_T(readException, void(Context*, exception_wrapper)); + + MOCK_METHOD2_T(write_, void(Context*, Win&)); + MOCK_METHOD1_T(close_, void(Context*)); + + MOCK_METHOD1_T(attachPipeline, void(Context*)); + MOCK_METHOD1_T(attachTransport, void(Context*)); + MOCK_METHOD1_T(detachPipeline, void(Context*)); + MOCK_METHOD1_T(detachTransport, void(Context*)); + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + + void read(Context* ctx, Rin msg) override { + read_(ctx, msg); + } + + Future write(Context* ctx, Win msg) override { + return makeFutureWith([&](){ + write_(ctx, msg); + }); + } + + Future close(Context* ctx) override { + return makeFutureWith([&](){ + close_(ctx); + }); + } +}; + +template +using MockHandlerAdapter = MockHandler; + +}} diff --git a/folly/wangle/channel/test/OutputBufferingHandlerTest.cpp b/folly/wangle/channel/test/OutputBufferingHandlerTest.cpp new file mode 100644 index 00000000..0fce7911 --- /dev/null +++ b/folly/wangle/channel/test/OutputBufferingHandlerTest.cpp @@ -0,0 +1,65 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +using namespace folly; +using namespace folly::wangle; +using namespace testing; + +typedef StrictMock>> +MockBytesHandler; + +MATCHER_P(IOBufContains, str, "") { return arg->moveToFbString() == str; } + +TEST(OutputBufferingHandlerTest, Basic) { + MockBytesHandler mockHandler; + EXPECT_CALL(mockHandler, attachPipeline(_)); + StaticPipeline, + MockBytesHandler, + OutputBufferingHandler> + pipeline(&mockHandler, OutputBufferingHandler{}); + + EventBase eb; + auto socket = AsyncSocket::newSocket(&eb); + pipeline.setTransport(socket); + + // Buffering should prevent writes until the EB loops, and the writes should + // be batched into one write call. + auto f1 = pipeline.write(IOBuf::copyBuffer("hello")); + auto f2 = pipeline.write(IOBuf::copyBuffer("world")); + EXPECT_FALSE(f1.isReady()); + EXPECT_FALSE(f2.isReady()); + EXPECT_CALL(mockHandler, write_(_, IOBufContains("helloworld"))); + eb.loopOnce(); + EXPECT_TRUE(f1.isReady()); + EXPECT_TRUE(f2.isReady()); + EXPECT_CALL(mockHandler, detachPipeline(_)); + + // Make sure the SharedPromise resets correctly + auto f = pipeline.write(IOBuf::copyBuffer("foo")); + EXPECT_FALSE(f.isReady()); + EXPECT_CALL(mockHandler, write_(_, IOBufContains("foo"))); + eb.loopOnce(); + EXPECT_TRUE(f.isReady()); +} diff --git a/folly/wangle/channel/test/PipelineTest.cpp b/folly/wangle/channel/test/PipelineTest.cpp new file mode 100644 index 00000000..cdc4e980 --- /dev/null +++ b/folly/wangle/channel/test/PipelineTest.cpp @@ -0,0 +1,306 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace folly; +using namespace folly::wangle; +using namespace testing; + +typedef StrictMock> IntHandler; +class IntHandler2 : public StrictMock> {}; + +ACTION(FireRead) { + arg0->fireRead(arg1); +} + +ACTION(FireReadEOF) { + arg0->fireReadEOF(); +} + +ACTION(FireReadException) { + arg0->fireReadException(arg1); +} + +ACTION(FireWrite) { + arg0->fireWrite(arg1); +} + +ACTION(FireClose) { + arg0->fireClose(); +} + +// Test move only types, among other things +TEST(PipelineTest, RealHandlersCompile) { + EventBase eb; + auto socket = AsyncSocket::newSocket(&eb); + // static + { + StaticPipeline, + AsyncSocketHandler, + OutputBufferingHandler> + pipeline{AsyncSocketHandler(socket), OutputBufferingHandler()}; + EXPECT_TRUE(pipeline.getHandler(0)); + EXPECT_TRUE(pipeline.getHandler(1)); + } + // dynamic + { + Pipeline> pipeline; + pipeline + .addBack(AsyncSocketHandler(socket)) + .addBack(OutputBufferingHandler()) + .finalize(); + EXPECT_TRUE(pipeline.getHandler(0)); + EXPECT_TRUE(pipeline.getHandler(1)); + } +} + +// Test that handlers correctly fire the next handler when directed +TEST(PipelineTest, FireActions) { + IntHandler handler1; + IntHandler2 handler2; + + { + InSequence sequence; + EXPECT_CALL(handler2, attachPipeline(_)); + EXPECT_CALL(handler1, attachPipeline(_)); + } + + StaticPipeline + pipeline(&handler1, &handler2); + + EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead()); + EXPECT_CALL(handler2, read_(_, _)).Times(1); + pipeline.read(1); + + EXPECT_CALL(handler1, readEOF(_)).WillOnce(FireReadEOF()); + EXPECT_CALL(handler2, readEOF(_)).Times(1); + pipeline.readEOF(); + + EXPECT_CALL(handler1, readException(_, _)).WillOnce(FireReadException()); + EXPECT_CALL(handler2, readException(_, _)).Times(1); + pipeline.readException(make_exception_wrapper("blah")); + + EXPECT_CALL(handler2, write_(_, _)).WillOnce(FireWrite()); + EXPECT_CALL(handler1, write_(_, _)).Times(1); + EXPECT_NO_THROW(pipeline.write(1).value()); + + EXPECT_CALL(handler2, close_(_)).WillOnce(FireClose()); + EXPECT_CALL(handler1, close_(_)).Times(1); + EXPECT_NO_THROW(pipeline.close().value()); + + { + InSequence sequence; + EXPECT_CALL(handler1, detachPipeline(_)); + EXPECT_CALL(handler2, detachPipeline(_)); + } +} + +// Test that nothing bad happens when actions reach the end of the pipeline +// (a warning will be logged, however) +TEST(PipelineTest, ReachEndOfPipeline) { + IntHandler handler; + EXPECT_CALL(handler, attachPipeline(_)); + StaticPipeline + pipeline(&handler); + + EXPECT_CALL(handler, read_(_, _)).WillOnce(FireRead()); + pipeline.read(1); + + EXPECT_CALL(handler, readEOF(_)).WillOnce(FireReadEOF()); + pipeline.readEOF(); + + EXPECT_CALL(handler, readException(_, _)).WillOnce(FireReadException()); + pipeline.readException(make_exception_wrapper("blah")); + + EXPECT_CALL(handler, write_(_, _)).WillOnce(FireWrite()); + EXPECT_NO_THROW(pipeline.write(1).value()); + + EXPECT_CALL(handler, close_(_)).WillOnce(FireClose()); + EXPECT_NO_THROW(pipeline.close().value()); + + EXPECT_CALL(handler, detachPipeline(_)); +} + +// Test having the last read handler turn around and write +TEST(PipelineTest, TurnAround) { + IntHandler handler1; + IntHandler2 handler2; + + { + InSequence sequence; + EXPECT_CALL(handler2, attachPipeline(_)); + EXPECT_CALL(handler1, attachPipeline(_)); + } + + StaticPipeline + pipeline(&handler1, &handler2); + + EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead()); + EXPECT_CALL(handler2, read_(_, _)).WillOnce(FireWrite()); + EXPECT_CALL(handler1, write_(_, _)).Times(1); + pipeline.read(1); + + { + InSequence sequence; + EXPECT_CALL(handler1, detachPipeline(_)); + EXPECT_CALL(handler2, detachPipeline(_)); + } +} + +TEST(PipelineTest, DynamicFireActions) { + IntHandler handler1, handler2, handler3; + EXPECT_CALL(handler2, attachPipeline(_)); + StaticPipeline + pipeline(&handler2); + + { + InSequence sequence; + EXPECT_CALL(handler3, attachPipeline(_)); + EXPECT_CALL(handler1, attachPipeline(_)); + } + + pipeline + .addFront(&handler1) + .addBack(&handler3) + .finalize(); + + EXPECT_TRUE(pipeline.getHandler(0)); + EXPECT_TRUE(pipeline.getHandler(1)); + EXPECT_TRUE(pipeline.getHandler(2)); + + EXPECT_CALL(handler1, read_(_, _)).WillOnce(FireRead()); + EXPECT_CALL(handler2, read_(_, _)).WillOnce(FireRead()); + EXPECT_CALL(handler3, read_(_, _)).Times(1); + pipeline.read(1); + + EXPECT_CALL(handler3, write_(_, _)).WillOnce(FireWrite()); + EXPECT_CALL(handler2, write_(_, _)).WillOnce(FireWrite()); + EXPECT_CALL(handler1, write_(_, _)).Times(1); + EXPECT_NO_THROW(pipeline.write(1).value()); + + { + InSequence sequence; + EXPECT_CALL(handler1, detachPipeline(_)); + EXPECT_CALL(handler2, detachPipeline(_)); + EXPECT_CALL(handler3, detachPipeline(_)); + } +} + +TEST(PipelineTest, DynamicAttachDetachOrder) { + IntHandler handler1, handler2; + Pipeline pipeline; + { + InSequence sequence; + EXPECT_CALL(handler2, attachPipeline(_)); + EXPECT_CALL(handler1, attachPipeline(_)); + } + pipeline + .addBack(&handler1) + .addBack(&handler2) + .finalize(); + { + InSequence sequence; + EXPECT_CALL(handler1, detachPipeline(_)); + EXPECT_CALL(handler2, detachPipeline(_)); + } +} + +TEST(PipelineTest, GetContext) { + IntHandler handler; + EXPECT_CALL(handler, attachPipeline(_)); + StaticPipeline pipeline(&handler); + EXPECT_TRUE(handler.getContext()); + EXPECT_CALL(handler, detachPipeline(_)); +} + +TEST(PipelineTest, HandlerInMultiplePipelines) { + IntHandler handler; + EXPECT_CALL(handler, attachPipeline(_)).Times(2); + StaticPipeline pipeline1(&handler); + StaticPipeline pipeline2(&handler); + EXPECT_FALSE(handler.getContext()); + EXPECT_CALL(handler, detachPipeline(_)).Times(2); +} + +TEST(PipelineTest, HandlerInPipelineTwice) { + auto handler = std::make_shared(); + EXPECT_CALL(*handler, attachPipeline(_)).Times(2); + Pipeline pipeline; + pipeline.addBack(handler); + pipeline.addBack(handler); + pipeline.finalize(); + EXPECT_FALSE(handler->getContext()); + EXPECT_CALL(*handler, detachPipeline(_)).Times(2); +} + +TEST(PipelineTest, NoDetachOnOwner) { + IntHandler handler; + EXPECT_CALL(handler, attachPipeline(_)); + StaticPipeline pipeline(&handler); + pipeline.setOwner(&handler); +} + +template +class ConcreteHandler : public Handler { + typedef typename Handler::Context Context; + public: + void read(Context* ctx, Rin msg) {} + Future write(Context* ctx, Win msg) { return makeFuture(); } +}; + +typedef HandlerAdapter StringHandler; +typedef ConcreteHandler IntToStringHandler; +typedef ConcreteHandler StringToIntHandler; + +TEST(Pipeline, MissingInboundOrOutbound) { + Pipeline pipeline; + pipeline + .addBack(HandlerAdapter{}) + .finalize(); + EXPECT_THROW(pipeline.read(0), std::invalid_argument); + EXPECT_THROW(pipeline.readEOF(), std::invalid_argument); + EXPECT_THROW( + pipeline.readException(exception_wrapper(std::runtime_error("blah"))), + std::invalid_argument); + EXPECT_THROW(pipeline.write(0), std::invalid_argument); + EXPECT_THROW(pipeline.close(), std::invalid_argument); +} + +TEST(Pipeline, DynamicConstruction) { + { + Pipeline pipeline; + pipeline.addBack(StringHandler()); + pipeline.addBack(StringHandler()); + + // Exercise both addFront and addBack. Final pipeline is + // StI <-> ItS <-> StS <-> StS <-> StI <-> ItS + EXPECT_NO_THROW( + pipeline + .addFront(IntToStringHandler{}) + .addFront(StringToIntHandler{}) + .addBack(StringToIntHandler{}) + .addBack(IntToStringHandler{}) + .finalize()); + } +} diff --git a/folly/wangle/codec/ByteToMessageCodec.cpp b/folly/wangle/codec/ByteToMessageCodec.cpp new file mode 100644 index 00000000..e16183bb --- /dev/null +++ b/folly/wangle/codec/ByteToMessageCodec.cpp @@ -0,0 +1,33 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +namespace folly { namespace wangle { + +void ByteToMessageCodec::read(Context* ctx, IOBufQueue& q) { + size_t needed = 0; + std::unique_ptr result; + while (true) { + result = decode(ctx, q, needed); + if (result) { + ctx->fireRead(std::move(result)); + } else { + break; + } + } +} + +}} // namespace diff --git a/folly/wangle/codec/ByteToMessageCodec.h b/folly/wangle/codec/ByteToMessageCodec.h new file mode 100644 index 00000000..600bb028 --- /dev/null +++ b/folly/wangle/codec/ByteToMessageCodec.h @@ -0,0 +1,52 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace folly { namespace wangle { + +/** + * A Handler which decodes bytes in a stream-like fashion from + * IOBufQueue to a Message type. + * + * Frame detection + * + * Generally frame detection should be handled earlier in the pipeline + * by adding a DelimiterBasedFrameDecoder, FixedLengthFrameDecoder, + * LengthFieldBasedFrameDecoder, LineBasedFrameDecoder. + * + * If a custom frame decoder is required, then one needs to be careful + * when implementing one with {@link ByteToMessageDecoder}. Ensure + * there are enough bytes in the buffer for a complete frame by + * checking {@link ByteBuf#readableBytes()}. If there are not enough + * bytes for a complete frame, return without modify the reader index + * to allow more bytes to arrive. + * + * To check for complete frames without modify the reader index, use + * IOBufQueue.front(), without split() or pop_front(). + */ +class ByteToMessageCodec + : public InboundBytesToBytesHandler { + public: + + virtual std::unique_ptr decode( + Context* ctx, IOBufQueue& buf, size_t&) = 0; + + void read(Context* ctx, IOBufQueue& q); +}; + +}} diff --git a/folly/wangle/codec/CodecTest.cpp b/folly/wangle/codec/CodecTest.cpp new file mode 100644 index 00000000..ecca824f --- /dev/null +++ b/folly/wangle/codec/CodecTest.cpp @@ -0,0 +1,637 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include + +using namespace folly; +using namespace folly::wangle; +using namespace folly::io; + +class FrameTester + : public InboundHandler> { + public: + explicit FrameTester(std::function)> test) + : test_(test) {} + + void read(Context* ctx, std::unique_ptr buf) { + test_(std::move(buf)); + } + + void readException(Context* ctx, exception_wrapper w) { + test_(nullptr); + } + private: + std::function)> test_; +}; + +class BytesReflector + : public BytesToBytesHandler { + public: + + Future write(Context* ctx, std::unique_ptr buf) { + IOBufQueue q_(IOBufQueue::cacheChainLength()); + q_.append(std::move(buf)); + ctx->fireRead(q_); + + return makeFuture(); + } +}; + +TEST(FixedLengthFrameDecoder, FailWhenLengthFieldEndOffset) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(FixedLengthFrameDecoder(10)) + .addBack(FrameTester([&](std::unique_ptr buf) { + auto sz = buf->computeChainDataLength(); + called++; + EXPECT_EQ(sz, 10); + })) + .finalize(); + + auto buf3 = IOBuf::create(3); + buf3->append(3); + auto buf11 = IOBuf::create(11); + buf11->append(11); + auto buf16 = IOBuf::create(16); + buf16->append(16); + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(buf3)); + pipeline.read(q); + EXPECT_EQ(called, 0); + + q.append(std::move(buf11)); + pipeline.read(q); + EXPECT_EQ(called, 1); + + q.append(std::move(buf16)); + pipeline.read(q); + EXPECT_EQ(called, 3); +} + +TEST(LengthFieldFramePipeline, SimpleTest) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(BytesReflector()) + .addBack(LengthFieldPrepender()) + .addBack(LengthFieldBasedFrameDecoder()) + .addBack(FrameTester([&](std::unique_ptr buf) { + auto sz = buf->computeChainDataLength(); + called++; + EXPECT_EQ(sz, 2); + })) + .finalize(); + + auto buf = IOBuf::create(2); + buf->append(2); + pipeline.write(std::move(buf)); + EXPECT_EQ(called, 1); +} + +TEST(LengthFieldFramePipeline, LittleEndian) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(BytesReflector()) + .addBack(LengthFieldBasedFrameDecoder(4, 100, 0, 0, 4, false)) + .addBack(FrameTester([&](std::unique_ptr buf) { + auto sz = buf->computeChainDataLength(); + called++; + EXPECT_EQ(sz, 1); + })) + .addBack(LengthFieldPrepender(4, 0, false, false)) + .finalize(); + + auto buf = IOBuf::create(1); + buf->append(1); + pipeline.write(std::move(buf)); + EXPECT_EQ(called, 1); +} + +TEST(LengthFieldFrameDecoder, Simple) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LengthFieldBasedFrameDecoder()) + .addBack(FrameTester([&](std::unique_ptr buf) { + auto sz = buf->computeChainDataLength(); + called++; + EXPECT_EQ(sz, 1); + })) + .finalize(); + + auto bufFrame = IOBuf::create(4); + bufFrame->append(4); + RWPrivateCursor c(bufFrame.get()); + c.writeBE((uint32_t)1); + auto bufData = IOBuf::create(1); + bufData->append(1); + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(bufFrame)); + pipeline.read(q); + EXPECT_EQ(called, 0); + + q.append(std::move(bufData)); + pipeline.read(q); + EXPECT_EQ(called, 1); +} + +TEST(LengthFieldFrameDecoder, NoStrip) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LengthFieldBasedFrameDecoder(2, 10, 0, 0, 0)) + .addBack(FrameTester([&](std::unique_ptr buf) { + auto sz = buf->computeChainDataLength(); + called++; + EXPECT_EQ(sz, 3); + })) + .finalize(); + + auto bufFrame = IOBuf::create(2); + bufFrame->append(2); + RWPrivateCursor c(bufFrame.get()); + c.writeBE((uint16_t)1); + auto bufData = IOBuf::create(1); + bufData->append(1); + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(bufFrame)); + pipeline.read(q); + EXPECT_EQ(called, 0); + + q.append(std::move(bufData)); + pipeline.read(q); + EXPECT_EQ(called, 1); +} + +TEST(LengthFieldFrameDecoder, Adjustment) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LengthFieldBasedFrameDecoder(2, 10, 0, -2, 0)) + .addBack(FrameTester([&](std::unique_ptr buf) { + auto sz = buf->computeChainDataLength(); + called++; + EXPECT_EQ(sz, 3); + })) + .finalize(); + + auto bufFrame = IOBuf::create(2); + bufFrame->append(2); + RWPrivateCursor c(bufFrame.get()); + c.writeBE((uint16_t)3); // includes frame size + auto bufData = IOBuf::create(1); + bufData->append(1); + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(bufFrame)); + pipeline.read(q); + EXPECT_EQ(called, 0); + + q.append(std::move(bufData)); + pipeline.read(q); + EXPECT_EQ(called, 1); +} + +TEST(LengthFieldFrameDecoder, PreHeader) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LengthFieldBasedFrameDecoder(2, 10, 2, 0, 0)) + .addBack(FrameTester([&](std::unique_ptr buf) { + auto sz = buf->computeChainDataLength(); + called++; + EXPECT_EQ(sz, 5); + })) + .finalize(); + + auto bufFrame = IOBuf::create(4); + bufFrame->append(4); + RWPrivateCursor c(bufFrame.get()); + c.write((uint16_t)100); // header + c.writeBE((uint16_t)1); // frame size + auto bufData = IOBuf::create(1); + bufData->append(1); + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(bufFrame)); + pipeline.read(q); + EXPECT_EQ(called, 0); + + q.append(std::move(bufData)); + pipeline.read(q); + EXPECT_EQ(called, 1); +} + +TEST(LengthFieldFrameDecoder, PostHeader) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LengthFieldBasedFrameDecoder(2, 10, 0, 2, 0)) + .addBack(FrameTester([&](std::unique_ptr buf) { + auto sz = buf->computeChainDataLength(); + called++; + EXPECT_EQ(sz, 5); + })) + .finalize(); + + auto bufFrame = IOBuf::create(4); + bufFrame->append(4); + RWPrivateCursor c(bufFrame.get()); + c.writeBE((uint16_t)1); // frame size + c.write((uint16_t)100); // header + auto bufData = IOBuf::create(1); + bufData->append(1); + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(bufFrame)); + pipeline.read(q); + EXPECT_EQ(called, 0); + + q.append(std::move(bufData)); + pipeline.read(q); + EXPECT_EQ(called, 1); +} + +TEST(LengthFieldFrameDecoderStrip, PrePostHeader) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LengthFieldBasedFrameDecoder(2, 10, 2, 2, 4)) + .addBack(FrameTester([&](std::unique_ptr buf) { + auto sz = buf->computeChainDataLength(); + called++; + EXPECT_EQ(sz, 3); + })) + .finalize(); + + auto bufFrame = IOBuf::create(6); + bufFrame->append(6); + RWPrivateCursor c(bufFrame.get()); + c.write((uint16_t)100); // pre header + c.writeBE((uint16_t)1); // frame size + c.write((uint16_t)100); // post header + auto bufData = IOBuf::create(1); + bufData->append(1); + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(bufFrame)); + pipeline.read(q); + EXPECT_EQ(called, 0); + + q.append(std::move(bufData)); + pipeline.read(q); + EXPECT_EQ(called, 1); +} + +TEST(LengthFieldFrameDecoder, StripPrePostHeaderFrameInclHeader) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LengthFieldBasedFrameDecoder(2, 10, 2, -2, 4)) + .addBack(FrameTester([&](std::unique_ptr buf) { + auto sz = buf->computeChainDataLength(); + called++; + EXPECT_EQ(sz, 3); + })) + .finalize(); + + auto bufFrame = IOBuf::create(6); + bufFrame->append(6); + RWPrivateCursor c(bufFrame.get()); + c.write((uint16_t)100); // pre header + c.writeBE((uint16_t)5); // frame size + c.write((uint16_t)100); // post header + auto bufData = IOBuf::create(1); + bufData->append(1); + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(bufFrame)); + pipeline.read(q); + EXPECT_EQ(called, 0); + + q.append(std::move(bufData)); + pipeline.read(q); + EXPECT_EQ(called, 1); +} + +TEST(LengthFieldFrameDecoder, FailTestLengthFieldEndOffset) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LengthFieldBasedFrameDecoder(4, 10, 4, -2, 4)) + .addBack(FrameTester([&](std::unique_ptr buf) { + ASSERT_EQ(nullptr, buf); + called++; + })) + .finalize(); + + auto bufFrame = IOBuf::create(8); + bufFrame->append(8); + RWPrivateCursor c(bufFrame.get()); + c.writeBE((uint32_t)0); // frame size + c.write((uint32_t)0); // crap + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(bufFrame)); + pipeline.read(q); + EXPECT_EQ(called, 1); +} + +TEST(LengthFieldFrameDecoder, FailTestLengthFieldFrameSize) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LengthFieldBasedFrameDecoder(4, 10, 0, 0, 4)) + .addBack(FrameTester([&](std::unique_ptr buf) { + ASSERT_EQ(nullptr, buf); + called++; + })) + .finalize(); + + auto bufFrame = IOBuf::create(16); + bufFrame->append(16); + RWPrivateCursor c(bufFrame.get()); + c.writeBE((uint32_t)12); // frame size + c.write((uint32_t)0); // nothing + c.write((uint32_t)0); // nothing + c.write((uint32_t)0); // nothing + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(bufFrame)); + pipeline.read(q); + EXPECT_EQ(called, 1); +} + +TEST(LengthFieldFrameDecoder, FailTestLengthFieldInitialBytes) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LengthFieldBasedFrameDecoder(4, 10, 0, 0, 10)) + .addBack(FrameTester([&](std::unique_ptr buf) { + ASSERT_EQ(nullptr, buf); + called++; + })) + .finalize(); + + auto bufFrame = IOBuf::create(16); + bufFrame->append(16); + RWPrivateCursor c(bufFrame.get()); + c.writeBE((uint32_t)4); // frame size + c.write((uint32_t)0); // nothing + c.write((uint32_t)0); // nothing + c.write((uint32_t)0); // nothing + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(bufFrame)); + pipeline.read(q); + EXPECT_EQ(called, 1); +} + +TEST(LineBasedFrameDecoder, Simple) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LineBasedFrameDecoder(10)) + .addBack(FrameTester([&](std::unique_ptr buf) { + auto sz = buf->computeChainDataLength(); + called++; + EXPECT_EQ(sz, 3); + })) + .finalize(); + + auto buf = IOBuf::create(3); + buf->append(3); + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(buf)); + pipeline.read(q); + EXPECT_EQ(called, 0); + + buf = IOBuf::create(1); + buf->append(1); + RWPrivateCursor c(buf.get()); + c.write('\n'); + q.append(std::move(buf)); + pipeline.read(q); + EXPECT_EQ(called, 1); + + buf = IOBuf::create(4); + buf->append(4); + RWPrivateCursor c1(buf.get()); + c1.write(' '); + c1.write(' '); + c1.write(' '); + + c1.write('\r'); + q.append(std::move(buf)); + pipeline.read(q); + EXPECT_EQ(called, 1); + + buf = IOBuf::create(1); + buf->append(1); + RWPrivateCursor c2(buf.get()); + c2.write('\n'); + q.append(std::move(buf)); + pipeline.read(q); + EXPECT_EQ(called, 2); +} + +TEST(LineBasedFrameDecoder, SaveDelimiter) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LineBasedFrameDecoder(10, false)) + .addBack(FrameTester([&](std::unique_ptr buf) { + auto sz = buf->computeChainDataLength(); + called++; + EXPECT_EQ(sz, 4); + })) + .finalize(); + + auto buf = IOBuf::create(3); + buf->append(3); + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(buf)); + pipeline.read(q); + EXPECT_EQ(called, 0); + + buf = IOBuf::create(1); + buf->append(1); + RWPrivateCursor c(buf.get()); + c.write('\n'); + q.append(std::move(buf)); + pipeline.read(q); + EXPECT_EQ(called, 1); + + buf = IOBuf::create(3); + buf->append(3); + RWPrivateCursor c1(buf.get()); + c1.write(' '); + c1.write(' '); + c1.write('\r'); + q.append(std::move(buf)); + pipeline.read(q); + EXPECT_EQ(called, 1); + + buf = IOBuf::create(1); + buf->append(1); + RWPrivateCursor c2(buf.get()); + c2.write('\n'); + q.append(std::move(buf)); + pipeline.read(q); + EXPECT_EQ(called, 2); +} + +TEST(LineBasedFrameDecoder, Fail) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LineBasedFrameDecoder(10)) + .addBack(FrameTester([&](std::unique_ptr buf) { + ASSERT_EQ(nullptr, buf); + called++; + })) + .finalize(); + + auto buf = IOBuf::create(11); + buf->append(11); + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(buf)); + pipeline.read(q); + EXPECT_EQ(called, 1); + + buf = IOBuf::create(1); + buf->append(1); + q.append(std::move(buf)); + pipeline.read(q); + EXPECT_EQ(called, 1); + + buf = IOBuf::create(2); + buf->append(2); + RWPrivateCursor c(buf.get()); + c.write(' '); + c.write('\n'); + q.append(std::move(buf)); + pipeline.read(q); + EXPECT_EQ(called, 1); + + buf = IOBuf::create(12); + buf->append(12); + RWPrivateCursor c2(buf.get()); + for (int i = 0; i < 11; i++) { + c2.write(' '); + } + c2.write('\n'); + q.append(std::move(buf)); + pipeline.read(q); + EXPECT_EQ(called, 2); +} + +TEST(LineBasedFrameDecoder, NewLineOnly) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LineBasedFrameDecoder( + 10, true, LineBasedFrameDecoder::TerminatorType::NEWLINE)) + .addBack(FrameTester([&](std::unique_ptr buf) { + auto sz = buf->computeChainDataLength(); + called++; + EXPECT_EQ(sz, 1); + })) + .finalize(); + + auto buf = IOBuf::create(2); + buf->append(2); + RWPrivateCursor c(buf.get()); + c.write('\r'); + c.write('\n'); + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(buf)); + pipeline.read(q); + EXPECT_EQ(called, 1); +} + +TEST(LineBasedFrameDecoder, CarriageNewLineOnly) { + Pipeline> pipeline; + int called = 0; + + pipeline + .addBack(LineBasedFrameDecoder( + 10, true, LineBasedFrameDecoder::TerminatorType::CARRIAGENEWLINE)) + .addBack(FrameTester([&](std::unique_ptr buf) { + auto sz = buf->computeChainDataLength(); + called++; + EXPECT_EQ(sz, 1); + })) + .finalize(); + + auto buf = IOBuf::create(3); + buf->append(3); + RWPrivateCursor c(buf.get()); + c.write('\n'); + c.write('\r'); + c.write('\n'); + + IOBufQueue q(IOBufQueue::cacheChainLength()); + + q.append(std::move(buf)); + pipeline.read(q); + EXPECT_EQ(called, 1); +} diff --git a/folly/wangle/codec/FixedLengthFrameDecoder.h b/folly/wangle/codec/FixedLengthFrameDecoder.h new file mode 100644 index 00000000..5b6d1893 --- /dev/null +++ b/folly/wangle/codec/FixedLengthFrameDecoder.h @@ -0,0 +1,59 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace folly {namespace wangle { + +/** + * A decoder that splits the received IOBufs by the fixed number + * of bytes. For example, if you received the following four + * fragmented packets: + * + * +---+----+------+----+ + * | A | BC | DEFG | HI | + * +---+----+------+----+ + * + * A FixedLengthFrameDecoder will decode them into the following three + * packets with the fixed length: + * + * +-----+-----+-----+ + * | ABC | DEF | GHI | + * +-----+-----+-----+ + * + */ +class FixedLengthFrameDecoder + : public ByteToMessageCodec { + public: + + FixedLengthFrameDecoder(size_t length) + : length_(length) {} + + std::unique_ptr decode(Context* ctx, IOBufQueue& q, size_t& needed) { + if (q.chainLength() < length_) { + needed = length_ - q.chainLength(); + return nullptr; + } + + return q.split(length_); + } + + private: + size_t length_; +}; + +}} // Namespace diff --git a/folly/wangle/codec/LengthFieldBasedFrameDecoder.cpp b/folly/wangle/codec/LengthFieldBasedFrameDecoder.cpp new file mode 100644 index 00000000..5fff70d5 --- /dev/null +++ b/folly/wangle/codec/LengthFieldBasedFrameDecoder.cpp @@ -0,0 +1,127 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +namespace folly { namespace wangle { + +LengthFieldBasedFrameDecoder::LengthFieldBasedFrameDecoder( + uint32_t lengthFieldLength, + uint32_t maxFrameLength, + uint32_t lengthFieldOffset, + uint32_t lengthAdjustment, + uint32_t initialBytesToStrip, + bool networkByteOrder) + : lengthFieldLength_(lengthFieldLength) + , maxFrameLength_(maxFrameLength) + , lengthFieldOffset_(lengthFieldOffset) + , lengthAdjustment_(lengthAdjustment) + , initialBytesToStrip_(initialBytesToStrip) + , networkByteOrder_(networkByteOrder) + , lengthFieldEndOffset_(lengthFieldOffset + lengthFieldLength) { + CHECK(maxFrameLength > 0); + CHECK(lengthFieldOffset <= maxFrameLength - lengthFieldLength); +} + +std::unique_ptr LengthFieldBasedFrameDecoder::decode( + Context* ctx, IOBufQueue& buf, size_t&) { + // discarding too long frame + if (buf.chainLength() < lengthFieldEndOffset_) { + return nullptr; + } + + uint64_t frameLength = getUnadjustedFrameLength( + buf, lengthFieldOffset_, lengthFieldLength_, networkByteOrder_); + + frameLength += lengthAdjustment_ + lengthFieldEndOffset_; + + if (frameLength < lengthFieldEndOffset_) { + buf.trimStart(lengthFieldEndOffset_); + ctx->fireReadException(folly::make_exception_wrapper( + "Frame too small")); + return nullptr; + } + + if (frameLength > maxFrameLength_) { + buf.trimStart(frameLength); + ctx->fireReadException(folly::make_exception_wrapper( + "Frame larger than " + + folly::to(maxFrameLength_))); + return nullptr; + } + + if (buf.chainLength() < frameLength) { + return nullptr; + } + + if (initialBytesToStrip_ > frameLength) { + buf.trimStart(frameLength); + ctx->fireReadException(folly::make_exception_wrapper( + "InitialBytesToSkip larger than frame")); + return nullptr; + } + + buf.trimStart(initialBytesToStrip_); + int actualFrameLength = frameLength - initialBytesToStrip_; + return buf.split(actualFrameLength); +} + +uint64_t LengthFieldBasedFrameDecoder::getUnadjustedFrameLength( + IOBufQueue& buf, int offset, int length, bool networkByteOrder) { + folly::io::Cursor c(buf.front()); + uint64_t frameLength; + + c.skip(offset); + + switch(length) { + case 1:{ + if (networkByteOrder) { + frameLength = c.readBE(); + } else { + frameLength = c.readLE(); + } + break; + } + case 2:{ + if (networkByteOrder) { + frameLength = c.readBE(); + } else { + frameLength = c.readLE(); + } + break; + } + case 4:{ + if (networkByteOrder) { + frameLength = c.readBE(); + } else { + frameLength = c.readLE(); + } + break; + } + case 8:{ + if (networkByteOrder) { + frameLength = c.readBE(); + } else { + frameLength = c.readLE(); + } + break; + } + } + + return frameLength; +} + + +}} // namespace diff --git a/folly/wangle/codec/LengthFieldBasedFrameDecoder.h b/folly/wangle/codec/LengthFieldBasedFrameDecoder.h new file mode 100644 index 00000000..1c44de55 --- /dev/null +++ b/folly/wangle/codec/LengthFieldBasedFrameDecoder.h @@ -0,0 +1,209 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace folly { namespace wangle { + +/** + * A decoder that splits the received IOBufs dynamically by the + * value of the length field in the message. It is particularly useful when you + * decode a binary message which has an integer header field that represents the + * length of the message body or the whole message. + * + * LengthFieldBasedFrameDecoder has many configuration parameters so + * that it can decode any message with a length field, which is often seen in + * proprietary client-server protocols. Here are some example that will give + * you the basic idea on which option does what. + * + * 2 bytes length field at offset 0, do not strip header + * + * The value of the length field in this example is 12 (0x0C) which + * represents the length of "HELLO, WORLD". By default, the decoder assumes + * that the length field represents the number of the bytes that follows the + * length field. Therefore, it can be decoded with the simplistic parameter + * combination. + * + * lengthFieldOffset = 0 + * lengthFieldLength = 2 + * lengthAdjustment = 0 + * initialBytesToStrip = 0 (= do not strip header) + * + * BEFORE DECODE (14 bytes) AFTER DECODE (14 bytes) + * +--------+----------------+ +--------+----------------+ + * | Length | Actual Content |----->| Length | Actual Content | + * | 0x000C | "HELLO, WORLD" | | 0x000C | "HELLO, WORLD" | + * +--------+----------------+ +--------+----------------+ + * + * + * 2 bytes length field at offset 0, strip header + * + * Because we can get the length of the content by calling + * ioBuf->computeChainDataLength(), you might want to strip the length + * field by specifying initialBytesToStrip. In this example, we + * specified 2, that is same with the length of the length field, to + * strip the first two bytes. + * + * lengthFieldOffset = 0 + * lengthFieldLength = 2 + * lengthAdjustment = 0 + * initialBytesToStrip = 2 (= the length of the Length field) + * + * BEFORE DECODE (14 bytes) AFTER DECODE (12 bytes) + * +--------+----------------+ +----------------+ + * | Length | Actual Content |----->| Actual Content | + * | 0x000C | "HELLO, WORLD" | | "HELLO, WORLD" | + * +--------+----------------+ +----------------+ + * + * + * 2 bytes length field at offset 0, do not strip header, the length field + * represents the length of the whole message + * + * In most cases, the length field represents the length of the message body + * only, as shown in the previous examples. However, in some protocols, the + * length field represents the length of the whole message, including the + * message header. In such a case, we specify a non-zero + * lengthAdjustment. Because the length value in this example message + * is always greater than the body length by 2, we specify -2 + * as lengthAdjustment for compensation. + * + * lengthFieldOffset = 0 + * lengthFieldLength = 2 + * lengthAdjustment = -2 (= the length of the Length field) + * initialBytesToStrip = 0 + * + * BEFORE DECODE (14 bytes) AFTER DECODE (14 bytes) + * +--------+----------------+ +--------+----------------+ + * | Length | Actual Content |----->| Length | Actual Content | + * | 0x000E | "HELLO, WORLD" | | 0x000E | "HELLO, WORLD" | + * +--------+----------------+ +--------+----------------+ + * + * + * 3 bytes length field at the end of 5 bytes header, do not strip header + * + * The following message is a simple variation of the first example. An extra + * header value is prepended to the message. lengthAdjustment is zero + * again because the decoder always takes the length of the prepended data into + * account during frame length calculation. + * + * lengthFieldOffset = 2 (= the length of Header 1) + * lengthFieldLength = 3 + * lengthAdjustment = 0 + * initialBytesToStrip = 0 + * + * BEFORE DECODE (17 bytes) AFTER DECODE (17 bytes) + * +----------+----------+----------------+ +----------+----------+----------------+ + * | Header 1 | Length | Actual Content |----->| Header 1 | Length | Actual Content | + * | 0xCAFE | 0x00000C | "HELLO, WORLD" | | 0xCAFE | 0x00000C | "HELLO, WORLD" | + * +----------+----------+----------------+ +----------+----------+----------------+ + * + * + * 3 bytes length field at the beginning of 5 bytes header, do not strip header + * + * This is an advanced example that shows the case where there is an extra + * header between the length field and the message body. You have to specify a + * positive lengthAdjustment so that the decoder counts the extra + * header into the frame length calculation. + * + * lengthFieldOffset = 0 + * lengthFieldLength = 3 + * lengthAdjustment = 2 (= the length of Header 1) + * initialBytesToStrip = 0 + * + * BEFORE DECODE (17 bytes) AFTER DECODE (17 bytes) + * +----------+----------+----------------+ +----------+----------+----------------+ + * | Length | Header 1 | Actual Content |----->| Length | Header 1 | Actual Content | + * | 0x00000C | 0xCAFE | "HELLO, WORLD" | | 0x00000C | 0xCAFE | "HELLO, WORLD" | + * +----------+----------+----------------+ +----------+----------+----------------+ + * + * + * 2 bytes length field at offset 1 in the middle of 4 bytes header, + * strip the first header field and the length field + * + * This is a combination of all the examples above. There are the prepended + * header before the length field and the extra header after the length field. + * The prepended header affects the lengthFieldOffset and the extra + * header affects the lengthAdjustment. We also specified a non-zero + * initialBytesToStrip to strip the length field and the prepended + * header from the frame. If you don't want to strip the prepended header, you + * could specify 0 for initialBytesToSkip. + * + * lengthFieldOffset = 1 (= the length of HDR1) + * lengthFieldLength = 2 + * lengthAdjustment = 1 (= the length of HDR2) + * initialBytesToStrip = 3 (= the length of HDR1 + LEN) + * + * BEFORE DECODE (16 bytes) AFTER DECODE (13 bytes) + * +------+--------+------+----------------+ +------+----------------+ + * | HDR1 | Length | HDR2 | Actual Content |----->| HDR2 | Actual Content | + * | 0xCA | 0x000C | 0xFE | "HELLO, WORLD" | | 0xFE | "HELLO, WORLD" | + * +------+--------+------+----------------+ +------+----------------+ + * + * + * 2 bytes length field at offset 1 in the middle of 4 bytes header, + * strip the first header field and the length field, the length field + * represents the length of the whole message + * + * Let's give another twist to the previous example. The only difference from + * the previous example is that the length field represents the length of the + * whole message instead of the message body, just like the third example. + * We have to count the length of HDR1 and Length into lengthAdjustment. + * Please note that we don't need to take the length of HDR2 into account + * because the length field already includes the whole header length. + * + * lengthFieldOffset = 1 + * lengthFieldLength = 2 + * lengthAdjustment = -3 (= the length of HDR1 + LEN, negative) + * initialBytesToStrip = 3 + * + * BEFORE DECODE (16 bytes) AFTER DECODE (13 bytes) + * +------+--------+------+----------------+ +------+----------------+ + * | HDR1 | Length | HDR2 | Actual Content |----->| HDR2 | Actual Content | + * | 0xCA | 0x0010 | 0xFE | "HELLO, WORLD" | | 0xFE | "HELLO, WORLD" | + * +------+--------+------+----------------+ +------+----------------+ + * + * @see LengthFieldPrepender + */ +class LengthFieldBasedFrameDecoder : public ByteToMessageCodec { + public: + LengthFieldBasedFrameDecoder( + uint32_t lengthFieldLength = 4, + uint32_t maxFrameLength = UINT_MAX, + uint32_t lengthFieldOffset = 0, + uint32_t lengthAdjustment = 0, + uint32_t initialBytesToStrip = 4, + bool networkByteOrder = true); + + std::unique_ptr decode(Context* ctx, IOBufQueue& buf, size_t&); + + private: + + uint64_t getUnadjustedFrameLength( + IOBufQueue& buf, int offset, int length, bool networkByteOrder); + + uint32_t lengthFieldLength_; + uint32_t maxFrameLength_; + uint32_t lengthFieldOffset_; + uint32_t lengthAdjustment_; + uint32_t initialBytesToStrip_; + bool networkByteOrder_; + + uint32_t lengthFieldEndOffset_; +}; + +}} // namespace diff --git a/folly/wangle/codec/LengthFieldPrepender.cpp b/folly/wangle/codec/LengthFieldPrepender.cpp new file mode 100644 index 00000000..88238e08 --- /dev/null +++ b/folly/wangle/codec/LengthFieldPrepender.cpp @@ -0,0 +1,99 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +namespace folly { namespace wangle { + +LengthFieldPrepender::LengthFieldPrepender( + int lengthFieldLength, + int lengthAdjustment, + bool lengthIncludesLengthField, + bool networkByteOrder) + : lengthFieldLength_(lengthFieldLength) + , lengthAdjustment_(lengthAdjustment) + , lengthIncludesLengthField_(lengthIncludesLengthField) + , networkByteOrder_(networkByteOrder) { + CHECK(lengthFieldLength == 1 || + lengthFieldLength == 2 || + lengthFieldLength == 4 || + lengthFieldLength == 8 ); + } + +Future LengthFieldPrepender::write( + Context* ctx, std::unique_ptr buf) { + int length = lengthAdjustment_ + buf->computeChainDataLength(); + if (lengthIncludesLengthField_) { + length += lengthFieldLength_; + } + + if (length < 0) { + throw std::runtime_error("Length field < 0"); + } + + auto len = IOBuf::create(lengthFieldLength_); + len->append(lengthFieldLength_); + folly::io::RWPrivateCursor c(len.get()); + + switch (lengthFieldLength_) { + case 1: { + if (length >= 256) { + throw std::runtime_error("length does not fit byte"); + } + if (networkByteOrder_) { + c.writeBE((uint8_t)length); + } else { + c.writeLE((uint8_t)length); + } + break; + } + case 2: { + if (length >= 65536) { + throw std::runtime_error("length does not fit byte"); + } + if (networkByteOrder_) { + c.writeBE((uint16_t)length); + } else { + c.writeLE((uint16_t)length); + } + break; + } + case 4: { + if (networkByteOrder_) { + c.writeBE((uint32_t)length); + } else { + c.writeLE((uint32_t)length); + } + break; + } + case 8: { + if (networkByteOrder_) { + c.writeBE((uint64_t)length); + } else { + c.writeLE((uint64_t)length); + } + break; + } + default: { + throw std::runtime_error("Invalid lengthFieldLength"); + } + } + + len->prependChain(std::move(buf)); + return ctx->fireWrite(std::move(len)); +} + + +}} // Namespace diff --git a/folly/wangle/codec/LengthFieldPrepender.h b/folly/wangle/codec/LengthFieldPrepender.h new file mode 100644 index 00000000..d2e1d37b --- /dev/null +++ b/folly/wangle/codec/LengthFieldPrepender.h @@ -0,0 +1,67 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace folly { namespace wangle { + +/** + * An encoder that prepends the length of the message. The length value is + * prepended as a binary form. + * + * For example, LengthFieldPrepender(2)will encode the + * following 12-bytes string: + * + * +----------------+ + * | "HELLO, WORLD" | + * +----------------+ + * + * into the following: + * + * +--------+----------------+ + * + 0x000C | "HELLO, WORLD" | + * +--------+----------------+ + * + * If you turned on the lengthIncludesLengthFieldLength flag in the + * constructor, the encoded data would look like the following + * (12 (original data) + 2 (prepended data) = 14 (0xE)): + * + * +--------+----------------+ + * + 0x000E | "HELLO, WORLD" | + * +--------+----------------+ + * + */ +class LengthFieldPrepender +: public OutboundBytesToBytesHandler { + public: + LengthFieldPrepender( + int lengthFieldLength = 4, + int lengthAdjustment = 0, + bool lengthIncludesLengthField = false, + bool networkByteOrder = true); + + Future write(Context* ctx, std::unique_ptr buf); + + private: + int lengthFieldLength_; + int lengthAdjustment_; + bool lengthIncludesLengthField_; + bool networkByteOrder_; +}; + +}} // namespace diff --git a/folly/wangle/codec/LineBasedFrameDecoder.cpp b/folly/wangle/codec/LineBasedFrameDecoder.cpp new file mode 100644 index 00000000..ab0bb074 --- /dev/null +++ b/folly/wangle/codec/LineBasedFrameDecoder.cpp @@ -0,0 +1,103 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +namespace folly { namespace wangle { + +using folly::io::Cursor; + +LineBasedFrameDecoder::LineBasedFrameDecoder(uint32_t maxLength, + bool stripDelimiter, + TerminatorType terminatorType) + : maxLength_(maxLength) + , stripDelimiter_(stripDelimiter) + , terminatorType_(terminatorType) {} + +std::unique_ptr LineBasedFrameDecoder::decode( + Context* ctx, IOBufQueue& buf, size_t&) { + int64_t eol = findEndOfLine(buf); + + if (!discarding_) { + if (eol >= 0) { + Cursor c(buf.front()); + c += eol; + auto delimLength = c.read() == '\r' ? 2 : 1; + if (eol > maxLength_) { + buf.split(eol + delimLength); + fail(ctx, folly::to(eol)); + return nullptr; + } + + std::unique_ptr frame; + + if (stripDelimiter_) { + frame = buf.split(eol); + buf.trimStart(delimLength); + } else { + frame = buf.split(eol + delimLength); + } + + return std::move(frame); + } else { + auto len = buf.chainLength(); + if (len > maxLength_) { + discardedBytes_ = len; + buf.trimStart(len); + discarding_ = true; + fail(ctx, "over " + folly::to(len)); + } + return nullptr; + } + } else { + if (eol >= 0) { + Cursor c(buf.front()); + c += eol; + auto delimLength = c.read() == '\r' ? 2 : 1; + buf.trimStart(eol + delimLength); + discardedBytes_ = 0; + discarding_ = false; + } else { + discardedBytes_ = buf.chainLength(); + buf.move(); + } + + return nullptr; + } +} + +void LineBasedFrameDecoder::fail(Context* ctx, std::string len) { + ctx->fireReadException( + folly::make_exception_wrapper( + "frame length" + len + + " exeeds max " + folly::to(maxLength_))); +} + +int64_t LineBasedFrameDecoder::findEndOfLine(IOBufQueue& buf) { + Cursor c(buf.front()); + for (uint32_t i = 0; i < maxLength_ && i < buf.chainLength(); i++) { + auto b = c.read(); + if (b == '\n' && terminatorType_ != TerminatorType::CARRIAGENEWLINE) { + return i; + } else if (terminatorType_ != TerminatorType::NEWLINE && + b == '\r' && !c.isAtEnd() && c.read() == '\n') { + return i; + } + } + + return -1; +} + +}} // namespace diff --git a/folly/wangle/codec/LineBasedFrameDecoder.h b/folly/wangle/codec/LineBasedFrameDecoder.h new file mode 100644 index 00000000..5ae9433f --- /dev/null +++ b/folly/wangle/codec/LineBasedFrameDecoder.h @@ -0,0 +1,59 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace folly { namespace wangle { + +/** + * A decoder that splits the received IOBufQueue on line endings. + * + * Both "\n" and "\r\n" are handled, or optionally reqire only + * one or the other. + */ +class LineBasedFrameDecoder : public ByteToMessageCodec { + public: + enum class TerminatorType { + BOTH, + NEWLINE, + CARRIAGENEWLINE + }; + + LineBasedFrameDecoder(uint32_t maxLength = UINT_MAX, + bool stripDelimiter = true, + TerminatorType terminatorType = + TerminatorType::BOTH); + + std::unique_ptr decode(Context* ctx, IOBufQueue& buf, size_t&); + + private: + + int64_t findEndOfLine(IOBufQueue& buf); + + void fail(Context* ctx, std::string len); + + uint32_t maxLength_; + bool stripDelimiter_; + + bool discarding_{false}; + uint32_t discardedBytes_{0}; + + TerminatorType terminatorType_; +}; + +}} // namespace diff --git a/folly/wangle/codec/README.md b/folly/wangle/codec/README.md new file mode 100644 index 00000000..46ddf660 --- /dev/null +++ b/folly/wangle/codec/README.md @@ -0,0 +1,5 @@ +Codecs are modeled after netty's codecs: + +https://github.com/netty/netty/tree/master/codec/src/main/java/io/netty/handler/codec + +Most of the changes are due to differing memory allocation strategies. \ No newline at end of file diff --git a/folly/wangle/codec/StringCodec.h b/folly/wangle/codec/StringCodec.h new file mode 100644 index 00000000..cbe0e843 --- /dev/null +++ b/folly/wangle/codec/StringCodec.h @@ -0,0 +1,46 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace folly { namespace wangle { + +/* + * StringCodec converts a pipeline from IOBufs to std::strings. + */ +class StringCodec : public Handler, std::string, + std::string, std::unique_ptr> { + public: + typedef typename Handler< + std::unique_ptr, std::string, + std::string, std::unique_ptr>::Context Context; + + void read(Context* ctx, std::unique_ptr buf) override { + buf->coalesce(); + std::string data((const char*)buf->data(), buf->length()); + + ctx->fireRead(data); + } + + Future write(Context* ctx, std::string msg) override { + auto buf = IOBuf::copyBuffer(msg.data(), msg.length()); + return ctx->fireWrite(std::move(buf)); + } +}; + +}} // namespace diff --git a/folly/wangle/concurrent/BlockingQueue.h b/folly/wangle/concurrent/BlockingQueue.h new file mode 100644 index 00000000..72c7aa35 --- /dev/null +++ b/folly/wangle/concurrent/BlockingQueue.h @@ -0,0 +1,38 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace folly { namespace wangle { + +template +class BlockingQueue { + public: + virtual ~BlockingQueue() {} + virtual void add(T item) = 0; + virtual void addWithPriority(T item, int8_t priority) { + add(std::move(item)); + } + virtual uint8_t getNumPriorities() { + return 1; + } + virtual T take() = 0; + virtual size_t size() = 0; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/CPUThreadPoolExecutor.cpp b/folly/wangle/concurrent/CPUThreadPoolExecutor.cpp new file mode 100644 index 00000000..864bd3a1 --- /dev/null +++ b/folly/wangle/concurrent/CPUThreadPoolExecutor.cpp @@ -0,0 +1,152 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace folly { namespace wangle { + +const size_t CPUThreadPoolExecutor::kDefaultMaxQueueSize = 1 << 14; + +CPUThreadPoolExecutor::CPUThreadPoolExecutor( + size_t numThreads, + std::unique_ptr> taskQueue, + std::shared_ptr threadFactory) + : ThreadPoolExecutor(numThreads, std::move(threadFactory)), + taskQueue_(std::move(taskQueue)) { + addThreads(numThreads); + CHECK(threadList_.get().size() == numThreads); +} + +CPUThreadPoolExecutor::CPUThreadPoolExecutor( + size_t numThreads, + std::shared_ptr threadFactory) + : CPUThreadPoolExecutor( + numThreads, + folly::make_unique>( + CPUThreadPoolExecutor::kDefaultMaxQueueSize), + std::move(threadFactory)) {} + +CPUThreadPoolExecutor::CPUThreadPoolExecutor(size_t numThreads) + : CPUThreadPoolExecutor( + numThreads, + std::make_shared("CPUThreadPool")) {} + +CPUThreadPoolExecutor::CPUThreadPoolExecutor( + size_t numThreads, + int8_t numPriorities, + std::shared_ptr threadFactory) + : CPUThreadPoolExecutor( + numThreads, + folly::make_unique>( + numPriorities, + CPUThreadPoolExecutor::kDefaultMaxQueueSize), + std::move(threadFactory)) {} + +CPUThreadPoolExecutor::CPUThreadPoolExecutor( + size_t numThreads, + int8_t numPriorities, + size_t maxQueueSize, + std::shared_ptr threadFactory) + : CPUThreadPoolExecutor( + numThreads, + folly::make_unique>( + numPriorities, + maxQueueSize), + std::move(threadFactory)) {} + +CPUThreadPoolExecutor::~CPUThreadPoolExecutor() { + stop(); + CHECK(threadsToStop_ == 0); +} + +void CPUThreadPoolExecutor::add(Func func) { + add(std::move(func), std::chrono::milliseconds(0)); +} + +void CPUThreadPoolExecutor::add( + Func func, + std::chrono::milliseconds expiration, + Func expireCallback) { + // TODO handle enqueue failure, here and in other add() callsites + taskQueue_->add( + CPUTask(std::move(func), expiration, std::move(expireCallback))); +} + +void CPUThreadPoolExecutor::addWithPriority(Func func, int8_t priority) { + add(std::move(func), priority, std::chrono::milliseconds(0)); +} + +void CPUThreadPoolExecutor::add( + Func func, + int8_t priority, + std::chrono::milliseconds expiration, + Func expireCallback) { + CHECK(getNumPriorities() > 0); + taskQueue_->addWithPriority( + CPUTask(std::move(func), expiration, std::move(expireCallback)), + priority); +} + +uint8_t CPUThreadPoolExecutor::getNumPriorities() const { + return taskQueue_->getNumPriorities(); +} + +BlockingQueue* +CPUThreadPoolExecutor::getTaskQueue() { + return taskQueue_.get(); +} + +void CPUThreadPoolExecutor::threadRun(std::shared_ptr thread) { + thread->startupBaton.post(); + while (1) { + auto task = taskQueue_->take(); + if (UNLIKELY(task.poison)) { + CHECK(threadsToStop_-- > 0); + for (auto& o : observers_) { + o->threadStopped(thread.get()); + } + + stoppedThreads_.add(thread); + return; + } else { + runTask(thread, std::move(task)); + } + + if (UNLIKELY(threadsToStop_ > 0 && !isJoin_)) { + if (--threadsToStop_ >= 0) { + stoppedThreads_.add(thread); + return; + } else { + threadsToStop_++; + } + } + } +} + +void CPUThreadPoolExecutor::stopThreads(size_t n) { + CHECK(stoppedThreads_.size() == 0); + threadsToStop_ = n; + for (size_t i = 0; i < n; i++) { + taskQueue_->addWithPriority(CPUTask(), Executor::LO_PRI); + } +} + +uint64_t CPUThreadPoolExecutor::getPendingTaskCount() { + return taskQueue_->size(); +} + +}} // folly::wangle diff --git a/folly/wangle/concurrent/CPUThreadPoolExecutor.h b/folly/wangle/concurrent/CPUThreadPoolExecutor.h new file mode 100644 index 00000000..7b85ae1f --- /dev/null +++ b/folly/wangle/concurrent/CPUThreadPoolExecutor.h @@ -0,0 +1,99 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace folly { namespace wangle { + +class CPUThreadPoolExecutor : public ThreadPoolExecutor { + public: + struct CPUTask; + + CPUThreadPoolExecutor( + size_t numThreads, + std::unique_ptr> taskQueue, + std::shared_ptr threadFactory = + std::make_shared("CPUThreadPool")); + + explicit CPUThreadPoolExecutor(size_t numThreads); +CPUThreadPoolExecutor( + size_t numThreads, + std::shared_ptr threadFactory); + + CPUThreadPoolExecutor( + size_t numThreads, + int8_t numPriorities, + std::shared_ptr threadFactory = + std::make_shared("CPUThreadPool")); + + CPUThreadPoolExecutor( + size_t numThreads, + int8_t numPriorities, + size_t maxQueueSize, + std::shared_ptr threadFactory = + std::make_shared("CPUThreadPool")); + + ~CPUThreadPoolExecutor(); + + void add(Func func) override; + void add( + Func func, + std::chrono::milliseconds expiration, + Func expireCallback = nullptr) override; + + void addWithPriority(Func func, int8_t priority) override; + void add( + Func func, + int8_t priority, + std::chrono::milliseconds expiration, + Func expireCallback = nullptr); + + uint8_t getNumPriorities() const override; + + struct CPUTask : public ThreadPoolExecutor::Task { + // Must be noexcept move constructible so it can be used in MPMCQueue + explicit CPUTask( + Func&& f, + std::chrono::milliseconds expiration, + Func&& expireCallback) + : Task(std::move(f), expiration, std::move(expireCallback)), + poison(false) {} + CPUTask() + : Task(nullptr, std::chrono::milliseconds(0), nullptr), + poison(true) {} + CPUTask(CPUTask&& o) noexcept : Task(std::move(o)), poison(o.poison) {} + CPUTask(const CPUTask&) = default; + CPUTask& operator=(const CPUTask&) = default; + bool poison; + }; + + static const size_t kDefaultMaxQueueSize; + + protected: + BlockingQueue* getTaskQueue(); + + private: + void threadRun(ThreadPtr thread) override; + void stopThreads(size_t n) override; + uint64_t getPendingTaskCount() override; + + std::unique_ptr> taskQueue_; + std::atomic threadsToStop_{0}; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/Codel.cpp b/folly/wangle/concurrent/Codel.cpp new file mode 100644 index 00000000..74a832b7 --- /dev/null +++ b/folly/wangle/concurrent/Codel.cpp @@ -0,0 +1,91 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#ifndef NO_LIB_GFLAGS + #include + DEFINE_int32(codel_interval, 100, + "Codel default interval time in ms"); + DEFINE_int32(codel_target_delay, 5, + "Target codel queueing delay in ms"); +#endif + +namespace folly { namespace wangle { + +#ifdef NO_LIB_GFLAGS + int32_t FLAGS_codel_interval = 100; + int32_t FLAGS_codel_target_delay = 5; +#endif + +Codel::Codel() + : codelMinDelay_(0), + codelIntervalTime_(std::chrono::steady_clock::now()), + codelResetDelay_(true), + overloaded_(false) {} + +bool Codel::overloaded(std::chrono::microseconds delay) { + bool ret = false; + auto now = std::chrono::steady_clock::now(); + + // Avoid another thread updating the value at the same time we are using it + // to calculate the overloaded state + auto minDelay = codelMinDelay_; + + if (now > codelIntervalTime_ && + (!codelResetDelay_.load(std::memory_order_acquire) + && !codelResetDelay_.exchange(true))) { + codelIntervalTime_ = now + std::chrono::milliseconds(FLAGS_codel_interval); + + if (minDelay > std::chrono::milliseconds(FLAGS_codel_target_delay)) { + overloaded_ = true; + } else { + overloaded_ = false; + } + } + // Care must be taken that only a single thread resets codelMinDelay_, + // and that it happens after the interval reset above + if (codelResetDelay_.load(std::memory_order_acquire) && + codelResetDelay_.exchange(false)) { + codelMinDelay_ = delay; + // More than one request must come in during an interval before codel + // starts dropping requests + return false; + } else if(delay < codelMinDelay_) { + codelMinDelay_ = delay; + } + + if (overloaded_ && + delay > std::chrono::milliseconds(FLAGS_codel_target_delay * 2)) { + ret = true; + } + + return ret; + +} + +int Codel::getLoad() { + return std::min(100, (int)codelMinDelay_.count() / + (2 * FLAGS_codel_target_delay)); +} + +int Codel::getMinDelay() { + return (int) codelMinDelay_.count(); +} + +}} //namespace diff --git a/folly/wangle/concurrent/Codel.h b/folly/wangle/concurrent/Codel.h new file mode 100644 index 00000000..4c3fcd49 --- /dev/null +++ b/folly/wangle/concurrent/Codel.h @@ -0,0 +1,66 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace folly { namespace wangle { + +/* Codel algorithm implementation: + * http://en.wikipedia.org/wiki/CoDel + * + * Algorithm modified slightly: Instead of changing the interval time + * based on the average min delay, instead we use an alternate timeout + * for each task if the min delay during the interval period is too + * high. + * + * This was found to have better latency metrics than changing the + * window size, since we can communicate with the sender via thrift + * instead of only via the tcp window size congestion control, as in TCP. + */ +class Codel { + + public: + Codel(); + + // Given a delay, returns wether the codel algorithm would + // reject a queued request with this delay. + // + // Internally, it also keeps track of the interval + bool overloaded(std::chrono::microseconds delay); + + // Get the queue load, as seen by the codel algorithm + // Gives a rough guess at how bad the queue delay is. + // + // Return: 0 = no delay, 100 = At the queueing limit + int getLoad(); + + int getMinDelay(); + + private: + std::chrono::microseconds codelMinDelay_; + std::chrono::time_point codelIntervalTime_; + + // flag to make overloaded() thread-safe, since we only want + // to reset the delay once per time period + std::atomic codelResetDelay_; + + bool overloaded_; +}; + +}} // Namespace diff --git a/folly/wangle/concurrent/FiberIOExecutor.h b/folly/wangle/concurrent/FiberIOExecutor.h new file mode 100644 index 00000000..cd6e6758 --- /dev/null +++ b/folly/wangle/concurrent/FiberIOExecutor.h @@ -0,0 +1,49 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace folly { namespace fibers { + +/** + * @class FiberIOExecutor + * @brief An IOExecutor that executes funcs under mapped fiber context + * + * A FiberIOExecutor wraps an IOExecutor, but executes funcs on the FiberManager + * mapped to the underlying IOExector's event base. + */ +class FiberIOExecutor : public folly::wangle::IOExecutor { + public: + explicit FiberIOExecutor( + const std::shared_ptr& ioExecutor) + : ioExecutor_(ioExecutor) {} + + virtual void add(std::function f) override { + auto eventBase = ioExecutor_->getEventBase(); + getFiberManager(*eventBase).add(std::move(f)); + } + + virtual EventBase* getEventBase() override { + return ioExecutor_->getEventBase(); + } + + private: + std::shared_ptr ioExecutor_; +}; + +}} diff --git a/folly/wangle/concurrent/FutureExecutor.h b/folly/wangle/concurrent/FutureExecutor.h new file mode 100644 index 00000000..cbf4123a --- /dev/null +++ b/folly/wangle/concurrent/FutureExecutor.h @@ -0,0 +1,79 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace folly { namespace wangle { + +template +class FutureExecutor : public ExecutorImpl { + public: + template + explicit FutureExecutor(Args&&... args) + : ExecutorImpl(std::forward(args)...) {} + + /* + * Given a function func that returns a Future, adds that function to the + * contained Executor and returns a Future which will be fulfilled with + * func's result once it has been executed. + * + * For example: auto f = futureExecutor.addFuture([](){ + * return doAsyncWorkAndReturnAFuture(); + * }); + */ + template + typename std::enable_if::type>::value, + typename std::result_of::type>::type + addFuture(F func) { + typedef typename std::result_of::type::value_type T; + Promise promise; + auto future = promise.getFuture(); + auto movePromise = folly::makeMoveWrapper(std::move(promise)); + auto moveFunc = folly::makeMoveWrapper(std::move(func)); + ExecutorImpl::add([movePromise, moveFunc] () mutable { + (*moveFunc)().then([movePromise] (Try&& t) mutable { + movePromise->setTry(std::move(t)); + }); + }); + return future; + } + + /* + * Similar to addFuture above, but takes a func that returns some non-Future + * type T. + * + * For example: auto f = futureExecutor.addFuture([]() { + * return 42; + * }); + */ + template + typename std::enable_if::type>::value, + Future::type>>::type + addFuture(F func) { + typedef typename std::result_of::type T; + Promise promise; + auto future = promise.getFuture(); + auto movePromise = folly::makeMoveWrapper(std::move(promise)); + auto moveFunc = folly::makeMoveWrapper(std::move(func)); + ExecutorImpl::add([movePromise, moveFunc] () mutable { + movePromise->setWith(std::move(*moveFunc)); + }); + return future; + } +}; + +}} diff --git a/folly/wangle/concurrent/GlobalExecutor.cpp b/folly/wangle/concurrent/GlobalExecutor.cpp new file mode 100644 index 00000000..36a82155 --- /dev/null +++ b/folly/wangle/concurrent/GlobalExecutor.cpp @@ -0,0 +1,120 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +using namespace folly; +using namespace folly::wangle; + +namespace { + +// lock protecting global CPU executor +struct CPUExecutorLock {}; +Singleton globalCPUExecutorLock; +// global CPU executor +Singleton> globalCPUExecutor; +// default global CPU executor is an InlineExecutor +Singleton> globalInlineExecutor( + []{ + return new std::shared_ptr( + std::make_shared()); + }); + +// lock protecting global IO executor +struct IOExecutorLock {}; +Singleton globalIOExecutorLock; +// global IO executor +Singleton> globalIOExecutor; +// default global IO executor is an IOThreadPoolExecutor +Singleton> globalIOThreadPool( + []{ + return new std::shared_ptr( + std::make_shared( + sysconf(_SC_NPROCESSORS_ONLN), + std::make_shared("GlobalIOThreadPool"))); + }); + +} + +namespace folly { namespace wangle { + +template +std::shared_ptr getExecutor( + Singleton>& sExecutor, + Singleton>& sDefaultExecutor, + Singleton& sExecutorLock) { + std::shared_ptr executor; + auto singleton = sExecutor.get(); + auto lock = sExecutorLock.get(); + + { + RWSpinLock::ReadHolder guard(lock); + if ((executor = sExecutor->lock())) { + return executor; + } + } + + + RWSpinLock::WriteHolder guard(lock); + executor = singleton->lock(); + if (!executor) { + executor = *sDefaultExecutor.get(); + *singleton = executor; + } + return executor; +} + +template +void setExecutor( + std::shared_ptr executor, + Singleton>& sExecutor, + Singleton& sExecutorLock) { + RWSpinLock::WriteHolder guard(sExecutorLock.get()); + *sExecutor.get() = std::move(executor); +} + +std::shared_ptr getCPUExecutor() { + return getExecutor( + globalCPUExecutor, + globalInlineExecutor, + globalCPUExecutorLock); +} + +void setCPUExecutor(std::shared_ptr executor) { + setExecutor( + std::move(executor), + globalCPUExecutor, + globalCPUExecutorLock); +} + +std::shared_ptr getIOExecutor() { + return getExecutor( + globalIOExecutor, + globalIOThreadPool, + globalIOExecutorLock); +} + +void setIOExecutor(std::shared_ptr executor) { + setExecutor( + std::move(executor), + globalIOExecutor, + globalIOExecutorLock); +} + +}} // folly::wangle diff --git a/folly/wangle/concurrent/GlobalExecutor.h b/folly/wangle/concurrent/GlobalExecutor.h new file mode 100644 index 00000000..fa7d06c6 --- /dev/null +++ b/folly/wangle/concurrent/GlobalExecutor.h @@ -0,0 +1,46 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include + +namespace folly { namespace wangle { + +// Retrieve the global Executor. If there is none, a default InlineExecutor +// will be constructed and returned. This is named CPUExecutor to distinguish +// it from IOExecutor below and to hint that it's intended for CPU-bound tasks. +std::shared_ptr getCPUExecutor(); + +// Set an Executor to be the global Executor which will be returned by +// subsequent calls to getCPUExecutor(). Takes a non-owning (weak) reference. +void setCPUExecutor(std::shared_ptr executor); + +// Retrieve the global IOExecutor. If there is none, a default +// IOThreadPoolExecutor will be constructed and returned. +// +// IOExecutors differ from Executors in that they drive and provide access to +// one or more EventBases. +std::shared_ptr getIOExecutor(); + +// Set an IOExecutor to be the global IOExecutor which will be returned by +// subsequent calls to getIOExecutor(). Takes a non-owning (weak) reference. +void setIOExecutor(std::shared_ptr executor); + +}} diff --git a/folly/wangle/concurrent/IOExecutor.h b/folly/wangle/concurrent/IOExecutor.h new file mode 100644 index 00000000..08ed94da --- /dev/null +++ b/folly/wangle/concurrent/IOExecutor.h @@ -0,0 +1,47 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace folly { +class EventBase; +} + +namespace folly { namespace wangle { + +// An IOExecutor is an executor that operates on at least one EventBase. One of +// these EventBases should be accessible via getEventBase(). The event base +// returned by a call to getEventBase() is implementation dependent. +// +// Note that IOExecutors don't necessarily loop on the base themselves - for +// instance, EventBase itself is an IOExecutor but doesn't drive itself. +// +// Implementations of IOExecutor are eligible to become the global IO executor, +// returned on every call to getIOExecutor(), via setIOExecutor(). +// These functions are declared in GlobalExecutor.h +// +// If getIOExecutor is called and none has been set, a default global +// IOThreadPoolExecutor will be created and returned. +class IOExecutor : public virtual Executor { + public: + virtual ~IOExecutor() {} + virtual EventBase* getEventBase() = 0; +}; + +}} diff --git a/folly/wangle/concurrent/IOThreadPoolExecutor.cpp b/folly/wangle/concurrent/IOThreadPoolExecutor.cpp new file mode 100644 index 00000000..f8a68e25 --- /dev/null +++ b/folly/wangle/concurrent/IOThreadPoolExecutor.cpp @@ -0,0 +1,188 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include + +namespace folly { namespace wangle { + +using folly::detail::MemoryIdler; + +/* Class that will free jemalloc caches and madvise the stack away + * if the event loop is unused for some period of time + */ +class MemoryIdlerTimeout + : public AsyncTimeout , public EventBase::LoopCallback { + public: + explicit MemoryIdlerTimeout(EventBase* b) : AsyncTimeout(b), base_(b) {} + + virtual void timeoutExpired() noexcept { + idled = true; + } + + virtual void runLoopCallback() noexcept { + if (idled) { + MemoryIdler::flushLocalMallocCaches(); + MemoryIdler::unmapUnusedStack(MemoryIdler::kDefaultStackToRetain); + + idled = false; + } else { + std::chrono::steady_clock::duration idleTimeout = + MemoryIdler::defaultIdleTimeout.load( + std::memory_order_acquire); + + idleTimeout = MemoryIdler::getVariationTimeout(idleTimeout); + + scheduleTimeout(std::chrono::duration_cast( + idleTimeout).count()); + } + + // reschedule this callback for the next event loop. + base_->runBeforeLoop(this); + } + private: + EventBase* base_; + bool idled{false}; +} ; + +IOThreadPoolExecutor::IOThreadPoolExecutor( + size_t numThreads, + std::shared_ptr threadFactory, + EventBaseManager* ebm) + : ThreadPoolExecutor(numThreads, std::move(threadFactory)), + nextThread_(0), + eventBaseManager_(ebm) { + addThreads(numThreads); + CHECK(threadList_.get().size() == numThreads); +} + +IOThreadPoolExecutor::~IOThreadPoolExecutor() { + stop(); +} + +void IOThreadPoolExecutor::add(Func func) { + add(std::move(func), std::chrono::milliseconds(0)); +} + +void IOThreadPoolExecutor::add( + Func func, + std::chrono::milliseconds expiration, + Func expireCallback) { + RWSpinLock::ReadHolder{&threadListLock_}; + if (threadList_.get().empty()) { + throw std::runtime_error("No threads available"); + } + auto ioThread = pickThread(); + + auto moveTask = folly::makeMoveWrapper( + Task(std::move(func), expiration, std::move(expireCallback))); + auto wrappedFunc = [ioThread, moveTask] () mutable { + runTask(ioThread, std::move(*moveTask)); + ioThread->pendingTasks--; + }; + + ioThread->pendingTasks++; + if (!ioThread->eventBase->runInEventBaseThread(std::move(wrappedFunc))) { + ioThread->pendingTasks--; + throw std::runtime_error("Unable to run func in event base thread"); + } +} + +std::shared_ptr +IOThreadPoolExecutor::pickThread() { + if (*thisThread_) { + return *thisThread_; + } + auto thread = threadList_.get()[nextThread_++ % threadList_.get().size()]; + return std::static_pointer_cast(thread); +} + +EventBase* IOThreadPoolExecutor::getEventBase() { + return pickThread()->eventBase; +} + +EventBase* IOThreadPoolExecutor::getEventBase( + ThreadPoolExecutor::ThreadHandle* h) { + auto thread = dynamic_cast(h); + + if (thread) { + return thread->eventBase; + } + + return nullptr; +} + +EventBaseManager* IOThreadPoolExecutor::getEventBaseManager() { + return eventBaseManager_; +} + +std::shared_ptr +IOThreadPoolExecutor::makeThread() { + return std::make_shared(this); +} + +void IOThreadPoolExecutor::threadRun(ThreadPtr thread) { + const auto ioThread = std::static_pointer_cast(thread); + ioThread->eventBase = eventBaseManager_->getEventBase(); + thisThread_.reset(new std::shared_ptr(ioThread)); + + auto idler = new MemoryIdlerTimeout(ioThread->eventBase); + ioThread->eventBase->runBeforeLoop(idler); + + thread->startupBaton.post(); + while (ioThread->shouldRun) { + ioThread->eventBase->loopForever(); + } + if (isJoin_) { + while (ioThread->pendingTasks > 0) { + ioThread->eventBase->loopOnce(); + } + } + stoppedThreads_.add(ioThread); +} + +// threadListLock_ is writelocked +void IOThreadPoolExecutor::stopThreads(size_t n) { + for (size_t i = 0; i < n; i++) { + const auto ioThread = std::static_pointer_cast( + threadList_.get()[i]); + for (auto& o : observers_) { + o->threadStopped(ioThread.get()); + } + ioThread->shouldRun = false; + ioThread->eventBase->terminateLoopSoon(); + } +} + +// threadListLock_ is readlocked +uint64_t IOThreadPoolExecutor::getPendingTaskCount() { + uint64_t count = 0; + for (const auto& thread : threadList_.get()) { + auto ioThread = std::static_pointer_cast(thread); + size_t pendingTasks = ioThread->pendingTasks; + if (pendingTasks > 0 && !ioThread->idle) { + pendingTasks--; + } + count += pendingTasks; + } + return count; +} + +}} // folly::wangle diff --git a/folly/wangle/concurrent/IOThreadPoolExecutor.h b/folly/wangle/concurrent/IOThreadPoolExecutor.h new file mode 100644 index 00000000..b298ccae --- /dev/null +++ b/folly/wangle/concurrent/IOThreadPoolExecutor.h @@ -0,0 +1,71 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace folly { namespace wangle { + +// N.B. For this thread pool, stop() behaves like join() because outstanding +// tasks belong to the event base and will be executed upon its destruction. +class IOThreadPoolExecutor : public ThreadPoolExecutor, public IOExecutor { + public: + explicit IOThreadPoolExecutor( + size_t numThreads, + std::shared_ptr threadFactory = + std::make_shared("IOThreadPool"), + EventBaseManager* ebm = folly::EventBaseManager::get()); + + ~IOThreadPoolExecutor(); + + void add(Func func) override; + void add( + Func func, + std::chrono::milliseconds expiration, + Func expireCallback = nullptr) override; + + EventBase* getEventBase() override; + + static EventBase* getEventBase(ThreadPoolExecutor::ThreadHandle*); + + EventBaseManager* getEventBaseManager(); + + private: + struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING IOThread : public Thread { + IOThread(IOThreadPoolExecutor* pool) + : Thread(pool), + shouldRun(true), + pendingTasks(0) {}; + std::atomic shouldRun; + std::atomic pendingTasks; + EventBase* eventBase; + }; + + ThreadPtr makeThread() override; + std::shared_ptr pickThread(); + void threadRun(ThreadPtr thread) override; + void stopThreads(size_t n) override; + uint64_t getPendingTaskCount() override; + + size_t nextThread_; + ThreadLocal> thisThread_; + EventBaseManager* eventBaseManager_; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/LifoSemMPMCQueue.h b/folly/wangle/concurrent/LifoSemMPMCQueue.h new file mode 100644 index 00000000..5c79cf19 --- /dev/null +++ b/folly/wangle/concurrent/LifoSemMPMCQueue.h @@ -0,0 +1,57 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +namespace folly { namespace wangle { + +template +class LifoSemMPMCQueue : public BlockingQueue { + public: + explicit LifoSemMPMCQueue(size_t max_capacity) : queue_(max_capacity) {} + + void add(T item) override { + if (!queue_.write(std::move(item))) { + throw std::runtime_error("LifoSemMPMCQueue full, can't add item"); + } + sem_.post(); + } + + T take() override { + T item; + while (!queue_.read(item)) { + sem_.wait(); + } + return item; + } + + size_t capacity() { + return queue_.capacity(); + } + + size_t size() override { + return queue_.size(); + } + + private: + LifoSem sem_; + MPMCQueue queue_; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/NamedThreadFactory.h b/folly/wangle/concurrent/NamedThreadFactory.h new file mode 100644 index 00000000..668fb2bc --- /dev/null +++ b/folly/wangle/concurrent/NamedThreadFactory.h @@ -0,0 +1,56 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +namespace folly { namespace wangle { + +class NamedThreadFactory : public ThreadFactory { + public: + explicit NamedThreadFactory(folly::StringPiece prefix) + : prefix_(prefix.str()), suffix_(0) {} + + std::thread newThread(Func&& func) override { + auto thread = std::thread(std::move(func)); + folly::setThreadName( + thread.native_handle(), + folly::to(prefix_, suffix_++)); + return thread; + } + + void setNamePrefix(folly::StringPiece prefix) { + prefix_ = prefix.str(); + } + + std::string getNamePrefix() { + return prefix_; + } + + private: + std::string prefix_; + std::atomic suffix_; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/PriorityLifoSemMPMCQueue.h b/folly/wangle/concurrent/PriorityLifoSemMPMCQueue.h new file mode 100644 index 00000000..583a9a34 --- /dev/null +++ b/folly/wangle/concurrent/PriorityLifoSemMPMCQueue.h @@ -0,0 +1,80 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +namespace folly { namespace wangle { + +template +class PriorityLifoSemMPMCQueue : public BlockingQueue { + public: + explicit PriorityLifoSemMPMCQueue(uint8_t numPriorities, size_t capacity) { + queues_.reserve(numPriorities); + for (int8_t i = 0; i < numPriorities; i++) { + queues_.push_back(MPMCQueue(capacity)); + } + } + + uint8_t getNumPriorities() override { + return queues_.size(); + } + + // Add at medium priority by default + void add(T item) override { + addWithPriority(std::move(item), Executor::MID_PRI); + } + + void addWithPriority(T item, int8_t priority) override { + int mid = getNumPriorities() / 2; + size_t queue = priority < 0 ? + std::max(0, mid + priority) : + std::min(getNumPriorities() - 1, mid + priority); + CHECK(queue < queues_.size()); + if (!queues_[queue].write(std::move(item))) { + throw std::runtime_error("LifoSemMPMCQueue full, can't add item"); + } + sem_.post(); + } + + T take() override { + T item; + while (true) { + for (auto it = queues_.rbegin(); it != queues_.rend(); it++) { + if (it->read(item)) { + return item; + } + } + sem_.wait(); + } + } + + size_t size() override { + size_t size = 0; + for (auto& q : queues_) { + size += q.size(); + } + return size; + } + + private: + LifoSem sem_; + std::vector> queues_; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/ThreadFactory.h b/folly/wangle/concurrent/ThreadFactory.h new file mode 100644 index 00000000..effd7e09 --- /dev/null +++ b/folly/wangle/concurrent/ThreadFactory.h @@ -0,0 +1,30 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#include + +namespace folly { namespace wangle { + +class ThreadFactory { + public: + virtual ~ThreadFactory() {} + virtual std::thread newThread(Func&& func) = 0; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/ThreadPoolExecutor.cpp b/folly/wangle/concurrent/ThreadPoolExecutor.cpp new file mode 100644 index 00000000..2cf4029c --- /dev/null +++ b/folly/wangle/concurrent/ThreadPoolExecutor.cpp @@ -0,0 +1,202 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace folly { namespace wangle { + +ThreadPoolExecutor::ThreadPoolExecutor( + size_t numThreads, + std::shared_ptr threadFactory) + : threadFactory_(std::move(threadFactory)), + taskStatsSubject_(std::make_shared>()) {} + +ThreadPoolExecutor::~ThreadPoolExecutor() { + CHECK(threadList_.get().size() == 0); +} + +ThreadPoolExecutor::Task::Task( + Func&& func, + std::chrono::milliseconds expiration, + Func&& expireCallback) + : func_(std::move(func)), + expiration_(expiration), + expireCallback_(std::move(expireCallback)) { + // Assume that the task in enqueued on creation + enqueueTime_ = std::chrono::steady_clock::now(); +} + +void ThreadPoolExecutor::runTask( + const ThreadPtr& thread, + Task&& task) { + thread->idle = false; + auto startTime = std::chrono::steady_clock::now(); + task.stats_.waitTime = startTime - task.enqueueTime_; + if (task.expiration_ > std::chrono::milliseconds(0) && + task.stats_.waitTime >= task.expiration_) { + task.stats_.expired = true; + if (task.expireCallback_ != nullptr) { + task.expireCallback_(); + } + } else { + try { + task.func_(); + } catch (const std::exception& e) { + LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled " << + typeid(e).name() << " exception: " << e.what(); + } catch (...) { + LOG(ERROR) << "ThreadPoolExecutor: func threw unhandled non-exception " + "object"; + } + task.stats_.runTime = std::chrono::steady_clock::now() - startTime; + } + thread->idle = true; + thread->taskStatsSubject->onNext(std::move(task.stats_)); +} + +size_t ThreadPoolExecutor::numThreads() { + RWSpinLock::ReadHolder{&threadListLock_}; + return threadList_.get().size(); +} + +void ThreadPoolExecutor::setNumThreads(size_t n) { + RWSpinLock::WriteHolder{&threadListLock_}; + const auto current = threadList_.get().size(); + if (n > current ) { + addThreads(n - current); + } else if (n < current) { + removeThreads(current - n, true); + } + CHECK(threadList_.get().size() == n); +} + +// threadListLock_ is writelocked +void ThreadPoolExecutor::addThreads(size_t n) { + std::vector newThreads; + for (size_t i = 0; i < n; i++) { + newThreads.push_back(makeThread()); + } + for (auto& thread : newThreads) { + // TODO need a notion of failing to create the thread + // and then handling for that case + thread->handle = threadFactory_->newThread( + std::bind(&ThreadPoolExecutor::threadRun, this, thread)); + threadList_.add(thread); + } + for (auto& thread : newThreads) { + thread->startupBaton.wait(); + } + for (auto& o : observers_) { + for (auto& thread : newThreads) { + o->threadStarted(thread.get()); + } + } +} + +// threadListLock_ is writelocked +void ThreadPoolExecutor::removeThreads(size_t n, bool isJoin) { + CHECK(n <= threadList_.get().size()); + CHECK(stoppedThreads_.size() == 0); + isJoin_ = isJoin; + stopThreads(n); + for (size_t i = 0; i < n; i++) { + auto thread = stoppedThreads_.take(); + thread->handle.join(); + threadList_.remove(thread); + } + CHECK(stoppedThreads_.size() == 0); +} + +void ThreadPoolExecutor::stop() { + RWSpinLock::WriteHolder{&threadListLock_}; + removeThreads(threadList_.get().size(), false); + CHECK(threadList_.get().size() == 0); +} + +void ThreadPoolExecutor::join() { + RWSpinLock::WriteHolder{&threadListLock_}; + removeThreads(threadList_.get().size(), true); + CHECK(threadList_.get().size() == 0); +} + +ThreadPoolExecutor::PoolStats ThreadPoolExecutor::getPoolStats() { + RWSpinLock::ReadHolder{&threadListLock_}; + ThreadPoolExecutor::PoolStats stats; + stats.threadCount = threadList_.get().size(); + for (auto thread : threadList_.get()) { + if (thread->idle) { + stats.idleThreadCount++; + } else { + stats.activeThreadCount++; + } + } + stats.pendingTaskCount = getPendingTaskCount(); + stats.totalTaskCount = stats.pendingTaskCount + stats.activeThreadCount; + return stats; +} + +std::atomic ThreadPoolExecutor::Thread::nextId(0); + +void ThreadPoolExecutor::StoppedThreadQueue::add( + ThreadPoolExecutor::ThreadPtr item) { + std::lock_guard guard(mutex_); + queue_.push(std::move(item)); + sem_.post(); +} + +ThreadPoolExecutor::ThreadPtr ThreadPoolExecutor::StoppedThreadQueue::take() { + while(1) { + { + std::lock_guard guard(mutex_); + if (queue_.size() > 0) { + auto item = std::move(queue_.front()); + queue_.pop(); + return item; + } + } + sem_.wait(); + } +} + +size_t ThreadPoolExecutor::StoppedThreadQueue::size() { + std::lock_guard guard(mutex_); + return queue_.size(); +} + +void ThreadPoolExecutor::addObserver(std::shared_ptr o) { + RWSpinLock::ReadHolder{&threadListLock_}; + observers_.push_back(o); + for (auto& thread : threadList_.get()) { + o->threadPreviouslyStarted(thread.get()); + } +} + +void ThreadPoolExecutor::removeObserver(std::shared_ptr o) { + RWSpinLock::ReadHolder{&threadListLock_}; + for (auto& thread : threadList_.get()) { + o->threadNotYetStopped(thread.get()); + } + + for (auto it = observers_.begin(); it != observers_.end(); it++) { + if (*it == o) { + observers_.erase(it); + return; + } + } + DCHECK(false); +} + +}} // folly::wangle diff --git a/folly/wangle/concurrent/ThreadPoolExecutor.h b/folly/wangle/concurrent/ThreadPoolExecutor.h new file mode 100644 index 00000000..c8ca7bb1 --- /dev/null +++ b/folly/wangle/concurrent/ThreadPoolExecutor.h @@ -0,0 +1,234 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +namespace folly { namespace wangle { + +class ThreadPoolExecutor : public virtual Executor { + public: + explicit ThreadPoolExecutor( + size_t numThreads, + std::shared_ptr threadFactory); + + ~ThreadPoolExecutor(); + + virtual void add(Func func) override = 0; + virtual void add( + Func func, + std::chrono::milliseconds expiration, + Func expireCallback) = 0; + + void setThreadFactory(std::shared_ptr threadFactory) { + CHECK(numThreads() == 0); + threadFactory_ = std::move(threadFactory); + } + + std::shared_ptr getThreadFactory(void) { + return threadFactory_; + } + + size_t numThreads(); + void setNumThreads(size_t numThreads); + /* + * stop() is best effort - there is no guarantee that unexecuted tasks won't + * be executed before it returns. Specifically, IOThreadPoolExecutor's stop() + * behaves like join(). + */ + void stop(); + void join(); + + struct PoolStats { + PoolStats() : threadCount(0), idleThreadCount(0), activeThreadCount(0), + pendingTaskCount(0), totalTaskCount(0) {} + size_t threadCount, idleThreadCount, activeThreadCount; + uint64_t pendingTaskCount, totalTaskCount; + }; + + PoolStats getPoolStats(); + + struct TaskStats { + TaskStats() : expired(false), waitTime(0), runTime(0) {} + bool expired; + std::chrono::nanoseconds waitTime; + std::chrono::nanoseconds runTime; + }; + + Subscription subscribeToTaskStats( + const ObserverPtr& observer) { + return taskStatsSubject_->subscribe(observer); + } + + /** + * Base class for threads created with ThreadPoolExecutor. + * Some subclasses have methods that operate on these + * handles. + */ + class ThreadHandle { + public: + virtual ~ThreadHandle() = default; + }; + + /** + * Observer interface for thread start/stop. + * Provides hooks so actions can be taken when + * threads are created + */ + class Observer { + public: + virtual void threadStarted(ThreadHandle*) = 0; + virtual void threadStopped(ThreadHandle*) = 0; + virtual void threadPreviouslyStarted(ThreadHandle* h) { + threadStarted(h); + } + virtual void threadNotYetStopped(ThreadHandle* h) { + threadStopped(h); + } + virtual ~Observer() = default; + }; + + void addObserver(std::shared_ptr); + void removeObserver(std::shared_ptr); + + protected: + // Prerequisite: threadListLock_ writelocked + void addThreads(size_t n); + // Prerequisite: threadListLock_ writelocked + void removeThreads(size_t n, bool isJoin); + + struct FOLLY_ALIGN_TO_AVOID_FALSE_SHARING Thread : public ThreadHandle { + explicit Thread(ThreadPoolExecutor* pool) + : id(nextId++), + handle(), + idle(true), + taskStatsSubject(pool->taskStatsSubject_) {} + + virtual ~Thread() {} + + static std::atomic nextId; + uint64_t id; + std::thread handle; + bool idle; + Baton<> startupBaton; + std::shared_ptr> taskStatsSubject; + }; + + typedef std::shared_ptr ThreadPtr; + + struct Task { + explicit Task( + Func&& func, + std::chrono::milliseconds expiration, + Func&& expireCallback); + Func func_; + TaskStats stats_; + std::chrono::steady_clock::time_point enqueueTime_; + std::chrono::milliseconds expiration_; + Func expireCallback_; + }; + + static void runTask(const ThreadPtr& thread, Task&& task); + + // The function that will be bound to pool threads. It must call + // thread->startupBaton.post() when it's ready to consume work. + virtual void threadRun(ThreadPtr thread) = 0; + + // Stop n threads and put their ThreadPtrs in the threadsStopped_ queue + // Prerequisite: threadListLock_ writelocked + virtual void stopThreads(size_t n) = 0; + + // Create a suitable Thread struct + virtual ThreadPtr makeThread() { + return std::make_shared(this); + } + + // Prerequisite: threadListLock_ readlocked + virtual uint64_t getPendingTaskCount() = 0; + + class ThreadList { + public: + void add(const ThreadPtr& state) { + auto it = std::lower_bound(vec_.begin(), vec_.end(), state, + // compare method is a static method of class + // and therefore cannot be inlined by compiler + // as a template predicate of the STL algorithm + // but wrapped up with the lambda function (lambda will be inlined) + // compiler can inline compare method as well + [&](const ThreadPtr& ts1, const ThreadPtr& ts2) -> bool { // inline + return compare(ts1, ts2); + }); + vec_.insert(it, state); + } + + void remove(const ThreadPtr& state) { + auto itPair = std::equal_range(vec_.begin(), vec_.end(), state, + // the same as above + [&](const ThreadPtr& ts1, const ThreadPtr& ts2) -> bool { // inline + return compare(ts1, ts2); + }); + CHECK(itPair.first != vec_.end()); + CHECK(std::next(itPair.first) == itPair.second); + vec_.erase(itPair.first); + } + + const std::vector& get() const { + return vec_; + } + + private: + static bool compare(const ThreadPtr& ts1, const ThreadPtr& ts2) { + return ts1->id < ts2->id; + } + + std::vector vec_; + }; + + class StoppedThreadQueue : public BlockingQueue { + public: + void add(ThreadPtr item) override; + ThreadPtr take() override; + size_t size() override; + + private: + LifoSem sem_; + std::mutex mutex_; + std::queue queue_; + }; + + std::shared_ptr threadFactory_; + ThreadList threadList_; + RWSpinLock threadListLock_; + StoppedThreadQueue stoppedThreads_; + std::atomic isJoin_; // whether the current downsizing is a join + + std::shared_ptr> taskStatsSubject_; + std::vector> observers_; +}; + +}} // folly::wangle diff --git a/folly/wangle/concurrent/test/CodelTest.cpp b/folly/wangle/concurrent/test/CodelTest.cpp new file mode 100644 index 00000000..fd420bac --- /dev/null +++ b/folly/wangle/concurrent/test/CodelTest.cpp @@ -0,0 +1,38 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +TEST(CodelTest, Basic) { + using std::chrono::milliseconds; + folly::wangle::Codel c; + std::this_thread::sleep_for(milliseconds(110)); + // This interval is overloaded + EXPECT_FALSE(c.overloaded(milliseconds(100))); + std::this_thread::sleep_for(milliseconds(90)); + // At least two requests must happen in an interval before they will fail + EXPECT_FALSE(c.overloaded(milliseconds(50))); + EXPECT_TRUE(c.overloaded(milliseconds(50))); + std::this_thread::sleep_for(milliseconds(110)); + // Previous interval is overloaded, but 2ms isn't enough to fail + EXPECT_FALSE(c.overloaded(milliseconds(2))); + std::this_thread::sleep_for(milliseconds(90)); + // 20 ms > target interval * 2 + EXPECT_TRUE(c.overloaded(milliseconds(20))); +} diff --git a/folly/wangle/concurrent/test/GlobalExecutorTest.cpp b/folly/wangle/concurrent/test/GlobalExecutorTest.cpp new file mode 100644 index 00000000..6fedebb7 --- /dev/null +++ b/folly/wangle/concurrent/test/GlobalExecutorTest.cpp @@ -0,0 +1,85 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +using namespace folly::wangle; + +TEST(GlobalExecutorTest, GlobalCPUExecutor) { + class DummyExecutor : public folly::Executor { + public: + void add(folly::Func f) override { + f(); + count++; + } + int count{0}; + }; + + // The default CPU executor is a synchronous inline executor, lets verify + // that work we add is executed + auto count = 0; + auto f = [&](){ count++; }; + + // Don't explode, we should create the default global CPUExecutor lazily here. + getCPUExecutor()->add(f); + EXPECT_EQ(1, count); + + { + auto dummy = std::make_shared(); + setCPUExecutor(dummy); + getCPUExecutor()->add(f); + // Make sure we were properly installed. + EXPECT_EQ(1, dummy->count); + EXPECT_EQ(2, count); + } + + // Don't explode, we should restore the default global CPUExecutor because our + // weak reference to dummy has expired + getCPUExecutor()->add(f); + EXPECT_EQ(3, count); +} + +TEST(GlobalExecutorTest, GlobalIOExecutor) { + class DummyExecutor : public IOExecutor { + public: + void add(folly::Func f) override { + count++; + } + folly::EventBase* getEventBase() override { + return nullptr; + } + int count{0}; + }; + + auto f = [](){}; + + // Don't explode, we should create the default global IOExecutor lazily here. + getIOExecutor()->add(f); + + { + auto dummy = std::make_shared(); + setIOExecutor(dummy); + getIOExecutor()->add(f); + // Make sure we were properly installed. + EXPECT_EQ(1, dummy->count); + } + + // Don't explode, we should restore the default global IOExecutor because our + // weak reference to dummy has expired + getIOExecutor()->add(f); +} diff --git a/folly/wangle/concurrent/test/ThreadPoolExecutorTest.cpp b/folly/wangle/concurrent/test/ThreadPoolExecutorTest.cpp new file mode 100644 index 00000000..8a6fcc02 --- /dev/null +++ b/folly/wangle/concurrent/test/ThreadPoolExecutorTest.cpp @@ -0,0 +1,395 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +using namespace folly; +using namespace folly::wangle; +using namespace std::chrono; + +static folly::Func burnMs(uint64_t ms) { + return [ms]() { std::this_thread::sleep_for(milliseconds(ms)); }; +} + +template +static void basic() { + // Create and destroy + TPE tpe(10); +} + +TEST(ThreadPoolExecutorTest, CPUBasic) { + basic(); +} + +TEST(IOThreadPoolExecutorTest, IOBasic) { + basic(); +} + +template +static void resize() { + TPE tpe(100); + EXPECT_EQ(100, tpe.numThreads()); + tpe.setNumThreads(50); + EXPECT_EQ(50, tpe.numThreads()); + tpe.setNumThreads(150); + EXPECT_EQ(150, tpe.numThreads()); +} + +TEST(ThreadPoolExecutorTest, CPUResize) { + resize(); +} + +TEST(ThreadPoolExecutorTest, IOResize) { + resize(); +} + +template +static void stop() { + TPE tpe(1); + std::atomic completed(0); + auto f = [&](){ + burnMs(10)(); + completed++; + }; + for (int i = 0; i < 1000; i++) { + tpe.add(f); + } + tpe.stop(); + EXPECT_GT(1000, completed); +} + +// IOThreadPoolExecutor's stop() behaves like join(). Outstanding tasks belong +// to the event base, will be executed upon its destruction, and cannot be +// taken back. +template <> +void stop() { + IOThreadPoolExecutor tpe(1); + std::atomic completed(0); + auto f = [&](){ + burnMs(10)(); + completed++; + }; + for (int i = 0; i < 10; i++) { + tpe.add(f); + } + tpe.stop(); + EXPECT_EQ(10, completed); +} + +TEST(ThreadPoolExecutorTest, CPUStop) { + stop(); +} + +TEST(ThreadPoolExecutorTest, IOStop) { + stop(); +} + +template +static void join() { + TPE tpe(10); + std::atomic completed(0); + auto f = [&](){ + burnMs(1)(); + completed++; + }; + for (int i = 0; i < 1000; i++) { + tpe.add(f); + } + tpe.join(); + EXPECT_EQ(1000, completed); +} + +TEST(ThreadPoolExecutorTest, CPUJoin) { + join(); +} + +TEST(ThreadPoolExecutorTest, IOJoin) { + join(); +} + +template +static void resizeUnderLoad() { + TPE tpe(10); + std::atomic completed(0); + auto f = [&](){ + burnMs(1)(); + completed++; + }; + for (int i = 0; i < 1000; i++) { + tpe.add(f); + } + tpe.setNumThreads(5); + tpe.setNumThreads(15); + tpe.join(); + EXPECT_EQ(1000, completed); +} + +TEST(ThreadPoolExecutorTest, CPUResizeUnderLoad) { + resizeUnderLoad(); +} + +TEST(ThreadPoolExecutorTest, IOResizeUnderLoad) { + resizeUnderLoad(); +} + +template +static void poolStats() { + folly::Baton<> startBaton, endBaton; + TPE tpe(1); + auto stats = tpe.getPoolStats(); + EXPECT_EQ(1, stats.threadCount); + EXPECT_EQ(1, stats.idleThreadCount); + EXPECT_EQ(0, stats.activeThreadCount); + EXPECT_EQ(0, stats.pendingTaskCount); + EXPECT_EQ(0, stats.totalTaskCount); + tpe.add([&](){ startBaton.post(); endBaton.wait(); }); + tpe.add([&](){}); + startBaton.wait(); + stats = tpe.getPoolStats(); + EXPECT_EQ(1, stats.threadCount); + EXPECT_EQ(0, stats.idleThreadCount); + EXPECT_EQ(1, stats.activeThreadCount); + EXPECT_EQ(1, stats.pendingTaskCount); + EXPECT_EQ(2, stats.totalTaskCount); + endBaton.post(); +} + +TEST(ThreadPoolExecutorTest, CPUPoolStats) { + poolStats(); +} + +TEST(ThreadPoolExecutorTest, IOPoolStats) { + poolStats(); +} + +template +static void taskStats() { + TPE tpe(1); + std::atomic c(0); + auto s = tpe.subscribeToTaskStats( + Observer::create( + [&](ThreadPoolExecutor::TaskStats stats) { + int i = c++; + EXPECT_LT(milliseconds(0), stats.runTime); + if (i == 1) { + EXPECT_LT(milliseconds(0), stats.waitTime); + } + })); + tpe.add(burnMs(10)); + tpe.add(burnMs(10)); + tpe.join(); + EXPECT_EQ(2, c); +} + +TEST(ThreadPoolExecutorTest, CPUTaskStats) { + taskStats(); +} + +TEST(ThreadPoolExecutorTest, IOTaskStats) { + taskStats(); +} + +template +static void expiration() { + TPE tpe(1); + std::atomic statCbCount(0); + auto s = tpe.subscribeToTaskStats( + Observer::create( + [&](ThreadPoolExecutor::TaskStats stats) { + int i = statCbCount++; + if (i == 0) { + EXPECT_FALSE(stats.expired); + } else if (i == 1) { + EXPECT_TRUE(stats.expired); + } else { + FAIL(); + } + })); + std::atomic expireCbCount(0); + auto expireCb = [&] () { expireCbCount++; }; + tpe.add(burnMs(10), seconds(60), expireCb); + tpe.add(burnMs(10), milliseconds(10), expireCb); + tpe.join(); + EXPECT_EQ(2, statCbCount); + EXPECT_EQ(1, expireCbCount); +} + +TEST(ThreadPoolExecutorTest, CPUExpiration) { + expiration(); +} + +TEST(ThreadPoolExecutorTest, IOExpiration) { + expiration(); +} + +template +static void futureExecutor() { + FutureExecutor fe(2); + std::atomic c{0}; + fe.addFuture([] () { return makeFuture(42); }).then( + [&] (Try&& t) { + c++; + EXPECT_EQ(42, t.value()); + }); + fe.addFuture([] () { return 100; }).then( + [&] (Try&& t) { + c++; + EXPECT_EQ(100, t.value()); + }); + fe.addFuture([] () { return makeFuture(); }).then( + [&] (Try&& t) { + c++; + EXPECT_NO_THROW(t.value()); + }); + fe.addFuture([] () { return; }).then( + [&] (Try&& t) { + c++; + EXPECT_NO_THROW(t.value()); + }); + fe.addFuture([] () { throw std::runtime_error("oops"); }).then( + [&] (Try&& t) { + c++; + EXPECT_THROW(t.value(), std::runtime_error); + }); + // Test doing actual async work + folly::Baton<> baton; + fe.addFuture([&] () { + auto p = std::make_shared>(); + std::thread t([p](){ + burnMs(10)(); + p->setValue(42); + }); + t.detach(); + return p->getFuture(); + }).then([&] (Try&& t) { + EXPECT_EQ(42, t.value()); + c++; + baton.post(); + }); + baton.wait(); + fe.join(); + EXPECT_EQ(6, c); +} + +TEST(ThreadPoolExecutorTest, CPUFuturePool) { + futureExecutor(); +} + +TEST(ThreadPoolExecutorTest, IOFuturePool) { + futureExecutor(); +} + +TEST(ThreadPoolExecutorTest, PriorityPreemptionTest) { + bool tookLopri = false; + auto completed = 0; + auto hipri = [&] { + EXPECT_FALSE(tookLopri); + completed++; + }; + auto lopri = [&] { + tookLopri = true; + completed++; + }; + CPUThreadPoolExecutor pool(0, 2); + for (int i = 0; i < 50; i++) { + pool.addWithPriority(lopri, Executor::LO_PRI); + } + for (int i = 0; i < 50; i++) { + pool.addWithPriority(hipri, Executor::HI_PRI); + } + pool.setNumThreads(1); + pool.join(); + EXPECT_EQ(100, completed); +} + +class TestObserver : public ThreadPoolExecutor::Observer { + public: + void threadStarted(ThreadPoolExecutor::ThreadHandle*) { + threads_++; + } + void threadStopped(ThreadPoolExecutor::ThreadHandle*) { + threads_--; + } + void threadPreviouslyStarted(ThreadPoolExecutor::ThreadHandle*) { + threads_++; + } + void threadNotYetStopped(ThreadPoolExecutor::ThreadHandle*) { + threads_--; + } + void checkCalls() { + ASSERT_EQ(threads_, 0); + } + private: + std::atomic threads_{0}; +}; + +TEST(ThreadPoolExecutorTest, IOObserver) { + auto observer = std::make_shared(); + + { + IOThreadPoolExecutor exe(10); + exe.addObserver(observer); + exe.setNumThreads(3); + exe.setNumThreads(0); + exe.setNumThreads(7); + exe.removeObserver(observer); + exe.setNumThreads(10); + } + + observer->checkCalls(); +} + +TEST(ThreadPoolExecutorTest, CPUObserver) { + auto observer = std::make_shared(); + + { + CPUThreadPoolExecutor exe(10); + exe.addObserver(observer); + exe.setNumThreads(3); + exe.setNumThreads(0); + exe.setNumThreads(7); + exe.removeObserver(observer); + exe.setNumThreads(10); + } + + observer->checkCalls(); +} + +TEST(ThreadPoolExecutorTest, AddWithPriority) { + std::atomic_int c{0}; + auto f = [&]{ c++; }; + + // IO exe doesn't support priorities + IOThreadPoolExecutor ioExe(10); + EXPECT_THROW(ioExe.addWithPriority(f, 0), std::runtime_error); + + CPUThreadPoolExecutor cpuExe(10, 3); + cpuExe.addWithPriority(f, -1); + cpuExe.addWithPriority(f, 0); + cpuExe.addWithPriority(f, 1); + cpuExe.addWithPriority(f, -2); // will add at the lowest priority + cpuExe.addWithPriority(f, 2); // will add at the highest priority + cpuExe.addWithPriority(f, Executor::LO_PRI); + cpuExe.addWithPriority(f, Executor::HI_PRI); + cpuExe.join(); + + EXPECT_EQ(7, c); +} diff --git a/folly/wangle/rx/Dummy.cpp b/folly/wangle/rx/Dummy.cpp new file mode 100644 index 00000000..ec999ca4 --- /dev/null +++ b/folly/wangle/rx/Dummy.cpp @@ -0,0 +1,19 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// fbbuild is too dumb to know that .h files in the directory affect +// our project, unless we have a .cpp file in the target, in the same +// directory. diff --git a/folly/wangle/rx/Observable.h b/folly/wangle/rx/Observable.h new file mode 100644 index 00000000..95b60bfc --- /dev/null +++ b/folly/wangle/rx/Observable.h @@ -0,0 +1,285 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // must come first +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace folly { namespace wangle { + +template +class Observable { + public: + Observable() : nextSubscriptionId_{1} {} + + // TODO perhaps we want to provide this #5283229 + Observable(Observable&& other) = delete; + + virtual ~Observable() { + if (unsubscriber_) { + unsubscriber_->disable(); + } + } + + // The next three methods subscribe the given Observer to this Observable. + // + // If these are called within an Observer callback, the new observer will not + // get the current update but will get subsequent updates. + // + // subscribe() returns a Subscription object. The observer will continue to + // get updates until the Subscription is destroyed. + // + // observe(ObserverPtr) creates an indefinite subscription + // + // observe(Observer*) also creates an indefinite subscription, but the + // caller is responsible for ensuring that the given Observer outlives this + // Observable. This might be useful in high performance environments where + // allocations must be kept to a minimum. Template parameter InlineObservers + // specifies how many observers can been subscribed inline without any + // allocations (it's just the size of a folly::small_vector). + virtual Subscription subscribe(ObserverPtr observer) { + return subscribeImpl(observer, false); + } + + virtual void observe(ObserverPtr observer) { + subscribeImpl(observer, true); + } + + virtual void observe(Observer* observer) { + if (inCallback_ && *inCallback_) { + if (!newObservers_) { + newObservers_.reset(new ObserverList()); + } + newObservers_->push_back(observer); + } else { + RWSpinLock::WriteHolder{&observersLock_}; + observers_.push_back(observer); + } + } + + // TODO unobserve(ObserverPtr), unobserve(Observer*) + + /// Returns a new Observable that will call back on the given Scheduler. + /// The returned Observable must outlive the parent Observable. + + // This and subscribeOn should maybe just be a first-class feature of an + // Observable, rather than making new ones whose lifetimes are tied to their + // parents. In that case it'd return a reference to this object for + // chaining. + ObservablePtr observeOn(SchedulerPtr scheduler) { + // you're right Hannes, if we have Observable::create we don't need this + // helper class. + struct ViaSubject : public Observable + { + ViaSubject(SchedulerPtr sched, + Observable* obs) + : scheduler_(sched), observable_(obs) + {} + + Subscription subscribe(ObserverPtr o) override { + return observable_->subscribe( + Observer::create( + [=](T val) { scheduler_->add([o, val] { o->onNext(val); }); }, + [=](Error e) { scheduler_->add([o, e] { o->onError(e); }); }, + [=]() { scheduler_->add([o] { o->onCompleted(); }); })); + } + + protected: + SchedulerPtr scheduler_; + Observable* observable_; + }; + + return std::make_shared(scheduler, this); + } + + /// Returns a new Observable that will subscribe to this parent Observable + /// via the given Scheduler. This can be subtle and confusing at first, see + /// http://www.introtorx.com/Content/v1.0.10621.0/15_SchedulingAndThreading.html#SubscribeOnObserveOn + std::unique_ptr subscribeOn(SchedulerPtr scheduler) { + struct Subject_ : public Subject { + public: + Subject_(SchedulerPtr s, Observable* o) : scheduler_(s), observable_(o) { + } + + Subscription subscribe(ObserverPtr o) { + scheduler_->add([=] { + observable_->subscribe(o); + }); + return Subscription(nullptr, 0); // TODO + } + + protected: + SchedulerPtr scheduler_; + Observable* observable_; + }; + + return folly::make_unique(scheduler, this); + } + + protected: + // Safely execute an operation on each observer. F must take a single + // Observer* as its argument. + template + void forEachObserver(F f) { + if (UNLIKELY(!inCallback_)) { + inCallback_.reset(new bool{false}); + } + CHECK(!(*inCallback_)); + *inCallback_ = true; + + { + RWSpinLock::ReadHolder rh(observersLock_); + for (auto o : observers_) { + f(o); + } + + for (auto& kv : subscribers_) { + f(kv.second.get()); + } + } + + if (UNLIKELY((newObservers_ && !newObservers_->empty()) || + (newSubscribers_ && !newSubscribers_->empty()) || + (oldSubscribers_ && !oldSubscribers_->empty()))) { + { + RWSpinLock::WriteHolder wh(observersLock_); + if (newObservers_) { + for (auto observer : *(newObservers_)) { + observers_.push_back(observer); + } + newObservers_->clear(); + } + if (newSubscribers_) { + for (auto& kv : *(newSubscribers_)) { + subscribers_.insert(std::move(kv)); + } + newSubscribers_->clear(); + } + if (oldSubscribers_) { + for (auto id : *(oldSubscribers_)) { + subscribers_.erase(id); + } + oldSubscribers_->clear(); + } + } + } + *inCallback_ = false; + } + + private: + Subscription subscribeImpl(ObserverPtr observer, bool indefinite) { + auto subscription = makeSubscription(indefinite); + typename SubscriberMap::value_type kv{subscription.id_, std::move(observer)}; + if (inCallback_ && *inCallback_) { + if (!newSubscribers_) { + newSubscribers_.reset(new SubscriberMap()); + } + newSubscribers_->insert(std::move(kv)); + } else { + RWSpinLock::WriteHolder{&observersLock_}; + subscribers_.insert(std::move(kv)); + } + return subscription; + } + + class Unsubscriber { + public: + explicit Unsubscriber(Observable* observable) : observable_(observable) { + CHECK(observable_); + } + + void unsubscribe(uint64_t id) { + CHECK(id > 0); + RWSpinLock::ReadHolder guard(lock_); + if (observable_) { + observable_->unsubscribe(id); + } + } + + void disable() { + RWSpinLock::WriteHolder guard(lock_); + observable_ = nullptr; + } + + private: + RWSpinLock lock_; + Observable* observable_; + }; + + std::shared_ptr unsubscriber_{nullptr}; + MicroSpinLock unsubscriberLock_{0}; + + friend class Subscription; + + void unsubscribe(uint64_t id) { + if (inCallback_ && *inCallback_) { + if (!oldSubscribers_) { + oldSubscribers_.reset(new std::vector()); + } + if (newSubscribers_) { + auto it = newSubscribers_->find(id); + if (it != newSubscribers_->end()) { + newSubscribers_->erase(it); + return; + } + } + oldSubscribers_->push_back(id); + } else { + RWSpinLock::WriteHolder{&observersLock_}; + subscribers_.erase(id); + } + } + + Subscription makeSubscription(bool indefinite) { + if (indefinite) { + return Subscription(nullptr, nextSubscriptionId_++); + } else { + if (!unsubscriber_) { + std::lock_guard guard(unsubscriberLock_); + if (!unsubscriber_) { + unsubscriber_ = std::make_shared(this); + } + } + return Subscription(unsubscriber_, nextSubscriptionId_++); + } + } + + std::atomic nextSubscriptionId_; + RWSpinLock observersLock_; + folly::ThreadLocalPtr inCallback_; + + typedef folly::small_vector*, InlineObservers> ObserverList; + ObserverList observers_; + folly::ThreadLocalPtr newObservers_; + + typedef std::map> SubscriberMap; + SubscriberMap subscribers_; + folly::ThreadLocalPtr newSubscribers_; + folly::ThreadLocalPtr> oldSubscribers_; +}; + +}} diff --git a/folly/wangle/rx/Observer.h b/folly/wangle/rx/Observer.h new file mode 100644 index 00000000..5797a0c1 --- /dev/null +++ b/folly/wangle/rx/Observer.h @@ -0,0 +1,113 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // must come first +#include +#include +#include +#include + +namespace folly { namespace wangle { + +template struct FunctionObserver; + +/// Observer interface. You can subclass it, or you can just use create() +/// to use std::functions. +template +struct Observer { + // These are what it means to be an Observer. + virtual void onNext(const T&) = 0; + virtual void onError(Error) = 0; + virtual void onCompleted() = 0; + + virtual ~Observer() = default; + + /// Create an Observer with std::function callbacks. Handy to make ad-hoc + /// Observers with lambdas. + /// + /// Templated for maximum perfect forwarding flexibility, but ultimately + /// whatever you pass in has to implicitly become a std::function for the + /// same signature as onNext(), onError(), and onCompleted() respectively. + /// (see the FunctionObserver typedefs) + template + static std::unique_ptr create( + N&& onNextFn, E&& onErrorFn, C&& onCompletedFn) + { + return folly::make_unique>( + std::forward(onNextFn), + std::forward(onErrorFn), + std::forward(onCompletedFn)); + } + + /// Create an Observer with only onNext and onError callbacks. + /// onCompleted will just be a no-op. + template + static std::unique_ptr create(N&& onNextFn, E&& onErrorFn) { + return folly::make_unique>( + std::forward(onNextFn), + std::forward(onErrorFn), + nullptr); + } + + /// Create an Observer with only an onNext callback. + /// onError and onCompleted will just be no-ops. + template + static std::unique_ptr create(N&& onNextFn) { + return folly::make_unique>( + std::forward(onNextFn), + nullptr, + nullptr); + } +}; + +/// An observer that uses std::function callbacks. You don't really want to +/// make one of these directly - instead use the Observer::create() methods. +template +struct FunctionObserver : public Observer { + typedef std::function OnNext; + typedef std::function OnError; + typedef std::function OnCompleted; + + /// We don't need any fancy overloads of this constructor because that's + /// what Observer::create() is for. + template + FunctionObserver(N&& n, E&& e, C&& c) + : onNext_(std::forward(n)), + onError_(std::forward(e)), + onCompleted_(std::forward(c)) + {} + + void onNext(const T& val) override { + if (onNext_) onNext_(val); + } + + void onError(Error e) override { + if (onError_) onError_(e); + } + + void onCompleted() override { + if (onCompleted_) onCompleted_(); + } + + protected: + OnNext onNext_; + OnError onError_; + OnCompleted onCompleted_; +}; + +}} diff --git a/folly/wangle/rx/README.md b/folly/wangle/rx/README.md new file mode 100644 index 00000000..0f7f6972 --- /dev/null +++ b/folly/wangle/rx/README.md @@ -0,0 +1,36 @@ +Rx is a pattern for "functional reactive programming" that started at +Microsoft in C#, and has been reimplemented in various languages, notably +RxJava for JVM languages. + +It is basically the plural of Futures (a la Wangle). + +``` + singular | plural + +---------------------------------+----------------------------------- + sync | Foo getData() | std::vector getData() + async | wangle::Future getData() | wangle::Observable getData() +``` + +For more on Rx, I recommend these resources: + +Netflix blog post (RxJava): http://techblog.netflix.com/2013/02/rxjava-netflix-api.html +Introduction to Rx eBook (C#): http://www.introtorx.com/content/v1.0.10621.0/01_WhyRx.html +The RxJava wiki: https://github.com/Netflix/RxJava/wiki +Netflix QCon presentation: http://www.infoq.com/presentations/netflix-functional-rx +https://rx.codeplex.com/ + +There are open source C++ implementations, I haven't looked at them. They +might be the best way to go rather than writing it NIH-style. I mostly did it +as an exercise, to think through how closely we might want to integrate +something like this with Wangle, and to get a feel for how it works in C++. + +I haven't even tried to support move-only data in this version. I'm on the +fence about the usage of shared_ptr. Subject is underdeveloped. A whole rich +set of operations is obviously missing. I haven't decided how to handle +subscriptions (and therefore cancellation), but I'm pretty sure C#'s +"Disposable" is thoroughly un-C++ (opposite of RAII). So for now subscribe +returns nothing at all and you can't cancel anything ever. The whole thing is +probably riddled with lifetime corner case bugs that will come out like a +swarm of angry bees as soon as someone tries an infinite sequence, or tries to +partially observe a long sequence. I'm pretty sure subscribeOn has a bug that +I haven't tracked down yet. diff --git a/folly/wangle/rx/Subject.h b/folly/wangle/rx/Subject.h new file mode 100644 index 00000000..c806d705 --- /dev/null +++ b/folly/wangle/rx/Subject.h @@ -0,0 +1,47 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // must come first +#include +#include + +namespace folly { namespace wangle { + +/// Subject interface. A Subject is both an Observable and an Observer. There +/// is a default implementation of the Observer methods that just forwards the +/// observed events to the Subject's observers. +template +struct Subject : public Observable, public Observer { + void onNext(const T& val) override { + this->forEachObserver([&](Observer* o){ + o->onNext(val); + }); + } + void onError(Error e) override { + this->forEachObserver([&](Observer* o){ + o->onError(e); + }); + } + void onCompleted() override { + this->forEachObserver([](Observer* o){ + o->onCompleted(); + }); + } +}; + +}} diff --git a/folly/wangle/rx/Subscription.h b/folly/wangle/rx/Subscription.h new file mode 100644 index 00000000..8445ccf2 --- /dev/null +++ b/folly/wangle/rx/Subscription.h @@ -0,0 +1,70 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // must come first +#include + +namespace folly { namespace wangle { + +template +class Subscription { + public: + Subscription() {} + + Subscription(const Subscription&) = delete; + + Subscription(Subscription&& other) noexcept { + *this = std::move(other); + } + + Subscription& operator=(Subscription&& other) noexcept { + unsubscribe(); + unsubscriber_ = std::move(other.unsubscriber_); + id_ = other.id_; + other.unsubscriber_ = nullptr; + other.id_ = 0; + return *this; + } + + ~Subscription() { + unsubscribe(); + } + + private: + typedef typename Observable::Unsubscriber Unsubscriber; + + Subscription(std::shared_ptr unsubscriber, uint64_t id) + : unsubscriber_(std::move(unsubscriber)), id_(id) { + CHECK(id_ > 0); + } + + void unsubscribe() { + if (unsubscriber_ && id_ > 0) { + unsubscriber_->unsubscribe(id_); + id_ = 0; + unsubscriber_ = nullptr; + } + } + + std::shared_ptr unsubscriber_; + uint64_t id_{0}; + + friend class Observable; +}; + +}} diff --git a/folly/wangle/rx/test/RxBenchmark.cpp b/folly/wangle/rx/test/RxBenchmark.cpp new file mode 100644 index 00000000..4e174942 --- /dev/null +++ b/folly/wangle/rx/test/RxBenchmark.cpp @@ -0,0 +1,155 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +using namespace folly::wangle; +using folly::BenchmarkSuspender; + +static std::unique_ptr> makeObserver() { + return Observer::create([&] (int x) {}); +} + +void subscribeImpl(uint iters, int N, bool countUnsubscribe) { + for (uint iter = 0; iter < iters; iter++) { + BenchmarkSuspender bs; + Subject subject; + std::vector>> observers; + std::vector> subscriptions; + subscriptions.reserve(N); + for (int i = 0; i < N; i++) { + observers.push_back(makeObserver()); + } + bs.dismiss(); + for (int i = 0; i < N; i++) { + subscriptions.push_back(subject.subscribe(std::move(observers[i]))); + } + if (countUnsubscribe) { + subscriptions.clear(); + } + bs.rehire(); + } +} + +void subscribeAndUnsubscribe(uint iters, int N) { + subscribeImpl(iters, N, true); +} + +void subscribe(uint iters, int N) { + subscribeImpl(iters, N, false); +} + +void observe(uint iters, int N) { + for (uint iter = 0; iter < iters; iter++) { + BenchmarkSuspender bs; + Subject subject; + std::vector>> observers; + for (int i = 0; i < N; i++) { + observers.push_back(makeObserver()); + } + bs.dismiss(); + for (int i = 0; i < N; i++) { + subject.observe(std::move(observers[i])); + } + bs.rehire(); + } +} + +void inlineObserve(uint iters, int N) { + for (uint iter = 0; iter < iters; iter++) { + BenchmarkSuspender bs; + Subject subject; + std::vector*> observers; + for (int i = 0; i < N; i++) { + observers.push_back(makeObserver().release()); + } + bs.dismiss(); + for (int i = 0; i < N; i++) { + subject.observe(observers[i]); + } + bs.rehire(); + for (int i = 0; i < N; i++) { + delete observers[i]; + } + } +} + +void notifySubscribers(uint iters, int N) { + for (uint iter = 0; iter < iters; iter++) { + BenchmarkSuspender bs; + Subject subject; + std::vector>> observers; + std::vector> subscriptions; + subscriptions.reserve(N); + for (int i = 0; i < N; i++) { + observers.push_back(makeObserver()); + } + for (int i = 0; i < N; i++) { + subscriptions.push_back(subject.subscribe(std::move(observers[i]))); + } + bs.dismiss(); + subject.onNext(42); + bs.rehire(); + } +} + +void notifyInlineObservers(uint iters, int N) { + for (uint iter = 0; iter < iters; iter++) { + BenchmarkSuspender bs; + Subject subject; + std::vector*> observers; + for (int i = 0; i < N; i++) { + observers.push_back(makeObserver().release()); + } + for (int i = 0; i < N; i++) { + subject.observe(observers[i]); + } + bs.dismiss(); + subject.onNext(42); + bs.rehire(); + } +} + +BENCHMARK_PARAM(subscribeAndUnsubscribe, 1); +BENCHMARK_RELATIVE_PARAM(subscribe, 1); +BENCHMARK_RELATIVE_PARAM(observe, 1); +BENCHMARK_RELATIVE_PARAM(inlineObserve, 1); + +BENCHMARK_DRAW_LINE(); + +BENCHMARK_PARAM(subscribeAndUnsubscribe, 1000); +BENCHMARK_RELATIVE_PARAM(subscribe, 1000); +BENCHMARK_RELATIVE_PARAM(observe, 1000); +BENCHMARK_RELATIVE_PARAM(inlineObserve, 1000); + +BENCHMARK_DRAW_LINE(); + +BENCHMARK_PARAM(notifySubscribers, 1); +BENCHMARK_RELATIVE_PARAM(notifyInlineObservers, 1); + +BENCHMARK_DRAW_LINE(); + +BENCHMARK_PARAM(notifySubscribers, 1000); +BENCHMARK_RELATIVE_PARAM(notifyInlineObservers, 1000); + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + folly::runBenchmarks(); + return 0; +} diff --git a/folly/wangle/rx/test/RxTest.cpp b/folly/wangle/rx/test/RxTest.cpp new file mode 100644 index 00000000..012a8c2f --- /dev/null +++ b/folly/wangle/rx/test/RxTest.cpp @@ -0,0 +1,195 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +using namespace folly::wangle; + +static std::unique_ptr> incrementer(int& counter) { + return Observer::create([&] (int x) { + counter++; + }); +} + +TEST(RxTest, Observe) { + Subject subject; + auto count = 0; + subject.observe(incrementer(count)); + subject.onNext(1); + EXPECT_EQ(1, count); +} + +TEST(RxTest, ObserveInline) { + Subject subject; + auto count = 0; + auto o = incrementer(count).release(); + subject.observe(o); + subject.onNext(1); + EXPECT_EQ(1, count); + delete o; +} + +TEST(RxTest, Subscription) { + Subject subject; + auto count = 0; + { + auto s = subject.subscribe(incrementer(count)); + subject.onNext(1); + } + // The subscription has gone out of scope so no one should get this. + subject.onNext(2); + EXPECT_EQ(1, count); +} + +TEST(RxTest, SubscriptionMove) { + Subject subject; + auto count = 0; + auto s = subject.subscribe(incrementer(count)); + auto s2 = subject.subscribe(incrementer(count)); + s2 = std::move(s); + subject.onNext(1); + Subscription s3(std::move(s2)); + subject.onNext(2); + EXPECT_EQ(2, count); +} + +TEST(RxTest, SubscriptionOutlivesSubject) { + Subscription s; + { + Subject subject; + s = subject.subscribe(Observer::create([](int){})); + } + // Don't explode when s is destroyed +} + +TEST(RxTest, SubscribeDuringCallback) { + // A subscriber who was subscribed in the course of a callback should get + // subsequent updates but not the current update. + Subject subject; + int outerCount = 0, innerCount = 0; + Subscription s1, s2; + s1 = subject.subscribe(Observer::create([&] (int x) { + outerCount++; + s2 = subject.subscribe(incrementer(innerCount)); + })); + subject.onNext(42); + subject.onNext(0xDEADBEEF); + EXPECT_EQ(2, outerCount); + EXPECT_EQ(1, innerCount); +} + +TEST(RxTest, ObserveDuringCallback) { + Subject subject; + int outerCount = 0, innerCount = 0; + subject.observe(Observer::create([&] (int x) { + outerCount++; + subject.observe(incrementer(innerCount)); + })); + subject.onNext(42); + subject.onNext(0xDEADBEEF); + EXPECT_EQ(2, outerCount); + EXPECT_EQ(1, innerCount); +} + +TEST(RxTest, ObserveInlineDuringCallback) { + Subject subject; + int outerCount = 0, innerCount = 0; + auto innerO = incrementer(innerCount).release(); + auto outerO = Observer::create([&] (int x) { + outerCount++; + subject.observe(innerO); + }).release(); + subject.observe(outerO); + subject.onNext(42); + subject.onNext(0xDEADBEEF); + EXPECT_EQ(2, outerCount); + EXPECT_EQ(1, innerCount); + delete innerO; + delete outerO; +} + +TEST(RxTest, UnsubscribeDuringCallback) { + // A subscriber who was unsubscribed in the course of a callback should get + // the current update but not subsequent ones + Subject subject; + int count1 = 0, count2 = 0; + auto s1 = subject.subscribe(incrementer(count1)); + auto s2 = subject.subscribe(Observer::create([&] (int x) { + count2++; + s1.~Subscription(); + })); + subject.onNext(1); + subject.onNext(2); + EXPECT_EQ(1, count1); + EXPECT_EQ(2, count2); +} + +TEST(RxTest, SubscribeUnsubscribeDuringCallback) { + // A subscriber who was subscribed and unsubscribed in the course of a + // callback should not get any updates + Subject subject; + int outerCount = 0, innerCount = 0; + auto s2 = subject.subscribe(Observer::create([&] (int x) { + outerCount++; + auto s2 = subject.subscribe(incrementer(innerCount)); + })); + subject.onNext(1); + subject.onNext(2); + EXPECT_EQ(2, outerCount); + EXPECT_EQ(0, innerCount); +} + +// Move only type +typedef std::unique_ptr MO; +static MO makeMO() { return folly::make_unique(1); } +template +static ObserverPtr makeMOObserver() { + return Observer::create([](const T& mo) { + EXPECT_EQ(1, *mo); + }); +} + +TEST(RxTest, MoveOnlyRvalue) { + Subject subject; + auto s1 = subject.subscribe(makeMOObserver()); + auto s2 = subject.subscribe(makeMOObserver()); + auto mo = makeMO(); + // Can't bind lvalues to rvalue references + // subject.onNext(mo); + subject.onNext(std::move(mo)); + subject.onNext(makeMO()); +} + +// Copy only type +struct CO { + CO() = default; + CO(const CO&) = default; + CO(CO&&) = delete; +}; + +template +static ObserverPtr makeCOObserver() { + return Observer::create([](const T& mo) {}); +} + +TEST(RxTest, CopyOnly) { + Subject subject; + auto s1 = subject.subscribe(makeCOObserver()); + CO co; + subject.onNext(co); +} diff --git a/folly/wangle/rx/types.h b/folly/wangle/rx/types.h new file mode 100644 index 00000000..3bb540e6 --- /dev/null +++ b/folly/wangle/rx/types.h @@ -0,0 +1,35 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace folly { namespace wangle { + typedef folly::exception_wrapper Error; + // The Executor is basically an rx Scheduler (by design). So just + // alias it. + typedef std::shared_ptr SchedulerPtr; + + template class Observable; + template struct Observer; + template struct Subject; + + template using ObservablePtr = std::shared_ptr>; + template using ObserverPtr = std::shared_ptr>; + template using SubjectPtr = std::shared_ptr>; +}} diff --git a/folly/wangle/service/ClientDispatcher.h b/folly/wangle/service/ClientDispatcher.h new file mode 100644 index 00000000..d6354f04 --- /dev/null +++ b/folly/wangle/service/ClientDispatcher.h @@ -0,0 +1,69 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace folly { namespace wangle { + +/** + * Dispatch a request, satisfying Promise `p` with the response; + * the returned Future is satisfied when the response is received: + * only one request is allowed at a time. + */ +template +class SerialClientDispatcher : public HandlerAdapter + , public Service { + public: + + typedef typename HandlerAdapter::Context Context; + + void setPipeline(Pipeline* pipeline) { + pipeline_ = pipeline; + pipeline->addBack(this); + pipeline->finalize(); + } + + void read(Context* ctx, Req in) override { + DCHECK(p_); + p_->setValue(std::move(in)); + p_ = none; + } + + virtual Future operator()(Req arg) override { + CHECK(!p_); + DCHECK(pipeline_); + + p_ = Promise(); + auto f = p_->getFuture(); + pipeline_->write(std::move(arg)); + return f; + } + + virtual Future close() override { + return HandlerAdapter::close(nullptr); + } + + virtual Future close(Context* ctx) override { + return HandlerAdapter::close(ctx); + } + private: + Pipeline* pipeline_{nullptr}; + folly::Optional> p_; +}; + +}} // namespace diff --git a/folly/wangle/service/ServerDispatcher.h b/folly/wangle/service/ServerDispatcher.h new file mode 100644 index 00000000..0b3167e0 --- /dev/null +++ b/folly/wangle/service/ServerDispatcher.h @@ -0,0 +1,46 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace folly { namespace wangle { + +/** + * Dispatch requests from pipeline one at a time synchronously. + * Concurrent requests are queued in the pipeline. + */ +template +class SerialServerDispatcher : public HandlerAdapter { + public: + + typedef typename HandlerAdapter::Context Context; + + explicit SerialServerDispatcher(Service* service) + : service_(service) {} + + void read(Context* ctx, Req in) override { + auto resp = (*service_)(std::move(in)).get(); + ctx->fireWrite(std::move(resp)); + } + + private: + + Service* service_; +}; + +}} // namespace diff --git a/folly/wangle/service/Service.h b/folly/wangle/service/Service.h new file mode 100644 index 00000000..7eb54241 --- /dev/null +++ b/folly/wangle/service/Service.h @@ -0,0 +1,154 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include +#include +#include +#include + +namespace folly { + +/** + * A Service is an asynchronous function from Request to + * Future. It is the basic unit of the RPC interface. + */ +template +class Service { + public: + virtual Future operator()(Req request) = 0; + virtual ~Service() {} + virtual Future close() { + return makeFuture(); + } + virtual bool isAvailable() { + return true; + } +}; + +/** + * A Filter acts as a decorator/transformer of a service. It may apply + * transformations to the input and output of that service: + * + * class MyService + * + * ReqA -> | + * | -> ReqB + * | <- RespB + * RespA <- | + * + * For example, you may have a service that takes Strings and parses + * them as Ints. If you want to expose this as a Network Service via + * Thrift, it is nice to isolate the protocol handling from the + * business rules. Hence you might have a Filter that converts back + * and forth between Thrift structs: + * + * [ThriftIn -> (String -> Int) -> ThriftOut] + */ +template +class ServiceFilter : public Service { + public: + explicit ServiceFilter(std::shared_ptr> service) + : service_(service) {} + virtual ~ServiceFilter() {} + + virtual Future close() override { + return service_->close(); + } + + virtual bool isAvailable() override { + return service_->isAvailable(); + } + + protected: + std::shared_ptr> service_; +}; + +/** + * A factory that creates services, given a client. This lets you + * make RPC calls on the Service interface over a client's pipeline. + * + * Clients can be reused after you are done using the service. + */ +template +class ServiceFactory { + public: + virtual Future>> operator()( + std::shared_ptr> client) = 0; + + virtual ~ServiceFactory() = default; + +}; + + +template +class ConstFactory : public ServiceFactory { + public: + explicit ConstFactory(std::shared_ptr> service) + : service_(service) {} + + virtual Future>> operator()( + std::shared_ptr> client) { + return service_; + } + private: + std::shared_ptr> service_; +}; + +template +class ServiceFactoryFilter : public ServiceFactory { + public: + explicit ServiceFactoryFilter( + std::shared_ptr> serviceFactory) + : serviceFactory_(std::move(serviceFactory)) {} + + virtual ~ServiceFactoryFilter() = default; + + protected: + std::shared_ptr> serviceFactory_; +}; + +template +class FactoryToService : public Service { + public: + explicit FactoryToService( + std::shared_ptr> factory) + : factory_(factory) {} + virtual ~FactoryToService() {} + + virtual Future operator()(Req request) override { + DCHECK(factory_); + return ((*factory_)(nullptr)).then( + [=](std::shared_ptr> service) + { + return (*service)(std::move(request)).ensure( + [this]() { + this->close(); + }); + }); + } + + private: + std::shared_ptr> factory_; +}; + + +} // namespace diff --git a/folly/wangle/service/ServiceTest.cpp b/folly/wangle/service/ServiceTest.cpp new file mode 100644 index 00000000..4bf37df8 --- /dev/null +++ b/folly/wangle/service/ServiceTest.cpp @@ -0,0 +1,258 @@ +/* + * Copyright 2015 Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include + +namespace folly { + +using namespace wangle; + +typedef Pipeline ServicePipeline; + +class SimpleDecode : public ByteToMessageCodec { + public: + virtual std::unique_ptr decode( + Context* ctx, IOBufQueue& buf, size_t&) { + return buf.move(); + } +}; + +class EchoService : public Service { + public: + virtual Future operator()(std::string req) override { + return req; + } +}; + +class EchoIntService : public Service { + public: + virtual Future operator()(std::string req) override { + return folly::to(req); + } +}; + +template +class ServerPipelineFactory + : public PipelineFactory { + public: + + std::unique_ptr + newPipeline(std::shared_ptr socket) override { + std::unique_ptr pipeline( + new ServicePipeline()); + pipeline->addBack(AsyncSocketHandler(socket)); + pipeline->addBack(SimpleDecode()); + pipeline->addBack(StringCodec()); + pipeline->addBack(SerialServerDispatcher(&service_)); + pipeline->finalize(); + return pipeline; + } + + private: + EchoService service_; +}; + +template +class ClientPipelineFactory : public PipelineFactory { + public: + + std::unique_ptr + newPipeline(std::shared_ptr socket) override { + std::unique_ptr pipeline( + new ServicePipeline()); + pipeline->addBack(AsyncSocketHandler(socket)); + pipeline->addBack(SimpleDecode()); + pipeline->addBack(StringCodec()); + pipeline->finalize(); + return pipeline; + } +}; + +template +class ClientServiceFactory : public ServiceFactory { + public: + class ClientService : public Service { + public: + explicit ClientService(Pipeline* pipeline) { + dispatcher_.setPipeline(pipeline); + } + Future operator()(Req request) override { + return dispatcher_(std::move(request)); + } + private: + SerialClientDispatcher dispatcher_; + }; + + Future>> operator() ( + std::shared_ptr> client) override { + return Future>>( + std::make_shared(client->getPipeline())); + } +}; + +TEST(Wangle, ClientServerTest) { + int port = 1234; + // server + + ServerBootstrap server; + server.childPipeline( + std::make_shared>()); + server.bind(port); + + // client + auto client = std::make_shared>(); + ClientServiceFactory serviceFactory; + client->pipelineFactory( + std::make_shared>()); + SocketAddress addr("127.0.0.1", port); + client->connect(addr); + auto service = serviceFactory(client).value(); + auto rep = (*service)("test"); + + rep.then([&](std::string value) { + EXPECT_EQ("test", value); + EventBaseManager::get()->getEventBase()->terminateLoopSoon(); + + }); + EventBaseManager::get()->getEventBase()->loopForever(); + server.stop(); + client.reset(); +} + +class AppendFilter : public ServiceFilter { + public: + explicit AppendFilter( + std::shared_ptr> service) : + ServiceFilter(service) {} + + virtual Future operator()(std::string req) { + return (*service_)(req + "\n"); + } +}; + +class IntToStringFilter + : public ServiceFilter { + public: + explicit IntToStringFilter( + std::shared_ptr> service) : + ServiceFilter(service) {} + + virtual Future operator()(int req) { + return (*service_)(folly::to(req)).then([](std::string resp) { + return folly::to(resp); + }); + } +}; + +TEST(Wangle, FilterTest) { + auto service = std::make_shared(); + auto filter = std::make_shared(service); + auto result = (*filter)("test"); + EXPECT_EQ(result.value(), "test\n"); +} + +TEST(Wangle, ComplexFilterTest) { + auto service = std::make_shared(); + auto filter = std::make_shared(service); + auto result = (*filter)(1); + EXPECT_EQ(result.value(), 1); +} + +class ChangeTypeFilter + : public ServiceFilter { + public: + explicit ChangeTypeFilter( + std::shared_ptr> service) : + ServiceFilter(service) {} + + virtual Future operator()(int req) { + return (*service_)(folly::to(req)).then([](int resp) { + return folly::to(resp); + }); + } +}; + +TEST(Wangle, SuperComplexFilterTest) { + auto service = std::make_shared(); + auto filter = std::make_shared(service); + auto result = (*filter)(1); + EXPECT_EQ(result.value(), "1"); +} + +template +class ConnectionCountFilter : public ServiceFactoryFilter { + public: + explicit ConnectionCountFilter( + std::shared_ptr> factory) + : ServiceFactoryFilter(factory) {} + + virtual Future>> operator()( + std::shared_ptr> client) { + connectionCount++; + return (*this->serviceFactory_)(client); + } + + int connectionCount{0}; +}; + +TEST(Wangle, ServiceFactoryFilter) { + auto clientFactory = + std::make_shared< + ClientServiceFactory>(); + auto countingFactory = + std::make_shared< + ConnectionCountFilter>( + clientFactory); + + auto client = std::make_shared>(); + + client->pipelineFactory( + std::make_shared>()); + // It doesn't matter if connect succeds or not, but it needs to be called + // to create a pipeline + client->connect(folly::SocketAddress("::1", 8090)); + + auto service = (*countingFactory)(client).value(); + + // After the first service goes away, the client can be reused + service = (*countingFactory)(client).value(); + EXPECT_EQ(2, countingFactory->connectionCount); +} + +TEST(Wangle, FactoryToService) { + auto constfactory = + std::make_shared>( + std::make_shared()); + FactoryToService service( + constfactory); + + EXPECT_EQ("test", service("test").value()); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + google::InitGoogleLogging(argv[0]); + gflags::ParseCommandLineFlags(&argc, &argv, true); + + return RUN_ALL_TESTS(); +} + +} // namespace diff --git a/folly/wangle/ssl/ClientHelloExtStats.h b/folly/wangle/ssl/ClientHelloExtStats.h new file mode 100644 index 00000000..02afbbb5 --- /dev/null +++ b/folly/wangle/ssl/ClientHelloExtStats.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +namespace folly { + +class ClientHelloExtStats { + public: + virtual ~ClientHelloExtStats() noexcept {} + + // client hello + virtual void recordAbsentHostname() noexcept = 0; + virtual void recordMatch() noexcept = 0; + virtual void recordNotMatch() noexcept = 0; +}; + +} diff --git a/folly/wangle/ssl/DHParam.h b/folly/wangle/ssl/DHParam.h new file mode 100644 index 00000000..b1965373 --- /dev/null +++ b/folly/wangle/ssl/DHParam.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include + +// The following was auto-generated by +// openssl dhparam -C 2048 +DH *get_dh2048() + { + static unsigned char dh2048_p[]={ + 0xF8,0x87,0xA5,0x15,0x98,0x35,0x20,0x1E,0xF5,0x81,0xE5,0x95, + 0x1B,0xE4,0x54,0xEA,0x53,0xF5,0xE7,0x26,0x30,0x03,0x06,0x79, + 0x3C,0xC1,0x0B,0xAD,0x3B,0x59,0x3C,0x61,0x13,0x03,0x7B,0x02, + 0x70,0xDE,0xC1,0x20,0x11,0x9E,0x94,0x13,0x50,0xF7,0x62,0xFC, + 0x99,0x0D,0xC1,0x12,0x6E,0x03,0x95,0xA3,0x57,0xC7,0x3C,0xB8, + 0x6B,0x40,0x56,0x65,0x70,0xFB,0x7A,0xE9,0x02,0xEC,0xD2,0xB6, + 0x54,0xD7,0x34,0xAD,0x3D,0x9E,0x11,0x61,0x53,0xBE,0xEA,0xB8, + 0x17,0x48,0xA8,0xDC,0x70,0xAE,0x65,0x99,0x3F,0x82,0x4C,0xFF, + 0x6A,0xC9,0xFA,0xB1,0xFA,0xE4,0x4F,0x5D,0xA4,0x05,0xC2,0x8E, + 0x55,0xC0,0xB1,0x1D,0xCC,0x17,0xF3,0xFA,0x65,0xD8,0x6B,0x09, + 0x13,0x01,0x2A,0x39,0xF1,0x86,0x73,0xE3,0x7A,0xC8,0xDB,0x7D, + 0xDA,0x1C,0xA1,0x2D,0xBA,0x2C,0x00,0x6B,0x2C,0x55,0x28,0x2B, + 0xD5,0xF5,0x3C,0x9F,0x50,0xA7,0xB7,0x28,0x9F,0x22,0xD5,0x3A, + 0xC4,0x53,0x01,0xC9,0xF3,0x69,0xB1,0x8D,0x01,0x36,0xF8,0xA8, + 0x89,0xCA,0x2E,0x72,0xBC,0x36,0x3A,0x42,0xC1,0x06,0xD6,0x0E, + 0xCB,0x4D,0x5C,0x1F,0xE4,0xA1,0x17,0xBF,0x55,0x64,0x1B,0xB4, + 0x52,0xEC,0x15,0xED,0x32,0xB1,0x81,0x07,0xC9,0x71,0x25,0xF9, + 0x4D,0x48,0x3D,0x18,0xF4,0x12,0x09,0x32,0xC4,0x0B,0x7A,0x4E, + 0x83,0xC3,0x10,0x90,0x51,0x2E,0xBE,0x87,0xF9,0xDE,0xB4,0xE6, + 0x3C,0x29,0xB5,0x32,0x01,0x9D,0x95,0x04,0xBD,0x42,0x89,0xFD, + 0x21,0xEB,0xE9,0x88,0x5A,0x27,0xBB,0x31,0xC4,0x26,0x99,0xAB, + 0x8C,0xA1,0x76,0xDB, + }; + static unsigned char dh2048_g[]={ + 0x02, + }; + DH *dh; + + if ((dh=DH_new()) == nullptr) return(nullptr); + dh->p=BN_bin2bn(dh2048_p,(int)sizeof(dh2048_p),nullptr); + dh->g=BN_bin2bn(dh2048_g,(int)sizeof(dh2048_g),nullptr); + if ((dh->p == nullptr) || (dh->g == nullptr)) + { DH_free(dh); return(nullptr); } + return(dh); + } diff --git a/folly/wangle/ssl/PasswordInFile.cpp b/folly/wangle/ssl/PasswordInFile.cpp new file mode 100644 index 00000000..3c1c4136 --- /dev/null +++ b/folly/wangle/ssl/PasswordInFile.cpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include + +using namespace std; + +namespace folly { + +PasswordInFile::PasswordInFile(const string& file) + : fileName_(file) { + folly::readFile(file.c_str(), password_); + auto p = password_.find('\0'); + if (p != std::string::npos) { + password_.erase(p); + } +} + +PasswordInFile::~PasswordInFile() { + OPENSSL_cleanse((char *)password_.data(), password_.length()); +} + +} diff --git a/folly/wangle/ssl/PasswordInFile.h b/folly/wangle/ssl/PasswordInFile.h new file mode 100644 index 00000000..7cd908e3 --- /dev/null +++ b/folly/wangle/ssl/PasswordInFile.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include // PasswordCollector + +namespace folly { + +class PasswordInFile: public folly::PasswordCollector { + public: + explicit PasswordInFile(const std::string& file); + ~PasswordInFile(); + + void getPassword(std::string& password, int size) override { + password = password_; + } + + const char* getPasswordStr() const { + return password_.c_str(); + } + + std::string describe() const override { + return fileName_; + } + + protected: + std::string fileName_; + std::string password_; +}; + +} diff --git a/folly/wangle/ssl/SSLCacheOptions.h b/folly/wangle/ssl/SSLCacheOptions.h new file mode 100644 index 00000000..02b0a5b0 --- /dev/null +++ b/folly/wangle/ssl/SSLCacheOptions.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include + +namespace folly { + +struct SSLCacheOptions { + std::chrono::seconds sslCacheTimeout; + uint64_t maxSSLCacheSize; + uint64_t sslCacheFlushSize; +}; + +} diff --git a/folly/wangle/ssl/SSLCacheProvider.h b/folly/wangle/ssl/SSLCacheProvider.h new file mode 100644 index 00000000..8817b143 --- /dev/null +++ b/folly/wangle/ssl/SSLCacheProvider.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include + +namespace folly { + +class SSLSessionCacheManager; + +/** + * Interface to be implemented by providers of external session caches + */ +class SSLCacheProvider { +public: + /** + * Context saved during an external cache request that is used to + * resume the waiting client. + */ + typedef struct { + std::string sessionId; + SSL_SESSION* session; + SSLSessionCacheManager* manager; + AsyncSSLSocket* sslSocket; + std::unique_ptr< + folly::DelayedDestruction::DestructorGuard> guard; + } CacheContext; + + virtual ~SSLCacheProvider() {} + + /** + * Store a session in the external cache. + * @param sessionId Identifier that can be used later to fetch the + * session with getAsync() + * @param value Serialized session to store + * @param expiration Relative expiration time: seconds from now + * @return true if the storing of the session is initiated successfully + * (though not necessarily completed; the completion may + * happen either before or after this method returns), or + * false if the storing cannot be initiated due to an error. + */ + virtual bool setAsync(const std::string& sessionId, + const std::string& value, + std::chrono::seconds expiration) = 0; + + /** + * Retrieve a session from the external cache. When done, call + * the cache manager's onGetSuccess() or onGetFailure() callback. + * @param sessionId Session ID to fetch + * @param context Data to pass back to the SSLSessionCacheManager + * in the completion callback + * @return true if the lookup of the session is initiated successfully + * (though not necessarily completed; the completion may + * happen either before or after this method returns), or + * false if the lookup cannot be initiated due to an error. + */ + virtual bool getAsync(const std::string& sessionId, + CacheContext* context) = 0; + +}; + +} diff --git a/folly/wangle/ssl/SSLContextConfig.h b/folly/wangle/ssl/SSLContextConfig.h new file mode 100644 index 00000000..9c6987af --- /dev/null +++ b/folly/wangle/ssl/SSLContextConfig.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include +#include + +/** + * SSLContextConfig helps to describe the configs/options for + * a SSL_CTX. For example: + * + * 1. Filename of X509, private key and its password. + * 2. ciphers list + * 3. NPN list + * 4. Is session cache enabled? + * 5. Is it the default X509 in SNI operation? + * 6. .... and a few more + */ +namespace folly { + +struct SSLContextConfig { + SSLContextConfig() {} + ~SSLContextConfig() {} + + struct CertificateInfo { + std::string certPath; + std::string keyPath; + std::string passwordPath; + }; + + /** + * Helpers to set/add a certificate + */ + void setCertificate(const std::string& certPath, + const std::string& keyPath, + const std::string& passwordPath) { + certificates.clear(); + addCertificate(certPath, keyPath, passwordPath); + } + + void addCertificate(const std::string& certPath, + const std::string& keyPath, + const std::string& passwordPath) { + certificates.emplace_back(CertificateInfo{certPath, keyPath, passwordPath}); + } + + /** + * Set the optional list of protocols to advertise via TLS + * Next Protocol Negotiation. An empty list means NPN is not enabled. + */ + void setNextProtocols(const std::list& inNextProtocols) { + nextProtocols.clear(); + nextProtocols.push_back({1, inNextProtocols}); + } + + typedef std::function SNINoMatchFn; + + std::vector certificates; + folly::SSLContext::SSLVersion sslVersion{ + folly::SSLContext::TLSv1}; + bool sessionCacheEnabled{true}; + bool sessionTicketEnabled{true}; + bool clientHelloParsingEnabled{false}; + std::string sslCiphers{ + "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:" + "ECDHE-ECDSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES128-GCM-SHA256:" + "ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-SHA:ECDHE-RSA-AES256-SHA:" + "AES128-GCM-SHA256:AES256-GCM-SHA384:AES128-SHA:AES256-SHA:" + "ECDHE-ECDSA-RC4-SHA:ECDHE-RSA-RC4-SHA:RC4-SHA:RC4-MD5:" + "ECDHE-RSA-DES-CBC3-SHA:DES-CBC3-SHA"}; + std::string eccCurveName; + // Ciphers to negotiate if TLS version >= 1.1 + std::string tls11Ciphers{""}; + // Weighted lists of NPN strings to advertise + std::list + nextProtocols; + bool isLocalPrivateKey{true}; + // Should this SSLContextConfig be the default for SNI purposes + bool isDefault{false}; + // Callback function to invoke when there are no matching certificates + // (will only be invoked once) + SNINoMatchFn sniNoMatchFn; + // File containing trusted CA's to validate client certificates + std::string clientCAFile; +}; + +} diff --git a/folly/wangle/ssl/SSLContextManager.cpp b/folly/wangle/ssl/SSLContextManager.cpp new file mode 100644 index 00000000..101dde79 --- /dev/null +++ b/folly/wangle/ssl/SSLContextManager.cpp @@ -0,0 +1,651 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#define OPENSSL_MISSING_FEATURE(name) \ +do { \ + throw std::runtime_error("missing " #name " support in openssl"); \ +} while(0) + + +using std::string; +using std::shared_ptr; + +/** + * SSLContextManager helps to create and manage all SSL_CTX, + * SSLSessionCacheManager and TLSTicketManager for a listening + * VIP:PORT. (Note, in SNI, a listening VIP:PORT can have >1 SSL_CTX(s)). + * + * Other responsibilities: + * 1. It also handles the SSL_CTX selection after getting the tlsext_hostname + * in the client hello message. + * + * Usage: + * 1. Each listening VIP:PORT serving SSL should have one SSLContextManager. + * It maps to Acceptor in the wangle vocabulary. + * + * 2. Create a SSLContextConfig object (e.g. by parsing the JSON config). + * + * 3. Call SSLContextManager::addSSLContextConfig() which will + * then create and configure the SSL_CTX + * + * Note: Each Acceptor, with SSL support, should have one SSLContextManager to + * manage all SSL_CTX for the VIP:PORT. + */ + +namespace folly { + +namespace { + +X509* getX509(SSL_CTX* ctx) { + SSL* ssl = SSL_new(ctx); + SSL_set_connect_state(ssl); + X509* x509 = SSL_get_certificate(ssl); + CRYPTO_add(&x509->references, 1, CRYPTO_LOCK_X509); + SSL_free(ssl); + return x509; +} + +void set_key_from_curve(SSL_CTX* ctx, const std::string& curveName) { +#if OPENSSL_VERSION_NUMBER >= 0x0090800fL +#ifndef OPENSSL_NO_ECDH + EC_KEY* ecdh = nullptr; + int nid; + + /* + * Elliptic-Curve Diffie-Hellman parameters are either "named curves" + * from RFC 4492 section 5.1.1, or explicitly described curves over + * binary fields. OpenSSL only supports the "named curves", which provide + * maximum interoperability. + */ + + nid = OBJ_sn2nid(curveName.c_str()); + if (nid == 0) { + LOG(FATAL) << "Unknown curve name:" << curveName.c_str(); + return; + } + ecdh = EC_KEY_new_by_curve_name(nid); + if (ecdh == nullptr) { + LOG(FATAL) << "Unable to create curve:" << curveName.c_str(); + return; + } + + SSL_CTX_set_tmp_ecdh(ctx, ecdh); + EC_KEY_free(ecdh); +#endif +#endif +} + +// Helper to create TLSTicketKeyManger and aware of the needed openssl +// version/feature. +std::unique_ptr createTicketManagerHelper( + std::shared_ptr ctx, + const TLSTicketKeySeeds* ticketSeeds, + const SSLContextConfig& ctxConfig, + SSLStats* stats) { + + std::unique_ptr ticketManager; +#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB + if (ticketSeeds && ctxConfig.sessionTicketEnabled) { + ticketManager = folly::make_unique(ctx.get(), stats); + ticketManager->setTLSTicketKeySeeds( + ticketSeeds->oldSeeds, + ticketSeeds->currentSeeds, + ticketSeeds->newSeeds); + } else { + ctx->setOptions(SSL_OP_NO_TICKET); + } +#else + if (ticketSeeds && ctxConfig.sessionTicketEnabled) { + OPENSSL_MISSING_FEATURE(TLSTicket); + } +#endif + return ticketManager; +} + +std::string flattenList(const std::list& list) { + std::string s; + bool first = true; + for (auto& item : list) { + if (first) { + first = false; + } else { + s.append(", "); + } + s.append(item); + } + return s; +} + +} + +SSLContextManager::~SSLContextManager() {} + +SSLContextManager::SSLContextManager( + EventBase* eventBase, + const std::string& vipName, + bool strict, + SSLStats* stats) : + stats_(stats), + eventBase_(eventBase), + strict_(strict) { +} + +void SSLContextManager::addSSLContextConfig( + const SSLContextConfig& ctxConfig, + const SSLCacheOptions& cacheOptions, + const TLSTicketKeySeeds* ticketSeeds, + const folly::SocketAddress& vipAddress, + const std::shared_ptr& externalCache) { + + unsigned numCerts = 0; + std::string commonName; + std::string lastCertPath; + std::unique_ptr> subjectAltName; + auto sslCtx = std::make_shared(ctxConfig.sslVersion); + for (const auto& cert : ctxConfig.certificates) { + try { + sslCtx->loadCertificate(cert.certPath.c_str()); + } catch (const std::exception& ex) { + // The exception isn't very useful without the certificate path name, + // so throw a new exception that includes the path to the certificate. + string msg = folly::to("error loading SSL certificate ", + cert.certPath, ": ", + folly::exceptionStr(ex)); + LOG(ERROR) << msg; + throw std::runtime_error(msg); + } + + // Verify that the Common Name and (if present) Subject Alternative Names + // are the same for all the certs specified for the SSL context. + numCerts++; + X509* x509 = getX509(sslCtx->getSSLCtx()); + auto guard = folly::makeGuard([x509] { X509_free(x509); }); + auto cn = SSLUtil::getCommonName(x509); + if (!cn) { + throw std::runtime_error(folly::to("Cannot get CN for X509 ", + cert.certPath)); + } + auto altName = SSLUtil::getSubjectAltName(x509); + VLOG(2) << "cert " << cert.certPath << " CN: " << *cn; + if (altName) { + altName->sort(); + VLOG(2) << "cert " << cert.certPath << " SAN: " << flattenList(*altName); + } else { + VLOG(2) << "cert " << cert.certPath << " SAN: " << "{none}"; + } + if (numCerts == 1) { + commonName = *cn; + subjectAltName = std::move(altName); + } else { + if (commonName != *cn) { + throw std::runtime_error(folly::to("X509 ", cert.certPath, + " does not have same CN as ", + lastCertPath)); + } + if (altName == nullptr) { + if (subjectAltName != nullptr) { + throw std::runtime_error(folly::to("X509 ", cert.certPath, + " does not have same SAN as ", + lastCertPath)); + } + } else { + if ((subjectAltName == nullptr) || (*altName != *subjectAltName)) { + throw std::runtime_error(folly::to("X509 ", cert.certPath, + " does not have same SAN as ", + lastCertPath)); + } + } + } + lastCertPath = cert.certPath; + + // TODO t4438250 - Add ECDSA support to the crypto_ssl offload server + // so we can avoid storing the ECDSA private key in the + // address space of the Internet-facing process. For + // now, if cert name includes "-EC" to denote elliptic + // curve, we load its private key even if the server as + // a whole has been configured for async crypto. + if (ctxConfig.isLocalPrivateKey || + (cert.certPath.find("-EC") != std::string::npos)) { + // The private key lives in the same process + + // This needs to be called before loadPrivateKey(). + if (!cert.passwordPath.empty()) { + auto sslPassword = std::make_shared(cert.passwordPath); + sslCtx->passwordCollector(sslPassword); + } + + try { + sslCtx->loadPrivateKey(cert.keyPath.c_str()); + } catch (const std::exception& ex) { + // Throw an error that includes the key path, so the user can tell + // which key had a problem. + string msg = folly::to("error loading private SSL key ", + cert.keyPath, ": ", + folly::exceptionStr(ex)); + LOG(ERROR) << msg; + throw std::runtime_error(msg); + } + } + } + if (!ctxConfig.isLocalPrivateKey) { + enableAsyncCrypto(sslCtx); + } + + // Let the server pick the highest performing cipher from among the client's + // choices. + // + // Let's use a unique private key for all DH key exchanges. + // + // Because some old implementations choke on empty fragments, most SSL + // applications disable them (it's part of SSL_OP_ALL). This + // will improve performance and decrease write buffer fragmentation. + sslCtx->setOptions(SSL_OP_CIPHER_SERVER_PREFERENCE | + SSL_OP_SINGLE_DH_USE | + SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS); + + // Configure SSL ciphers list + if (!ctxConfig.tls11Ciphers.empty()) { + // FIXME: create a dummy SSL_CTX for cipher testing purpose? It can + // remove the ordering dependency + + // Test to see if the specified TLS1.1 ciphers are valid. Note that + // these will be overwritten by the ciphers() call below. + sslCtx->setCiphersOrThrow(ctxConfig.tls11Ciphers); + } + + // Important that we do this *after* checking the TLS1.1 ciphers above, + // since we test their validity by actually setting them. + sslCtx->ciphers(ctxConfig.sslCiphers); + + // Use a fix DH param + DH* dh = get_dh2048(); + SSL_CTX_set_tmp_dh(sslCtx->getSSLCtx(), dh); + DH_free(dh); + + const string& curve = ctxConfig.eccCurveName; + if (!curve.empty()) { + set_key_from_curve(sslCtx->getSSLCtx(), curve); + } + + if (!ctxConfig.clientCAFile.empty()) { + try { + sslCtx->setVerificationOption(SSLContext::VERIFY_REQ_CLIENT_CERT); + sslCtx->loadTrustedCertificates(ctxConfig.clientCAFile.c_str()); + sslCtx->loadClientCAList(ctxConfig.clientCAFile.c_str()); + } catch (const std::exception& ex) { + string msg = folly::to("error loading client CA", + ctxConfig.clientCAFile, ": ", + folly::exceptionStr(ex)); + LOG(ERROR) << msg; + throw std::runtime_error(msg); + } + } + + // - start - SSL session cache config + // the internal cache never does what we want (per-thread-per-vip). + // Disable it. SSLSessionCacheManager will set it appropriately. + SSL_CTX_set_session_cache_mode(sslCtx->getSSLCtx(), SSL_SESS_CACHE_OFF); + SSL_CTX_set_timeout(sslCtx->getSSLCtx(), + cacheOptions.sslCacheTimeout.count()); + std::unique_ptr sessionCacheManager; + if (ctxConfig.sessionCacheEnabled && + cacheOptions.maxSSLCacheSize > 0 && + cacheOptions.sslCacheFlushSize > 0) { + sessionCacheManager = + folly::make_unique( + cacheOptions.maxSSLCacheSize, + cacheOptions.sslCacheFlushSize, + sslCtx.get(), + vipAddress, + commonName, + eventBase_, + stats_, + externalCache); + } + // - end - SSL session cache config + + std::unique_ptr ticketManager = + createTicketManagerHelper(sslCtx, ticketSeeds, ctxConfig, stats_); + + // finalize sslCtx setup by the individual features supported by openssl + ctxSetupByOpensslFeature(sslCtx, ctxConfig); + + try { + insert(sslCtx, + std::move(sessionCacheManager), + std::move(ticketManager), + ctxConfig.isDefault); + } catch (const std::exception& ex) { + string msg = folly::to("Error adding certificate : ", + folly::exceptionStr(ex)); + LOG(ERROR) << msg; + throw std::runtime_error(msg); + } + +} + +#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK +SSLContext::ServerNameCallbackResult +SSLContextManager::serverNameCallback(SSL* ssl) { + shared_ptr ctx; + + const char* sn = SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); + if (!sn) { + VLOG(6) << "Server Name (tlsext_hostname) is missing"; + if (clientHelloTLSExtStats_) { + clientHelloTLSExtStats_->recordAbsentHostname(); + } + return SSLContext::SERVER_NAME_NOT_FOUND; + } + size_t snLen = strlen(sn); + VLOG(6) << "Server Name (SNI TLS extension): '" << sn << "' "; + + // FIXME: This code breaks the abstraction. Suggestion? + AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl); + CHECK(sslSocket); + + DNString dnstr(sn, snLen); + + uint32_t count = 0; + do { + // Try exact match first + ctx = getSSLCtx(dnstr); + if (ctx) { + sslSocket->switchServerSSLContext(ctx); + if (clientHelloTLSExtStats_) { + clientHelloTLSExtStats_->recordMatch(); + } + return SSLContext::SERVER_NAME_FOUND; + } + + ctx = getSSLCtxBySuffix(dnstr); + if (ctx) { + sslSocket->switchServerSSLContext(ctx); + if (clientHelloTLSExtStats_) { + clientHelloTLSExtStats_->recordMatch(); + } + return SSLContext::SERVER_NAME_FOUND; + } + + // Give the noMatchFn one chance to add the correct cert + } + while (count++ == 0 && noMatchFn_ && noMatchFn_(sn)); + + VLOG(6) << folly::stringPrintf("Cannot find a SSL_CTX for \"%s\"", sn); + + if (clientHelloTLSExtStats_) { + clientHelloTLSExtStats_->recordNotMatch(); + } + return SSLContext::SERVER_NAME_NOT_FOUND; +} +#endif + +// Consolidate all SSL_CTX setup which depends on openssl version/feature +void +SSLContextManager::ctxSetupByOpensslFeature( + shared_ptr sslCtx, + const SSLContextConfig& ctxConfig) { + // Disable compression - profiling shows this to be very expensive in + // terms of CPU and memory consumption. + // +#ifdef SSL_OP_NO_COMPRESSION + sslCtx->setOptions(SSL_OP_NO_COMPRESSION); +#endif + + // Enable early release of SSL buffers to reduce the memory footprint +#ifdef SSL_MODE_RELEASE_BUFFERS + sslCtx->getSSLCtx()->mode |= SSL_MODE_RELEASE_BUFFERS; +#endif +#ifdef SSL_MODE_EARLY_RELEASE_BBIO + sslCtx->getSSLCtx()->mode |= SSL_MODE_EARLY_RELEASE_BBIO; +#endif + + // This number should (probably) correspond to HTTPSession::kMaxReadSize + // For now, this number must also be large enough to accommodate our + // largest certificate, because some older clients (IE6/7) require the + // cert to be in a single fragment. +#ifdef SSL_CTRL_SET_MAX_SEND_FRAGMENT + SSL_CTX_set_max_send_fragment(sslCtx->getSSLCtx(), 8000); +#endif + + // Specify cipher(s) to be used for TLS1.1 client + if (!ctxConfig.tls11Ciphers.empty()) { +#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK + // Specified TLS1.1 ciphers are valid + sslCtx->addClientHelloCallback( + std::bind( + &SSLContext::switchCiphersIfTLS11, + sslCtx.get(), + std::placeholders::_1, + ctxConfig.tls11Ciphers + ) + ); +#else + OPENSSL_MISSING_FEATURE(SNI); +#endif + } + + // NPN (Next Protocol Negotiation) + if (!ctxConfig.nextProtocols.empty()) { +#ifdef OPENSSL_NPN_NEGOTIATED + sslCtx->setRandomizedAdvertisedNextProtocols(ctxConfig.nextProtocols); +#else + OPENSSL_MISSING_FEATURE(NPN); +#endif + } + + // SNI +#ifdef PROXYGEN_HAVE_SERVERNAMECALLBACK + noMatchFn_ = ctxConfig.sniNoMatchFn; + if (ctxConfig.isDefault) { + if (defaultCtx_) { + throw std::runtime_error(">1 X509 is set as default"); + } + + defaultCtx_ = sslCtx; + defaultCtx_->setServerNameCallback( + std::bind(&SSLContextManager::serverNameCallback, this, + std::placeholders::_1)); + } +#else + if (ctxs_.size() > 1) { + OPENSSL_MISSING_FEATURE(SNI); + } +#endif +} + +void +SSLContextManager::insert(shared_ptr sslCtx, + std::unique_ptr smanager, + std::unique_ptr tmanager, + bool defaultFallback) { + X509* x509 = getX509(sslCtx->getSSLCtx()); + auto guard = folly::makeGuard([x509] { X509_free(x509); }); + auto cn = SSLUtil::getCommonName(x509); + if (!cn) { + throw std::runtime_error("Cannot get CN"); + } + + /** + * Some notes from RFC 2818. Only for future quick references in case of bugs + * + * RFC 2818 section 3.1: + * "...... + * If a subjectAltName extension of type dNSName is present, that MUST + * be used as the identity. Otherwise, the (most specific) Common Name + * field in the Subject field of the certificate MUST be used. Although + * the use of the Common Name is existing practice, it is deprecated and + * Certification Authorities are encouraged to use the dNSName instead. + * ...... + * In some cases, the URI is specified as an IP address rather than a + * hostname. In this case, the iPAddress subjectAltName must be present + * in the certificate and must exactly match the IP in the URI. + * ......" + */ + + // Not sure if we ever get this kind of X509... + // If we do, assume '*' is always in the CN and ignore all subject alternative + // names. + if (cn->length() == 1 && (*cn)[0] == '*') { + if (!defaultFallback) { + throw std::runtime_error("STAR X509 is not the default"); + } + ctxs_.emplace_back(sslCtx); + sessionCacheManagers_.emplace_back(std::move(smanager)); + ticketManagers_.emplace_back(std::move(tmanager)); + return; + } + + // Insert by CN + insertSSLCtxByDomainName(cn->c_str(), cn->length(), sslCtx); + + // Insert by subject alternative name(s) + auto altNames = SSLUtil::getSubjectAltName(x509); + if (altNames) { + for (auto& name : *altNames) { + insertSSLCtxByDomainName(name.c_str(), name.length(), sslCtx); + } + } + + ctxs_.emplace_back(sslCtx); + sessionCacheManagers_.emplace_back(std::move(smanager)); + ticketManagers_.emplace_back(std::move(tmanager)); +} + +void +SSLContextManager::insertSSLCtxByDomainName(const char* dn, size_t len, + shared_ptr sslCtx) { + try { + insertSSLCtxByDomainNameImpl(dn, len, sslCtx); + } catch (const std::runtime_error& ex) { + if (strict_) { + throw ex; + } else { + LOG(ERROR) << ex.what() << " DN=" << dn; + } + } +} +void +SSLContextManager::insertSSLCtxByDomainNameImpl(const char* dn, size_t len, + shared_ptr sslCtx) +{ + VLOG(4) << + folly::stringPrintf("Adding CN/Subject-alternative-name \"%s\" for " + "SNI search", dn); + + // Only support wildcard domains which are prefixed exactly by "*." . + // "*" appearing at other locations is not accepted. + + if (len > 2 && dn[0] == '*') { + if (dn[1] == '.') { + // skip the first '*' + dn++; + len--; + } else { + throw std::runtime_error( + "Invalid wildcard CN/subject-alternative-name \"" + std::string(dn) + "\" " + "(only allow character \".\" after \"*\""); + } + } + + if (len == 1 && *dn == '.') { + throw std::runtime_error("X509 has only '.' in the CN or subject alternative name " + "(after removing any preceding '*')"); + } + + if (strchr(dn, '*')) { + throw std::runtime_error("X509 has '*' in the the CN or subject alternative name " + "(after removing any preceding '*')"); + } + + DNString dnstr(dn, len); + const auto v = dnMap_.find(dnstr); + if (v == dnMap_.end()) { + dnMap_.emplace(dnstr, sslCtx); + } else if (v->second == sslCtx) { + VLOG(6)<< "Duplicate CN or subject alternative name found in the same X509." + " Ignore the later name."; + } else { + throw std::runtime_error("Duplicate CN or subject alternative name found: \"" + + std::string(dnstr.c_str()) + "\""); + } +} + +shared_ptr +SSLContextManager::getSSLCtxBySuffix(const DNString& dnstr) const +{ + size_t dot; + + if ((dot = dnstr.find_first_of(".")) != DNString::npos) { + DNString suffixDNStr(dnstr, dot); + const auto v = dnMap_.find(suffixDNStr); + if (v != dnMap_.end()) { + VLOG(6) << folly::stringPrintf("\"%s\" is a willcard match to \"%s\"", + dnstr.c_str(), suffixDNStr.c_str()); + return v->second; + } + } + + VLOG(6) << folly::stringPrintf("\"%s\" is not a wildcard match", + dnstr.c_str()); + return shared_ptr(); +} + +shared_ptr +SSLContextManager::getSSLCtx(const DNString& dnstr) const +{ + const auto v = dnMap_.find(dnstr); + if (v == dnMap_.end()) { + VLOG(6) << folly::stringPrintf("\"%s\" is not an exact match", + dnstr.c_str()); + return shared_ptr(); + } else { + VLOG(6) << folly::stringPrintf("\"%s\" is an exact match", dnstr.c_str()); + return v->second; + } +} + +shared_ptr +SSLContextManager::getDefaultSSLCtx() const { + return defaultCtx_; +} + +void +SSLContextManager::reloadTLSTicketKeys( + const std::vector& oldSeeds, + const std::vector& currentSeeds, + const std::vector& newSeeds) { +#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB + for (auto& tmgr: ticketManagers_) { + tmgr->setTLSTicketKeySeeds(oldSeeds, currentSeeds, newSeeds); + } +#endif +} + +} // namespace diff --git a/folly/wangle/ssl/SSLContextManager.h b/folly/wangle/ssl/SSLContextManager.h new file mode 100644 index 00000000..5877f1d4 --- /dev/null +++ b/folly/wangle/ssl/SSLContextManager.h @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace folly { + +class SocketAddress; +class SSLContext; +class ClientHelloExtStats; +struct SSLCacheOptions; +class SSLStats; +class TLSTicketKeyManager; +struct TLSTicketKeySeeds; + +class SSLContextManager { + public: + + explicit SSLContextManager(EventBase* eventBase, + const std::string& vipName, bool strict, + SSLStats* stats); + virtual ~SSLContextManager(); + + /** + * Add a new X509 to SSLContextManager. The details of a X509 + * is passed as a SSLContextConfig object. + * + * @param ctxConfig Details of a X509, its private key, password, etc. + * @param cacheOptions Options for how to do session caching. + * @param ticketSeeds If non-null, the initial ticket key seeds to use. + * @param vipAddress Which VIP are the X509(s) used for? It is only for + * for user friendly log message + * @param externalCache Optional external provider for the session cache; + * may be null + */ + void addSSLContextConfig( + const SSLContextConfig& ctxConfig, + const SSLCacheOptions& cacheOptions, + const TLSTicketKeySeeds* ticketSeeds, + const folly::SocketAddress& vipAddress, + const std::shared_ptr &externalCache); + + /** + * Get the default SSL_CTX for a VIP + */ + std::shared_ptr + getDefaultSSLCtx() const; + + /** + * Search by the _one_ level up subdomain + */ + std::shared_ptr + getSSLCtxBySuffix(const DNString& dnstr) const; + + /** + * Search by the full-string domain name + */ + std::shared_ptr + getSSLCtx(const DNString& dnstr) const; + + /** + * Insert a SSLContext by domain name. + */ + void insertSSLCtxByDomainName( + const char* dn, + size_t len, + std::shared_ptr sslCtx); + + void insertSSLCtxByDomainNameImpl( + const char* dn, + size_t len, + std::shared_ptr sslCtx); + + void reloadTLSTicketKeys(const std::vector& oldSeeds, + const std::vector& currentSeeds, + const std::vector& newSeeds); + + /** + * SSLContextManager only collects SNI stats now + */ + + void setClientHelloExtStats(ClientHelloExtStats* stats) { + clientHelloTLSExtStats_ = stats; + } + + protected: + virtual void enableAsyncCrypto( + const std::shared_ptr& sslCtx) { + LOG(FATAL) << "Unsupported in base SSLContextManager"; + } + SSLStats* stats_{nullptr}; + + private: + SSLContextManager(const SSLContextManager&) = delete; + + void ctxSetupByOpensslFeature( + std::shared_ptr sslCtx, + const SSLContextConfig& ctxConfig); + + /** + * Callback function from openssl to find the right X509 to + * use during SSL handshake + */ +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL && \ + !defined(OPENSSL_NO_TLSEXT) && \ + defined(SSL_CTRL_SET_TLSEXT_SERVERNAME_CB) +# define PROXYGEN_HAVE_SERVERNAMECALLBACK + SSLContext::ServerNameCallbackResult + serverNameCallback(SSL* ssl); +#endif + + /** + * The following functions help to maintain the data structure for + * domain name matching in SNI. Some notes: + * + * 1. It is a best match. + * + * 2. It allows wildcard CN and wildcard subject alternative name in a X509. + * The wildcard name must be _prefixed_ by '*.'. It errors out whenever + * it sees '*' in any other locations. + * + * 3. It uses one std::unordered_map object to + * do this. For wildcard name like "*.facebook.com", ".facebook.com" + * is used as the key. + * + * 4. After getting tlsext_hostname from the client hello message, it + * will do a full string search first and then try one level up to + * match any wildcard name (if any) in the X509. + * [Note, browser also only looks one level up when matching the requesting + * domain name with the wildcard name in the server X509]. + */ + + void insert( + std::shared_ptr sslCtx, + std::unique_ptr cmanager, + std::unique_ptr tManager, + bool defaultFallback); + + /** + * Container to own the SSLContext, SSLSessionCacheManager and + * TLSTicketKeyManager. + */ + std::vector> ctxs_; + std::vector> + sessionCacheManagers_; + std::vector> ticketManagers_; + + std::shared_ptr defaultCtx_; + + /** + * Container to store the (DomainName -> SSL_CTX) mapping + */ + std::unordered_map< + DNString, + std::shared_ptr, + DNStringHash> dnMap_; + + EventBase* eventBase_; + ClientHelloExtStats* clientHelloTLSExtStats_{nullptr}; + SSLContextConfig::SNINoMatchFn noMatchFn_; + bool strict_{true}; +}; + +} // namespace diff --git a/folly/wangle/ssl/SSLSessionCacheManager.cpp b/folly/wangle/ssl/SSLSessionCacheManager.cpp new file mode 100644 index 00000000..2b1f8a48 --- /dev/null +++ b/folly/wangle/ssl/SSLSessionCacheManager.cpp @@ -0,0 +1,354 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include +#include +#include + +#include + +#ifndef NO_LIB_GFLAGS +#include +#endif + +using std::string; +using std::shared_ptr; + +namespace { + +const uint32_t NUM_CACHE_BUCKETS = 16; + +// We use the default ID generator which fills the maximum ID length +// for the protocol. 16 bytes for SSLv2 or 32 for SSLv3+ +const int MIN_SESSION_ID_LENGTH = 16; + +} + +#ifndef NO_LIB_GFLAGS +DEFINE_bool(dcache_unit_test, false, "All VIPs share one session cache"); +#else +const bool FLAGS_dcache_unit_test = false; +#endif + +namespace folly { + + +int SSLSessionCacheManager::sExDataIndex_ = -1; +shared_ptr SSLSessionCacheManager::sCache_; +std::mutex SSLSessionCacheManager::sCacheLock_; + +LocalSSLSessionCache::LocalSSLSessionCache(uint32_t maxCacheSize, + uint32_t cacheCullSize) + : sessionCache(maxCacheSize, cacheCullSize) { + sessionCache.setPruneHook(std::bind( + &LocalSSLSessionCache::pruneSessionCallback, + this, std::placeholders::_1, + std::placeholders::_2)); +} + +void LocalSSLSessionCache::pruneSessionCallback(const string& sessionId, + SSL_SESSION* session) { + VLOG(4) << "Free SSL session from local cache; id=" + << SSLUtil::hexlify(sessionId); + SSL_SESSION_free(session); + ++removedSessions_; +} + + +// SSLSessionCacheManager implementation + +SSLSessionCacheManager::SSLSessionCacheManager( + uint32_t maxCacheSize, + uint32_t cacheCullSize, + SSLContext* ctx, + const folly::SocketAddress& sockaddr, + const string& context, + EventBase* eventBase, + SSLStats* stats, + const std::shared_ptr& externalCache): + ctx_(ctx), + stats_(stats), + externalCache_(externalCache) { + + SSL_CTX* sslCtx = ctx->getSSLCtx(); + + SSLUtil::getSSLCtxExIndex(&sExDataIndex_); + + SSL_CTX_set_ex_data(sslCtx, sExDataIndex_, this); + SSL_CTX_sess_set_new_cb(sslCtx, SSLSessionCacheManager::newSessionCallback); + SSL_CTX_sess_set_get_cb(sslCtx, SSLSessionCacheManager::getSessionCallback); + SSL_CTX_sess_set_remove_cb(sslCtx, + SSLSessionCacheManager::removeSessionCallback); + if (!FLAGS_dcache_unit_test && !context.empty()) { + // Use the passed in context + SSL_CTX_set_session_id_context(sslCtx, (const uint8_t *)context.data(), + std::min((int)context.length(), + SSL_MAX_SSL_SESSION_ID_LENGTH)); + } + + SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_NO_INTERNAL + | SSL_SESS_CACHE_SERVER); + + localCache_ = SSLSessionCacheManager::getLocalCache(maxCacheSize, + cacheCullSize); + + VLOG(2) << "On VipID=" << sockaddr.describe() << " context=" << context; +} + +SSLSessionCacheManager::~SSLSessionCacheManager() { +} + +void SSLSessionCacheManager::shutdown() { + std::lock_guard g(sCacheLock_); + sCache_.reset(); +} + +shared_ptr SSLSessionCacheManager::getLocalCache( + uint32_t maxCacheSize, + uint32_t cacheCullSize) { + + std::lock_guard g(sCacheLock_); + if (!sCache_) { + sCache_.reset(new ShardedLocalSSLSessionCache(NUM_CACHE_BUCKETS, + maxCacheSize, + cacheCullSize)); + } + return sCache_; +} + +int SSLSessionCacheManager::newSessionCallback(SSL* ssl, SSL_SESSION* session) { + SSLSessionCacheManager* manager = nullptr; + SSL_CTX* ctx = SSL_get_SSL_CTX(ssl); + manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_); + + if (manager == nullptr) { + LOG(FATAL) << "Null SSLSessionCacheManager in callback"; + return -1; + } + return manager->newSession(ssl, session); +} + + +int SSLSessionCacheManager::newSession(SSL* ssl, SSL_SESSION* session) { + string sessionId((char*)session->session_id, session->session_id_length); + VLOG(4) << "New SSL session; id=" << SSLUtil::hexlify(sessionId); + + if (stats_) { + stats_->recordSSLSession(true /* new session */, false, false); + } + + localCache_->storeSession(sessionId, session, stats_); + + if (externalCache_) { + VLOG(4) << "New SSL session: send session to external cache; id=" << + SSLUtil::hexlify(sessionId); + storeCacheRecord(sessionId, session); + } + + return 1; +} + +void SSLSessionCacheManager::removeSessionCallback(SSL_CTX* ctx, + SSL_SESSION* session) { + SSLSessionCacheManager* manager = nullptr; + manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_); + + if (manager == nullptr) { + LOG(FATAL) << "Null SSLSessionCacheManager in callback"; + return; + } + return manager->removeSession(ctx, session); +} + +void SSLSessionCacheManager::removeSession(SSL_CTX* ctx, + SSL_SESSION* session) { + string sessionId((char*)session->session_id, session->session_id_length); + + // This hook is only called from SSL when the internal session cache needs to + // flush sessions. Since we run with the internal cache disabled, this should + // never be called + VLOG(3) << "Remove SSL session; id=" << SSLUtil::hexlify(sessionId); + + localCache_->removeSession(sessionId); + + if (stats_) { + stats_->recordSSLSessionRemove(); + } +} + +SSL_SESSION* SSLSessionCacheManager::getSessionCallback(SSL* ssl, + unsigned char* sess_id, + int id_len, + int* copyflag) { + SSLSessionCacheManager* manager = nullptr; + SSL_CTX* ctx = SSL_get_SSL_CTX(ssl); + manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_); + + if (manager == nullptr) { + LOG(FATAL) << "Null SSLSessionCacheManager in callback"; + return nullptr; + } + return manager->getSession(ssl, sess_id, id_len, copyflag); +} + +SSL_SESSION* SSLSessionCacheManager::getSession(SSL* ssl, + unsigned char* session_id, + int id_len, + int* copyflag) { + VLOG(7) << "SSL get session callback"; + SSL_SESSION* session = nullptr; + bool foreign = false; + char const* missReason = nullptr; + + if (id_len < MIN_SESSION_ID_LENGTH) { + // We didn't generate this session so it's going to be a miss. + // This doesn't get logged or counted in the stats. + return nullptr; + } + string sessionId((char*)session_id, id_len); + + AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl); + + assert(sslSocket != nullptr); + + // look it up in the local cache first + session = localCache_->lookupSession(sessionId); +#ifdef SSL_SESSION_CB_WOULD_BLOCK + if (session == nullptr && externalCache_) { + // external cache might have the session + foreign = true; + if (!SSL_want_sess_cache_lookup(ssl)) { + missReason = "reason: No async cache support;"; + } else { + PendingLookupMap::iterator pit = pendingLookups_.find(sessionId); + if (pit == pendingLookups_.end()) { + auto result = pendingLookups_.emplace(sessionId, PendingLookup()); + // initiate fetch + VLOG(4) << "Get SSL session [Pending]: Initiate Fetch; fd=" << + sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId); + if (lookupCacheRecord(sessionId, sslSocket)) { + // response is pending + *copyflag = SSL_SESSION_CB_WOULD_BLOCK; + return nullptr; + } else { + missReason = "reason: failed to send lookup request;"; + pendingLookups_.erase(result.first); + } + } else { + // A lookup was already initiated from this thread + if (pit->second.request_in_progress) { + // Someone else initiated the request, attach + VLOG(4) << "Get SSL session [Pending]: Request in progess: attach; " + "fd=" << sslSocket->getFd() << " id=" << + SSLUtil::hexlify(sessionId); + std::unique_ptr dg( + new DelayedDestruction::DestructorGuard(sslSocket)); + pit->second.waiters.push_back( + std::make_pair(sslSocket, std::move(dg))); + *copyflag = SSL_SESSION_CB_WOULD_BLOCK; + return nullptr; + } + // request is complete + session = pit->second.session; // nullptr if our friend didn't have it + if (session != nullptr) { + CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION); + } + } + } + } +#endif + + bool hit = (session != nullptr); + if (stats_) { + stats_->recordSSLSession(false, hit, foreign); + } + if (hit) { + sslSocket->setSessionIDResumed(true); + } + + VLOG(4) << "Get SSL session [" << + ((hit) ? "Hit" : "Miss") << "]: " << + ((foreign) ? "external" : "local") << " cache; " << + ((missReason != nullptr) ? missReason : "") << "fd=" << + sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId); + + // We already bumped the refcount + *copyflag = 0; + + return session; +} + +bool SSLSessionCacheManager::storeCacheRecord(const string& sessionId, + SSL_SESSION* session) { + std::string sessionString; + uint32_t sessionLen = i2d_SSL_SESSION(session, nullptr); + sessionString.resize(sessionLen); + uint8_t* cp = (uint8_t *)sessionString.data(); + i2d_SSL_SESSION(session, &cp); + size_t expiration = SSL_CTX_get_timeout(ctx_->getSSLCtx()); + return externalCache_->setAsync(sessionId, sessionString, + std::chrono::seconds(expiration)); +} + +bool SSLSessionCacheManager::lookupCacheRecord(const string& sessionId, + AsyncSSLSocket* sslSocket) { + auto cacheCtx = new SSLCacheProvider::CacheContext(); + cacheCtx->sessionId = sessionId; + cacheCtx->session = nullptr; + cacheCtx->sslSocket = sslSocket; + cacheCtx->guard.reset( + new DelayedDestruction::DestructorGuard(cacheCtx->sslSocket)); + cacheCtx->manager = this; + bool res = externalCache_->getAsync(sessionId, cacheCtx); + if (!res) { + delete cacheCtx; + } + return res; +} + +void SSLSessionCacheManager::restartSSLAccept( + const SSLCacheProvider::CacheContext* cacheCtx) { + PendingLookupMap::iterator pit = pendingLookups_.find(cacheCtx->sessionId); + CHECK(pit != pendingLookups_.end()); + pit->second.request_in_progress = false; + pit->second.session = cacheCtx->session; + VLOG(7) << "Restart SSL accept"; + cacheCtx->sslSocket->restartSSLAccept(); + for (const auto& attachedLookup: pit->second.waiters) { + // Wake up anyone else who was waiting for this session + VLOG(4) << "Restart SSL accept (waiters) for fd=" << + attachedLookup.first->getFd(); + attachedLookup.first->restartSSLAccept(); + } + pendingLookups_.erase(pit); +} + +void SSLSessionCacheManager::onGetSuccess( + SSLCacheProvider::CacheContext* cacheCtx, + const std::string& value) { + const uint8_t* cp = (uint8_t*)value.data(); + cacheCtx->session = d2i_SSL_SESSION(nullptr, &cp, value.length()); + restartSSLAccept(cacheCtx); + + /* Insert in the LRU after restarting all clients. The stats logic + * in getSession would treat this as a local hit otherwise. + */ + localCache_->storeSession(cacheCtx->sessionId, cacheCtx->session, stats_); + delete cacheCtx; +} + +void SSLSessionCacheManager::onGetFailure( + SSLCacheProvider::CacheContext* cacheCtx) { + restartSSLAccept(cacheCtx); + delete cacheCtx; +} + +} // namespace diff --git a/folly/wangle/ssl/SSLSessionCacheManager.h b/folly/wangle/ssl/SSLSessionCacheManager.h new file mode 100644 index 00000000..977b2a2a --- /dev/null +++ b/folly/wangle/ssl/SSLSessionCacheManager.h @@ -0,0 +1,292 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include + +#include +#include +#include + +namespace folly { + +class SSLStats; + +/** + * Basic SSL session cache map: Maps session id -> session + */ +typedef folly::EvictingCacheMap SSLSessionCacheMap; + +/** + * Holds an SSLSessionCacheMap and associated lock + */ +class LocalSSLSessionCache: private boost::noncopyable { + public: + LocalSSLSessionCache(uint32_t maxCacheSize, uint32_t cacheCullSize); + + ~LocalSSLSessionCache() { + std::lock_guard g(lock); + // EvictingCacheMap dtor doesn't free values + sessionCache.clear(); + } + + SSLSessionCacheMap sessionCache; + std::mutex lock; + uint32_t removedSessions_{0}; + + private: + + void pruneSessionCallback(const std::string& sessionId, + SSL_SESSION* session); +}; + +/** + * A sharded LRU for SSL sessions. The sharding is inteneded to reduce + * contention for the LRU locks. Assuming uniform distribution, two workers + * will contend for the same lock with probability 1 / n_buckets^2. + */ +class ShardedLocalSSLSessionCache : private boost::noncopyable { + public: + ShardedLocalSSLSessionCache(uint32_t n_buckets, uint32_t maxCacheSize, + uint32_t cacheCullSize) { + CHECK(n_buckets > 0); + maxCacheSize = (uint32_t)(((double)maxCacheSize) / n_buckets); + cacheCullSize = (uint32_t)(((double)cacheCullSize) / n_buckets); + if (maxCacheSize == 0) { + maxCacheSize = 1; + } + if (cacheCullSize == 0) { + cacheCullSize = 1; + } + for (uint32_t i = 0; i < n_buckets; i++) { + caches_.push_back( + std::unique_ptr( + new LocalSSLSessionCache(maxCacheSize, cacheCullSize))); + } + } + + SSL_SESSION* lookupSession(const std::string& sessionId) { + size_t bucket = hash(sessionId); + SSL_SESSION* session = nullptr; + std::lock_guard g(caches_[bucket]->lock); + + auto itr = caches_[bucket]->sessionCache.find(sessionId); + if (itr != caches_[bucket]->sessionCache.end()) { + session = itr->second; + } + + if (session) { + CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION); + } + return session; + } + + void storeSession(const std::string& sessionId, SSL_SESSION* session, + SSLStats* stats) { + size_t bucket = hash(sessionId); + SSL_SESSION* oldSession = nullptr; + std::lock_guard g(caches_[bucket]->lock); + + auto itr = caches_[bucket]->sessionCache.find(sessionId); + if (itr != caches_[bucket]->sessionCache.end()) { + oldSession = itr->second; + } + + if (oldSession) { + // LRUCacheMap doesn't free on overwrite, so 2x the work for us + // This can happen in race conditions + SSL_SESSION_free(oldSession); + } + caches_[bucket]->removedSessions_ = 0; + caches_[bucket]->sessionCache.set(sessionId, session, true); + if (stats) { + stats->recordSSLSessionFree(caches_[bucket]->removedSessions_); + } + } + + void removeSession(const std::string& sessionId) { + size_t bucket = hash(sessionId); + std::lock_guard g(caches_[bucket]->lock); + caches_[bucket]->sessionCache.erase(sessionId); + } + + private: + + /* SSL session IDs are 32 bytes of random data, hash based on first 16 bits */ + size_t hash(const std::string& key) { + CHECK(key.length() >= 2); + return (key[0] << 8 | key[1]) % caches_.size(); + } + + std::vector< std::unique_ptr > caches_; +}; + +/* A socket/DestructorGuard pair */ +typedef std::pair> + AttachedLookup; + +/** + * PendingLookup structure + * + * Keeps track of clients waiting for an SSL session to be retrieved from + * the external cache provider. + */ +struct PendingLookup { + bool request_in_progress; + SSL_SESSION* session; + std::list waiters; + + PendingLookup() { + request_in_progress = true; + session = nullptr; + } +}; + +/* Maps SSL session id to a PendingLookup structure */ +typedef std::map PendingLookupMap; + +/** + * SSLSessionCacheManager handles all stateful session caching. There is an + * instance of this object per SSL VIP per thread, with a 1:1 correlation with + * SSL_CTX. The cache can work locally or in concert with an external cache + * to share sessions across instances. + * + * There is a single in memory session cache shared by all VIPs. The cache is + * split into N buckets (currently 16) with a separate lock per bucket. The + * VIP ID is hashed and stored as part of the session to handle the + * (very unlikely) case of session ID collision. + * + * When a new SSL session is created, it is added to the LRU cache and + * sent to the external cache to be stored. The external cache + * expiration is equal to the SSL session's expiration. + * + * When a resume request is received, SSLSessionCacheManager first looks in the + * local LRU cache for the VIP. If there is a miss there, an asynchronous + * request for this session is dispatched to the external cache. When the + * external cache query returns, the LRU cache is updated if the session was + * found, and the SSL_accept call is resumed. + * + * If additional resume requests for the same session ID arrive in the same + * thread while the request is pending, the 2nd - Nth callers attach to the + * original external cache requests and are resumed when it comes back. No + * attempt is made to coalesce external cache requests for the same session + * ID in different worker threads. Previous work did this, but the + * complexity was deemed to outweigh the potential savings. + * + */ +class SSLSessionCacheManager : private boost::noncopyable { + public: + /** + * Constructor. SSL session related callbacks will be set on the underlying + * SSL_CTX. vipId is assumed to a unique string identifying the VIP and must + * be the same on all servers that wish to share sessions via the same + * external cache. + */ + SSLSessionCacheManager( + uint32_t maxCacheSize, + uint32_t cacheCullSize, + SSLContext* ctx, + const folly::SocketAddress& sockaddr, + const std::string& context, + EventBase* eventBase, + SSLStats* stats, + const std::shared_ptr& externalCache); + + virtual ~SSLSessionCacheManager(); + + /** + * Call this on shutdown to release the global instance of the + * ShardedLocalSSLSessionCache. + */ + static void shutdown(); + + /** + * Callback for ExternalCache to call when an async get succeeds + * @param context The context that was passed to the async get request + * @param value Serialized session + */ + void onGetSuccess(SSLCacheProvider::CacheContext* context, + const std::string& value); + + /** + * Callback for ExternalCache to call when an async get fails, either + * because the requested session is not in the external cache or because + * of an error. + * @param context The context that was passed to the async get request + */ + void onGetFailure(SSLCacheProvider::CacheContext* context); + + private: + + SSLContext* ctx_; + std::shared_ptr localCache_; + PendingLookupMap pendingLookups_; + SSLStats* stats_{nullptr}; + std::shared_ptr externalCache_; + + /** + * Invoked by openssl when a new SSL session is created + */ + int newSession(SSL* ssl, SSL_SESSION* session); + + /** + * Invoked by openssl when an SSL session is ejected from its internal cache. + * This can't be invoked in the current implementation because SSL's internal + * caching is disabled. + */ + void removeSession(SSL_CTX* ctx, SSL_SESSION* session); + + /** + * Invoked by openssl when a client requests a stateful session resumption. + * Triggers a lookup in our local cache and potentially an asynchronous + * request to an external cache. + */ + SSL_SESSION* getSession(SSL* ssl, unsigned char* session_id, + int id_len, int* copyflag); + + /** + * Store a new session record in the external cache + */ + bool storeCacheRecord(const std::string& sessionId, SSL_SESSION* session); + + /** + * Lookup a session in the external cache for the specified SSL socket. + */ + bool lookupCacheRecord(const std::string& sessionId, + AsyncSSLSocket* sslSock); + + /** + * Restart all clients waiting for the answer to an external cache query + */ + void restartSSLAccept(const SSLCacheProvider::CacheContext* cacheCtx); + + /** + * Get or create the LRU cache for the given VIP ID + */ + static std::shared_ptr getLocalCache( + uint32_t maxCacheSize, uint32_t cacheCullSize); + + /** + * static functions registered as callbacks to openssl via + * SSL_CTX_sess_set_new/get/remove_cb + */ + static int newSessionCallback(SSL* ssl, SSL_SESSION* session); + static void removeSessionCallback(SSL_CTX* ctx, SSL_SESSION* session); + static SSL_SESSION* getSessionCallback(SSL* ssl, unsigned char* session_id, + int id_len, int* copyflag); + + static int32_t sExDataIndex_; + static std::shared_ptr sCache_; + static std::mutex sCacheLock_; +}; + +} diff --git a/folly/wangle/ssl/SSLStats.h b/folly/wangle/ssl/SSLStats.h new file mode 100644 index 00000000..0cfc9f22 --- /dev/null +++ b/folly/wangle/ssl/SSLStats.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +namespace folly { + +class SSLStats { + public: + virtual ~SSLStats() noexcept {} + + // downstream + virtual void recordSSLAcceptLatency(int64_t latency) noexcept = 0; + virtual void recordTLSTicket(bool ticketNew, bool ticketHit) noexcept = 0; + virtual void recordSSLSession(bool sessionNew, bool sessionHit, bool foreign) + noexcept = 0; + virtual void recordSSLSessionRemove() noexcept = 0; + virtual void recordSSLSessionFree(uint32_t freed) noexcept = 0; + virtual void recordSSLSessionSetError(uint32_t err) noexcept = 0; + virtual void recordSSLSessionGetError(uint32_t err) noexcept = 0; + virtual void recordClientRenegotiation() noexcept = 0; + + // upstream + virtual void recordSSLUpstreamConnection(bool handshake) noexcept = 0; + virtual void recordSSLUpstreamConnectionError(bool verifyError) noexcept = 0; + virtual void recordCryptoSSLExternalAttempt() noexcept = 0; + virtual void recordCryptoSSLExternalConnAlreadyClosed() noexcept = 0; + virtual void recordCryptoSSLExternalApplicationException() noexcept = 0; + virtual void recordCryptoSSLExternalSuccess() noexcept = 0; + virtual void recordCryptoSSLExternalDuration(uint64_t duration) noexcept = 0; + virtual void recordCryptoSSLLocalAttempt() noexcept = 0; + virtual void recordCryptoSSLLocalSuccess() noexcept = 0; + +}; + +} diff --git a/folly/wangle/ssl/SSLUtil.cpp b/folly/wangle/ssl/SSLUtil.cpp new file mode 100644 index 00000000..f900003d --- /dev/null +++ b/folly/wangle/ssl/SSLUtil.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include + +#if OPENSSL_VERSION_NUMBER >= 0x1000105fL +#define OPENSSL_GE_101 1 +#include +#include +#else +#undef OPENSSL_GE_101 +#endif + +namespace folly { + +std::mutex SSLUtil::sIndexLock_; + +std::unique_ptr SSLUtil::getCommonName(const X509* cert) { + X509_NAME* subject = X509_get_subject_name((X509*)cert); + if (!subject) { + return nullptr; + } + char cn[ub_common_name + 1]; + int res = X509_NAME_get_text_by_NID(subject, NID_commonName, + cn, ub_common_name); + if (res <= 0) { + return nullptr; + } else { + cn[ub_common_name] = '\0'; + return folly::make_unique(cn); + } +} + +std::unique_ptr> SSLUtil::getSubjectAltName( + const X509* cert) { +#ifdef OPENSSL_GE_101 + auto nameList = folly::make_unique>(); + GENERAL_NAMES* names = (GENERAL_NAMES*)X509_get_ext_d2i( + (X509*)cert, NID_subject_alt_name, nullptr, nullptr); + if (names) { + auto guard = folly::makeGuard([names] { GENERAL_NAMES_free(names); }); + size_t count = sk_GENERAL_NAME_num(names); + CHECK(count < std::numeric_limits::max()); + for (int i = 0; i < (int)count; ++i) { + GENERAL_NAME* generalName = sk_GENERAL_NAME_value(names, i); + if (generalName->type == GEN_DNS) { + ASN1_STRING* s = generalName->d.dNSName; + const char* name = (const char*)ASN1_STRING_data(s); + // I can't find any docs on what a negative return value here + // would mean, so I'm going to ignore it. + auto len = ASN1_STRING_length(s); + DCHECK(len >= 0); + if (size_t(len) != strlen(name)) { + // Null byte(s) in the name; return an error rather than depending on + // the caller to safely handle this case. + return nullptr; + } + nameList->emplace_back(name); + } + } + } + return nameList; +#else + return nullptr; +#endif +} + +} diff --git a/folly/wangle/ssl/SSLUtil.h b/folly/wangle/ssl/SSLUtil.h new file mode 100644 index 00000000..1e5b720a --- /dev/null +++ b/folly/wangle/ssl/SSLUtil.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include +#include + +namespace folly { + +/** + * SSL session establish/resume status + * + * changing these values will break logging pipelines + */ +enum class SSLResumeEnum : uint8_t { + HANDSHAKE = 0, + RESUME_SESSION_ID = 1, + RESUME_TICKET = 3, + NA = 2 +}; + +enum class SSLErrorEnum { + NO_ERROR, + TIMEOUT, + DROPPED +}; + +class SSLUtil { + private: + static std::mutex sIndexLock_; + + public: + /** + * Ensures only one caller will allocate an ex_data index for a given static + * or global. + */ + static void getSSLCtxExIndex(int* pindex) { + std::lock_guard g(sIndexLock_); + if (*pindex < 0) { + *pindex = SSL_CTX_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + } + } + + static void getRSAExIndex(int* pindex) { + std::lock_guard g(sIndexLock_); + if (*pindex < 0) { + *pindex = RSA_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + } + } + + static inline std::string hexlify(const std::string& binary) { + std::string hex; + folly::hexlify(binary, hex); + + return hex; + } + + static inline const std::string& hexlify(const std::string& binary, + std::string& hex) { + folly::hexlify(binary, hex); + + return hex; + } + + /** + * Return the SSL resume type for the given socket. + */ + static inline SSLResumeEnum getResumeState( + AsyncSSLSocket* sslSocket) { + return sslSocket->getSSLSessionReused() ? + (sslSocket->sessionIDResumed() ? + SSLResumeEnum::RESUME_SESSION_ID : + SSLResumeEnum::RESUME_TICKET) : + SSLResumeEnum::HANDSHAKE; + } + + /** + * Get the Common Name from an X.509 certificate + * @param cert certificate to inspect + * @return common name, or null if an error occurs + */ + static std::unique_ptr getCommonName(const X509* cert); + + /** + * Get the Subject Alternative Name value(s) from an X.509 certificate + * @param cert certificate to inspect + * @return set of zero or more alternative names, or null if + * an error occurs + */ + static std::unique_ptr> getSubjectAltName( + const X509* cert); +}; + +} diff --git a/folly/wangle/ssl/TLSTicketKeyManager.cpp b/folly/wangle/ssl/TLSTicketKeyManager.cpp new file mode 100644 index 00000000..45761c11 --- /dev/null +++ b/folly/wangle/ssl/TLSTicketKeyManager.cpp @@ -0,0 +1,305 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include + +#include +#include + +#include +#include +#include +#include +#include + +#ifdef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB +using std::string; + +namespace { + +const int kTLSTicketKeyNameLen = 4; +const int kTLSTicketKeySaltLen = 12; + +} + +namespace folly { + + +// TLSTicketKeyManager Implementation +int32_t TLSTicketKeyManager::sExDataIndex_ = -1; + +TLSTicketKeyManager::TLSTicketKeyManager(SSLContext* ctx, SSLStats* stats) + : ctx_(ctx), + randState_(0), + stats_(stats) { + SSLUtil::getSSLCtxExIndex(&sExDataIndex_); + SSL_CTX_set_ex_data(ctx_->getSSLCtx(), sExDataIndex_, this); +} + +TLSTicketKeyManager::~TLSTicketKeyManager() { +} + +int +TLSTicketKeyManager::callback(SSL* ssl, unsigned char* keyName, + unsigned char* iv, + EVP_CIPHER_CTX* cipherCtx, + HMAC_CTX* hmacCtx, int encrypt) { + TLSTicketKeyManager* manager = nullptr; + SSL_CTX* ctx = SSL_get_SSL_CTX(ssl); + manager = (TLSTicketKeyManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_); + + if (manager == nullptr) { + LOG(FATAL) << "Null TLSTicketKeyManager in callback" ; + return -1; + } + return manager->processTicket(ssl, keyName, iv, cipherCtx, hmacCtx, encrypt); +} + +int +TLSTicketKeyManager::processTicket(SSL* ssl, unsigned char* keyName, + unsigned char* iv, + EVP_CIPHER_CTX* cipherCtx, + HMAC_CTX* hmacCtx, int encrypt) { + uint8_t salt[kTLSTicketKeySaltLen]; + uint8_t* saltptr = nullptr; + uint8_t output[SHA256_DIGEST_LENGTH]; + uint8_t* hmacKey = nullptr; + uint8_t* aesKey = nullptr; + TLSTicketKeySource* key = nullptr; + int result = 0; + + if (encrypt) { + key = findEncryptionKey(); + if (key == nullptr) { + // no keys available to encrypt + VLOG(2) << "No TLS ticket key found"; + return -1; + } + VLOG(4) << "Encrypting new ticket with key name=" << + SSLUtil::hexlify(key->keyName_); + + // Get a random salt and write out key name + RAND_pseudo_bytes(salt, (int)sizeof(salt)); + memcpy(keyName, key->keyName_.data(), kTLSTicketKeyNameLen); + memcpy(keyName + kTLSTicketKeyNameLen, salt, kTLSTicketKeySaltLen); + + // Create the unique keys by hashing with the salt + makeUniqueKeys(key->keySource_, sizeof(key->keySource_), salt, output); + // This relies on the fact that SHA256 has 32 bytes of output + // and that AES-128 keys are 16 bytes + hmacKey = output; + aesKey = output + SHA256_DIGEST_LENGTH / 2; + + // Initialize iv and cipher/mac CTX + RAND_pseudo_bytes(iv, AES_BLOCK_SIZE); + HMAC_Init_ex(hmacCtx, hmacKey, SHA256_DIGEST_LENGTH / 2, + EVP_sha256(), nullptr); + EVP_EncryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, aesKey, iv); + + result = 1; + } else { + key = findDecryptionKey(keyName); + if (key == nullptr) { + // no ticket found for decryption - will issue a new ticket + if (VLOG_IS_ON(4)) { + string skeyName((char *)keyName, kTLSTicketKeyNameLen); + VLOG(4) << "Can't find ticket key with name=" << + SSLUtil::hexlify(skeyName)<< ", will generate new ticket"; + } + + result = 0; + } else { + VLOG(4) << "Decrypting ticket with key name=" << + SSLUtil::hexlify(key->keyName_); + + // Reconstruct the unique key via the salt + saltptr = keyName + kTLSTicketKeyNameLen; + makeUniqueKeys(key->keySource_, sizeof(key->keySource_), saltptr, output); + hmacKey = output; + aesKey = output + SHA256_DIGEST_LENGTH / 2; + + // Initialize cipher/mac CTX + HMAC_Init_ex(hmacCtx, hmacKey, SHA256_DIGEST_LENGTH / 2, + EVP_sha256(), nullptr); + EVP_DecryptInit_ex(cipherCtx, EVP_aes_128_cbc(), nullptr, aesKey, iv); + + result = 1; + } + } + // result records whether a ticket key was found to decrypt this ticket, + // not wether the session was re-used. + if (stats_) { + stats_->recordTLSTicket(encrypt, result); + } + + return result; +} + +bool +TLSTicketKeyManager::setTLSTicketKeySeeds( + const std::vector& oldSeeds, + const std::vector& currentSeeds, + const std::vector& newSeeds) { + + bool result = true; + + activeKeys_.clear(); + ticketKeys_.clear(); + ticketSeeds_.clear(); + const std::vector *seedList = &oldSeeds; + for (uint32_t i = 0; i < 3; i++) { + TLSTicketSeedType type = (TLSTicketSeedType)i; + if (type == SEED_CURRENT) { + seedList = ¤tSeeds; + } else if (type == SEED_NEW) { + seedList = &newSeeds; + } + + for (const auto& seedInput: *seedList) { + TLSTicketSeed* seed = insertSeed(seedInput, type); + if (seed == nullptr) { + result = false; + continue; + } + insertNewKey(seed, 1, nullptr); + } + } + if (!result) { + VLOG(2) << "One or more seeds failed to decode"; + } + + if (ticketKeys_.size() == 0 || activeKeys_.size() == 0) { + LOG(WARNING) << "No keys configured, falling back to default"; + SSL_CTX_set_tlsext_ticket_key_cb(ctx_->getSSLCtx(), nullptr); + return false; + } + SSL_CTX_set_tlsext_ticket_key_cb(ctx_->getSSLCtx(), + TLSTicketKeyManager::callback); + + return true; +} + +string +TLSTicketKeyManager::makeKeyName(TLSTicketSeed* seed, uint32_t n, + unsigned char* nameBuf) { + SHA256_CTX ctx; + + SHA256_Init(&ctx); + SHA256_Update(&ctx, seed->seedName_, sizeof(seed->seedName_)); + SHA256_Update(&ctx, &n, sizeof(n)); + SHA256_Final(nameBuf, &ctx); + return string((char *)nameBuf, kTLSTicketKeyNameLen); +} + +TLSTicketKeyManager::TLSTicketKeySource* +TLSTicketKeyManager::insertNewKey(TLSTicketSeed* seed, uint32_t hashCount, + TLSTicketKeySource* prevKey) { + unsigned char nameBuf[SHA256_DIGEST_LENGTH]; + std::unique_ptr newKey(new TLSTicketKeySource); + + // This function supports hash chaining but it is not currently used. + + if (prevKey != nullptr) { + hashNth(prevKey->keySource_, sizeof(prevKey->keySource_), + newKey->keySource_, 1); + } else { + // can't go backwards or the current is missing, start from the beginning + hashNth((unsigned char *)seed->seed_.data(), seed->seed_.length(), + newKey->keySource_, hashCount); + } + + newKey->hashCount_ = hashCount; + newKey->keyName_ = makeKeyName(seed, hashCount, nameBuf); + newKey->type_ = seed->type_; + auto it = ticketKeys_.insert(std::make_pair(newKey->keyName_, + std::move(newKey))); + + auto key = it.first->second.get(); + if (key->type_ == SEED_CURRENT) { + activeKeys_.push_back(key); + } + VLOG(4) << "Adding key for " << hashCount << " type=" << + (uint32_t)key->type_ << " Name=" << SSLUtil::hexlify(key->keyName_); + + return key; +} + +void +TLSTicketKeyManager::hashNth(const unsigned char* input, size_t input_len, + unsigned char* output, uint32_t n) { + assert(n > 0); + for (uint32_t i = 0; i < n; i++) { + SHA256(input, input_len, output); + input = output; + input_len = SHA256_DIGEST_LENGTH; + } +} + +TLSTicketKeyManager::TLSTicketSeed * +TLSTicketKeyManager::insertSeed(const string& seedInput, + TLSTicketSeedType type) { + TLSTicketSeed* seed = nullptr; + string seedOutput; + + if (!folly::unhexlify(seedInput, seedOutput)) { + LOG(WARNING) << "Failed to decode seed type=" << (uint32_t)type << + " seed=" << seedInput; + return seed; + } + + seed = new TLSTicketSeed(); + seed->seed_ = seedOutput; + seed->type_ = type; + SHA256((unsigned char *)seedOutput.data(), seedOutput.length(), + seed->seedName_); + ticketSeeds_.push_back(std::unique_ptr(seed)); + + return seed; +} + +TLSTicketKeyManager::TLSTicketKeySource * +TLSTicketKeyManager::findEncryptionKey() { + TLSTicketKeySource* result = nullptr; + // call to rand here is a bit hokey since it's not cryptographically + // random, and is predictably seeded with 0. However, activeKeys_ + // is probably not going to have very many keys in it, and most + // likely only 1. + size_t numKeys = activeKeys_.size(); + if (numKeys > 0) { + result = activeKeys_[rand_r(&randState_) % numKeys]; + } + return result; +} + +TLSTicketKeyManager::TLSTicketKeySource * +TLSTicketKeyManager::findDecryptionKey(unsigned char* keyName) { + string name((char *)keyName, kTLSTicketKeyNameLen); + TLSTicketKeySource* key = nullptr; + TLSTicketKeyMap::iterator mapit = ticketKeys_.find(name); + if (mapit != ticketKeys_.end()) { + key = mapit->second.get(); + } + return key; +} + +void +TLSTicketKeyManager::makeUniqueKeys(unsigned char* parentKey, + size_t keyLen, + unsigned char* salt, + unsigned char* output) { + SHA256_CTX hash_ctx; + + SHA256_Init(&hash_ctx); + SHA256_Update(&hash_ctx, parentKey, keyLen); + SHA256_Update(&hash_ctx, salt, kTLSTicketKeySaltLen); + SHA256_Final(output, &hash_ctx); +} + +} // namespace +#endif diff --git a/folly/wangle/ssl/TLSTicketKeyManager.h b/folly/wangle/ssl/TLSTicketKeyManager.h new file mode 100644 index 00000000..bd22e62b --- /dev/null +++ b/folly/wangle/ssl/TLSTicketKeyManager.h @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +#include +#include + +namespace folly { + +#ifndef SSL_CTRL_SET_TLSEXT_TICKET_KEY_CB +class TLSTicketKeyManager {}; +#else +class SSLStats; +/** + * The TLSTicketKeyManager handles TLS ticket key encryption and decryption in + * a way that facilitates sharing the ticket keys across a range of servers. + * Hash chaining is employed to achieve frequent key rotation with minimal + * configuration change. The scheme is as follows: + * + * The manager is supplied with three lists of seeds (old, current and new). + * The config should be updated with new seeds periodically (e.g., daily). + * 3 config changes are recommended to achieve the smoothest seed rotation + * eg: + * 1. Introduce new seed in the push prior to rotation + * 2. Rotation push + * 3. Remove old seeds in the push following rotation + * + * Multiple seeds are supported but only a single seed is required. + * + * Generating encryption keys from the seed works as follows. For a given + * seed, hash forward N times where N is currently the constant 1. + * This is the base key. The name of the base key is the first 4 + * bytes of hash(hash(seed), N). This is copied into the first 4 bytes of the + * TLS ticket key name field. + * + * For each new ticket encryption, the manager generates a random 12 byte salt. + * Hash the salt and the base key together to form the encryption key for + * that ticket. The salt is included in the ticket's 'key name' field so it + * can be used to derive the decryption key. The salt is copied into the second + * 8 bytes of the TLS ticket key name field. + * + * A key is valid for decryption for the lifetime of the instance. + * Sessions will be valid for less time than that, which results in an extra + * symmetric decryption to discover the session is expired. + * + * A TLSTicketKeyManager should be used in only one thread, and should have + * a 1:1 relationship with the SSLContext provided. + * + */ +class TLSTicketKeyManager : private boost::noncopyable { + public: + + explicit TLSTicketKeyManager(folly::SSLContext* ctx, + SSLStats* stats); + + virtual ~TLSTicketKeyManager(); + + /** + * SSL callback to set up encryption/decryption context for a TLS Ticket Key. + * + * This will be supplied to the SSL library via + * SSL_CTX_set_tlsext_ticket_key_cb. + */ + static int callback(SSL* ssl, unsigned char* keyName, + unsigned char* iv, + EVP_CIPHER_CTX* cipherCtx, + HMAC_CTX* hmacCtx, int encrypt); + + /** + * Initialize the manager with three sets of seeds. There must be at least + * one current seed, or the manager will revert to the default SSL behavior. + * + * @param oldSeeds Seeds previously used which can still decrypt. + * @param currentSeeds Seeds to use for new ticket encryptions. + * @param newSeeds Seeds which will be used soon, can be used to decrypt + * in case some servers in the cluster have already rotated. + */ + bool setTLSTicketKeySeeds(const std::vector& oldSeeds, + const std::vector& currentSeeds, + const std::vector& newSeeds); + + private: + enum TLSTicketSeedType { + SEED_OLD = 0, + SEED_CURRENT, + SEED_NEW + }; + + /* The seeds supplied by the configuration */ + struct TLSTicketSeed { + std::string seed_; + TLSTicketSeedType type_; + unsigned char seedName_[SHA256_DIGEST_LENGTH]; + }; + + struct TLSTicketKeySource { + int32_t hashCount_; + std::string keyName_; + TLSTicketSeedType type_; + unsigned char keySource_[SHA256_DIGEST_LENGTH]; + }; + + /** + * Method to setup encryption/decryption context for a TLS Ticket Key + * + * OpenSSL documentation is thin on the return value semantics. + * + * For encrypt=1, return < 0 on error, >= 0 for successfully initialized + * For encrypt=0, return < 0 on error, 0 on key not found + * 1 on key found, 2 renew_ticket + * + * renew_ticket means a new ticket will be issued. We could return this value + * when receiving a ticket encrypted with a key derived from an OLD seed. + * However, session_timeout seconds after deploying with a seed + * rotated from CURRENT -> OLD, there will be no valid tickets outstanding + * encrypted with the old key. This grace period means no unnecessary + * handshakes will be performed. If the seed is believed compromised, it + * should NOT be configured as an OLD seed. + */ + int processTicket(SSL* ssl, unsigned char* keyName, + unsigned char* iv, + EVP_CIPHER_CTX* cipherCtx, + HMAC_CTX* hmacCtx, int encrypt); + + // Creates the name for the nth key generated from seed + std::string makeKeyName(TLSTicketSeed* seed, uint32_t n, + unsigned char* nameBuf); + + /** + * Creates the key hashCount hashes from the given seed and inserts it in + * ticketKeys. A naked pointer to the key is returned for additional + * processing if needed. + */ + TLSTicketKeySource* insertNewKey(TLSTicketSeed* seed, uint32_t hashCount, + TLSTicketKeySource* prevKeySource); + + /** + * hashes input N times placing result in output, which must be at least + * SHA256_DIGEST_LENGTH long. + */ + void hashNth(const unsigned char* input, size_t input_len, + unsigned char* output, uint32_t n); + + /** + * Adds the given seed to the manager + */ + TLSTicketSeed* insertSeed(const std::string& seedInput, + TLSTicketSeedType type); + + /** + * Locate a key for encrypting a new ticket + */ + TLSTicketKeySource* findEncryptionKey(); + + /** + * Locate a key for decrypting a ticket with the given keyName + */ + TLSTicketKeySource* findDecryptionKey(unsigned char* keyName); + + /** + * Derive a unique key from the parent key and the salt via hashing + */ + void makeUniqueKeys(unsigned char* parentKey, size_t keyLen, + unsigned char* salt, unsigned char* output); + + /** + * For standalone decryption utility + */ + friend int decrypt_fb_ticket(folly::TLSTicketKeyManager* manager, + const std::string& testTicket, + SSL_SESSION **psess); + + typedef std::vector> TLSTicketSeedList; + typedef std::map > + TLSTicketKeyMap; + typedef std::vector TLSActiveKeyList; + + TLSTicketSeedList ticketSeeds_; + // All key sources that can be used for decryption + TLSTicketKeyMap ticketKeys_; + // Key sources that can be used for encryption + TLSActiveKeyList activeKeys_; + + folly::SSLContext* ctx_; + uint32_t randState_; + SSLStats* stats_{nullptr}; + + static int32_t sExDataIndex_; +}; +#endif +} diff --git a/folly/wangle/ssl/TLSTicketKeySeeds.h b/folly/wangle/ssl/TLSTicketKeySeeds.h new file mode 100644 index 00000000..b74ec52e --- /dev/null +++ b/folly/wangle/ssl/TLSTicketKeySeeds.h @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#pragma once + +namespace folly { + +struct TLSTicketKeySeeds { + std::vector oldSeeds; + std::vector currentSeeds; + std::vector newSeeds; +}; + +} diff --git a/folly/wangle/ssl/test/SSLCacheTest.cpp b/folly/wangle/ssl/test/SSLCacheTest.cpp new file mode 100644 index 00000000..f3129e48 --- /dev/null +++ b/folly/wangle/ssl/test/SSLCacheTest.cpp @@ -0,0 +1,272 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace folly; + +DEFINE_int32(clients, 1, "Number of simulated SSL clients"); +DEFINE_int32(threads, 1, "Number of threads to spread clients across"); +DEFINE_int32(requests, 2, "Total number of requests per client"); +DEFINE_int32(port, 9423, "Server port"); +DEFINE_bool(sticky, false, "A given client sends all reqs to one " + "(random) server"); +DEFINE_bool(global, false, "All clients in a thread use the same SSL session"); +DEFINE_bool(handshakes, false, "Force 100% handshakes"); + +string f_servers[10]; +int f_num_servers = 0; +int tnum = 0; + +class ClientRunner { + public: + + ClientRunner(): reqs(0), hits(0), miss(0), num(tnum++) {} + void run(); + + int reqs; + int hits; + int miss; + int num; +}; + +class SSLCacheClient : public AsyncSocket::ConnectCallback, + public AsyncSSLSocket::HandshakeCB +{ +private: + EventBase* eventBase_; + int currReq_; + int serverIdx_; + AsyncSocket* socket_; + AsyncSSLSocket* sslSocket_; + SSL_SESSION* session_; + SSL_SESSION **pSess_; + std::shared_ptr ctx_; + ClientRunner* cr_; + +public: + SSLCacheClient(EventBase* eventBase, SSL_SESSION **pSess, ClientRunner* cr); + ~SSLCacheClient() { + if (session_ && !FLAGS_global) + SSL_SESSION_free(session_); + if (socket_ != nullptr) { + if (sslSocket_ != nullptr) { + sslSocket_->destroy(); + sslSocket_ = nullptr; + } + socket_->destroy(); + socket_ = nullptr; + } + }; + + void start(); + + virtual void connectSuccess() noexcept; + + virtual void connectErr(const AsyncSocketException& ex) + noexcept ; + + virtual void handshakeSuc(AsyncSSLSocket* sock) noexcept; + + virtual void handshakeErr( + AsyncSSLSocket* sock, + const AsyncSocketException& ex) noexcept; + +}; + +int +main(int argc, char* argv[]) +{ + gflags::SetUsageMessage(std::string("\n\n" +"usage: sslcachetest [options] -c -t servers\n" +)); + gflags::ParseCommandLineFlags(&argc, &argv, true); + int reqs = 0; + int hits = 0; + int miss = 0; + struct timeval start; + struct timeval end; + struct timeval result; + + srand((unsigned int)time(nullptr)); + + for (int i = 1; i < argc; i++) { + f_servers[f_num_servers++] = argv[i]; + } + if (f_num_servers == 0) { + cout << "require at least one server\n"; + return 1; + } + + gettimeofday(&start, nullptr); + if (FLAGS_threads == 1) { + ClientRunner r; + r.run(); + gettimeofday(&end, nullptr); + reqs = r.reqs; + hits = r.hits; + miss = r.miss; + } + else { + std::vector clients; + std::vector threads; + for (int t = 0; t < FLAGS_threads; t++) { + threads.emplace_back([&] { + clients[t].run(); + }); + } + for (auto& thr: threads) { + thr.join(); + } + gettimeofday(&end, nullptr); + + for (const auto& client: clients) { + reqs += client.reqs; + hits += client.hits; + miss += client.miss; + } + } + + timersub(&end, &start, &result); + + cout << "Requests: " << reqs << endl; + cout << "Handshakes: " << miss << endl; + cout << "Resumes: " << hits << endl; + cout << "Runtime(ms): " << result.tv_sec << "." << result.tv_usec / 1000 << + endl; + + cout << "ops/sec: " << (reqs * 1.0) / + ((double)result.tv_sec * 1.0 + (double)result.tv_usec / 1000000.0) << endl; + + return 0; +} + +void +ClientRunner::run() +{ + EventBase eb; + std::list clients; + SSL_SESSION* session = nullptr; + + for (int i = 0; i < FLAGS_clients; i++) { + SSLCacheClient* c = new SSLCacheClient(&eb, &session, this); + c->start(); + clients.push_back(c); + } + + eb.loop(); + + for (auto it = clients.begin(); it != clients.end(); it++) { + delete* it; + } + + reqs += hits + miss; +} + +SSLCacheClient::SSLCacheClient(EventBase* eb, + SSL_SESSION **pSess, + ClientRunner* cr) + : eventBase_(eb), + currReq_(0), + serverIdx_(0), + socket_(nullptr), + sslSocket_(nullptr), + session_(nullptr), + pSess_(pSess), + cr_(cr) +{ + ctx_.reset(new SSLContext()); + ctx_->setOptions(SSL_OP_NO_TICKET); +} + +void +SSLCacheClient::start() +{ + if (currReq_ >= FLAGS_requests) { + cout << "+"; + return; + } + + if (currReq_ == 0 || !FLAGS_sticky) { + serverIdx_ = rand() % f_num_servers; + } + if (socket_ != nullptr) { + if (sslSocket_ != nullptr) { + sslSocket_->destroy(); + sslSocket_ = nullptr; + } + socket_->destroy(); + socket_ = nullptr; + } + socket_ = new AsyncSocket(eventBase_); + socket_->connect(this, f_servers[serverIdx_], (uint16_t)FLAGS_port); +} + +void +SSLCacheClient::connectSuccess() noexcept +{ + sslSocket_ = new AsyncSSLSocket(ctx_, eventBase_, socket_->detachFd(), + false); + + if (!FLAGS_handshakes) { + if (session_ != nullptr) + sslSocket_->setSSLSession(session_); + else if (FLAGS_global && pSess_ != nullptr) + sslSocket_->setSSLSession(*pSess_); + } + sslSocket_->sslConn(this); +} + +void +SSLCacheClient::connectErr(const AsyncSocketException& ex) + noexcept +{ + cout << "connectError: " << ex.what() << endl; +} + +void +SSLCacheClient::handshakeSuc(AsyncSSLSocket* socket) noexcept +{ + if (sslSocket_->getSSLSessionReused()) { + cr_->hits++; + } else { + cr_->miss++; + if (session_ != nullptr) { + SSL_SESSION_free(session_); + } + session_ = sslSocket_->getSSLSession(); + if (FLAGS_global && pSess_ != nullptr && *pSess_ == nullptr) { + *pSess_ = session_; + } + } + if ( ((cr_->hits + cr_->miss) % 100) == ((100 / FLAGS_threads) * cr_->num)) { + cout << "."; + cout.flush(); + } + sslSocket_->closeNow(); + currReq_++; + this->start(); +} + +void +SSLCacheClient::handshakeErr( + AsyncSSLSocket* sock, + const AsyncSocketException& ex) + noexcept +{ + cout << "handshakeError: " << ex.what() << endl; +} diff --git a/folly/wangle/ssl/test/SSLContextManagerTest.cpp b/folly/wangle/ssl/test/SSLContextManagerTest.cpp new file mode 100644 index 00000000..ab489397 --- /dev/null +++ b/folly/wangle/ssl/test/SSLContextManagerTest.cpp @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2015, Facebook, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + * + */ +#include +#include +#include +#include +#include +#include + +using std::shared_ptr; + +namespace folly { + +TEST(SSLContextManagerTest, Test1) +{ + EventBase eventBase; + SSLContextManager sslCtxMgr(&eventBase, "vip_ssl_context_manager_test_", + true, nullptr); + auto www_facebook_com_ctx = std::make_shared(); + auto start_facebook_com_ctx = std::make_shared(); + auto start_abc_facebook_com_ctx = std::make_shared(); + + sslCtxMgr.insertSSLCtxByDomainName( + "www.facebook.com", + strlen("www.facebook.com"), + www_facebook_com_ctx); + sslCtxMgr.insertSSLCtxByDomainName( + "www.facebook.com", + strlen("www.facebook.com"), + www_facebook_com_ctx); + try { + sslCtxMgr.insertSSLCtxByDomainName( + "www.facebook.com", + strlen("www.facebook.com"), + std::make_shared()); + } catch (const std::exception& ex) { + } + sslCtxMgr.insertSSLCtxByDomainName( + "*.facebook.com", + strlen("*.facebook.com"), + start_facebook_com_ctx); + sslCtxMgr.insertSSLCtxByDomainName( + "*.abc.facebook.com", + strlen("*.abc.facebook.com"), + start_abc_facebook_com_ctx); + try { + sslCtxMgr.insertSSLCtxByDomainName( + "*.abc.facebook.com", + strlen("*.abc.facebook.com"), + std::make_shared()); + FAIL(); + } catch (const std::exception& ex) { + } + + shared_ptr retCtx; + retCtx = sslCtxMgr.getSSLCtx(DNString("www.facebook.com")); + EXPECT_EQ(retCtx, www_facebook_com_ctx); + retCtx = sslCtxMgr.getSSLCtx(DNString("WWW.facebook.com")); + EXPECT_EQ(retCtx, www_facebook_com_ctx); + EXPECT_FALSE(sslCtxMgr.getSSLCtx(DNString("xyz.facebook.com"))); + + retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("xyz.facebook.com")); + EXPECT_EQ(retCtx, start_facebook_com_ctx); + retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("XYZ.facebook.com")); + EXPECT_EQ(retCtx, start_facebook_com_ctx); + + retCtx = sslCtxMgr.getSSLCtxBySuffix(DNString("www.abc.facebook.com")); + EXPECT_EQ(retCtx, start_abc_facebook_com_ctx); + + // ensure "facebook.com" does not match "*.facebook.com" + EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("facebook.com"))); + // ensure "Xfacebook.com" does not match "*.facebook.com" + EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("Xfacebook.com"))); + // ensure wildcard name only matches one domain up + EXPECT_FALSE(sslCtxMgr.getSSLCtxBySuffix(DNString("abc.xyz.facebook.com"))); + + eventBase.loop(); // Clean up events before SSLContextManager is destructed +} + +} -- 2.34.1