Removed trailing spaces
[libcds.git] / test / include / cds_test / thread.h
1 /*
2     This file is a part of libcds - Concurrent Data Structures library
3
4     (C) Copyright Maxim Khizhinsky (libcds.dev@gmail.com) 2006-2016
5
6     Source code repo: http://github.com/khizmax/libcds/
7     Download: http://sourceforge.net/projects/libcds/files/
8
9     Redistribution and use in source and binary forms, with or without
10     modification, are permitted provided that the following conditions are met:
11
12     * Redistributions of source code must retain the above copyright notice, this
13     list of conditions and the following disclaimer.
14
15     * Redistributions in binary form must reproduce the above copyright notice,
16     this list of conditions and the following disclaimer in the documentation
17     and/or other materials provided with the distribution.
18
19     THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20     AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21     IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22     DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23     FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24     DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25     SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26     CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27     OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28     OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 */
30
31 #ifndef CDSTEST_THREAD_H
32 #define CDSTEST_THREAD_H
33
34 #include <gtest/gtest.h>
35 #include <vector>
36 #include <thread>
37 #include <condition_variable>
38 #include <mutex>
39 #include <chrono>
40 #include <cds/threading/model.h>
41
42 namespace cds_test {
43
44     // Forwards
45     class thread;
46     class thread_pool;
47
48     // Test thread
49     class thread
50     {
51         void run();
52
53     protected: // thread_pool interface
54         thread( thread const& sample );
55
56         virtual ~thread()
57         {}
58
59         void join()
60         {
61             m_impl.join();
62         }
63
64     protected:
65         virtual thread * clone() = 0;
66         virtual void test() = 0;
67
68         virtual void SetUp()
69         {
70             cds::threading::Manager::attachThread();
71         }
72
73         virtual void TearDown()
74         {
75             cds::threading::Manager::detachThread();
76         }
77
78     public:
79         explicit thread( thread_pool& master, int type = 0 );
80
81         thread_pool& pool() { return m_pool; }
82         int type() const { return m_type; }
83         size_t id() const { return m_id;  }
84         bool time_elapsed() const;
85
86     private:
87         friend class thread_pool;
88
89         thread_pool&    m_pool;
90         int             m_type;
91         size_t          m_id;
92         std::thread     m_impl;
93     };
94
95     // Pool of test threads
96     class thread_pool
97     {
98     public:
99         explicit thread_pool( ::testing::Test& fixture )
100             : m_fixture( fixture )
101             , m_bRunning( false )
102             , m_bStopped( false )
103             , m_doneCount( 0 )
104             , m_bTimeElapsed( false )
105             , m_readyCount( 0 )
106         {}
107
108         ~thread_pool()
109         {
110             clear();
111         }
112
113         void add( thread * what )
114         {
115             m_threads.push_back( what );
116         }
117
118         void add( thread * what, size_t count )
119         {
120             add( what );
121             for ( size_t i = 1; i < count; ++i ) {
122                 thread * p = what->clone();
123                 add( p );
124             }
125         }
126
127         std::chrono::milliseconds run()
128         {
129             return run( std::chrono::seconds::zero() );
130         }
131
132         std::chrono::milliseconds run( std::chrono::seconds duration )
133         {
134             m_bStopped = false;
135             m_doneCount = 0;
136
137             while ( m_readyCount.load() != m_threads.size() )
138                 std::this_thread::yield();
139
140             m_bTimeElapsed.store( false, std::memory_order_release );
141
142             auto native_duration = std::chrono::duration_cast<std::chrono::steady_clock::duration>(duration);
143             auto time_start = std::chrono::steady_clock::now();
144             auto const expected_end = time_start + native_duration;
145
146             {
147                 scoped_lock l( m_cvMutex );
148                 m_bRunning = true;
149                 m_cvStart.notify_all();
150             }
151
152             if ( duration != std::chrono::seconds::zero() ) {
153                 for ( ;; ) {
154                     std::this_thread::sleep_for( native_duration );
155                     auto time_now = std::chrono::steady_clock::now();
156                     if ( time_now >= expected_end )
157                         break;
158                     native_duration = expected_end - time_now;
159                 }
160             }
161             m_bTimeElapsed.store( true, std::memory_order_release );
162
163             {
164                 scoped_lock l( m_cvMutex );
165                 while ( m_doneCount != m_threads.size() )
166                     m_cvDone.wait( l );
167                 m_bStopped = true;
168             }
169             auto time_end = std::chrono::steady_clock::now();
170
171             m_cvStop.notify_all();
172
173             for ( auto t : m_threads )
174                 t->join();
175
176             return m_testDuration = std::chrono::duration_cast<std::chrono::milliseconds>(time_end - time_start);
177         }
178
179         size_t size() const             { return m_threads.size(); }
180         thread& get( size_t idx ) const { return *m_threads.at( idx ); }
181
182         template <typename Fixture>
183         Fixture& fixture()
184         {
185             return static_cast<Fixture&>(m_fixture);
186         }
187
188         std::chrono::milliseconds duration() const { return m_testDuration; }
189
190         void clear()
191         {
192             for ( auto t : m_threads )
193                 delete t;
194             m_threads.clear();
195             m_bRunning = false;
196             m_bStopped = false;
197             m_doneCount = 0;
198             m_readyCount = 0;
199         }
200
201     protected: // thread interface
202         size_t get_next_id()
203         {
204             return m_threads.size();
205         }
206
207         void    ready_to_start( thread& /*who*/ )
208         {
209             // Called from test thread
210
211             // Wait for all thread created
212             scoped_lock l( m_cvMutex );
213             m_readyCount.fetch_add( 1 );
214             while ( !m_bRunning )
215                 m_cvStart.wait( l );
216         }
217
218         void    thread_done( thread& /*who*/ )
219         {
220             // Called from test thread
221
222             {
223                 scoped_lock l( m_cvMutex );
224                 ++m_doneCount;
225
226                 // Tell pool that the thread is done
227                 m_cvDone.notify_all();
228
229                 // Wait for all thread done
230                 while ( !m_bStopped )
231                     m_cvStop.wait( l );
232             }
233         }
234
235     private:
236         friend class thread;
237
238         ::testing::Test&        m_fixture;
239         std::vector<thread *>   m_threads;
240
241         typedef std::unique_lock<std::mutex> scoped_lock;
242         std::mutex              m_cvMutex;
243         std::condition_variable m_cvStart;
244         std::condition_variable m_cvStop;
245         std::condition_variable m_cvDone;
246
247         volatile bool   m_bRunning;
248         volatile bool   m_bStopped;
249         volatile size_t m_doneCount;
250         std::atomic<bool> m_bTimeElapsed;
251         std::atomic<size_t> m_readyCount;
252
253         std::chrono::milliseconds m_testDuration;
254     };
255
256     inline thread::thread( thread_pool& master, int type /*= 0*/ )
257         : m_pool( master )
258         , m_type( type )
259         , m_id( master.get_next_id())
260         , m_impl( &thread::run, this )
261     {}
262
263     inline thread::thread( thread const& sample )
264         : m_pool( sample.m_pool )
265         , m_type( sample.m_type )
266         , m_id( m_pool.get_next_id() )
267         , m_impl( &thread::run, this )
268     {}
269
270     inline void thread::run()
271     {
272         SetUp();
273         m_pool.ready_to_start( *this );
274         test();
275         m_pool.thread_done( *this );
276         TearDown();
277     }
278
279     inline bool thread::time_elapsed() const
280     {
281         return m_pool.m_bTimeElapsed.load( std::memory_order_acquire );
282     }
283
284 } // namespace cds_test
285
286 #endif // CDSTEST_THREAD_H