Revert D6725091: [Folly] Use thread-local in RequestContext::getStaticContext
[folly.git] / folly / io / async / Request.cpp
index a3bd7382facfd63a756ef470e6af17a3e97b670f..f8cca8b92a866c53bf34253fbf6ea96343208425 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2004-present Facebook, Inc.
+ * Copyright 2016-present Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#include <folly/io/async/Request.h>
 
-#include <algorithm>
-#include <stdexcept>
-#include <utility>
+#include <folly/io/async/Request.h>
+#include <folly/tracing/StaticTracepoint.h>
 
 #include <glog/logging.h>
 
 #include <folly/MapUtil.h>
 #include <folly/SingletonThreadLocal.h>
-#include <folly/tracing/StaticTracepoint.h>
 
 namespace folly {
 
+bool RequestContext::doSetContextData(
+    const std::string& val,
+    std::unique_ptr<RequestData>& data,
+    bool strict) {
+  auto ulock = state_.ulock();
+
+  bool conflict = false;
+  auto it = ulock->requestData_.find(val);
+  if (it != ulock->requestData_.end()) {
+    if (strict) {
+      return false;
+    } else {
+      LOG_FIRST_N(WARNING, 1) << "Calling RequestContext::setContextData for "
+                              << val << " but it is already set";
+      conflict = true;
+    }
+  }
+
+  auto wlock = ulock.moveFromUpgradeToWrite();
+  if (conflict) {
+    if (it->second) {
+      if (it->second->hasCallback()) {
+        wlock->callbackData_.erase(it->second.get());
+      }
+      it->second.reset(nullptr);
+    }
+    return true;
+  }
+
+  if (data && data->hasCallback()) {
+    wlock->callbackData_.insert(data.get());
+  }
+  wlock->requestData_[val] = std::move(data);
+
+  return true;
+}
+
 void RequestContext::setContextData(
     const std::string& val,
     std::unique_ptr<RequestData> data) {
-  auto wlock = data_.wlock();
-  if (wlock->count(val)) {
-    LOG_FIRST_N(WARNING, 1)
-        << "Called RequestContext::setContextData with data already set";
-
-    (*wlock)[val] = nullptr;
-  } else {
-    (*wlock)[val] = std::move(data);
-  }
+  doSetContextData(val, data, false /* strict */);
 }
 
 bool RequestContext::setContextDataIfAbsent(
     const std::string& val,
     std::unique_ptr<RequestData> data) {
-  auto ulock = data_.ulock();
-  if (ulock->count(val)) {
-    return false;
-  }
-
-  auto wlock = ulock.moveFromUpgradeToWrite();
-  (*wlock)[val] = std::move(data);
-  return true;
+  return doSetContextData(val, data, true /* strict */);
 }
 
 bool RequestContext::hasContextData(const std::string& val) const {
-  return data_.rlock()->count(val);
+  return state_.rlock()->requestData_.count(val);
 }
 
 RequestData* RequestContext::getContextData(const std::string& val) {
   const std::unique_ptr<RequestData> dflt{nullptr};
-  return get_ref_default(*data_.rlock(), val, dflt).get();
+  return get_ref_default(state_.rlock()->requestData_, val, dflt).get();
 }
 
 const RequestData* RequestContext::getContextData(
     const std::string& val) const {
   const std::unique_ptr<RequestData> dflt{nullptr};
-  return get_ref_default(*data_.rlock(), val, dflt).get();
+  return get_ref_default(state_.rlock()->requestData_, val, dflt).get();
 }
 
 void RequestContext::onSet() {
-  auto rlock = data_.rlock();
-  for (auto const& ent : *rlock) {
-    if (auto& data = ent.second) {
-      data->onSet();
-    }
+  auto rlock = state_.rlock();
+  for (const auto& data : rlock->callbackData_) {
+    data->onSet();
   }
 }
 
 void RequestContext::onUnset() {
-  auto rlock = data_.rlock();
-  for (auto const& ent : *rlock) {
-    if (auto& data = ent.second) {
-      data->onUnset();
-    }
+  auto rlock = state_.rlock();
+  for (const auto& data : rlock->callbackData_) {
+    data->onUnset();
   }
 }
 
@@ -92,12 +107,19 @@ void RequestContext::clearContextData(const std::string& val) {
   // Delete the RequestData after giving up the wlock just in case one of the
   // RequestData destructors will try to grab the lock again.
   {
-    auto wlock = data_.wlock();
-    auto it = wlock->find(val);
-    if (it != wlock->end()) {
-      requestData = std::move(it->second);
-      wlock->erase(it);
+    auto ulock = state_.ulock();
+    auto it = ulock->requestData_.find(val);
+    if (it == ulock->requestData_.end()) {
+      return;
     }
+
+    auto wlock = ulock.moveFromUpgradeToWrite();
+    if (it->second && it->second->hasCallback()) {
+      wlock->callbackData_.erase(it->second.get());
+    }
+
+    requestData = std::move(it->second);
+    wlock->requestData_.erase(it);
   }
 }
 
@@ -118,31 +140,11 @@ std::shared_ptr<RequestContext> RequestContext::setContext(
   return ctx;
 }
 
-RequestContext::Provider& RequestContext::requestContextProvider() {
-  class DefaultProvider {
-   public:
-    constexpr DefaultProvider() = default;
-    DefaultProvider(const DefaultProvider&) = delete;
-    DefaultProvider& operator=(const DefaultProvider&) = delete;
-    DefaultProvider(DefaultProvider&&) = default;
-    DefaultProvider& operator=(DefaultProvider&&) = default;
-
-    std::shared_ptr<RequestContext>& operator()() {
-      return context;
-    }
-
-   private:
-    std::shared_ptr<RequestContext> context;
-  };
-
-  static SingletonThreadLocal<Provider> providerSingleton(
-      []() { return new Provider(DefaultProvider()); });
-  return providerSingleton.get();
-}
-
 std::shared_ptr<RequestContext>& RequestContext::getStaticContext() {
-  auto& provider = requestContextProvider();
-  return provider();
+  using SingletonT = SingletonThreadLocal<std::shared_ptr<RequestContext>>;
+  static SingletonT singleton;
+
+  return singleton.get();
 }
 
 RequestContext* RequestContext::get() {
@@ -153,15 +155,4 @@ RequestContext* RequestContext::get() {
   }
   return context.get();
 }
-
-RequestContext::Provider RequestContext::setRequestContextProvider(
-    RequestContext::Provider newProvider) {
-  if (!newProvider) {
-    throw std::runtime_error("RequestContext provider must be non-empty");
-  }
-
-  auto& provider = requestContextProvider();
-  std::swap(provider, newProvider);
-  return newProvider;
-}
-}
+} // namespace folly