more fix
[cdsspec-compiler.git] / benchmark / mpmc-queue / mpmc-queue.h
1 #include <stdatomic.h>
2 #include <unrelacy.h>
3
4 template <typename t_element, size_t t_size>
5 struct mpmc_boundq_1_alt
6 {
7 private:
8
9         // elements should generally be cache-line-size padded :
10         t_element               m_array[t_size];
11
12         // rdwr counts the reads & writes that have started
13         atomic<unsigned int>    m_rdwr;
14         // "read" and "written" count the number completed
15         atomic<unsigned int>    m_read;
16         atomic<unsigned int>    m_written;
17
18 public:
19
20         mpmc_boundq_1_alt()
21         {
22                 m_rdwr = 0;
23                 m_read = 0;
24                 m_written = 0;
25         }
26         
27
28         /**
29                 @Global_define:
30                 @Options:
31                         LANG = CPP;
32                         CLASS = mpmc_boundq_1_alt;
33                 @DeclareStruct:
34                         typedef struct elem {
35                                 t_element *pos;
36                                 boolean written;
37                                 thread_id_t tid;
38                                 call_id_t id;
39                         } elem;
40                 @DeclareVar:
41                         spec_list *list;
42                         id_tag_t *tag;
43                 @InitVar:
44                         list = new_spec_list();
45                         tag = new_id_tag();
46                 @DefineFunc:
47                         elem* new_elem(t_element *pos, call_id_t id, thread_id_t tid) {
48                                 elem *e = (elem*) MODEL_MALLOC(sizeof(elem));
49                                 e->pos = pos;
50                                 e->written = false;
51                                 e->id = id;
52                                 e->tid = tid;
53                         }
54                 @DefineFunc:
55                         elem* get_elem_by_pos(t_element *pos) {
56                                 for (int i = 0; i < size(list); i++) {
57                                         elem *e = (elem*) elem_at_index(list, i);
58                                         if (e->pos == pos) {
59                                                 return e;
60                                         }
61                                 }
62                                 return NULL;
63                         }
64                 @DefineFunc:
65                         elem* get_elem_by_tid(thread_id_t tid) {
66                                 for (int i = 0; i < size(list); i++) {
67                                         elem *e = (elem*) elem_at_index(list, i);
68                                         if (e->tid== tid) {
69                                                 return e;
70                                         }
71                                 }
72                                 return NULL;
73                         }
74                 @DefineFunc:
75                         int elem_idx_by_pos(t_element *pos) {
76                                 for (int i = 0; i < size(list); i++) {
77                                         elem *existing = (elem*) elem_at_index(list, i);
78                                         if (pos == existing->pos) {
79                                                 return i;
80                                         }
81                                 }
82                                 return -1;
83                         }
84                 @DefineFunc:
85                         int elem_idx_by_tid(thread_id_t tid) {
86                                 for (int i = 0; i < size(list); i++) {
87                                         elem *existing = (elem*) elem_at_index(list, i);
88                                         if (tid == existing->tid) {
89                                                 return i;
90                                         }
91                                 }
92                                 return -1;
93                         }
94                 @DefineFunc:
95                         call_id_t prepare_id() {
96                                 return get_and_inc(tag);
97                         }
98                 @DefineFunc:
99                         bool prepare_check(t_element *pos, thread_id_t tid) {
100                                 elem *e = get_elem_by_tid(tid);
101                                 return NULL == e;
102                         }
103                 @DefineFunc:
104                         void prepare(call_id_t id, t_element *pos, thread_id_t tid) {
105                                 call_id_t id = get_and_inc(tag);
106                                 elem *e = new_elem(pos, id, tid);
107                                 push_back(list, e);
108                         }
109                 @DefineFunc:
110                         call_id_t publish_id(thread_id_t tid) {
111                                 elem *e = get_elem_by_tid(tid);
112                                 if (NULL == e)
113                                         return DEFAULT_CALL_ID;
114                                 return e->id;
115                         }
116                 @DefineFunc:
117                         bool publish_check(thread_id_t tid) {
118                                 elem *e = get_elem_by_tid(tid);
119                                 if (NULL == e)
120                                         return false;
121                                 return e->written;
122                         }
123                 @DefineFunc:
124                         void publish(thread_id_t tid) {
125                                 elem *e = get_elem_by_tid(tid);
126                                 e->written = true;
127                         }
128                 @DefineFunc:
129                         call_id_t fetch_id(t_element *pos) {
130                                 elem *e = get_elem_by_pos(pos);
131                                 if (NULL == e)
132                                         return DEFAULT_CALL_ID;
133                                 return e->id;
134                         }
135                 @DefineFunc:
136                         bool fetch_check(t_element *pos) {
137                                 int idx = elem_idx_by_pos(pos);
138                                 if (idx == -1)
139                                         return false;
140                                 else
141                                         return true;
142                         }
143                 @DefineFunc:
144                         void fetch(t_element *pos) {
145                                 int idx = elem_idx_by_pos(pos);
146                                 if (idx == -1)
147                                         return;
148                                 remove_at_index(list, idx);
149                         }
150                 @DefineFunc:
151                         bool consume_check(thread_id_t tid) {
152                                 elem *e = get_elem_by_tid(tid);
153                                 if (NULL == e)
154                                         return false;
155                                 return e->written;
156                         }
157                 @DefineFunc:
158                         call_id_t consume_id(thread_id_t tid) {
159                                 elem *e = get_elem_by_tid(tid);
160                                 if (NULL == e)
161                                         return DEFAULT_CALL_ID;
162                                 return e->id;
163                         }
164                 @DefineFunc:
165                         void consume(thread_id_t tid) {
166                                 int idx = elem_idx_by_tid(tid);
167                                 if (idx == -1)
168                                         return;
169                                 remove_at_index(list, idx);
170                         }
171         @Happens_before:
172                 Prepare -> Fetch
173                 Publish -> Consume
174         @End
175         */
176
177         //-----------------------------------------------------
178
179         /**
180                 @Begin
181                 @Interface: Fetch
182                 @Commit_point_set: Fetch_Point
183                 @ID: fetch_id(__RET__)
184                 @Check:
185                         fetch_check(__RET__)
186                 @Action:
187                         fetch(__RET__);
188                 @End
189         */
190         t_element * read_fetch() {
191                 unsigned int rdwr = m_rdwr.load(mo_acquire);
192                 unsigned int rd,wr;
193                 for(;;) {
194                         rd = (rdwr>>16) & 0xFFFF;
195                         wr = rdwr & 0xFFFF;
196
197                         if ( wr == rd ) { // empty
198                                 return false;
199                         }
200
201                         if ( m_rdwr.compare_exchange_weak(rdwr,rdwr+(1<<16),mo_acq_rel) )
202                                 break;
203                         else
204                                 thrd_yield();
205                 }
206
207                 // (*1)
208                 rl::backoff bo;
209                 while ( (m_written.load(mo_acquire) & 0xFFFF) != wr ) {
210                         thrd_yield();
211                 }
212
213                 t_element * p = & ( m_array[ rd % t_size ] );
214                 
215                 return p;
216         }
217
218         /**
219                 @Begin
220                 @Interface: Consume
221                 @Commit_point_set: Consume_Point
222                 @ID: consume_id(__TID__)
223                 @Check:
224                         consume_check(__TID__)
225                 @Action:
226                         consume(__TID__);
227                 @End
228         */
229         void read_consume() {
230                 m_read.fetch_add(1,mo_release);
231         }
232
233         //-----------------------------------------------------
234
235         /**
236                 @Begin
237                 @Interface: Prepare 
238                 @Commit_point_set: Prepare_Point
239                 @ID: prepare_id(__RET__)
240                 @Check:
241                         prepare_check(__RET__)
242                 @Action:
243                         prepare(__RET__);
244                 @End
245         */
246         t_element * write_prepare() {
247                 unsigned int rdwr = m_rdwr.load(mo_acquire);
248                 unsigned int rd,wr;
249                 for(;;) {
250                         rd = (rdwr>>16) & 0xFFFF;
251                         wr = rdwr & 0xFFFF;
252
253                         if ( wr == ((rd + t_size)&0xFFFF) ) // full
254                                 return NULL;
255
256                         if ( m_rdwr.compare_exchange_weak(rdwr,(rd<<16) | ((wr+1)&0xFFFF),mo_acq_rel) )
257                                 break;
258                         else
259                                 thrd_yield();
260                 }
261
262                 // (*1)
263                 rl::backoff bo;
264                 while ( (m_read.load(mo_acquire) & 0xFFFF) != rd ) {
265                         thrd_yield();
266                 }
267
268
269                 t_element * p = & ( m_array[ wr % t_size ] );
270
271                 return p;
272         }
273
274         /**
275                 @Begin
276                 @Interface: Publish 
277                 @Commit_point_set: Publish_Point
278                 @ID: publish_id(__TID__)
279                 @Check:
280                         publish_check(__TID__)
281                 @Action:
282                         publish(__TID__);
283                 @End
284         */
285         void write_publish()
286         {
287                 m_written.fetch_add(1,mo_release);
288         }
289
290         //-----------------------------------------------------
291
292
293 };