Avoid passing temporary to get_ref_default()
[folly.git] / folly / io / async / Request.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed to the Apache Software Foundation (ASF) under one
5  * or more contributor license agreements. See the NOTICE file
6  * distributed with this work for additional information
7  * regarding copyright ownership. The ASF licenses this file
8  * to you under the Apache License, Version 2.0 (the
9  * "License"); you may not use this file except in compliance
10  * with the License. You may obtain a copy of the License at
11  *
12  *   http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing,
15  * software distributed under the License is distributed on an
16  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17  * KIND, either express or implied. See the License for the
18  * specific language governing permissions and limitations
19  * under the License.
20  */
21
22 #include <folly/io/async/Request.h>
23 #include <folly/tracing/StaticTracepoint.h>
24
25 #include <glog/logging.h>
26
27 #include <folly/MapUtil.h>
28 #include <folly/SingletonThreadLocal.h>
29
30 namespace folly {
31
32 void RequestContext::setContextData(
33     const std::string& val,
34     std::unique_ptr<RequestData> data) {
35   auto wlock = data_.wlock();
36   if (wlock->count(val)) {
37     LOG_FIRST_N(WARNING, 1)
38         << "Called RequestContext::setContextData with data already set";
39
40     (*wlock)[val] = nullptr;
41   } else {
42     (*wlock)[val] = std::move(data);
43   }
44 }
45
46 bool RequestContext::setContextDataIfAbsent(
47     const std::string& val,
48     std::unique_ptr<RequestData> data) {
49   auto ulock = data_.ulock();
50   if (ulock->count(val)) {
51     return false;
52   }
53
54   auto wlock = ulock.moveFromUpgradeToWrite();
55   (*wlock)[val] = std::move(data);
56   return true;
57 }
58
59 bool RequestContext::hasContextData(const std::string& val) const {
60   return data_.rlock()->count(val);
61 }
62
63 RequestData* RequestContext::getContextData(const std::string& val) {
64   const std::unique_ptr<RequestData> dflt{nullptr};
65   return get_ref_default(*data_.rlock(), val, dflt).get();
66 }
67
68 const RequestData* RequestContext::getContextData(
69     const std::string& val) const {
70   const std::unique_ptr<RequestData> dflt{nullptr};
71   return get_ref_default(*data_.rlock(), val, dflt).get();
72 }
73
74 void RequestContext::onSet() {
75   auto rlock = data_.rlock();
76   for (auto const& ent : *rlock) {
77     if (auto& data = ent.second) {
78       data->onSet();
79     }
80   }
81 }
82
83 void RequestContext::onUnset() {
84   auto rlock = data_.rlock();
85   for (auto const& ent : *rlock) {
86     if (auto& data = ent.second) {
87       data->onUnset();
88     }
89   }
90 }
91
92 void RequestContext::clearContextData(const std::string& val) {
93   std::unique_ptr<RequestData> requestData;
94   // Delete the RequestData after giving up the wlock just in case one of the
95   // RequestData destructors will try to grab the lock again.
96   {
97     auto wlock = data_.wlock();
98     auto it = wlock->find(val);
99     if (it != wlock->end()) {
100       requestData = std::move(it->second);
101       wlock->erase(it);
102     }
103   }
104 }
105
106 std::shared_ptr<RequestContext> RequestContext::setContext(
107     std::shared_ptr<RequestContext> ctx) {
108   auto& curCtx = getStaticContext();
109   if (ctx != curCtx) {
110     FOLLY_SDT(folly, request_context_switch_before, curCtx.get(), ctx.get());
111     using std::swap;
112     if (curCtx) {
113       curCtx->onUnset();
114     }
115     swap(ctx, curCtx);
116     if (curCtx) {
117       curCtx->onSet();
118     }
119   }
120   return ctx;
121 }
122
123 std::shared_ptr<RequestContext>& RequestContext::getStaticContext() {
124   using SingletonT = SingletonThreadLocal<std::shared_ptr<RequestContext>>;
125   static SingletonT singleton;
126
127   return singleton.get();
128 }
129
130 RequestContext* RequestContext::get() {
131   auto context = getStaticContext();
132   if (!context) {
133     static RequestContext defaultContext;
134     return std::addressof(defaultContext);
135   }
136   return context.get();
137 }
138 }