copy wangle back into folly
[folly.git] / folly / wangle / ssl / SSLSessionCacheManager.cpp
1 /*
2  *  Copyright (c) 2015, Facebook, Inc.
3  *  All rights reserved.
4  *
5  *  This source code is licensed under the BSD-style license found in the
6  *  LICENSE file in the root directory of this source tree. An additional grant
7  *  of patent rights can be found in the PATENTS file in the same directory.
8  *
9  */
10 #include <folly/wangle/ssl/SSLSessionCacheManager.h>
11
12 #include <folly/wangle/ssl/SSLCacheProvider.h>
13 #include <folly/wangle/ssl/SSLStats.h>
14 #include <folly/wangle/ssl/SSLUtil.h>
15
16 #include <folly/io/async/EventBase.h>
17
18 #ifndef NO_LIB_GFLAGS
19 #include <gflags/gflags.h>
20 #endif
21
22 using std::string;
23 using std::shared_ptr;
24
25 namespace {
26
27 const uint32_t NUM_CACHE_BUCKETS = 16;
28
29 // We use the default ID generator which fills the maximum ID length
30 // for the protocol.  16 bytes for SSLv2 or 32 for SSLv3+
31 const int MIN_SESSION_ID_LENGTH = 16;
32
33 }
34
35 #ifndef NO_LIB_GFLAGS
36 DEFINE_bool(dcache_unit_test, false, "All VIPs share one session cache");
37 #else
38 const bool FLAGS_dcache_unit_test = false;
39 #endif
40
41 namespace folly {
42
43
44 int SSLSessionCacheManager::sExDataIndex_ = -1;
45 shared_ptr<ShardedLocalSSLSessionCache> SSLSessionCacheManager::sCache_;
46 std::mutex SSLSessionCacheManager::sCacheLock_;
47
48 LocalSSLSessionCache::LocalSSLSessionCache(uint32_t maxCacheSize,
49                                            uint32_t cacheCullSize)
50     : sessionCache(maxCacheSize, cacheCullSize) {
51   sessionCache.setPruneHook(std::bind(
52                               &LocalSSLSessionCache::pruneSessionCallback,
53                               this, std::placeholders::_1,
54                               std::placeholders::_2));
55 }
56
57 void LocalSSLSessionCache::pruneSessionCallback(const string& sessionId,
58                                                 SSL_SESSION* session) {
59   VLOG(4) << "Free SSL session from local cache; id="
60           << SSLUtil::hexlify(sessionId);
61   SSL_SESSION_free(session);
62   ++removedSessions_;
63 }
64
65
66 // SSLSessionCacheManager implementation
67
68 SSLSessionCacheManager::SSLSessionCacheManager(
69   uint32_t maxCacheSize,
70   uint32_t cacheCullSize,
71   SSLContext* ctx,
72   const folly::SocketAddress& sockaddr,
73   const string& context,
74   EventBase* eventBase,
75   SSLStats* stats,
76   const std::shared_ptr<SSLCacheProvider>& externalCache):
77     ctx_(ctx),
78     stats_(stats),
79     externalCache_(externalCache) {
80
81   SSL_CTX* sslCtx = ctx->getSSLCtx();
82
83   SSLUtil::getSSLCtxExIndex(&sExDataIndex_);
84
85   SSL_CTX_set_ex_data(sslCtx, sExDataIndex_, this);
86   SSL_CTX_sess_set_new_cb(sslCtx, SSLSessionCacheManager::newSessionCallback);
87   SSL_CTX_sess_set_get_cb(sslCtx, SSLSessionCacheManager::getSessionCallback);
88   SSL_CTX_sess_set_remove_cb(sslCtx,
89                              SSLSessionCacheManager::removeSessionCallback);
90   if (!FLAGS_dcache_unit_test && !context.empty()) {
91     // Use the passed in context
92     SSL_CTX_set_session_id_context(sslCtx, (const uint8_t *)context.data(),
93                                    std::min((int)context.length(),
94                                             SSL_MAX_SSL_SESSION_ID_LENGTH));
95   }
96
97   SSL_CTX_set_session_cache_mode(sslCtx, SSL_SESS_CACHE_NO_INTERNAL
98                                  | SSL_SESS_CACHE_SERVER);
99
100   localCache_ = SSLSessionCacheManager::getLocalCache(maxCacheSize,
101                                                       cacheCullSize);
102
103   VLOG(2) << "On VipID=" << sockaddr.describe() << " context=" << context;
104 }
105
106 SSLSessionCacheManager::~SSLSessionCacheManager() {
107 }
108
109 void SSLSessionCacheManager::shutdown() {
110   std::lock_guard<std::mutex> g(sCacheLock_);
111   sCache_.reset();
112 }
113
114 shared_ptr<ShardedLocalSSLSessionCache> SSLSessionCacheManager::getLocalCache(
115   uint32_t maxCacheSize,
116   uint32_t cacheCullSize) {
117
118   std::lock_guard<std::mutex> g(sCacheLock_);
119   if (!sCache_) {
120     sCache_.reset(new ShardedLocalSSLSessionCache(NUM_CACHE_BUCKETS,
121                                                   maxCacheSize,
122                                                   cacheCullSize));
123   }
124   return sCache_;
125 }
126
127 int SSLSessionCacheManager::newSessionCallback(SSL* ssl, SSL_SESSION* session) {
128   SSLSessionCacheManager* manager = nullptr;
129   SSL_CTX* ctx = SSL_get_SSL_CTX(ssl);
130   manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
131
132   if (manager == nullptr) {
133     LOG(FATAL) << "Null SSLSessionCacheManager in callback";
134     return -1;
135   }
136   return manager->newSession(ssl, session);
137 }
138
139
140 int SSLSessionCacheManager::newSession(SSL* ssl, SSL_SESSION* session) {
141   string sessionId((char*)session->session_id, session->session_id_length);
142   VLOG(4) << "New SSL session; id=" << SSLUtil::hexlify(sessionId);
143
144   if (stats_) {
145     stats_->recordSSLSession(true /* new session */, false, false);
146   }
147
148   localCache_->storeSession(sessionId, session, stats_);
149
150   if (externalCache_) {
151     VLOG(4) << "New SSL session: send session to external cache; id=" <<
152       SSLUtil::hexlify(sessionId);
153     storeCacheRecord(sessionId, session);
154   }
155
156   return 1;
157 }
158
159 void SSLSessionCacheManager::removeSessionCallback(SSL_CTX* ctx,
160                                                    SSL_SESSION* session) {
161   SSLSessionCacheManager* manager = nullptr;
162   manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
163
164   if (manager == nullptr) {
165     LOG(FATAL) << "Null SSLSessionCacheManager in callback";
166     return;
167   }
168   return manager->removeSession(ctx, session);
169 }
170
171 void SSLSessionCacheManager::removeSession(SSL_CTX* ctx,
172                                            SSL_SESSION* session) {
173   string sessionId((char*)session->session_id, session->session_id_length);
174
175   // This hook is only called from SSL when the internal session cache needs to
176   // flush sessions.  Since we run with the internal cache disabled, this should
177   // never be called
178   VLOG(3) << "Remove SSL session; id=" << SSLUtil::hexlify(sessionId);
179
180   localCache_->removeSession(sessionId);
181
182   if (stats_) {
183     stats_->recordSSLSessionRemove();
184   }
185 }
186
187 SSL_SESSION* SSLSessionCacheManager::getSessionCallback(SSL* ssl,
188                                                         unsigned char* sess_id,
189                                                         int id_len,
190                                                         int* copyflag) {
191   SSLSessionCacheManager* manager = nullptr;
192   SSL_CTX* ctx = SSL_get_SSL_CTX(ssl);
193   manager = (SSLSessionCacheManager *)SSL_CTX_get_ex_data(ctx, sExDataIndex_);
194
195   if (manager == nullptr) {
196     LOG(FATAL) << "Null SSLSessionCacheManager in callback";
197     return nullptr;
198   }
199   return manager->getSession(ssl, sess_id, id_len, copyflag);
200 }
201
202 SSL_SESSION* SSLSessionCacheManager::getSession(SSL* ssl,
203                                                 unsigned char* session_id,
204                                                 int id_len,
205                                                 int* copyflag) {
206   VLOG(7) << "SSL get session callback";
207   SSL_SESSION* session = nullptr;
208   bool foreign = false;
209   char const* missReason = nullptr;
210
211   if (id_len < MIN_SESSION_ID_LENGTH) {
212     // We didn't generate this session so it's going to be a miss.
213     // This doesn't get logged or counted in the stats.
214     return nullptr;
215   }
216   string sessionId((char*)session_id, id_len);
217
218   AsyncSSLSocket* sslSocket = AsyncSSLSocket::getFromSSL(ssl);
219
220   assert(sslSocket != nullptr);
221
222   // look it up in the local cache first
223   session = localCache_->lookupSession(sessionId);
224 #ifdef SSL_SESSION_CB_WOULD_BLOCK
225   if (session == nullptr && externalCache_) {
226     // external cache might have the session
227     foreign = true;
228     if (!SSL_want_sess_cache_lookup(ssl)) {
229       missReason = "reason: No async cache support;";
230     } else {
231       PendingLookupMap::iterator pit = pendingLookups_.find(sessionId);
232       if (pit == pendingLookups_.end()) {
233         auto result = pendingLookups_.emplace(sessionId, PendingLookup());
234         // initiate fetch
235         VLOG(4) << "Get SSL session [Pending]: Initiate Fetch; fd=" <<
236           sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId);
237         if (lookupCacheRecord(sessionId, sslSocket)) {
238           // response is pending
239           *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
240           return nullptr;
241         } else {
242           missReason = "reason: failed to send lookup request;";
243           pendingLookups_.erase(result.first);
244         }
245       } else {
246         // A lookup was already initiated from this thread
247         if (pit->second.request_in_progress) {
248           // Someone else initiated the request, attach
249           VLOG(4) << "Get SSL session [Pending]: Request in progess: attach; "
250             "fd=" << sslSocket->getFd() << " id=" <<
251             SSLUtil::hexlify(sessionId);
252           std::unique_ptr<DelayedDestruction::DestructorGuard> dg(
253             new DelayedDestruction::DestructorGuard(sslSocket));
254           pit->second.waiters.push_back(
255             std::make_pair(sslSocket, std::move(dg)));
256           *copyflag = SSL_SESSION_CB_WOULD_BLOCK;
257           return nullptr;
258         }
259         // request is complete
260         session = pit->second.session; // nullptr if our friend didn't have it
261         if (session != nullptr) {
262           CRYPTO_add(&session->references, 1, CRYPTO_LOCK_SSL_SESSION);
263         }
264       }
265     }
266   }
267 #endif
268
269   bool hit = (session != nullptr);
270   if (stats_) {
271     stats_->recordSSLSession(false, hit, foreign);
272   }
273   if (hit) {
274     sslSocket->setSessionIDResumed(true);
275   }
276
277   VLOG(4) << "Get SSL session [" <<
278     ((hit) ? "Hit" : "Miss") << "]: " <<
279     ((foreign) ? "external" : "local") << " cache; " <<
280     ((missReason != nullptr) ? missReason : "") << "fd=" <<
281     sslSocket->getFd() << " id=" << SSLUtil::hexlify(sessionId);
282
283   // We already bumped the refcount
284   *copyflag = 0;
285
286   return session;
287 }
288
289 bool SSLSessionCacheManager::storeCacheRecord(const string& sessionId,
290                                               SSL_SESSION* session) {
291   std::string sessionString;
292   uint32_t sessionLen = i2d_SSL_SESSION(session, nullptr);
293   sessionString.resize(sessionLen);
294   uint8_t* cp = (uint8_t *)sessionString.data();
295   i2d_SSL_SESSION(session, &cp);
296   size_t expiration = SSL_CTX_get_timeout(ctx_->getSSLCtx());
297   return externalCache_->setAsync(sessionId, sessionString,
298                                   std::chrono::seconds(expiration));
299 }
300
301 bool SSLSessionCacheManager::lookupCacheRecord(const string& sessionId,
302                                                AsyncSSLSocket* sslSocket) {
303   auto cacheCtx = new SSLCacheProvider::CacheContext();
304   cacheCtx->sessionId = sessionId;
305   cacheCtx->session = nullptr;
306   cacheCtx->sslSocket = sslSocket;
307   cacheCtx->guard.reset(
308       new DelayedDestruction::DestructorGuard(cacheCtx->sslSocket));
309   cacheCtx->manager = this;
310   bool res = externalCache_->getAsync(sessionId, cacheCtx);
311   if (!res) {
312     delete cacheCtx;
313   }
314   return res;
315 }
316
317 void SSLSessionCacheManager::restartSSLAccept(
318     const SSLCacheProvider::CacheContext* cacheCtx) {
319   PendingLookupMap::iterator pit = pendingLookups_.find(cacheCtx->sessionId);
320   CHECK(pit != pendingLookups_.end());
321   pit->second.request_in_progress = false;
322   pit->second.session = cacheCtx->session;
323   VLOG(7) << "Restart SSL accept";
324   cacheCtx->sslSocket->restartSSLAccept();
325   for (const auto& attachedLookup: pit->second.waiters) {
326     // Wake up anyone else who was waiting for this session
327     VLOG(4) << "Restart SSL accept (waiters) for fd=" <<
328       attachedLookup.first->getFd();
329     attachedLookup.first->restartSSLAccept();
330   }
331   pendingLookups_.erase(pit);
332 }
333
334 void SSLSessionCacheManager::onGetSuccess(
335     SSLCacheProvider::CacheContext* cacheCtx,
336     const std::string& value) {
337   const uint8_t* cp = (uint8_t*)value.data();
338   cacheCtx->session = d2i_SSL_SESSION(nullptr, &cp, value.length());
339   restartSSLAccept(cacheCtx);
340
341   /* Insert in the LRU after restarting all clients.  The stats logic
342    * in getSession would treat this as a local hit otherwise.
343    */
344   localCache_->storeSession(cacheCtx->sessionId, cacheCtx->session, stats_);
345   delete cacheCtx;
346 }
347
348 void SSLSessionCacheManager::onGetFailure(
349     SSLCacheProvider::CacheContext* cacheCtx) {
350   restartSSLAccept(cacheCtx);
351   delete cacheCtx;
352 }
353
354 } // namespace