Fixed use-after-free bug in SkipList<HP>
authorkhizmax <libcds.dev@gmail.com>
Sun, 27 Nov 2016 16:43:23 +0000 (19:43 +0300)
committerkhizmax <libcds.dev@gmail.com>
Sun, 27 Nov 2016 16:43:23 +0000 (19:43 +0300)
cds/container/details/make_skip_list_map.h
cds/container/details/make_skip_list_set.h
cds/intrusive/details/skip_list_base.h
cds/intrusive/impl/skip_list.h
test/include/cds_test/stat_skiplist_out.h

index 8e456fe..c494139 100644 (file)
@@ -57,22 +57,15 @@ namespace cds { namespace container { namespace details {
             //atomic_marked_ptr m_arrTower[] ;  // allocated together with node_type in single memory block
 
             template <typename Q>
-            node_type( unsigned int nHeight, atomic_marked_ptr * pTower, Q const& key )
-                : m_Value( std::make_pair( key, mapped_type()))
-            {
-                init_tower( nHeight, pTower );
-            }
-
-            template <typename Q, typename U>
-            node_type( unsigned int nHeight, atomic_marked_ptr * pTower, Q const& key, U const& val )
-                : m_Value( std::make_pair( key, val ))
+            node_type( unsigned int nHeight, atomic_marked_ptr * pTower, Q&& key )
+                : m_Value( std::make_pair( std::forward<Q>( key ), mapped_type()))
             {
                 init_tower( nHeight, pTower );
             }
 
             template <typename Q, typename... Args>
             node_type( unsigned int nHeight, atomic_marked_ptr * pTower, Q&& key, Args&&... args )
-                : m_Value( std::forward<Q>(key), std::move( mapped_type( std::forward<Args>(args)... )))
+                : m_Value( std::forward<Q>(key), mapped_type( std::forward<Args>(args)... ))
             {
                 init_tower( nHeight, pTower );
             }
index e63bbe8..658c523 100644 (file)
@@ -55,30 +55,32 @@ namespace cds { namespace container { namespace details {
             //atomic_marked_ptr m_arrTower[] ;  // allocated together with node_type in single memory block
 
             template <typename Q>
-            node_type( unsigned int nHeight, atomic_marked_ptr * pTower, Q const& v )
-                : m_Value(v)
+            node_type( unsigned int nHeight, atomic_marked_ptr * pTower, Q&& v )
+                : m_Value( std::forward<Q>( v ))
             {
-                if ( nHeight > 1 ) {
-                    // TSan: make_tower() issues atomic_thread_fence( release )
-                    CDS_TSAN_ANNOTATE_IGNORE_WRITES_BEGIN;
-                    new (pTower) atomic_marked_ptr[ nHeight - 1 ];
-                    base_class::make_tower( nHeight, pTower );
-                    CDS_TSAN_ANNOTATE_IGNORE_WRITES_END;
-                }
+                init_tower( nHeight, pTower );
             }
 
             template <typename Q, typename... Args>
             node_type( unsigned int nHeight, atomic_marked_ptr * pTower, Q&& q, Args&&... args )
                 : m_Value( std::forward<Q>(q), std::forward<Args>(args)... )
+            {
+                init_tower( nHeight, pTower );
+            }
+
+            node_type() = delete;
+
+        private:
+            void init_tower( unsigned nHeight, atomic_marked_ptr* pTower )
             {
                 if ( nHeight > 1 ) {
-                    new (pTower) atomic_marked_ptr[ nHeight - 1 ];
+                    // TSan: make_tower() issues atomic_thread_fence( release )
+                    CDS_TSAN_ANNOTATE_IGNORE_WRITES_BEGIN;
+                    new ( pTower ) atomic_marked_ptr[nHeight - 1];
                     base_class::make_tower( nHeight, pTower );
+                    CDS_TSAN_ANNOTATE_IGNORE_WRITES_END;
                 }
             }
-
-        private:
-            node_type() ;   // no default ctor
         };
 
         typedef skip_list::details::node_allocator< node_type, traits> node_allocator;
index 00ddad2..eb61006 100644 (file)
@@ -62,27 +62,29 @@ namespace cds { namespace intrusive {
             typedef typename gc::template atomic_marked_ptr< marked_ptr>  atomic_marked_ptr; ///< atomic marked pointer specific for GC
             //@cond
             typedef atomic_marked_ptr tower_item_type;
+
+            enum state {
+                clean,      // initial state
+                removed,    // final state
+                hand_off    // temp state
+            };
             //@endcond
 
         protected:
-            atomic_marked_ptr       m_pNext;   ///< Next item in bottom-list (list at level 0)
-            unsigned int            m_nHeight; ///< Node height (size of m_arrNext array). For node at level 0 the height is 1.
-            atomic_marked_ptr *     m_arrNext; ///< Array of next items for levels 1 .. m_nHeight - 1. For node at level 0 \p m_arrNext is \p nullptr
+            //@cond
+            atomic_marked_ptr           m_pNext{ nullptr };     ///< Next item in bottom-list (list at level 0)
+            unsigned int                m_nHeight{ 1 };         ///< Node height (size of \p m_arrNext array). For node at level 0 the height is 1.
+            atomic_marked_ptr *         m_arrNext{ nullptr };   ///< Array of next items for levels 1 .. m_nHeight - 1. For node at level 0 \p m_arrNext is \p nullptr
+            atomics::atomic< state >    m_state{ clean };
+            //@endcond
 
         public:
-            /// Constructs a node of height 1 (a bottom-list node)
-            node()
-                : m_pNext( nullptr )
-                , m_nHeight(1)
-                , m_arrNext( nullptr )
-            {}
-
-            /// Constructs a node of height \p nHeight
+            /// Constructs a node's tower of height \p nHeight
             void make_tower( unsigned int nHeight, atomic_marked_ptr * nextTower )
             {
                 assert( nHeight > 0 );
-                assert( (nHeight == 1 && nextTower == nullptr)  // bottom-list node
-                        || (nHeight > 1 && nextTower != nullptr)   // node at level of more than 0
+                assert( (nHeight == 1 && nextTower == nullptr)      // bottom-list node
+                        || (nHeight > 1 && nextTower != nullptr)    // node at level of more than 0
                 );
 
                 m_arrNext = nextTower;
@@ -160,6 +162,16 @@ namespace cds { namespace intrusive {
                     && m_arrNext == nullptr
                     && m_nHeight <= 1;
             }
+
+            bool set_state( state& cur_state, state new_state, atomics::memory_order order )
+            {
+                return m_state.compare_exchange_strong( cur_state, new_state, order, atomics::memory_order_relaxed );
+            }
+
+            void clear_state( atomics::memory_order order )
+            {
+                m_state.store( clean, order );
+            }
             //@endcond
         };
 
@@ -409,6 +421,7 @@ namespace cds { namespace intrusive {
             event_counter   m_nExtractMaxRetries    ; ///< Count of retries of \p extract_max call
             event_counter   m_nEraseWhileFind       ; ///< Count of erased item while searching
             event_counter   m_nExtractWhileFind     ; ///< Count of extracted item while searching (RCU only)
+            event_counter   m_nNodeHandOffFailed    ; ///< Cannot set "hand-off" node state
 
             //@cond
             void onAddNode( unsigned int nHeight )
@@ -454,6 +467,7 @@ namespace cds { namespace intrusive {
             void onExtractMaxSuccess()      { ++m_nExtractMaxSuccess; }
             void onExtractMaxFailed()       { ++m_nExtractMaxFailed;  }
             void onExtractMaxRetry()        { ++m_nExtractMaxRetries; }
+            void onNodeHandOffFailed()      { ++m_nNodeHandOffFailed; }
 
             //@endcond
         };
@@ -495,7 +509,7 @@ namespace cds { namespace intrusive {
             void onExtractMaxSuccess()      const {}
             void onExtractMaxFailed()       const {}
             void onExtractMaxRetry()        const {}
-
+            void onNodeHandOffFailed()      const {}
             //@endcond
         };
 
index 4f7d1ea..c72c5e8 100644 (file)
@@ -394,7 +394,8 @@ namespace cds { namespace intrusive {
         // c_nMaxHeight * 2 - pPred/pSucc guards
         // + 1 - for erase, unlink
         // + 1 - for clear
-        static size_t const c_nHazardPtrCount = c_nMaxHeight * 2 + 2; ///< Count of hazard pointer required for the skip-list
+        // + 1 - for help_remove
+        static size_t const c_nHazardPtrCount = c_nMaxHeight * 2 + 3; ///< Count of hazard pointer required for the skip-list
 
     protected:
         typedef typename node_type::atomic_marked_ptr   atomic_node_ptr;   ///< Atomic marked node pointer
@@ -639,7 +640,7 @@ namespace cds { namespace intrusive {
             node_type * pNode = node_traits::to_node_ptr( val );
             scoped_node_ptr scp( pNode );
             unsigned int nHeight = pNode->height();
-            bool bTowerOk = pNode->has_tower(); // nHeight > 1 && pNode->get_tower() != nullptr;
+            bool bTowerOk = pNode->has_tower();
             bool bTowerMade = false;
 
             position pos;
@@ -1147,6 +1148,32 @@ namespace cds { namespace intrusive {
             disposer()( pVal );
         }
 
+        void help_remove( int nLevel, node_type* pPred, marked_node_ptr pCur, marked_node_ptr pSucc )
+        {
+            typename gc::Guard succ_guard;
+            marked_node_ptr succ = succ_guard.protect( pCur->next( nLevel ), gc_protect );
+
+            typename node_type::state state = node_type::clean;
+            if ( succ == pSucc && ( succ.ptr() == nullptr ||
+                succ.ptr()->set_state( state, node_type::hand_off, memory_model::memory_order_acquire )))
+            {
+                marked_node_ptr p( pCur.ptr() );
+                if ( pPred->next( nLevel ).compare_exchange_strong( p, marked_node_ptr( succ.ptr()),
+                    memory_model::memory_order_acquire, atomics::memory_order_relaxed ) )
+                {
+                    if ( nLevel == 0 ) {
+                        gc::retire( node_traits::to_value_ptr( pCur.ptr() ), dispose_node );
+                        m_Stat.onEraseWhileFind();
+                    }
+                }
+
+                if ( succ.ptr() )
+                    succ.ptr()->clear_state( memory_model::memory_order_release );
+            }
+            else
+                m_Stat.onNodeHandOffFailed();
+        }
+
         template <typename Q, typename Compare >
         bool find_position( Q const& val, position& pos, Compare cmp, bool bStopIfFound )
         {
@@ -1183,16 +1210,9 @@ namespace cds { namespace intrusive {
                         goto retry;
 
                     if ( pSucc.bits() ) {
-                        // pCur is marked, i.e. logically deleted.
-                        marked_node_ptr p( pCur.ptr() );
-                        if ( pPred->next( nLevel ).compare_exchange_strong( p, marked_node_ptr( pSucc.ptr() ),
-                            memory_model::memory_order_acquire, atomics::memory_order_relaxed ) )
-                        {
-                            if ( nLevel == 0 ) {
-                                gc::retire( node_traits::to_value_ptr( pCur.ptr() ), dispose_node );
-                                m_Stat.onEraseWhileFind();
-                            }
-                        }
+                        // pCur is marked, i.e. logically deleted
+                        // try to help deleting pCur if pSucc is not being deleted
+                        help_remove( nLevel, pPred, pCur, pSucc );
                         goto retry;
                     }
                     else {
@@ -1252,15 +1272,8 @@ namespace cds { namespace intrusive {
 
                     if ( pSucc.bits() ) {
                         // pCur is marked, i.e. logically deleted.
-                        marked_node_ptr p( pCur.ptr() );
-                        if ( pPred->next( nLevel ).compare_exchange_strong( p, marked_node_ptr( pSucc.ptr() ),
-                            memory_model::memory_order_acquire, atomics::memory_order_relaxed ) )
-                        {
-                            if ( nLevel == 0 ) {
-                                gc::retire( node_traits::to_value_ptr( pCur.ptr() ), dispose_node );
-                                m_Stat.onEraseWhileFind();
-                            }
-                        }
+                        // try to help deleting pCur if pSucc is not being deleted
+                        help_remove( nLevel, pPred, pCur, pSucc );
                         goto retry;
                     }
                 }
@@ -1308,15 +1321,8 @@ namespace cds { namespace intrusive {
 
                     if ( pSucc.bits() ) {
                         // pCur is marked, i.e. logically deleted.
-                        marked_node_ptr p( pCur.ptr() );
-                        if ( pPred->next( nLevel ).compare_exchange_strong( p, marked_node_ptr( pSucc.ptr() ),
-                            memory_model::memory_order_acquire, atomics::memory_order_relaxed ) )
-                        {
-                            if ( nLevel == 0 ) {
-                                gc::retire( node_traits::to_value_ptr( pCur.ptr() ), dispose_node );
-                                m_Stat.onEraseWhileFind();
-                            }
-                        }
+                        // try to help deleting pCur if pSucc is not being deleted
+                        help_remove( nLevel, pPred, pCur, pSucc );
                         goto retry;
                     }
                     else {
@@ -1346,11 +1352,21 @@ namespace cds { namespace intrusive {
 
             // Insert at level 0
             {
-                marked_node_ptr p( pos.pSucc[0] );
+                node_type* succ = pos.pSucc[0];
+                typename node_type::state state = node_type::clean;
+                if ( succ != nullptr && !succ->set_state( state, node_type::hand_off, memory_model::memory_order_acquire ) )
+                    return false;
+
+                marked_node_ptr p( succ );
                 pNode->next( 0 ).store( p, memory_model::memory_order_release );
-                if ( !pos.pPrev[0]->next( 0 ).compare_exchange_strong( p, marked_node_ptr( pNode ), memory_model::memory_order_release, atomics::memory_order_relaxed ) )
+                if ( !pos.pPrev[0]->next( 0 ).compare_exchange_strong( p, marked_node_ptr( pNode ), memory_model::memory_order_release, atomics::memory_order_relaxed ) ) {
+                    if ( succ )
+                        succ->clear_state( memory_model::memory_order_release );
                     return false;
+                }
 
+                if ( succ )
+                    succ->clear_state( memory_model::memory_order_release );
                 f( val );
             }
 
@@ -1358,17 +1374,30 @@ namespace cds { namespace intrusive {
             for ( unsigned int nLevel = 1; nLevel < nHeight; ++nLevel ) {
                 marked_node_ptr p;
                 while ( true ) {
-                    marked_node_ptr q( pos.pSucc[nLevel] );
-                    if ( !pNode->next( nLevel ).compare_exchange_strong( p, q, memory_model::memory_order_release, atomics::memory_order_relaxed ) ) {
-                        // pNode has been marked as removed while we are inserting it
-                        // Stop inserting
-                        assert( p.bits() );
-                        m_Stat.onLogicDeleteWhileInsert();
-                        return true;
+                    typename node_type::state state = node_type::clean;
+                    node_type* succ = pos.pSucc[nLevel];
+                    if ( succ == nullptr ||
+                        succ->set_state( state, node_type::hand_off, memory_model::memory_order_acquire ) ) 
+                    {
+                        marked_node_ptr q( succ );
+                        if ( !pNode->next( nLevel ).compare_exchange_strong( p, q, memory_model::memory_order_release, atomics::memory_order_relaxed )) {
+                            // pNode has been marked as removed while we are inserting it
+                            // Stop inserting
+                            if ( succ )
+                                succ->clear_state( memory_model::memory_order_release );
+                            assert( p.bits() );
+                            m_Stat.onLogicDeleteWhileInsert();
+                            return true;
+                        }
+
+                        p = q;
+                        bool const result = pos.pPrev[nLevel]->next( nLevel ).compare_exchange_strong( q, marked_node_ptr( pNode ),
+                            memory_model::memory_order_release, atomics::memory_order_relaxed );
+                        if ( succ )
+                            succ->clear_state( memory_model::memory_order_release );
+                        if ( result )
+                            break;
                     }
-                    p = q;
-                    if ( pos.pPrev[nLevel]->next( nLevel ).compare_exchange_strong( q, marked_node_ptr( pNode ), memory_model::memory_order_release, atomics::memory_order_relaxed ) )
-                        break;
 
                     // Renew insert position
                     m_Stat.onRenewInsertPosition();
@@ -1387,6 +1416,17 @@ namespace cds { namespace intrusive {
         {
             assert( pDel != nullptr );
 
+            // set "removed" node state
+            {
+                back_off bkoff;
+                typename node_type::state state = node_type::clean;
+                while ( !( pDel->set_state( state, node_type::removed, memory_model::memory_order_release )
+                    || state == node_type::removed ))
+                {
+                    bkoff();
+                }
+            }
+
             marked_node_ptr pSucc;
 
             // logical deletion (marking)
@@ -1574,7 +1614,7 @@ namespace cds { namespace intrusive {
             position pos;
 
             guarded_ptr gp;
-            for ( ;;) {
+            for (;;) {
                 if ( !find_position( val, pos, cmp, false ) ) {
                     m_Stat.onExtractFailed();
                     return guarded_ptr();
index 8e5b22e..66ecdae 100644 (file)
@@ -81,7 +81,8 @@ namespace cds_test {
             << CDSSTRESS_STAT_OUT( s, m_nFastExtract )
             << CDSSTRESS_STAT_OUT( s, m_nSlowExtract )
             << CDSSTRESS_STAT_OUT( s, m_nEraseWhileFind )
-            << CDSSTRESS_STAT_OUT( s, m_nExtractWhileFind );
+            << CDSSTRESS_STAT_OUT( s, m_nExtractWhileFind )
+            << CDSSTRESS_STAT_OUT( s, m_nNodeHandOffFailed );
     }
 
 } // namespace cds_test