Added RCU lock for iterating thread
[libcds.git] / test / stress / set / iteration / set_iteration.h
1 /*
2     This file is a part of libcds - Concurrent Data Structures library
3
4     (C) Copyright Maxim Khizhinsky (libcds.dev@gmail.com) 2006-2016
5
6     Source code repo: http://github.com/khizmax/libcds/
7     Download: http://sourceforge.net/projects/libcds/files/
8     
9     Redistribution and use in source and binary forms, with or without
10     modification, are permitted provided that the following conditions are met:
11
12     * Redistributions of source code must retain the above copyright notice, this
13       list of conditions and the following disclaimer.
14
15     * Redistributions in binary form must reproduce the above copyright notice,
16       this list of conditions and the following disclaimer in the documentation
17       and/or other materials provided with the distribution.
18
19     THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20     AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21     IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22     DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23     FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24     DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25     SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26     CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27     OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28     OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 */
30
31 #include "set_type.h"
32 #include <cds_test/city.h>
33
34 namespace set {
35
36 // Test for set's thread-safe iterator:
37 //   Several thread inserts/erases elemets from the set.
38 //   Dedicated Iterator thread iterates over the set, calculates CityHash for each element
39 //   and stores it in the element.
40 // Test goal: no crash
41
42 #define TEST_CASE(TAG, X)  void X();
43
44     class Set_Iteration: public cds_test::stress_fixture
45     {
46     public:
47         static size_t s_nSetSize;               // set size
48         static size_t s_nInsertThreadCount;     // count of insertion thread
49         static size_t s_nDeleteThreadCount;     // count of deletion thread
50         static size_t s_nThreadPassCount;       // pass count for each thread
51         static size_t s_nMaxLoadFactor;         // maximum load factor
52
53         static size_t s_nCuckooInitialSize;     // initial size for CuckooSet
54         static size_t s_nCuckooProbesetSize;    // CuckooSet probeset size (only for list-based probeset)
55         static size_t s_nCuckooProbesetThreshold; // CUckooSet probeset threshold (0 - use default)
56
57         static size_t s_nFeldmanSet_HeadBits;
58         static size_t s_nFeldmanSet_ArrayBits;
59
60         static size_t s_nLoadFactor;
61         static std::vector<std::string>  m_arrString;
62
63         static void SetUpTestCase();
64         static void TearDownTestCase();
65
66         void on_modifier_done()
67         {
68             m_nModifierCount.fetch_sub( 1, atomics::memory_order_relaxed );
69         }
70
71         bool all_modifiers_done() const
72         {
73             return m_nModifierCount.load( atomics::memory_order_relaxed ) == 0;
74         }
75
76         typedef std::string key_type;
77
78         struct value_type
79         {
80             size_t   val;
81             uint64_t hash;
82
83             explicit value_type( size_t v )
84                 : val(v)
85                 , hash(0)
86             {}
87         };
88
89     private:
90         enum {
91             insert_thread,
92             delete_thread,
93             extract_thread,
94             iterator_thread
95         };
96
97         atomics::atomic<size_t> m_nModifierCount;
98
99         template <class Set>
100         class Inserter: public cds_test::thread
101         {
102             typedef cds_test::thread base_class;
103
104             Set&     m_Set;
105             typedef typename Set::value_type keyval_type;
106
107         public:
108             size_t  m_nInsertSuccess = 0;
109             size_t  m_nInsertFailed = 0;
110
111         public:
112             Inserter( cds_test::thread_pool& pool, Set& set )
113                 : base_class( pool, insert_thread )
114                 , m_Set( set )
115             {}
116
117             Inserter( Inserter& src )
118                 : base_class( src )
119                 , m_Set( src.m_Set )
120             {}
121
122             virtual thread * clone()
123             {
124                 return new Inserter( *this );
125             }
126
127             virtual void test()
128             {
129                 Set& rSet = m_Set;
130
131                 Set_Iteration& fixture = pool().template fixture<Set_Iteration>();
132                 size_t nArrSize = m_arrString.size();
133                 size_t const nSetSize = fixture.s_nSetSize;
134                 size_t const nPassCount = fixture.s_nThreadPassCount;
135
136                 if ( id() & 1 ) {
137                     for ( size_t nPass = 0; nPass < nPassCount; ++nPass ) {
138                         for ( size_t nItem = 0; nItem < nSetSize; ++nItem ) {
139                             if ( rSet.insert( keyval_type( m_arrString[nItem % nArrSize], nItem * 8 )))
140                                 ++m_nInsertSuccess;
141                             else
142                                 ++m_nInsertFailed;
143                         }
144                     }
145                 }
146                 else {
147                     for ( size_t nPass = 0; nPass < nPassCount; ++nPass ) {
148                         for ( size_t nItem = nSetSize; nItem > 0; --nItem ) {
149                             if ( rSet.insert( keyval_type( m_arrString[nItem % nArrSize], nItem * 8 )))
150                                 ++m_nInsertSuccess;
151                             else
152                                 ++m_nInsertFailed;
153                         }
154                     }
155                 }
156
157                 fixture.on_modifier_done();
158             }
159         };
160
161         template <class Set>
162         class Deleter: public cds_test::thread
163         {
164             typedef cds_test::thread base_class;
165
166             Set&     m_Set;
167         public:
168             size_t  m_nDeleteSuccess = 0;
169             size_t  m_nDeleteFailed = 0;
170
171         public:
172             Deleter( cds_test::thread_pool& pool, Set& set )
173                 : base_class( pool, delete_thread )
174                 , m_Set( set )
175             {}
176
177             Deleter( Deleter& src )
178                 : base_class( src )
179                 , m_Set( src.m_Set )
180             {}
181
182             virtual thread * clone()
183             {
184                 return new Deleter( *this );
185             }
186
187             virtual void test()
188             {
189                 Set& rSet = m_Set;
190
191                 Set_Iteration& fixture = pool().template fixture<Set_Iteration>();
192                 size_t nArrSize = m_arrString.size();
193                 size_t const nSetSize = fixture.s_nSetSize;
194                 size_t const nPassCount = fixture.s_nThreadPassCount;
195
196                 if ( id() & 1 ) {
197                     for ( size_t nPass = 0; nPass < nPassCount; ++nPass ) {
198                         for ( size_t nItem = 0; nItem < nSetSize; ++nItem ) {
199                             if ( rSet.erase( m_arrString[nItem % nArrSize] ))
200                                 ++m_nDeleteSuccess;
201                             else
202                                 ++m_nDeleteFailed;
203                         }
204                     }
205                 }
206                 else {
207                     for ( size_t nPass = 0; nPass < nPassCount; ++nPass ) {
208                         for ( size_t nItem = nSetSize; nItem > 0; --nItem ) {
209                             if ( rSet.erase( m_arrString[nItem % nArrSize] ))
210                                 ++m_nDeleteSuccess;
211                             else
212                                 ++m_nDeleteFailed;
213                         }
214                     }
215                 }
216
217                 fixture.on_modifier_done();
218             }
219         };
220
221         template <typename GC, class Set>
222         class Extractor: public cds_test::thread
223         {
224             typedef cds_test::thread base_class;
225             Set&     m_Set;
226
227         public:
228             size_t  m_nDeleteSuccess = 0;
229             size_t  m_nDeleteFailed = 0;
230
231         public:
232             Extractor( cds_test::thread_pool& pool, Set& set )
233                 : base_class( pool, extract_thread )
234                 , m_Set( set )
235             {}
236
237             Extractor( Extractor& src )
238                 : base_class( src )
239                 , m_Set( src.m_Set )
240             {}
241
242             virtual thread * clone()
243             {
244                 return new Extractor( *this );
245             }
246
247             virtual void test()
248             {
249                 Set& rSet = m_Set;
250
251                 typename Set::guarded_ptr gp;
252
253                 Set_Iteration& fixture = pool().template fixture<Set_Iteration>();
254                 size_t nArrSize = m_arrString.size();
255                 size_t const nSetSize = fixture.s_nSetSize;
256                 size_t const nPassCount = fixture.s_nThreadPassCount;
257
258                 if ( id() & 1 ) {
259                     for ( size_t nPass = 0; nPass < nPassCount; ++nPass ) {
260                         for ( size_t nItem = 0; nItem < nSetSize; ++nItem ) {
261                             gp = rSet.extract( m_arrString[nItem % nArrSize] );
262                             if ( gp )
263                                 ++m_nDeleteSuccess;
264                             else
265                                 ++m_nDeleteFailed;
266                             gp.release();
267                         }
268                     }
269                 }
270                 else {
271                     for ( size_t nPass = 0; nPass < nPassCount; ++nPass ) {
272                         for ( size_t nItem = nSetSize; nItem > 0; --nItem ) {
273                             gp = rSet.extract( m_arrString[nItem % nArrSize] );
274                             if ( gp )
275                                 ++m_nDeleteSuccess;
276                             else
277                                 ++m_nDeleteFailed;
278                             gp.release();
279                         }
280                     }
281                 }
282
283                 fixture.on_modifier_done();
284             }
285         };
286
287         template <typename RCU, class Set>
288         class Extractor<cds::urcu::gc<RCU>, Set >: public cds_test::thread
289         {
290             typedef cds_test::thread base_class;
291             Set&     m_Set;
292
293         public:
294             size_t  m_nDeleteSuccess = 0;
295             size_t  m_nDeleteFailed = 0;
296
297         public:
298             Extractor( cds_test::thread_pool& pool, Set& set )
299                 : base_class( pool, extract_thread )
300                 , m_Set( set )
301             {}
302
303             Extractor( Extractor& src )
304                 : base_class( src )
305                 , m_Set( src.m_Set )
306             {}
307
308             virtual thread * clone()
309             {
310                 return new Extractor( *this );
311             }
312
313             virtual void test()
314             {
315                 Set& rSet = m_Set;
316
317                 typename Set::exempt_ptr xp;
318
319                 Set_Iteration& fixture = pool().template fixture<Set_Iteration>();
320                 size_t nArrSize = m_arrString.size();
321                 size_t const nSetSize = fixture.s_nSetSize;
322                 size_t const nPassCount = fixture.s_nThreadPassCount;
323
324                 if ( id() & 1 ) {
325                     for ( size_t nPass = 0; nPass < nPassCount; ++nPass ) {
326                         for ( size_t nItem = 0; nItem < nSetSize; ++nItem ) {
327                             if ( Set::c_bExtractLockExternal ) {
328                                 typename Set::rcu_lock l;
329                                 xp = rSet.extract( m_arrString[nItem % nArrSize] );
330                                 if ( xp )
331                                     ++m_nDeleteSuccess;
332                                 else
333                                     ++m_nDeleteFailed;
334                             }
335                             else {
336                                 xp = rSet.extract( m_arrString[nItem % nArrSize] );
337                                 if ( xp )
338                                     ++m_nDeleteSuccess;
339                                 else
340                                     ++m_nDeleteFailed;
341                             }
342                             xp.release();
343                         }
344                     }
345                 }
346                 else {
347                     for ( size_t nPass = 0; nPass < nPassCount; ++nPass ) {
348                         for ( size_t nItem = nSetSize; nItem > 0; --nItem ) {
349                             if ( Set::c_bExtractLockExternal ) {
350                                 typename Set::rcu_lock l;
351                                 xp = rSet.extract( m_arrString[nItem % nArrSize] );
352                                 if ( xp )
353                                     ++m_nDeleteSuccess;
354                                 else
355                                     ++m_nDeleteFailed;
356                             }
357                             else {
358                                 xp = rSet.extract( m_arrString[nItem % nArrSize] );
359                                 if ( xp )
360                                     ++m_nDeleteSuccess;
361                                 else
362                                     ++m_nDeleteFailed;
363                             }
364                             xp.release();
365                         }
366                     }
367                 }
368
369                 fixture.on_modifier_done();
370             }
371         };
372
373         template <typename GC, class Set>
374         class Iterator: public cds_test::thread
375         {
376             typedef cds_test::thread base_class;
377
378             Set&     m_Set;
379             typedef typename Set::value_type keyval_type;
380
381         public:
382             size_t  m_nPassCount = 0;
383             size_t  m_nVisitCount = 0; // how many items the iterator visited
384
385         public:
386             Iterator( cds_test::thread_pool& pool, Set& set )
387                 : base_class( pool, iterator_thread )
388                 , m_Set( set )
389             {}
390
391             Iterator( Iterator& src )
392                 : base_class( src )
393                 , m_Set( src.m_Set )
394             {}
395
396             virtual thread * clone()
397             {
398                 return new Iterator( *this );
399             }
400
401             virtual void test()
402             {
403                 Set& rSet = m_Set;
404
405                 Set_Iteration& fixture = pool().template fixture<Set_Iteration>();
406                 while ( !fixture.all_modifiers_done() ) {
407                     ++m_nPassCount;
408                     for ( auto it = rSet.begin(); it != rSet.end(); ++it ) {
409                         it->val.hash = CityHash64( it->key.c_str(), it->key.length());
410                         ++m_nVisitCount;
411                     }
412                 }
413             }
414         };
415
416         template <typename RCU, class Set>
417         class Iterator<cds::urcu::gc<RCU>, Set>: public cds_test::thread
418         {
419             typedef cds_test::thread base_class;
420
421             Set&     m_Set;
422             typedef typename Set::value_type keyval_type;
423
424         public:
425             size_t  m_nPassCount = 0;
426             size_t  m_nVisitCount = 0; // how many items the iterator visited
427
428         public:
429             Iterator( cds_test::thread_pool& pool, Set& set )
430                 : base_class( pool, iterator_thread )
431                 , m_Set( set )
432             {}
433
434             Iterator( Iterator& src )
435                 : base_class( src )
436                 , m_Set( src.m_Set )
437             {}
438
439             virtual thread * clone()
440             {
441                 return new Iterator( *this );
442             }
443
444             virtual void test()
445             {
446                 Set& rSet = m_Set;
447
448                 Set_Iteration& fixture = pool().template fixture<Set_Iteration>();
449                 while ( !fixture.all_modifiers_done() ) {
450                     ++m_nPassCount;
451                     typename Set::rcu_lock l;
452                     for ( auto it = rSet.begin(); it != rSet.end(); ++it ) {
453                         it->val.hash = CityHash64( it->key.c_str(), it->key.length() );
454                         ++m_nVisitCount;
455                     }
456                 }
457             }
458         };
459
460     protected:
461         template <class Set>
462         void do_test( Set& testSet )
463         {
464             typedef Inserter<Set> InserterThread;
465             typedef Deleter<Set>  DeleterThread;
466             typedef Iterator<typename Set::gc, Set> IteratorThread;
467
468             cds_test::thread_pool& pool = get_pool();
469             pool.add( new InserterThread( pool, testSet ), s_nInsertThreadCount );
470             pool.add( new DeleterThread( pool, testSet ), s_nDeleteThreadCount );
471
472             m_nModifierCount.store( pool.size(), atomics::memory_order_relaxed );
473             pool.add( new IteratorThread( pool, testSet ), 1 );
474
475             propout() << std::make_pair( "insert_thread_count", s_nInsertThreadCount )
476                 << std::make_pair( "delete_thread_count", s_nDeleteThreadCount )
477                 << std::make_pair( "thread_pass_count", s_nThreadPassCount )
478                 << std::make_pair( "set_size", s_nSetSize );
479
480             std::chrono::milliseconds duration = pool.run();
481
482             propout() << std::make_pair( "duration", duration );
483
484             size_t nInsertSuccess = 0;
485             size_t nInsertFailed = 0;
486             size_t nDeleteSuccess = 0;
487             size_t nDeleteFailed = 0;
488             size_t nIteratorPassCount = 0;
489             size_t nIteratorVisitCount = 0;
490             for ( size_t i = 0; i < pool.size(); ++i ) {
491                 cds_test::thread& thr = pool.get( i );
492                 switch ( thr.type() ) {
493                 case insert_thread:
494                     {
495                         InserterThread& inserter = static_cast<InserterThread&>( thr );
496                         nInsertSuccess += inserter.m_nInsertSuccess;
497                         nInsertFailed += inserter.m_nInsertFailed;
498                     }
499                     break;
500                 case delete_thread:
501                     {
502                         DeleterThread& deleter = static_cast<DeleterThread&>(thr);
503                         nDeleteSuccess += deleter.m_nDeleteSuccess;
504                         nDeleteFailed += deleter.m_nDeleteFailed;
505                     }
506                     break;
507                 case iterator_thread:
508                     {
509                         IteratorThread& iter = static_cast<IteratorThread&>(thr);
510                         nIteratorPassCount += iter.m_nPassCount;
511                         nIteratorVisitCount += iter.m_nVisitCount;
512                     }
513                     break;
514                 default:
515                     assert( false ); // Forgot anything?..
516                 }
517             }
518
519             propout()
520                 << std::make_pair( "insert_success", nInsertSuccess )
521                 << std::make_pair( "delete_success", nDeleteSuccess )
522                 << std::make_pair( "insert_failed", nInsertFailed )
523                 << std::make_pair( "delete_failed", nDeleteFailed )
524                 << std::make_pair( "iterator_pass_count", nIteratorPassCount )
525                 << std::make_pair( "iterator_visit_count", nIteratorVisitCount )
526                 << std::make_pair( "final_set_size", testSet.size() );
527
528             testSet.clear();
529             EXPECT_TRUE( testSet.empty() );
530
531             additional_check( testSet );
532             print_stat( propout(), testSet );
533             additional_cleanup( testSet );
534         }
535
536         template <class Set>
537         void do_test_extract( Set& testSet )
538         {
539             typedef Inserter<Set> InserterThread;
540             typedef Deleter<Set>  DeleterThread;
541             typedef Extractor<typename Set::gc, Set> ExtractThread;
542             typedef Iterator<typename Set::gc, Set> IteratorThread;
543
544             size_t const nDelThreadCount = s_nDeleteThreadCount / 2;
545             size_t const nExtractThreadCount = s_nDeleteThreadCount - nDelThreadCount;
546
547             cds_test::thread_pool& pool = get_pool();
548             pool.add( new InserterThread( pool, testSet ), s_nInsertThreadCount );
549             pool.add( new DeleterThread( pool, testSet ), nDelThreadCount );
550             pool.add( new ExtractThread( pool, testSet ), nExtractThreadCount );
551
552             m_nModifierCount.store( pool.size(), atomics::memory_order_relaxed );
553             pool.add( new IteratorThread( pool, testSet ), 1 );
554
555             propout() << std::make_pair( "insert_thread_count", s_nInsertThreadCount )
556                 << std::make_pair( "delete_thread_count", nDelThreadCount )
557                 << std::make_pair( "extract_thread_count", nExtractThreadCount )
558                 << std::make_pair( "thread_pass_count", s_nThreadPassCount )
559                 << std::make_pair( "set_size", s_nSetSize );
560
561             std::chrono::milliseconds duration = pool.run();
562
563             propout() << std::make_pair( "duration", duration );
564
565             size_t nInsertSuccess = 0;
566             size_t nInsertFailed = 0;
567             size_t nDeleteSuccess = 0;
568             size_t nDeleteFailed = 0;
569             size_t nExtractSuccess = 0;
570             size_t nExtractFailed = 0;
571             size_t nIteratorPassCount = 0;
572             size_t nIteratorVisitCount = 0;
573             for ( size_t i = 0; i < pool.size(); ++i ) {
574                 cds_test::thread& thr = pool.get( i );
575                 switch ( thr.type() ) {
576                 case insert_thread:
577                     {
578                         InserterThread& inserter = static_cast<InserterThread&>(thr);
579                         nInsertSuccess += inserter.m_nInsertSuccess;
580                         nInsertFailed += inserter.m_nInsertFailed;
581                     }
582                     break;
583                 case delete_thread:
584                     {
585                         DeleterThread& deleter = static_cast<DeleterThread&>(thr);
586                         nDeleteSuccess += deleter.m_nDeleteSuccess;
587                         nDeleteFailed += deleter.m_nDeleteFailed;
588                     }
589                     break;
590                 case extract_thread:
591                     {
592                         ExtractThread& extractor = static_cast<ExtractThread&>(thr);
593                         nExtractSuccess += extractor.m_nDeleteSuccess;
594                         nExtractFailed += extractor.m_nDeleteFailed;
595                     }
596                     break;
597                 case iterator_thread:
598                     {
599                         IteratorThread& iter = static_cast<IteratorThread&>(thr);
600                         nIteratorPassCount += iter.m_nPassCount;
601                         nIteratorVisitCount += iter.m_nVisitCount;
602                     }
603                     break;
604                 default:
605                     assert( false ); // Forgot anything?..
606                 }
607             }
608
609             propout()
610                 << std::make_pair( "insert_success", nInsertSuccess )
611                 << std::make_pair( "delete_success", nDeleteSuccess )
612                 << std::make_pair( "extract_success", nExtractSuccess )
613                 << std::make_pair( "insert_failed",  nInsertFailed )
614                 << std::make_pair( "delete_failed",  nDeleteFailed )
615                 << std::make_pair( "extract_failed", nExtractFailed )
616                 << std::make_pair( "iterator_pass_count", nIteratorPassCount )
617                 << std::make_pair( "iterator_visit_count", nIteratorVisitCount )
618                 << std::make_pair( "final_set_size", testSet.size() );
619
620             testSet.clear();
621             EXPECT_TRUE( testSet.empty() );
622
623             additional_check( testSet );
624             print_stat( propout(), testSet );
625             additional_cleanup( testSet );
626         }
627
628         template <class Set>
629         void run_test()
630         {
631             ASSERT_TRUE( m_arrString.size() > 0 );
632
633             Set s( *this );
634             do_test( s );
635         }
636
637         template <class Set>
638         void run_test_extract()
639         {
640             ASSERT_TRUE( m_arrString.size() > 0 );
641
642             Set s( *this );
643             do_test_extract( s );
644         }
645     };
646
647     class Set_Iteration_LF: public Set_Iteration
648         , public ::testing::WithParamInterface<size_t>
649     {
650     public:
651         template <class Set>
652         void run_test()
653         {
654             s_nLoadFactor = GetParam();
655             propout() << std::make_pair( "load_factor", s_nLoadFactor );
656             Set_Iteration::run_test<Set>();
657         }
658
659         template <class Set>
660         void run_test_extract()
661         {
662             s_nLoadFactor = GetParam();
663             propout() << std::make_pair( "load_factor", s_nLoadFactor );
664             Set_Iteration::run_test_extract<Set>();
665         }
666
667         static std::vector<size_t> get_load_factors();
668     };
669
670 } // namespace set