* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
#include <folly/io/async/Request.h>
-#include <folly/tracing/StaticTracepoint.h>
+
+#include <algorithm>
+#include <stdexcept>
+#include <utility>
#include <glog/logging.h>
#include <folly/MapUtil.h>
#include <folly/SingletonThreadLocal.h>
+#include <folly/tracing/StaticTracepoint.h>
namespace folly {
return ctx;
}
-std::shared_ptr<RequestContext>& RequestContext::getStaticContext() {
- using SingletonT = SingletonThreadLocal<std::shared_ptr<RequestContext>>;
- static SingletonT singleton;
+RequestContext::Provider& RequestContext::requestContextProvider() {
+ class DefaultProvider {
+ public:
+ constexpr DefaultProvider() = default;
+ DefaultProvider(const DefaultProvider&) = delete;
+ DefaultProvider& operator=(const DefaultProvider&) = delete;
+ DefaultProvider(DefaultProvider&&) = default;
+ DefaultProvider& operator=(DefaultProvider&&) = default;
+
+ std::shared_ptr<RequestContext>& operator()() {
+ return context;
+ }
+
+ private:
+ std::shared_ptr<RequestContext> context;
+ };
- return singleton.get();
+ static SingletonThreadLocal<Provider> providerSingleton(
+ []() { return new Provider(DefaultProvider()); });
+ return providerSingleton.get();
+}
+
+std::shared_ptr<RequestContext>& RequestContext::getStaticContext() {
+ auto& provider = requestContextProvider();
+ return provider();
}
RequestContext* RequestContext::get() {
- auto context = getStaticContext();
+ auto& context = getStaticContext();
if (!context) {
static RequestContext defaultContext;
return std::addressof(defaultContext);
}
return context.get();
}
+
+RequestContext::Provider RequestContext::setRequestContextProvider(
+ RequestContext::Provider newProvider) {
+ if (!newProvider) {
+ throw std::runtime_error("RequestContext provider must be non-empty");
+ }
+
+ auto& provider = requestContextProvider();
+ std::swap(provider, newProvider);
+ return newProvider;
+}
}