2c61d440e4a75c3e563db772e35d177f98aafb95
[libcds.git] / test / stress / misc / rwlock_driver.cpp
1 #include "common.h"
2 #include <atomic>
3 #include <cds/gc/dhp.h>
4 #include <cds/gc/hp.h>
5 #include <cds/sync/rwlock.h>
6 #include <cds_test/stress_test.h>
7 #include <iostream>
8 #include <thread>
9
10 using namespace std;
11
12 namespace {
13
14 static size_t s_nRWLockThreadCount = 6;
15 static size_t s_nRWLockPassCount = 200000;
16
17 typedef cds_others::RWLock RWLock;
18 class RWLockTest : public cds_test::stress_fixture {
19 protected:
20   static int sum;
21   static int x;
22   static RWLock *rwlock;
23
24   static void SetUpTestCase() {
25     cds_test::config const &cfg = get_config("Misc");
26     GetConfig(RWLockThreadCount);
27     GetConfig(RWLockPassCount);
28   }
29
30   static void ReaderWriterThread(int write_percentage) {
31     for (size_t i = 0; i < s_nRWLockPassCount; i++) {
32       if (rand(100) < write_percentage) {
33         if (rwlock->read_can_lock()) {
34           if (!rwlock->read_trylock()) {
35             rwlock->read_lock();
36           }
37           sum += x;
38           rwlock->read_unlock();
39         } else {
40           rwlock->read_lock();
41           sum += x;
42           rwlock->read_unlock();
43         }
44       } else {
45         if (rwlock->write_can_lock()) {
46           if (!rwlock->write_trylock()) {
47             rwlock->write_lock();
48           }
49           x++;
50           rwlock->write_unlock();
51         } else {
52           rwlock->write_lock();
53           x++;
54           rwlock->write_unlock();
55         }
56       }
57     }
58   }
59 };
60
61 int RWLockTest::x;
62 int RWLockTest::sum;
63 RWLock *RWLockTest::rwlock;
64
65 TEST_F(RWLockTest, BasicLockUnlock) {
66   rwlock = new RWLock();
67   int num_threads = s_nRWLockThreadCount;
68   for (int write_percentage = 5; write_percentage < 40; write_percentage += 5) {
69     std::thread *threads = new std::thread[num_threads];
70     for (size_t i = 0; i < num_threads; i++) {
71       threads[i] = std::thread(ReaderWriterThread, write_percentage);
72     }
73
74     for (int i = 0; i < num_threads; i++) {
75       threads[i].join();
76     }
77   }
78 }
79
80 } // namespace