Don't throw in Singleton::get() if singleton is alive
[folly.git] / folly / experimental / Singleton.h
index eba1c57f9aee683f5496d64daa41564ef0693ab5..1d6c9a78c83397bee0216bf61870986cbab2ea29 100644 (file)
@@ -233,12 +233,9 @@ class SingletonVault {
   void reenableInstances();
 
   // Retrieve a singleton from the vault, creating it if necessary.
-  std::shared_ptr<void> get_shared(detail::TypeDescriptor type) {
+  std::weak_ptr<void> get_weak(detail::TypeDescriptor type) {
     auto entry = get_entry_create(type);
-    if (UNLIKELY(!entry)) {
-      return std::shared_ptr<void>();
-    }
-    return entry->instance;
+    return entry->instance_weak;
   }
 
   // This function is inherently racy since we don't hold the
@@ -247,7 +244,7 @@ class SingletonVault {
   // the weak_ptr interface for true safety.
   void* get_ptr(detail::TypeDescriptor type) {
     auto entry = get_entry_create(type);
-    if (UNLIKELY(!entry)) {
+    if (UNLIKELY(entry->instance_weak.expired())) {
       throw std::runtime_error(
         "Raw pointer to a singleton requested after its destruction.");
     }
@@ -318,6 +315,7 @@ class SingletonVault {
 
     // The singleton itself and related functions.
     std::shared_ptr<void> instance;
+    std::weak_ptr<void> instance_weak;
     void* instance_ptr = nullptr;
     CreateFunc create = nullptr;
     TeardownFunc teardown = nullptr;
@@ -375,7 +373,7 @@ class SingletonVault {
     if (entry->instance == nullptr) {
       RWSpinLock::ReadHolder rh(&stateMutex_);
       if (state_ == SingletonVaultState::Quiescing) {
-        return nullptr;
+        return entry;
       }
 
       CHECK(entry->state == SingletonEntryState::Dead);
@@ -396,6 +394,7 @@ class SingletonVault {
 
       CHECK(entry->state == SingletonEntryState::BeingBorn);
       entry->instance = instance;
+      entry->instance_weak = instance;
       entry->instance_ptr = instance.get();
       entry->state = SingletonEntryState::Living;
       entry->state_condvar.notify_all();
@@ -456,7 +455,16 @@ class Singleton {
 
   static std::weak_ptr<T> get_weak(
       const char* name, SingletonVault* vault = nullptr /* for testing */) {
-    return std::weak_ptr<T>(get_shared({typeid(T), name}, vault));
+    auto weak_void_ptr =
+      (vault ?: SingletonVault::singleton())->get_weak({typeid(T), name});
+
+    // This is ugly and inefficient, but there's no other way to do it, because
+    // there's no static_pointer_cast for weak_ptr.
+    auto shared_void_ptr = weak_void_ptr.lock();
+    if (!shared_void_ptr) {
+      return std::weak_ptr<T>();
+    }
+    return std::static_pointer_cast<T>(shared_void_ptr);
   }
 
   // Allow the Singleton<t> instance to also retrieve the underlying
@@ -532,7 +540,7 @@ class Singleton {
       detail::TypeDescriptor type_descriptor = {typeid(T), ""},
       SingletonVault* vault = nullptr /* for testing */) {
     return std::static_pointer_cast<T>(
-        (vault ?: SingletonVault::singleton())->get_shared(type_descriptor));
+      (vault ?: SingletonVault::singleton())->get_weak(type_descriptor).lock());
   }
 
   detail::TypeDescriptor type_descriptor_;