Make RequestContext provider overridable in order to save cost of setContext() on...
[folly.git] / folly / io / async / Request.cpp
index 11894b5d04eac9067ec1d3c243cc24ef624e08f0..a3bd7382facfd63a756ef470e6af17a3e97b670f 100644 (file)
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 #include <folly/io/async/Request.h>
-#include <folly/tracing/StaticTracepoint.h>
+
+#include <algorithm>
+#include <stdexcept>
+#include <utility>
 
 #include <glog/logging.h>
 
 #include <folly/MapUtil.h>
 #include <folly/SingletonThreadLocal.h>
+#include <folly/tracing/StaticTracepoint.h>
 
 namespace folly {
 
@@ -115,19 +118,50 @@ std::shared_ptr<RequestContext> RequestContext::setContext(
   return ctx;
 }
 
-std::shared_ptr<RequestContext>& RequestContext::getStaticContext() {
-  using SingletonT = SingletonThreadLocal<std::shared_ptr<RequestContext>>;
-  static SingletonT singleton;
+RequestContext::Provider& RequestContext::requestContextProvider() {
+  class DefaultProvider {
+   public:
+    constexpr DefaultProvider() = default;
+    DefaultProvider(const DefaultProvider&) = delete;
+    DefaultProvider& operator=(const DefaultProvider&) = delete;
+    DefaultProvider(DefaultProvider&&) = default;
+    DefaultProvider& operator=(DefaultProvider&&) = default;
+
+    std::shared_ptr<RequestContext>& operator()() {
+      return context;
+    }
+
+   private:
+    std::shared_ptr<RequestContext> context;
+  };
 
-  return singleton.get();
+  static SingletonThreadLocal<Provider> providerSingleton(
+      []() { return new Provider(DefaultProvider()); });
+  return providerSingleton.get();
+}
+
+std::shared_ptr<RequestContext>& RequestContext::getStaticContext() {
+  auto& provider = requestContextProvider();
+  return provider();
 }
 
 RequestContext* RequestContext::get() {
-  auto context = getStaticContext();
+  auto& context = getStaticContext();
   if (!context) {
     static RequestContext defaultContext;
     return std::addressof(defaultContext);
   }
   return context.get();
 }
+
+RequestContext::Provider RequestContext::setRequestContextProvider(
+    RequestContext::Provider newProvider) {
+  if (!newProvider) {
+    throw std::runtime_error("RequestContext provider must be non-empty");
+  }
+
+  auto& provider = requestContextProvider();
+  std::swap(provider, newProvider);
+  return newProvider;
+}
 }