2017
[folly.git] / folly / io / async / Request.cpp
index 8ea7f3e311789d82d576a4b401d6eced3114d24a..8310033b2020529c1a9f16453f44c0a521a27455 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016 Facebook, Inc.
+ * Copyright 2017 Facebook, Inc.
  *
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements. See the NOTICE file
  */
 
 #include <folly/io/async/Request.h>
+#include <folly/tracing/StaticTracepoint.h>
 
 #include <glog/logging.h>
 
+#include <folly/MapUtil.h>
 #include <folly/SingletonThreadLocal.h>
 
 namespace folly {
@@ -30,91 +32,78 @@ namespace folly {
 void RequestContext::setContextData(
     const std::string& val,
     std::unique_ptr<RequestData> data) {
-  folly::RWSpinLock::WriteHolder guard(lock);
-  if (data_.find(val) != data_.end()) {
+  auto wlock = data_.wlock();
+  if (wlock->count(val)) {
     LOG_FIRST_N(WARNING, 1)
         << "Called RequestContext::setContextData with data already set";
 
-    data_[val] = nullptr;
+    (*wlock)[val] = nullptr;
   } else {
-    data_[val] = std::move(data);
+    (*wlock)[val] = std::move(data);
   }
 }
 
 bool RequestContext::setContextDataIfAbsent(
     const std::string& val,
     std::unique_ptr<RequestData> data) {
-  folly::RWSpinLock::UpgradedHolder guard(lock);
-  if (data_.find(val) != data_.end()) {
+  auto ulock = data_.ulock();
+  if (ulock->count(val)) {
     return false;
   }
 
-  folly::RWSpinLock::WriteHolder writeGuard(std::move(guard));
-  data_[val] = std::move(data);
+  auto wlock = ulock.moveFromUpgradeToWrite();
+  (*wlock)[val] = std::move(data);
   return true;
 }
 
 bool RequestContext::hasContextData(const std::string& val) const {
-  folly::RWSpinLock::ReadHolder guard(lock);
-  return data_.find(val) != data_.end();
+  return data_.rlock()->count(val);
 }
 
 RequestData* RequestContext::getContextData(const std::string& val) {
-  folly::RWSpinLock::ReadHolder guard(lock);
-  auto r = data_.find(val);
-  if (r == data_.end()) {
-    return nullptr;
-  } else {
-    return r->second.get();
-  }
+  return get_ref_default(*data_.rlock(), val, nullptr).get();
 }
 
 const RequestData* RequestContext::getContextData(
     const std::string& val) const {
-  folly::RWSpinLock::ReadHolder guard(lock);
-  auto r = data_.find(val);
-  if (r == data_.end()) {
-    return nullptr;
-  } else {
-    return r->second.get();
-  }
+  return get_ref_default(*data_.rlock(), val, nullptr).get();
 }
 
 void RequestContext::onSet() {
-  folly::RWSpinLock::ReadHolder guard(lock);
-  for (auto const& ent : data_) {
-    if (RequestData* data = ent.second.get()) {
+  auto rlock = data_.rlock();
+  for (auto const& ent : *rlock) {
+    if (auto& data = ent.second) {
       data->onSet();
     }
   }
 }
 
 void RequestContext::onUnset() {
-  folly::RWSpinLock::ReadHolder guard(lock);
-  for (auto const& ent : data_) {
-    if (RequestData* data = ent.second.get()) {
+  auto rlock = data_.rlock();
+  for (auto const& ent : *rlock) {
+    if (auto& data = ent.second) {
       data->onUnset();
     }
   }
 }
 
 void RequestContext::clearContextData(const std::string& val) {
-  folly::RWSpinLock::WriteHolder guard(lock);
-  data_.erase(val);
+  data_.wlock()->erase(val);
 }
 
 std::shared_ptr<RequestContext> RequestContext::setContext(
     std::shared_ptr<RequestContext> ctx) {
-  auto& prev = getStaticContext();
-  if (ctx != prev) {
+  auto& curCtx = getStaticContext();
+  if (ctx != curCtx) {
+    FOLLY_SDT(folly, request_context_switch_before, curCtx.get(), ctx.get());
     using std::swap;
-    if (prev) {
-      prev->onUnset();
+    if (curCtx) {
+      curCtx->onUnset();
     }
-    if (ctx) {
-      ctx->onSet();
+    swap(ctx, curCtx);
+    if (curCtx) {
+      curCtx->onSet();
     }
-    swap(ctx, prev);
   }
   return ctx;
 }
@@ -125,4 +114,13 @@ std::shared_ptr<RequestContext>& RequestContext::getStaticContext() {
 
   return singleton.get();
 }
+
+RequestContext* RequestContext::get() {
+  auto context = getStaticContext();
+  if (!context) {
+    static RequestContext defaultContext;
+    return std::addressof(defaultContext);
+  }
+  return context.get();
+}
 }