SIOF-proof thread local
authorBen Maurer <bmaurer@fb.com>
Mon, 12 Oct 2015 21:35:54 +0000 (14:35 -0700)
committerfacebook-github-bot-4 <folly-bot@fb.com>
Tue, 13 Oct 2015 05:20:56 +0000 (22:20 -0700)
Summary: Right now ThreadLocal & friends don't operate correctly when used as a
static variable (which is the idiomatic way to use them). The TLS id is
allocated in the static constructor so anybody who uses the ID prior to
first use would use an invalid ID. This makes ThreadLocal unusable for core
code such as per-thread reference counting.

This diff allocates the ID on first use. By making the invalid ID maxint we
avoid adding any extra branches in the fast path. We can then make the
constructor a constexpr meaning that initialization will happen prior to
any code running.

Reviewed By: @meyering

Differential Revision: D2457989

fb-gh-sync-id: 21d0c0d00c638fbbd36148d14d4c891f66f83706

folly/ThreadLocal.h
folly/detail/ThreadLocalDetail.h
folly/test/ThreadLocalTest.cpp

index fee6fac4b0ce4e4524dafa20f3bcb58c34af7b30..a0d93fb46f70aba615530e8f0a8a2e8fc9460ec6 100644 (file)
@@ -59,7 +59,7 @@ template<class T, class Tag> class ThreadLocalPtr;
 template<class T, class Tag=void>
 class ThreadLocal {
  public:
-  ThreadLocal() = default;
+  constexpr ThreadLocal() {}
 
   T* get() const {
     T* ptr = tlp_.get();
@@ -134,18 +134,19 @@ class ThreadLocal {
 
 template<class T, class Tag=void>
 class ThreadLocalPtr {
+ private:
+  typedef threadlocal_detail::StaticMeta<Tag> StaticMeta;
  public:
-  ThreadLocalPtr() : id_(threadlocal_detail::StaticMeta<Tag>::create()) { }
+  constexpr ThreadLocalPtr() : id_() {}
 
-  ThreadLocalPtr(ThreadLocalPtr&& other) noexcept : id_(other.id_) {
-    other.id_ = 0;
+  ThreadLocalPtr(ThreadLocalPtr&& other) noexcept :
+    id_(std::move(other.id_)) {
   }
 
   ThreadLocalPtr& operator=(ThreadLocalPtr&& other) {
     assert(this != &other);
     destroy();
-    id_ = other.id_;
-    other.id_ = 0;
+    id_ = std::move(other.id_);
     return *this;
   }
 
@@ -154,7 +155,8 @@ class ThreadLocalPtr {
   }
 
   T* get() const {
-    return static_cast<T*>(threadlocal_detail::StaticMeta<Tag>::get(id_).ptr);
+    threadlocal_detail::ElementWrapper& w = StaticMeta::get(&id_);
+    return static_cast<T*>(w.ptr);
   }
 
   T* operator->() const {
@@ -166,15 +168,14 @@ class ThreadLocalPtr {
   }
 
   T* release() {
-    threadlocal_detail::ElementWrapper& w =
-      threadlocal_detail::StaticMeta<Tag>::get(id_);
+    threadlocal_detail::ElementWrapper& w = StaticMeta::get(&id_);
 
     return static_cast<T*>(w.release());
   }
 
   void reset(T* newPtr = nullptr) {
-    threadlocal_detail::ElementWrapper& w =
-      threadlocal_detail::StaticMeta<Tag>::get(id_);
+    threadlocal_detail::ElementWrapper& w = StaticMeta::get(&id_);
+
     if (w.ptr != newPtr) {
       w.dispose(TLPDestructionMode::THIS_THREAD);
       w.set(newPtr);
@@ -194,8 +195,7 @@ class ThreadLocalPtr {
    */
   template <class Deleter>
   void reset(T* newPtr, Deleter deleter) {
-    threadlocal_detail::ElementWrapper& w =
-      threadlocal_detail::StaticMeta<Tag>::get(id_);
+    threadlocal_detail::ElementWrapper& w = StaticMeta::get(&id_);
     if (w.ptr != newPtr) {
       w.dispose(TLPDestructionMode::THIS_THREAD);
       w.set(newPtr, deleter);
@@ -330,21 +330,19 @@ class ThreadLocalPtr {
   Accessor accessAllThreads() const {
     static_assert(!std::is_same<Tag, void>::value,
                   "Must use a unique Tag to use the accessAllThreads feature");
-    return Accessor(id_);
+    return Accessor(id_.getOrAllocate());
   }
 
  private:
   void destroy() {
-    if (id_) {
-      threadlocal_detail::StaticMeta<Tag>::destroy(id_);
-    }
+    StaticMeta::destroy(&id_);
   }
 
   // non-copyable
   ThreadLocalPtr(const ThreadLocalPtr&) = delete;
   ThreadLocalPtr& operator=(const ThreadLocalPtr&) = delete;
 
-  uint32_t id_;  // every instantiation has a unique id
+  mutable typename StaticMeta::EntryID id_;
 };
 
 }  // namespace folly
index bd8658cfc3a0721e764a1bca90a2b55859d970e7..d3d62ad84c674533b34a2932388c21c4832ecc8f 100644 (file)
@@ -161,6 +161,8 @@ struct ThreadEntry {
   ThreadEntry* prev;
 };
 
+
+
 // Held in a singleton to track our global instances.
 // We have one of these per "Tag", by default one for the whole system
 // (Tag=void).
@@ -170,6 +172,50 @@ struct ThreadEntry {
 // StaticMeta; you can specify multiple Tag types to break that lock.
 template <class Tag>
 struct StaticMeta {
+  // Represents an ID of a thread local object. Initially set to the maximum
+  // uint. This representation allows us to avoid a branch in accessing TLS data
+  // (because if you test capacity > id if id = maxint then the test will always
+  // fail). It allows us to keep a constexpr constructor and avoid SIOF.
+  class EntryID {
+   public:
+    static constexpr uint32_t kInvalid = std::numeric_limits<uint32_t>::max();
+    std::atomic<uint32_t> value;
+
+    constexpr EntryID() : value(kInvalid) {
+    }
+
+    EntryID(EntryID&& other) noexcept : value(other.value.load()) {
+      other.value = kInvalid;
+    }
+
+    EntryID& operator=(EntryID&& other) {
+      assert(this != &other);
+      value = other.value.load();
+      other.value = kInvalid;
+      return *this;
+    }
+
+    EntryID(const EntryID& other) = delete;
+    EntryID& operator=(const EntryID& other) = delete;
+
+    uint32_t getOrInvalid() {
+      // It's OK for this to be relaxed, even though we're effectively doing
+      // double checked locking in using this value. We only care about the
+      // uniqueness of IDs, getOrAllocate does not modify any other memory
+      // this thread will use.
+      return value.load(std::memory_order_relaxed);
+    }
+
+    uint32_t getOrAllocate() {
+      uint32_t id = getOrInvalid();
+      if (id != kInvalid) {
+        return id;
+      }
+      // The lock inside allocate ensures that a single value is allocated
+      return instance().allocate(this);
+    }
+  };
+
   static StaticMeta<Tag>& instance() {
     // Leak it on exit, there's only one per process and we don't have to
     // worry about synchronization with exiting threads.
@@ -303,26 +349,40 @@ struct StaticMeta {
 #endif
   }
 
-  static uint32_t create() {
+  static uint32_t allocate(EntryID* ent) {
     uint32_t id;
     auto & meta = instance();
     std::lock_guard<std::mutex> g(meta.lock_);
+
+    id = ent->value.load();
+    if (id != EntryID::kInvalid) {
+      return id;
+    }
+
     if (!meta.freeIds_.empty()) {
       id = meta.freeIds_.back();
       meta.freeIds_.pop_back();
     } else {
       id = meta.nextId_++;
     }
+
+    uint32_t old_id = ent->value.exchange(id);
+    DCHECK_EQ(old_id, EntryID::kInvalid);
     return id;
   }
 
-  static void destroy(uint32_t id) {
+  static void destroy(EntryID* ent) {
     try {
       auto & meta = instance();
       // Elements in other threads that use this id.
       std::vector<ElementWrapper> elements;
       {
         std::lock_guard<std::mutex> g(meta.lock_);
+        uint32_t id = ent->value.exchange(EntryID::kInvalid);
+        if (id == EntryID::kInvalid) {
+          return;
+        }
+
         for (ThreadEntry* e = meta.head_.next; e != &meta.head_; e = e->next) {
           if (id < e->elementsCapacity && e->elements[id].ptr) {
             elements.push_back(e->elements[id]);
@@ -358,13 +418,18 @@ struct StaticMeta {
    * Reserve enough space in the ThreadEntry::elements for the item
    * @id to fit in.
    */
-  static void reserve(uint32_t id) {
+  static void reserve(EntryID* id) {
     auto& meta = instance();
     ThreadEntry* threadEntry = getThreadEntry();
     size_t prevCapacity = threadEntry->elementsCapacity;
+
+    uint32_t idval = id->getOrAllocate();
+    if (prevCapacity > idval) {
+      return;
+    }
     // Growth factor < 2, see folly/docs/FBVector.md; + 5 to prevent
     // very slow start.
-    size_t newCapacity = static_cast<size_t>((id + 5) * 1.7);
+    size_t newCapacity = static_cast<size_t>((idval + 5) * 1.7);
     assert(newCapacity > prevCapacity);
     ElementWrapper* reallocated = nullptr;
 
@@ -444,10 +509,14 @@ struct StaticMeta {
 #endif
   }
 
-  static ElementWrapper& get(uint32_t id) {
+  static ElementWrapper& get(EntryID* ent) {
     ThreadEntry* threadEntry = getThreadEntry();
+    uint32_t id = ent->getOrInvalid();
+    // if id is invalid, it is equal to uint32_t's max value.
+    // x <= max value is always true
     if (UNLIKELY(threadEntry->elementsCapacity <= id)) {
-      reserve(id);
+      reserve(ent);
+      id = ent->getOrInvalid();
       assert(threadEntry->elementsCapacity > id);
     }
     return threadEntry->elements[id];
index b6ab9cdfb53d0b4c3dc8d1a7cd7ae366419c35d1..a319c2f6580691500bd8f3beb7ffc2be45f06943 100644 (file)
@@ -538,6 +538,21 @@ TEST(ThreadLocal, Fork2) {
   }
 }
 
+// clang is unable to compile this code unless in c++14 mode.
+#if __cplusplus >= 201402L
+namespace {
+// This will fail to compile unless ThreadLocal{Ptr} has a constexpr
+// default constructor. This ensures that ThreadLocal is safe to use in
+// static constructors without worrying about initialization order
+class ConstexprThreadLocalCompile {
+  ThreadLocal<int> a_;
+  ThreadLocalPtr<int> b_;
+
+  constexpr ConstexprThreadLocalCompile() {}
+};
+}
+#endif
+
 // Simple reference implementation using pthread_get_specific
 template<typename T>
 class PThreadGetSpecific {