save
[cdsspec-compiler.git] / benchmark / mpmc-queue / mpmc-queue.h
1 #include <stdatomic.h>
2 #include <unrelacy.h>
3 #include <common.h>
4
5 #include <spec_lib.h>
6 #include <stdlib.h>
7 #include <cdsannotate.h>
8 #include <specannotation.h>
9 #include <model_memory.h>
10 #include "common.h" 
11
12 /**
13         @Begin
14         @Class_begin
15         @End
16 */
17 template <typename t_element, size_t t_size>
18 struct mpmc_boundq_1_alt
19 {
20 private:
21
22         // elements should generally be cache-line-size padded :
23         t_element               m_array[t_size];
24
25         // rdwr counts the reads & writes that have started
26         atomic<unsigned int>    m_rdwr;
27         // "read" and "written" count the number completed
28         atomic<unsigned int>    m_read;
29         atomic<unsigned int>    m_written;
30
31 public:
32
33         mpmc_boundq_1_alt()
34         {
35         /**
36                 @Begin
37                         @Entry_point
38                         @End
39                 */
40                 m_rdwr = 0;
41                 m_read = 0;
42                 m_written = 0;
43         }
44         
45
46         /**
47                 @Begin
48                 @Options:
49                         LANG = CPP;
50                         CLASS = mpmc_boundq_1_alt;
51                 @Global_define:
52                 @DeclareStruct:
53                         typedef struct elem {
54                                 t_element *pos;
55                                 bool written;
56                                 thread_id_t tid;
57                                 thread_id_t fetch_tid;
58                                 call_id_t id;
59                         } elem;
60                 @DeclareVar:
61                         spec_list *list;
62                         id_tag_t *tag;
63                 @InitVar:
64                         list = new_spec_list();
65                         tag = new_id_tag();
66                 @DefineFunc:
67                         elem* new_elem(t_element *pos, call_id_t id, thread_id_t tid) {
68                                 elem *e = (elem*) MODEL_MALLOC(sizeof(elem));
69                                 e->pos = pos;
70                                 e->written = false;
71                                 e->id = id;
72                                 e->tid = tid;
73                                 e->fetch_tid = -1;
74                         }
75                 @DefineFunc:
76                         elem* get_elem_by_pos(t_element *pos) {
77                                 for (int i = 0; i < size(list); i++) {
78                                         elem *e = (elem*) elem_at_index(list, i);
79                                         if (e->pos == pos) {
80                                                 return e;
81                                         }
82                                 }
83                                 return NULL;
84                         }
85                 @DefineFunc:
86                         void show_list() {
87                                 //model_print("Status:\n");
88                                 for (int i = 0; i < size(list); i++) {
89                                         elem *e = (elem*) elem_at_index(list, i);
90                                         //model_print("%d: pos %d, written %d, tid %d, fetch_tid %d, call_id %d\n", i, e->pos, e->written, e->tid, e->fetch_tid, e->id); 
91                                 }
92                         }
93                 @DefineFunc:
94                         elem* get_elem_by_tid(thread_id_t tid) {
95                                 for (int i = 0; i < size(list); i++) {
96                                         elem *e = (elem*) elem_at_index(list, i);
97                                         if (e->tid== tid) {
98                                                 return e;
99                                         }
100                                 }
101                                 return NULL;
102                         }
103                 @DefineFunc:
104                         elem* get_elem_by_fetch_tid(thread_id_t fetch_tid) {
105                                 for (int i = 0; i < size(list); i++) {
106                                         elem *e = (elem*) elem_at_index(list, i);
107                                         if (e->fetch_tid== fetch_tid) {
108                                                 return e;
109                                         }
110                                 }
111                                 return NULL;
112                         }
113                 @DefineFunc:
114                         int elem_idx_by_pos(t_element *pos) {
115                                 for (int i = 0; i < size(list); i++) {
116                                         elem *existing = (elem*) elem_at_index(list, i);
117                                         if (pos == existing->pos) {
118                                                 return i;
119                                         }
120                                 }
121                                 return -1;
122                         }
123                 @DefineFunc:
124                         int elem_idx_by_tid(thread_id_t tid) {
125                                 for (int i = 0; i < size(list); i++) {
126                                         elem *existing = (elem*) elem_at_index(list, i);
127                                         if (tid == existing->tid) {
128                                                 return i;
129                                         }
130                                 }
131                                 return -1;
132                         }
133                 @DefineFunc:
134                         int elem_idx_by_fetch_tid(thread_id_t fetch_tid) {
135                                 for (int i = 0; i < size(list); i++) {
136                                         elem *existing = (elem*) elem_at_index(list, i);
137                                         if (fetch_tid == existing->fetch_tid) {
138                                                 return i;
139                                         }
140                                 }
141                                 return -1;
142                         }
143                 @DefineFunc:
144                         int elem_num(t_element *pos) {
145                                 int cnt = 0;
146                                 for (int i = 0; i < size(list); i++) {
147                                         elem *existing = (elem*) elem_at_index(list, i);
148                                         if (pos == existing->pos) {
149                                                 cnt++;
150                                         }
151                                 }
152                                 return cnt;
153                         }
154                 @DefineFunc:
155                         call_id_t prepare_id() {
156                                 call_id_t res = get_and_inc(tag);
157                                 //model_print("prepare_id: %d\n", res);
158                                 return res;
159                         }
160                 @DefineFunc:
161                         bool prepare_check(t_element *pos, thread_id_t tid) {
162                                 show_list();
163                                 elem *e = get_elem_by_pos(pos);
164                                 //model_print("prepare_check: e %d\n", e);
165                                 return NULL == e;
166                         }
167                 @DefineFunc:
168                         void prepare(call_id_t id, t_element *pos, thread_id_t tid) {
169                                 //model_print("prepare: id %d, pos %d, tid %d\n", id, pos, tid);
170                                 elem *e = new_elem(pos, id, tid);
171                                 push_back(list, e);
172                         }
173                 @DefineFunc:
174                         call_id_t publish_id(thread_id_t tid) {
175                                 elem *e = get_elem_by_tid(tid);
176                                 //model_print("publish_id: id %d\n", e == NULL ? 0 : e->id);
177                                 if (NULL == e)
178                                         return DEFAULT_CALL_ID;
179                                 return e->id;
180                         }
181                 @DefineFunc:
182                         bool publish_check(thread_id_t tid) {
183                                 show_list();
184                                 elem *e = get_elem_by_tid(tid);
185                                 //model_print("publish_check: tid %d\n", tid);
186                                 if (NULL == e)
187                                         return false;
188                                 if (elem_num(e->pos) > 1)
189                                         return false;
190                                 return !e->written;
191                         }
192                 @DefineFunc:
193                         void publish(thread_id_t tid) {
194                                 //model_print("publish: tid %d\n", tid);
195                                 elem *e = get_elem_by_tid(tid);
196                                 e->written = true;
197                         }
198                 @DefineFunc:
199                         call_id_t fetch_id(t_element *pos) {
200                                 elem *e = get_elem_by_pos(pos);
201                                 //model_print("fetch_id: id %d\n", e == NULL ? 0 : e->id);
202                                 if (NULL == e)
203                                         return DEFAULT_CALL_ID;
204                                 return e->id;
205                         }
206                 @DefineFunc:
207                         bool fetch_check(t_element *pos) {
208                                 show_list();
209                                 if (pos == NULL) return true;
210                                 elem *e = get_elem_by_pos(pos);
211                                 //model_print("fetch_check: pos %d, e %d\n", pos, e);
212                                 if (e == NULL) return false;
213                                 if (elem_num(e->pos) > 1)
214                                         return false;
215                                 return true;
216                         }
217                 @DefineFunc:
218                         void fetch(t_element *pos, thread_id_t tid) {
219                                 if (pos == NULL) return;
220                                 elem *e = (elem*) get_elem_by_pos(pos);
221                                 //model_print("fetch: pos %d, tid %d\n", pos, tid);
222                                 // Remember the thread that fetches the position
223                                 e->fetch_tid = tid;
224                         }
225                 @DefineFunc:
226                         bool consume_check(thread_id_t tid) {
227                                 show_list();
228                                 elem *e = get_elem_by_fetch_tid(tid);
229                                 //model_print("consume_check: tid %d, e %d\n", tid, e);
230                                 if (NULL == e)
231                                         return false;
232                                 if (elem_num(e->pos) > 1)
233                                         return false;
234                                 return e->written;
235                         }
236                 @DefineFunc:
237                         call_id_t consume_id(thread_id_t tid) {
238                                 elem *e = get_elem_by_fetch_tid(tid);
239                                 //model_print("consume_id: id %d\n", e == NULL ? 0 : e->id);
240                                 if (NULL == e)
241                                         return DEFAULT_CALL_ID;
242                                 return e->id;
243                         }
244                 @DefineFunc:
245                         void consume(thread_id_t tid) {
246                                 //model_print("consume: tid %d\n", tid);
247                                 int idx = elem_idx_by_fetch_tid(tid);
248                                 if (idx == -1)
249                                         return;
250                                 remove_at_index(list, idx);
251                         }
252         @Happens_before:
253                 Prepare -> Fetch
254                 Publish -> Consume
255         @End
256         */
257
258         //-----------------------------------------------------
259
260         /**
261                 @Begin
262                 @Interface: Fetch
263                 @Commit_point_set: Fetch_Empty_Point | Fetch_Succ_Point
264                 @ID: fetch_id(__RET__)
265                 @Check:
266                         fetch_check(__RET__)
267                 @Action:
268                         fetch(__RET__, __TID__);
269                 @End
270         */
271         t_element * read_fetch() {
272                 unsigned int rdwr = m_rdwr.load(mo_acquire);
273                 /**
274                         @Begin
275                         @Potential_commit_point_define: true
276                         @Label: Fetch_Potential_Point
277                         @End
278                 */
279                 unsigned int rd,wr;
280                 for(;;) {
281                         rd = (rdwr>>16) & 0xFFFF;
282                         wr = rdwr & 0xFFFF;
283
284                         if ( wr == rd ) { // empty
285                                 /**
286                                         @Begin
287                                         @Commit_point_define: true
288                                         @Potential_commit_point_label: Fetch_Potential_Point 
289                                         @Label: Fetch_Empty_Point
290                                         @End
291                                 */
292                                 return false;
293                         }
294                         
295                         bool succ = m_rdwr.compare_exchange_weak(rdwr,rdwr+(1<<16),mo_acq_rel);
296                         /**
297                                 @Begin
298                                 @Commit_point_define_check: succ == true
299                                 @Label: Fetch_Succ_Point
300                                 @End
301                         */
302                         if (succ)
303                                 break;
304                         else
305                                 thrd_yield();
306                 }
307
308                 // (*1)
309                 rl::backoff bo;
310                 while ( (m_written.load(mo_acquire) & 0xFFFF) != wr ) {
311                         thrd_yield();
312                 }
313
314                 t_element * p = & ( m_array[ rd % t_size ] );
315                 
316                 return p;
317         }
318
319         /**
320                 @Begin
321                 @Interface: Consume
322                 @Commit_point_set: Consume_Point
323                 @ID: consume_id(__TID__)
324                 @Check:
325                         consume_check(__TID__)
326                 @Action:
327                         consume(__TID__);
328                 @End
329         */
330         void read_consume() {
331                 m_read.fetch_add(1,mo_release);
332                 /**
333                         @Begin
334                         @Commit_point_define_check: true
335                         @Label: Consume_Point
336                         @End
337                 */
338         }
339
340         //-----------------------------------------------------
341
342         /**
343                 @Begin
344                 @Interface: Prepare 
345                 @Commit_point_set: Prepare_Full_Point | Prepare_Succ_Point
346                 @ID: prepare_id()
347                 @Check:
348                         prepare_check(__RET__, __TID__)
349                 @Action:
350                         prepare(__ID__, __RET__, __TID__);
351                 @End
352         */
353         t_element * write_prepare() {
354                 unsigned int rdwr = m_rdwr.load(mo_acquire);
355                 /**
356                         @Begin
357                         @Potential_commit_point_define: true
358                         @Label: Prepare_Potential_Point
359                         @End
360                 */
361                 unsigned int rd,wr;
362                 for(;;) {
363                         rd = (rdwr>>16) & 0xFFFF;
364                         wr = rdwr & 0xFFFF;
365
366                         if ( wr == ((rd + t_size)&0xFFFF) ) { // full
367                                 /**
368                                         @Begin
369                                         @Commit_point_define: true
370                                         @Potential_commit_point_label: Prepare_Potential_Point 
371                                         @Label: Prepare_Full_Point
372                                         @End
373                                 */
374                                 return NULL;
375                         }
376                         
377                         bool succ = m_rdwr.compare_exchange_weak(rdwr,(rd<<16) |
378                                 ((wr+1)&0xFFFF),mo_acq_rel);
379                         /**
380                                 @Begin
381                                 @Commit_point_define_check: succ == true
382                                 @Label: Prepare_Succ_Point
383                                 @End
384                         */
385                         if (succ)
386                                 break;
387                         else
388                                 thrd_yield();
389                 }
390
391                 // (*1)
392                 rl::backoff bo;
393                 while ( (m_read.load(mo_acquire) & 0xFFFF) != rd ) {
394                         thrd_yield();
395                 }
396
397                 t_element * p = & ( m_array[ wr % t_size ] );
398
399                 return p;
400         }
401
402         /**
403                 @Begin
404                 @Interface: Publish 
405                 @Commit_point_set: Publish_Point
406                 @ID: publish_id(__TID__)
407                 @Check:
408                         publish_check(__TID__)
409                 @Action:
410                         publish(__TID__);
411                 @End
412         */
413         void write_publish()
414         {
415                 m_written.fetch_add(1,mo_release);
416                 /**
417                         @Begin
418                         @Commit_point_define_check: true
419                         @Label: Publish_Point
420                         @End
421                 */
422         }
423
424         //-----------------------------------------------------
425
426
427 };
428 /**
429         @Begin
430         @Class_end
431         @End
432 */