Fixes misc test cases
[libcds.git] / test / stress / misc / rwlock_driver.cpp
index d2dfce90db47c7d64b5a06b03f91364429a2f881..2c61d440e4a75c3e563db772e35d177f98aafb95 100644 (file)
@@ -1,3 +1,4 @@
+#include "common.h"
 #include <atomic>
 #include <cds/gc/dhp.h>
 #include <cds/gc/hp.h>
@@ -10,23 +11,25 @@ using namespace std;
 
 namespace {
 
+static size_t s_nRWLockThreadCount = 6;
+static size_t s_nRWLockPassCount = 200000;
+
 typedef cds_others::RWLock RWLock;
 class RWLockTest : public cds_test::stress_fixture {
 protected:
   static int sum;
   static int x;
   static RWLock *rwlock;
-  static const int kReaderThreads = 0;
-  static const int kWriterThreads = 0;
-  static const int kReaderWriterThreads = 6;
-  static const int kWriterPercentage = 20;
-  static const int kRWPassCount = 20000;
 
-  static void SetUpTestCase() {}
+  static void SetUpTestCase() {
+    cds_test::config const &cfg = get_config("Misc");
+    GetConfig(RWLockThreadCount);
+    GetConfig(RWLockPassCount);
+  }
 
-  static void ReaderThread() {
-    for (int i = 0; i < 10000; i++) {
-      for (int j = 0; j < 10; i++) {
+  static void ReaderWriterThread(int write_percentage) {
+    for (size_t i = 0; i < s_nRWLockPassCount; i++) {
+      if (rand(100) < write_percentage) {
         if (rwlock->read_can_lock()) {
           if (!rwlock->read_trylock()) {
             rwlock->read_lock();
@@ -38,53 +41,17 @@ protected:
           sum += x;
           rwlock->read_unlock();
         }
-      }
-    }
-  }
-
-  static void WriterThread() {
-    for (int i = 0; i < 10000; i++) {
-      if (rwlock->write_can_lock()) {
-        if (!rwlock->write_trylock()) {
-          rwlock->write_lock();
-        }
-        x += 1;
-        rwlock->write_unlock();
       } else {
-        rwlock->write_lock();
-        x += 1;
-        rwlock->write_unlock();
-      }
-    }
-  }
-
-  static void ReaderWriterThread() {
-    for (int i = 0; i < kRWPassCount; i++) {
-      for (int j = 0; j < kRWPassCount; j++) {
-        if (rand(100) < kWriterPercentage) {
-          if (rwlock->read_can_lock()) {
-            if (!rwlock->read_trylock()) {
-              rwlock->read_lock();
-            }
-            sum += x;
-            rwlock->read_unlock();
-          } else {
-            rwlock->read_lock();
-            sum += x;
-            rwlock->read_unlock();
-          }
-        } else {
-          if (rwlock->write_can_lock()) {
-            if (!rwlock->write_trylock()) {
-              rwlock->write_lock();
-            }
-            x += 1;
-            rwlock->write_unlock();
-          } else {
+        if (rwlock->write_can_lock()) {
+          if (!rwlock->write_trylock()) {
             rwlock->write_lock();
-            x += 1;
-            rwlock->write_unlock();
           }
+          x++;
+          rwlock->write_unlock();
+        } else {
+          rwlock->write_lock();
+          x++;
+          rwlock->write_unlock();
         }
       }
     }
@@ -94,27 +61,19 @@ protected:
 int RWLockTest::x;
 int RWLockTest::sum;
 RWLock *RWLockTest::rwlock;
-const int RWLockTest::kReaderThreads;
-const int RWLockTest::kWriterThreads;
-const int RWLockTest::kReaderWriterThreads;
-const int RWLockTest::kRWPassCount;
 
 TEST_F(RWLockTest, BasicLockUnlock) {
   rwlock = new RWLock();
-  int num_threads = kReaderThreads + kWriterThreads + kReaderWriterThreads;
-  std::thread *threads = new std::thread[num_threads];
-  for (int i = 0; i < kReaderThreads; i++) {
-    threads[i] = std::thread(ReaderThread);
-  }
-  for (int i = kReaderThreads; i < (kReaderThreads + kWriterThreads); i++) {
-    threads[i] = std::thread(WriterThread);
-  }
-  for (int i = (kReaderThreads + kWriterThreads); i < num_threads; i++) {
-    threads[i] = std::thread(ReaderWriterThread);
-  }
+  int num_threads = s_nRWLockThreadCount;
+  for (int write_percentage = 5; write_percentage < 40; write_percentage += 5) {
+    std::thread *threads = new std::thread[num_threads];
+    for (size_t i = 0; i < num_threads; i++) {
+      threads[i] = std::thread(ReaderWriterThread, write_percentage);
+    }
 
-  for (int i = 0; i < num_threads; i++) {
-    threads[i].join();
+    for (int i = 0; i < num_threads; i++) {
+      threads[i].join();
+    }
   }
 }