fdf8337ad3bd61beb7c8ebd9c1a502991d98b5b9
[model-checker-benchmarks.git] / williams-queue / williams-queue.h
1 /*
2  * Lock-free queue code from
3  * "C++ Concurrency in Action: Practical Multithreading", by Anthony Williams
4  *
5  * Code taken from:
6  *  http://www.manning.com/williams/CCiA_SourceCode.zip
7  *  http://www.manning.com/williams/
8  */
9
10 #include <memory>
11 #include <atomic>
12
13 template<typename T>
14 class lock_free_queue
15 {
16 private:
17     struct node;
18     struct node_counter;
19     node* pop_head()
20     {
21         node* const old_head=head.load();
22         if(old_head==tail.load())
23         {
24             return nullptr;
25         }
26         head.store(old_head->next);
27         return old_head;
28     }
29 public:
30     lock_free_queue():
31         head(new node),tail(head.load())
32     {}
33     // lock_free_queue(const lock_free_queue& other)=delete;
34     // lock_free_queue& operator=(const lock_free_queue& other)=delete;
35     ~lock_free_queue()
36     {
37         while(node* const old_head=head.load())
38         {
39             head.store(old_head->next);
40             delete old_head;
41         }
42     }
43
44 private:
45     struct counted_node_ptr
46     {
47         int external_count;
48         node* ptr;
49     };
50     std::atomic<counted_node_ptr> head;
51     std::atomic<counted_node_ptr> tail;
52     struct node_counter
53     {
54         unsigned internal_count:30;
55         unsigned external_counters:2;
56     };
57
58     struct node
59     {
60         std::atomic<T*> data;
61         std::atomic<node_counter> count;
62         std::atomic<counted_node_ptr> next;
63         node()
64         {
65             node_counter new_count;
66             new_count.internal_count=0;
67             new_count.external_counters=2;
68             count.store(new_count);
69             next.ptr=nullptr;
70             next.external_count=0;
71         }
72         void release_ref()
73         {
74             node_counter old_counter=
75                 count.load(std::memory_order_relaxed);
76             node_counter new_counter;
77             do
78             {
79                 new_counter=old_counter;
80                 --new_counter.internal_count;
81             }
82             while(!count.compare_exchange_strong(
83                       old_counter,new_counter,
84                       std::memory_order_acquire,std::memory_order_relaxed));
85             if(!new_counter.internal_count &&
86                !new_counter.external_counters)
87             {
88                 delete this;
89             }
90         }
91     };
92
93     static void increase_external_count(
94         std::atomic<counted_node_ptr>& counter,
95         counted_node_ptr& old_counter)
96     {
97         counted_node_ptr new_counter;
98         do
99         {
100             new_counter=old_counter;
101             ++new_counter.external_count;
102         }
103         while(!counter.compare_exchange_strong(
104                   old_counter,new_counter,
105                   std::memory_order_acquire,std::memory_order_relaxed));
106         old_counter.external_count=new_counter.external_count;
107     }
108
109     static void free_external_counter(counted_node_ptr &old_node_ptr)
110     {
111         node* const ptr=old_node_ptr.ptr;
112         int const count_increase=old_node_ptr.external_count-2;
113         node_counter old_counter=
114             ptr->count.load(std::memory_order_relaxed);
115         node_counter new_counter;
116         do
117         {
118             new_counter=old_counter;
119             --new_counter.external_counters;
120             new_counter.internal_count+=count_increase;
121         }
122         while(!ptr->count.compare_exchange_strong(
123                   old_counter,new_counter,
124                   std::memory_order_acquire,std::memory_order_relaxed));
125         if(!new_counter.internal_count &&
126            !new_counter.external_counters)
127         {
128             delete ptr;
129         }
130     }
131 public:
132     std::unique_ptr<T> pop()
133     {
134         counted_node_ptr old_head=head.load(std::memory_order_relaxed);
135         for(;;)
136         {
137             increase_external_count(head,old_head);
138             node* const ptr=old_head.ptr;
139             if(ptr==tail.load().ptr)
140             {
141                 return std::unique_ptr<T>();
142             }
143             counted_node_ptr next=ptr->next.load();
144             if(head.compare_exchange_strong(old_head,next))
145             {
146                 T* const res=ptr->data.exchange(nullptr);
147                 free_external_counter(old_head);
148                 return std::unique_ptr<T>(res);
149             }
150             ptr->release_ref();
151         }
152     }
153
154 private:
155     void set_new_tail(counted_node_ptr &old_tail,
156                       counted_node_ptr const &new_tail)
157     {
158         node* const current_tail_ptr=old_tail.ptr;
159         while(!tail.compare_exchange_weak(old_tail,new_tail) &&
160               old_tail.ptr==current_tail_ptr);
161         if(old_tail.ptr==current_tail_ptr)
162             free_external_counter(old_tail);
163         else
164             current_tail_ptr->release_ref();
165     }
166 public:
167     void push(T new_value)
168     {
169         std::unique_ptr<T> new_data(new T(new_value));
170         counted_node_ptr new_next;
171         new_next.ptr=new node;
172         new_next.external_count=1;
173         counted_node_ptr old_tail=tail.load();
174         for(;;)
175         {
176             increase_external_count(tail,old_tail);
177             T* old_data=nullptr;
178             if(old_tail.ptr->data.compare_exchange_strong(
179                    old_data,new_data.get()))
180             {
181                 counted_node_ptr old_next={0};
182                 if(!old_tail.ptr->next.compare_exchange_strong(
183                        old_next,new_next))
184                 {
185                     delete new_next.ptr;
186                     new_next=old_next;
187                 }
188                 set_new_tail(old_tail, new_next);
189                 new_data.release();
190                 break;
191             }
192             else
193             {
194                 counted_node_ptr old_next={0};
195                 if(old_tail.ptr->next.compare_exchange_strong(
196                        old_next,new_next))
197                 {
198                     old_next=new_next;
199                     new_next.ptr=new node;
200                 }
201                 set_new_tail(old_tail, old_next);
202             }
203         }
204     }
205 };