Remove a RequestContext deadlock
[folly.git] / folly / io / async / Request.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed to the Apache Software Foundation (ASF) under one
5  * or more contributor license agreements. See the NOTICE file
6  * distributed with this work for additional information
7  * regarding copyright ownership. The ASF licenses this file
8  * to you under the Apache License, Version 2.0 (the
9  * "License"); you may not use this file except in compliance
10  * with the License. You may obtain a copy of the License at
11  *
12  *   http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17  * KIND, either express or implied. See the License for the
18  * specific language governing permissions and limitations
19  * under the License.
20  */
21
22 #include <folly/io/async/Request.h>
23 #include <folly/tracing/StaticTracepoint.h>
24
25 #include <glog/logging.h>
26
27 #include <folly/MapUtil.h>
28 #include <folly/SingletonThreadLocal.h>
29
30 namespace folly {
31
32 void RequestContext::setContextData(
33     const std::string& val,
34     std::unique_ptr<RequestData> data) {
35   auto wlock = data_.wlock();
36   if (wlock->count(val)) {
37     LOG_FIRST_N(WARNING, 1)
38         << "Called RequestContext::setContextData with data already set";
39
40     (*wlock)[val] = nullptr;
41   } else {
42     (*wlock)[val] = std::move(data);
43   }
44 }
45
46 bool RequestContext::setContextDataIfAbsent(
47     const std::string& val,
48     std::unique_ptr<RequestData> data) {
49   auto ulock = data_.ulock();
50   if (ulock->count(val)) {
51     return false;
52   }
53
54   auto wlock = ulock.moveFromUpgradeToWrite();
55   (*wlock)[val] = std::move(data);
56   return true;
57 }
58
59 bool RequestContext::hasContextData(const std::string& val) const {
60   return data_.rlock()->count(val);
61 }
62
63 RequestData* RequestContext::getContextData(const std::string& val) {
64   return get_ref_default(*data_.rlock(), val, nullptr).get();
65 }
66
67 const RequestData* RequestContext::getContextData(
68     const std::string& val) const {
69   return get_ref_default(*data_.rlock(), val, nullptr).get();
70 }
71
72 void RequestContext::onSet() {
73   auto rlock = data_.rlock();
74   for (auto const& ent : *rlock) {
75     if (auto& data = ent.second) {
76       data->onSet();
77     }
78   }
79 }
80
81 void RequestContext::onUnset() {
82   auto rlock = data_.rlock();
83   for (auto const& ent : *rlock) {
84     if (auto& data = ent.second) {
85       data->onUnset();
86     }
87   }
88 }
89
90 void RequestContext::clearContextData(const std::string& val) {
91   std::unique_ptr<RequestData> requestData;
92   // Delete the RequestData after giving up the wlock just in case one of the
93   // RequestData destructors will try to grab the lock again.
94   {
95     auto wlock = data_.wlock();
96     auto it = wlock->find(val);
97     if (it != wlock->end()) {
98       requestData = std::move(it->second);
99       wlock->erase(it);
100     }
101   }
102 }
103
104 std::shared_ptr<RequestContext> RequestContext::setContext(
105     std::shared_ptr<RequestContext> ctx) {
106   auto& curCtx = getStaticContext();
107   if (ctx != curCtx) {
108     FOLLY_SDT(folly, request_context_switch_before, curCtx.get(), ctx.get());
109     using std::swap;
110     if (curCtx) {
111       curCtx->onUnset();
112     }
113     swap(ctx, curCtx);
114     if (curCtx) {
115       curCtx->onSet();
116     }
117   }
118   return ctx;
119 }
120
121 std::shared_ptr<RequestContext>& RequestContext::getStaticContext() {
122   using SingletonT = SingletonThreadLocal<std::shared_ptr<RequestContext>>;
123   static SingletonT singleton;
124
125   return singleton.get();
126 }
127
128 RequestContext* RequestContext::get() {
129   auto context = getStaticContext();
130   if (!context) {
131     static RequestContext defaultContext;
132     return std::addressof(defaultContext);
133   }
134   return context.get();
135 }
136 }