mcs-queue: a few changes
[model-checker-benchmarks.git] / williams-queue / williams-queue.h
1 /*
2  * Lock-free queue code from
3  * "C++ Concurrency in Action: Practical Multithreading", by Anthony Williams
4  *
5  * Code taken from:
6  *  http://www.manning.com/williams/CCiA_SourceCode.zip
7  *  http://www.manning.com/williams/
8  */
9
10 #include <memory>
11 #include <atomic>
12
13 template<typename T>
14 class lock_free_queue
15 {
16 private:
17     struct node;
18     struct node_counter;
19     node* pop_head()
20     {
21         node* const old_head=head.load();
22         if(old_head==tail.load())
23         {
24             return nullptr;
25         }
26         head.store(old_head->next);
27         return old_head;
28     }
29
30     struct counted_node_ptr
31     {
32         int external_count;
33         node* ptr;
34     };
35     std::atomic<counted_node_ptr> head;
36     std::atomic<counted_node_ptr> tail;
37     struct node_counter
38     {
39         unsigned internal_count:30;
40         unsigned external_counters:2;
41     };
42
43     struct node
44     {
45         std::atomic<T*> data;
46         std::atomic<node_counter> count;
47         std::atomic<counted_node_ptr> next;
48         node()
49         {
50             node_counter new_count;
51             new_count.internal_count=0;
52             new_count.external_counters=2;
53             count.store(new_count);
54
55             counted_node_ptr emptynode = {0, nullptr};
56             next = emptynode;
57         }
58         void release_ref()
59         {
60             node_counter old_counter=
61                 count.load(std::memory_order_relaxed);
62             node_counter new_counter;
63             do
64             {
65                 new_counter=old_counter;
66                 --new_counter.internal_count;
67             }
68             while(!count.compare_exchange_strong(
69                       old_counter,new_counter,
70                       std::memory_order_acquire,std::memory_order_relaxed));
71             if(!new_counter.internal_count &&
72                !new_counter.external_counters)
73             {
74                 delete this;
75             }
76         }
77     };
78
79     static void increase_external_count(
80         std::atomic<counted_node_ptr>& counter,
81         counted_node_ptr& old_counter)
82     {
83         counted_node_ptr new_counter;
84         do
85         {
86             new_counter=old_counter;
87             ++new_counter.external_count;
88         }
89         while(!counter.compare_exchange_strong(
90                   old_counter,new_counter,
91                   std::memory_order_acquire,std::memory_order_relaxed));
92         old_counter.external_count=new_counter.external_count;
93     }
94
95     static void free_external_counter(counted_node_ptr &old_node_ptr)
96     {
97         node* const ptr=old_node_ptr.ptr;
98         int const count_increase=old_node_ptr.external_count-2;
99         node_counter old_counter=
100             ptr->count.load(std::memory_order_relaxed);
101         node_counter new_counter;
102         do
103         {
104             new_counter=old_counter;
105             --new_counter.external_counters;
106             new_counter.internal_count+=count_increase;
107         }
108         while(!ptr->count.compare_exchange_strong(
109                   old_counter,new_counter,
110                   std::memory_order_acquire,std::memory_order_relaxed));
111         if(!new_counter.internal_count &&
112            !new_counter.external_counters)
113         {
114             delete ptr;
115         }
116     }
117 public:
118     std::unique_ptr<T> pop()
119     {
120         counted_node_ptr old_head=head.load(std::memory_order_relaxed);
121         for(;;)
122         {
123             increase_external_count(head,old_head);
124             node* const ptr=old_head.ptr;
125             if(ptr==tail.load().ptr)
126             {
127                 return std::unique_ptr<T>();
128             }
129             counted_node_ptr next=ptr->next.load();
130             if(head.compare_exchange_strong(old_head,next))
131             {
132                 T* const res=ptr->data.exchange(nullptr);
133                 free_external_counter(old_head);
134                 return std::unique_ptr<T>(res);
135             }
136             ptr->release_ref();
137         }
138     }
139
140 private:
141     void set_new_tail(counted_node_ptr &old_tail,
142                       counted_node_ptr const &new_tail)
143     {
144         node* const current_tail_ptr=old_tail.ptr;
145         while(!tail.compare_exchange_weak(old_tail,new_tail) &&
146               old_tail.ptr==current_tail_ptr);
147         if(old_tail.ptr==current_tail_ptr)
148             free_external_counter(old_tail);
149         else
150             current_tail_ptr->release_ref();
151     }
152 public:
153     lock_free_queue()
154     {
155         counted_node_ptr newnode = {0, new node};
156         head = newnode;
157         tail = head.load();
158     }
159     // lock_free_queue(const lock_free_queue& other)=delete;
160     // lock_free_queue& operator=(const lock_free_queue& other)=delete;
161     ~lock_free_queue()
162     {
163         while(node* const old_head=head.load())
164         {
165             head.store(old_head->next);
166             delete old_head;
167         }
168     }
169
170     void push(T new_value)
171     {
172         std::unique_ptr<T> new_data(new T(new_value));
173         counted_node_ptr new_next;
174         new_next.ptr=new node;
175         new_next.external_count=1;
176         counted_node_ptr old_tail=tail.load();
177         for(;;)
178         {
179             increase_external_count(tail,old_tail);
180             T* old_data=nullptr;
181             if(old_tail.ptr->data.compare_exchange_strong(
182                    old_data,new_data.get()))
183             {
184                 counted_node_ptr old_next={0};
185                 if(!old_tail.ptr->next.compare_exchange_strong(
186                        old_next,new_next))
187                 {
188                     delete new_next.ptr;
189                     new_next=old_next;
190                 }
191                 set_new_tail(old_tail, new_next);
192                 new_data.release();
193                 break;
194             }
195             else
196             {
197                 counted_node_ptr old_next={0};
198                 if(old_tail.ptr->next.compare_exchange_strong(
199                        old_next,new_next))
200                 {
201                     old_next=new_next;
202                     new_next.ptr=new node;
203                 }
204                 set_new_tail(old_tail, old_next);
205             }
206         }
207     }
208 };