Fixed thread launching in stress test framework
[libcds.git] / test / include / cds_test / thread.h
index 0048294b97538a7ab342cf26a535c5182b8529e7..a56987adf34fb89efdcd81b40ab6fb582dd7a5d0 100644 (file)
@@ -56,11 +56,6 @@ namespace cds_test {
         virtual ~thread()
         {}
 
-        void join()
-        {
-            m_impl.join();
-        }
-
     protected:
         virtual thread * clone() = 0;
         virtual void test() = 0;
@@ -87,22 +82,83 @@ namespace cds_test {
         friend class thread_pool;
 
         thread_pool&    m_pool;
-        int             m_type;
-        size_t          m_id;
-        std::thread     m_impl;
+        int const       m_type;
+        size_t const    m_id;
     };
 
     // Pool of test threads
     class thread_pool
     {
+        class barrier
+        {
+        public:
+            barrier()
+                : m_count( 0 )
+            {}
+
+            void reset( size_t count )
+            {
+                std::unique_lock< std::mutex > lock( m_mtx );
+                m_count = count;
+            }
+
+            bool wait()
+            {
+                std::unique_lock< std::mutex > lock( m_mtx );
+                if ( --m_count == 0 ) {
+                    m_cv.notify_all();
+                    return true;
+                }
+
+                while ( m_count != 0 )
+                    m_cv.wait( lock );
+
+                return false;
+            }
+
+        private:
+            size_t      m_count;
+            std::mutex  m_mtx;
+            std::condition_variable m_cv;
+        };
+
+        class initial_gate
+        {
+        public:
+            initial_gate()
+                : m_ready( false )
+            {}
+
+            void wait()
+            {
+                std::unique_lock< std::mutex > lock( m_mtx );
+                while ( !m_ready )
+                    m_cv.wait( lock );
+            }
+
+            void ready()
+            {
+                std::unique_lock< std::mutex > lock( m_mtx );
+                m_ready = true;
+                m_cv.notify_all();
+            }
+
+            void reset()
+            {
+                std::unique_lock< std::mutex > lock( m_mtx );
+                m_ready = false;
+            }
+
+        private:
+            std::mutex  m_mtx;
+            std::condition_variable m_cv;
+            bool        m_ready;
+        };
+
     public:
         explicit thread_pool( ::testing::Test& fixture )
             : m_fixture( fixture )
-            , m_bRunning( false )
-            , m_bStopped( false )
-            , m_doneCount( 0 )
             , m_bTimeElapsed( false )
-            , m_readyCount( 0 )
         {}
 
         ~thread_pool()
@@ -112,7 +168,7 @@ namespace cds_test {
 
         void add( thread * what )
         {
-            m_threads.push_back( what );
+            m_workers.push_back( what );
         }
 
         void add( thread * what, size_t count )
@@ -131,24 +187,28 @@ namespace cds_test {
 
         std::chrono::milliseconds run( std::chrono::seconds duration )
         {
-            m_bStopped = false;
-            m_doneCount = 0;
+            m_startBarrier.reset( m_workers.size() + 1 );
+            m_stopBarrier.reset( m_workers.size() + 1 );
 
-            while ( m_readyCount.load() != m_threads.size())
-                std::this_thread::yield();
+            // Create threads
+            std::vector< std::thread > threads;
+            threads.reserve( m_workers.size() );
+            for ( auto w : m_workers )
+                threads.emplace_back( &thread::run, w );
+
+            // The pool is intialized
+            m_startPoint.ready();
 
             m_bTimeElapsed.store( false, std::memory_order_release );
 
             auto native_duration = std::chrono::duration_cast<std::chrono::steady_clock::duration>(duration);
+
+            // The pool is ready to start all workers
+            m_startBarrier.wait();
+
             auto time_start = std::chrono::steady_clock::now();
             auto const expected_end = time_start + native_duration;
 
-            {
-                scoped_lock l( m_cvMutex );
-                m_bRunning = true;
-                m_cvStart.notify_all();
-            }
-
             if ( duration != std::chrono::seconds::zero()) {
                 for ( ;; ) {
                     std::this_thread::sleep_for( native_duration );
@@ -160,24 +220,19 @@ namespace cds_test {
             }
             m_bTimeElapsed.store( true, std::memory_order_release );
 
-            {
-                scoped_lock l( m_cvMutex );
-                while ( m_doneCount != m_threads.size())
-                    m_cvDone.wait( l );
-                m_bStopped = true;
-            }
-            auto time_end = std::chrono::steady_clock::now();
+            // Waiting for all workers done
+            m_stopBarrier.wait();
 
-            m_cvStop.notify_all();
+            auto time_end = std::chrono::steady_clock::now();
 
-            for ( auto t : m_threads )
-                t->join();
+            for ( auto& t : threads )
+                t.join();
 
             return m_testDuration = std::chrono::duration_cast<std::chrono::milliseconds>(time_end - time_start);
         }
 
-        size_t size() const             { return m_threads.size(); }
-        thread& get( size_t idx ) const { return *m_threads.at( idx ); }
+        size_t size() const             { return m_workers.size(); }
+        thread& get( size_t idx ) const { return *m_workers.at( idx ); }
 
         template <typename Fixture>
         Fixture& fixture()
@@ -189,67 +244,51 @@ namespace cds_test {
 
         void clear()
         {
-            for ( auto t : m_threads )
+            for ( auto t : m_workers )
                 delete t;
-            m_threads.clear();
-            m_bRunning = false;
-            m_bStopped = false;
-            m_doneCount = 0;
-            m_readyCount = 0;
+            m_workers.clear();
+            m_startPoint.reset();
+        }
+
+        void reset()
+        {
+            clear();
         }
 
     protected: // thread interface
         size_t get_next_id()
         {
-            return m_threads.size();
+            return m_workers.size();
         }
 
-        void    ready_to_start( thread& /*who*/ )
+        void ready_to_start( thread& /*who*/ )
         {
             // Called from test thread
 
-            // Wait for all thread created
-            scoped_lock l( m_cvMutex );
-            m_readyCount.fetch_add( 1 );
-            while ( !m_bRunning )
-                m_cvStart.wait( l );
+            // Wait until the pool is ready
+            m_startPoint.wait();
+
+            // Wait until all thread ready
+            m_startBarrier.wait();
         }
 
-        void    thread_done( thread& /*who*/ )
+        void thread_done( thread& /*who*/ )
         {
             // Called from test thread
-
-            {
-                scoped_lock l( m_cvMutex );
-                ++m_doneCount;
-
-                // Tell pool that the thread is done
-                m_cvDone.notify_all();
-
-                // Wait for all thread done
-                while ( !m_bStopped )
-                    m_cvStop.wait( l );
-            }
+            m_stopBarrier.wait();
         }
 
     private:
         friend class thread;
 
         ::testing::Test&        m_fixture;
-        std::vector<thread *>   m_threads;
+        std::vector<thread *>   m_workers;
 
-        typedef std::unique_lock<std::mutex> scoped_lock;
-        std::mutex              m_cvMutex;
-        std::condition_variable m_cvStart;
-        std::condition_variable m_cvStop;
-        std::condition_variable m_cvDone;
+        initial_gate            m_startPoint;
+        barrier                 m_startBarrier;
+        barrier                 m_stopBarrier;
 
-        volatile bool   m_bRunning;
-        volatile bool   m_bStopped;
-        volatile size_t m_doneCount;
         std::atomic<bool> m_bTimeElapsed;
-        std::atomic<size_t> m_readyCount;
-
         std::chrono::milliseconds m_testDuration;
     };
 
@@ -257,14 +296,12 @@ namespace cds_test {
         : m_pool( master )
         , m_type( type )
         , m_id( master.get_next_id())
-        , m_impl( &thread::run, this )
     {}
 
     inline thread::thread( thread const& sample )
         : m_pool( sample.m_pool )
         , m_type( sample.m_type )
         , m_id( m_pool.get_next_id())
-        , m_impl( &thread::run, this )
     {}
 
     inline void thread::run()