Make RequestContext provider overridable in order to save cost of setContext() on...
[folly.git] / folly / io / async / Request.cpp
1 /*
2  * Copyright 2004-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/io/async/Request.h>
17
18 #include <algorithm>
19 #include <stdexcept>
20 #include <utility>
21
22 #include <glog/logging.h>
23
24 #include <folly/MapUtil.h>
25 #include <folly/SingletonThreadLocal.h>
26 #include <folly/tracing/StaticTracepoint.h>
27
28 namespace folly {
29
30 void RequestContext::setContextData(
31     const std::string& val,
32     std::unique_ptr<RequestData> data) {
33   auto wlock = data_.wlock();
34   if (wlock->count(val)) {
35     LOG_FIRST_N(WARNING, 1)
36         << "Called RequestContext::setContextData with data already set";
37
38     (*wlock)[val] = nullptr;
39   } else {
40     (*wlock)[val] = std::move(data);
41   }
42 }
43
44 bool RequestContext::setContextDataIfAbsent(
45     const std::string& val,
46     std::unique_ptr<RequestData> data) {
47   auto ulock = data_.ulock();
48   if (ulock->count(val)) {
49     return false;
50   }
51
52   auto wlock = ulock.moveFromUpgradeToWrite();
53   (*wlock)[val] = std::move(data);
54   return true;
55 }
56
57 bool RequestContext::hasContextData(const std::string& val) const {
58   return data_.rlock()->count(val);
59 }
60
61 RequestData* RequestContext::getContextData(const std::string& val) {
62   const std::unique_ptr<RequestData> dflt{nullptr};
63   return get_ref_default(*data_.rlock(), val, dflt).get();
64 }
65
66 const RequestData* RequestContext::getContextData(
67     const std::string& val) const {
68   const std::unique_ptr<RequestData> dflt{nullptr};
69   return get_ref_default(*data_.rlock(), val, dflt).get();
70 }
71
72 void RequestContext::onSet() {
73   auto rlock = data_.rlock();
74   for (auto const& ent : *rlock) {
75     if (auto& data = ent.second) {
76       data->onSet();
77     }
78   }
79 }
80
81 void RequestContext::onUnset() {
82   auto rlock = data_.rlock();
83   for (auto const& ent : *rlock) {
84     if (auto& data = ent.second) {
85       data->onUnset();
86     }
87   }
88 }
89
90 void RequestContext::clearContextData(const std::string& val) {
91   std::unique_ptr<RequestData> requestData;
92   // Delete the RequestData after giving up the wlock just in case one of the
93   // RequestData destructors will try to grab the lock again.
94   {
95     auto wlock = data_.wlock();
96     auto it = wlock->find(val);
97     if (it != wlock->end()) {
98       requestData = std::move(it->second);
99       wlock->erase(it);
100     }
101   }
102 }
103
104 std::shared_ptr<RequestContext> RequestContext::setContext(
105     std::shared_ptr<RequestContext> ctx) {
106   auto& curCtx = getStaticContext();
107   if (ctx != curCtx) {
108     FOLLY_SDT(folly, request_context_switch_before, curCtx.get(), ctx.get());
109     using std::swap;
110     if (curCtx) {
111       curCtx->onUnset();
112     }
113     swap(ctx, curCtx);
114     if (curCtx) {
115       curCtx->onSet();
116     }
117   }
118   return ctx;
119 }
120
121 RequestContext::Provider& RequestContext::requestContextProvider() {
122   class DefaultProvider {
123    public:
124     constexpr DefaultProvider() = default;
125     DefaultProvider(const DefaultProvider&) = delete;
126     DefaultProvider& operator=(const DefaultProvider&) = delete;
127     DefaultProvider(DefaultProvider&&) = default;
128     DefaultProvider& operator=(DefaultProvider&&) = default;
129
130     std::shared_ptr<RequestContext>& operator()() {
131       return context;
132     }
133
134    private:
135     std::shared_ptr<RequestContext> context;
136   };
137
138   static SingletonThreadLocal<Provider> providerSingleton(
139       []() { return new Provider(DefaultProvider()); });
140   return providerSingleton.get();
141 }
142
143 std::shared_ptr<RequestContext>& RequestContext::getStaticContext() {
144   auto& provider = requestContextProvider();
145   return provider();
146 }
147
148 RequestContext* RequestContext::get() {
149   auto& context = getStaticContext();
150   if (!context) {
151     static RequestContext defaultContext;
152     return std::addressof(defaultContext);
153   }
154   return context.get();
155 }
156
157 RequestContext::Provider RequestContext::setRequestContextProvider(
158     RequestContext::Provider newProvider) {
159   if (!newProvider) {
160     throw std::runtime_error("RequestContext provider must be non-empty");
161   }
162
163   auto& provider = requestContextProvider();
164   std::swap(provider, newProvider);
165   return newProvider;
166 }
167 }