more fix
[cdsspec-compiler.git] / benchmark / mpmc-queue / mpmc-queue.h
1 #include <stdatomic.h>
2 #include <unrelacy.h>
3 #include <common.h>
4
5 #include <spec_lib.h>
6 #include <stdlib.h>
7 #include <cdsannotate.h>
8 #include <specannotation.h>
9 #include <model_memory.h>
10
11 /**
12         @Begin
13         @Class_begin
14         @End
15 */
16 template <typename t_element, size_t t_size>
17 struct mpmc_boundq_1_alt
18 {
19 private:
20
21         // elements should generally be cache-line-size padded :
22         t_element               m_array[t_size];
23
24         // rdwr counts the reads & writes that have started
25         atomic<unsigned int>    m_rdwr;
26         // "read" and "written" count the number completed
27         atomic<unsigned int>    m_read;
28         atomic<unsigned int>    m_written;
29
30 public:
31
32         mpmc_boundq_1_alt()
33         {
34         /**
35                 @Begin
36                         @Entry_point
37                         @End
38                 */
39                 m_rdwr = 0;
40                 m_read = 0;
41                 m_written = 0;
42         }
43         
44
45         /**
46                 @Begin
47                 @Options:
48                         LANG = CPP;
49                         CLASS = mpmc_boundq_1_alt;
50                 @Global_define:
51                 @DeclareStruct:
52                         typedef struct elem {
53                                 t_element *pos;
54                                 bool written;
55                                 thread_id_t tid;
56                                 thread_id_t fetch_tid;
57                                 call_id_t id;
58                         } elem;
59                 @DeclareVar:
60                         spec_list *list;
61                         id_tag_t *tag;
62                 @InitVar:
63                         list = new_spec_list();
64                         tag = new_id_tag();
65                 @DefineFunc:
66                         elem* new_elem(t_element *pos, call_id_t id, thread_id_t tid) {
67                                 elem *e = (elem*) MODEL_MALLOC(sizeof(elem));
68                                 e->pos = pos;
69                                 e->written = false;
70                                 e->id = id;
71                                 e->tid = tid;
72                                 e->fetch_tid = -1;
73                         }
74                 @DefineFunc:
75                         elem* get_elem_by_pos(t_element *pos) {
76                                 for (int i = 0; i < size(list); i++) {
77                                         elem *e = (elem*) elem_at_index(list, i);
78                                         if (e->pos == pos) {
79                                                 return e;
80                                         }
81                                 }
82                                 return NULL;
83                         }
84                 @DefineFunc:
85                         void show_list() {
86                                 //model_print("Status:\n");
87                                 for (int i = 0; i < size(list); i++) {
88                                         elem *e = (elem*) elem_at_index(list, i);
89                                         //model_print("%d: pos %d, written %d, tid %d, fetch_tid %d, call_id %d\n", i, e->pos, e->written, e->tid, e->fetch_tid, e->id); 
90                                 }
91                         }
92                 @DefineFunc:
93                         elem* get_elem_by_tid(thread_id_t tid) {
94                                 for (int i = 0; i < size(list); i++) {
95                                         elem *e = (elem*) elem_at_index(list, i);
96                                         if (e->tid== tid) {
97                                                 return e;
98                                         }
99                                 }
100                                 return NULL;
101                         }
102                 @DefineFunc:
103                         elem* get_elem_by_fetch_tid(thread_id_t fetch_tid) {
104                                 for (int i = 0; i < size(list); i++) {
105                                         elem *e = (elem*) elem_at_index(list, i);
106                                         if (e->fetch_tid== fetch_tid) {
107                                                 return e;
108                                         }
109                                 }
110                                 return NULL;
111                         }
112                 @DefineFunc:
113                         int elem_idx_by_pos(t_element *pos) {
114                                 for (int i = 0; i < size(list); i++) {
115                                         elem *existing = (elem*) elem_at_index(list, i);
116                                         if (pos == existing->pos) {
117                                                 return i;
118                                         }
119                                 }
120                                 return -1;
121                         }
122                 @DefineFunc:
123                         int elem_idx_by_tid(thread_id_t tid) {
124                                 for (int i = 0; i < size(list); i++) {
125                                         elem *existing = (elem*) elem_at_index(list, i);
126                                         if (tid == existing->tid) {
127                                                 return i;
128                                         }
129                                 }
130                                 return -1;
131                         }
132                 @DefineFunc:
133                         int elem_idx_by_fetch_tid(thread_id_t fetch_tid) {
134                                 for (int i = 0; i < size(list); i++) {
135                                         elem *existing = (elem*) elem_at_index(list, i);
136                                         if (fetch_tid == existing->fetch_tid) {
137                                                 return i;
138                                         }
139                                 }
140                                 return -1;
141                         }
142                 @DefineFunc:
143                         int elem_num(t_element *pos) {
144                                 int cnt = 0;
145                                 for (int i = 0; i < size(list); i++) {
146                                         elem *existing = (elem*) elem_at_index(list, i);
147                                         if (pos == existing->pos) {
148                                                 cnt++;
149                                         }
150                                 }
151                                 return cnt;
152                         }
153                 @DefineFunc:
154                         call_id_t prepare_id() {
155                                 call_id_t res = get_and_inc(tag);
156                                 //model_print("prepare_id: %d\n", res);
157                                 return res;
158                         }
159                 @DefineFunc:
160                         bool prepare_check(t_element *pos, thread_id_t tid) {
161                                 show_list();
162                                 elem *e = get_elem_by_pos(pos);
163                                 //model_print("prepare_check: e %d\n", e);
164                                 return NULL == e;
165                         }
166                 @DefineFunc:
167                         void prepare(call_id_t id, t_element *pos, thread_id_t tid) {
168                                 //model_print("prepare: id %d, pos %d, tid %d\n", id, pos, tid);
169                                 elem *e = new_elem(pos, id, tid);
170                                 push_back(list, e);
171                         }
172                 @DefineFunc:
173                         call_id_t publish_id(thread_id_t tid) {
174                                 elem *e = get_elem_by_tid(tid);
175                                 //model_print("publish_id: id %d\n", e == NULL ? 0 : e->id);
176                                 if (NULL == e)
177                                         return DEFAULT_CALL_ID;
178                                 return e->id;
179                         }
180                 @DefineFunc:
181                         bool publish_check(thread_id_t tid) {
182                                 show_list();
183                                 elem *e = get_elem_by_tid(tid);
184                                 //model_print("publish_check: tid %d\n", tid);
185                                 if (NULL == e)
186                                         return false;
187                                 if (elem_num(e->pos) > 1)
188                                         return false;
189                                 return !e->written;
190                         }
191                 @DefineFunc:
192                         void publish(thread_id_t tid) {
193                                 //model_print("publish: tid %d\n", tid);
194                                 elem *e = get_elem_by_tid(tid);
195                                 e->written = true;
196                         }
197                 @DefineFunc:
198                         call_id_t fetch_id(t_element *pos) {
199                                 elem *e = get_elem_by_pos(pos);
200                                 //model_print("fetch_id: id %d\n", e == NULL ? 0 : e->id);
201                                 if (NULL == e)
202                                         return DEFAULT_CALL_ID;
203                                 return e->id;
204                         }
205                 @DefineFunc:
206                         bool fetch_check(t_element *pos) {
207                                 show_list();
208                                 if (pos == NULL) return true;
209                                 elem *e = get_elem_by_pos(pos);
210                                 //model_print("fetch_check: pos %d, e %d\n", pos, e);
211                                 if (e == NULL) return false;
212                                 if (elem_num(e->pos) > 1)
213                                         return false;
214                                 return true;
215                         }
216                 @DefineFunc:
217                         void fetch(t_element *pos, thread_id_t tid) {
218                                 if (pos == NULL) return;
219                                 elem *e = (elem*) get_elem_by_pos(pos);
220                                 //model_print("fetch: pos %d, tid %d\n", pos, tid);
221                                 // Remember the thread that fetches the position
222                                 e->fetch_tid = tid;
223                         }
224                 @DefineFunc:
225                         bool consume_check(thread_id_t tid) {
226                                 show_list();
227                                 elem *e = get_elem_by_fetch_tid(tid);
228                                 //model_print("consume_check: tid %d, e %d\n", tid, e);
229                                 if (NULL == e)
230                                         return false;
231                                 if (elem_num(e->pos) > 1)
232                                         return false;
233                                 return e->written;
234                         }
235                 @DefineFunc:
236                         call_id_t consume_id(thread_id_t tid) {
237                                 elem *e = get_elem_by_fetch_tid(tid);
238                                 //model_print("consume_id: id %d\n", e == NULL ? 0 : e->id);
239                                 if (NULL == e)
240                                         return DEFAULT_CALL_ID;
241                                 return e->id;
242                         }
243                 @DefineFunc:
244                         void consume(thread_id_t tid) {
245                                 //model_print("consume: tid %d\n", tid);
246                                 int idx = elem_idx_by_fetch_tid(tid);
247                                 if (idx == -1)
248                                         return;
249                                 remove_at_index(list, idx);
250                         }
251         @Happens_before:
252                 Prepare -> Fetch
253                 Publish -> Consume
254         @End
255         */
256
257         //-----------------------------------------------------
258
259         /**
260                 @Begin
261                 @Interface: Fetch
262                 @Commit_point_set: Fetch_Empty_Point | Fetch_Succ_Point
263                 @ID: fetch_id(__RET__)
264                 @Check:
265                         fetch_check(__RET__)
266                 @Action:
267                         fetch(__RET__, __TID__);
268                 @End
269         */
270         t_element * read_fetch() {
271                 unsigned int rdwr = m_rdwr.load(mo_acquire);
272                 /**
273                         @Begin
274                         @Potential_commit_point_define: true
275                         @Label: Fetch_Potential_Point
276                         @End
277                 */
278                 unsigned int rd,wr;
279                 for(;;) {
280                         rd = (rdwr>>16) & 0xFFFF;
281                         wr = rdwr & 0xFFFF;
282
283                         if ( wr == rd ) { // empty
284                                 /**
285                                         @Begin
286                                         @Commit_point_define: true
287                                         @Potential_commit_point_label: Fetch_Potential_Point 
288                                         @Label: Fetch_Empty_Point
289                                         @End
290                                 */
291                                 return false;
292                         }
293                         
294                         bool succ = m_rdwr.compare_exchange_weak(rdwr,rdwr+(1<<16),mo_acq_rel);
295                         /**
296                                 @Begin
297                                 @Commit_point_define_check: succ == true
298                                 @Label: Fetch_Succ_Point
299                                 @End
300                         */
301                         if (succ)
302                                 break;
303                         else
304                                 thrd_yield();
305                 }
306
307                 // (*1)
308                 rl::backoff bo;
309                 while ( (m_written.load(mo_acquire) & 0xFFFF) != wr ) {
310                         thrd_yield();
311                 }
312
313                 t_element * p = & ( m_array[ rd % t_size ] );
314                 
315                 return p;
316         }
317
318         /**
319                 @Begin
320                 @Interface: Consume
321                 @Commit_point_set: Consume_Point
322                 @ID: consume_id(__TID__)
323                 @Check:
324                         consume_check(__TID__)
325                 @Action:
326                         consume(__TID__);
327                 @End
328         */
329         void read_consume() {
330                 m_read.fetch_add(1,mo_release);
331                 /**
332                         @Begin
333                         @Commit_point_define_check: true
334                         @Label: Consume_Point
335                         @End
336                 */
337         }
338
339         //-----------------------------------------------------
340
341         /**
342                 @Begin
343                 @Interface: Prepare 
344                 @Commit_point_set: Prepare_Full_Point | Prepare_Succ_Point
345                 @ID: prepare_id()
346                 @Check:
347                         prepare_check(__RET__, __TID__)
348                 @Action:
349                         prepare(__ID__, __RET__, __TID__);
350                 @End
351         */
352         t_element * write_prepare() {
353                 unsigned int rdwr = m_rdwr.load(mo_acquire);
354                 /**
355                         @Begin
356                         @Potential_commit_point_define: true
357                         @Label: Prepare_Potential_Point
358                         @End
359                 */
360                 unsigned int rd,wr;
361                 for(;;) {
362                         rd = (rdwr>>16) & 0xFFFF;
363                         wr = rdwr & 0xFFFF;
364
365                         if ( wr == ((rd + t_size)&0xFFFF) ) { // full
366                                 /**
367                                         @Begin
368                                         @Commit_point_define: true
369                                         @Potential_commit_point_label: Prepare_Potential_Point 
370                                         @Label: Prepare_Full_Point
371                                         @End
372                                 */
373                                 return NULL;
374                         }
375                         
376                         bool succ = m_rdwr.compare_exchange_weak(rdwr,(rd<<16) |
377                                 ((wr+1)&0xFFFF),mo_acq_rel);
378                         /**
379                                 @Begin
380                                 @Commit_point_define_check: succ == true
381                                 @Label: Prepare_Succ_Point
382                                 @End
383                         */
384                         if (succ)
385                                 break;
386                         else
387                                 thrd_yield();
388                 }
389
390                 // (*1)
391                 rl::backoff bo;
392                 while ( (m_read.load(mo_acquire) & 0xFFFF) != rd ) {
393                         thrd_yield();
394                 }
395
396                 t_element * p = & ( m_array[ wr % t_size ] );
397
398                 return p;
399         }
400
401         /**
402                 @Begin
403                 @Interface: Publish 
404                 @Commit_point_set: Publish_Point
405                 @ID: publish_id(__TID__)
406                 @Check:
407                         publish_check(__TID__)
408                 @Action:
409                         publish(__TID__);
410                 @End
411         */
412         void write_publish()
413         {
414                 m_written.fetch_add(1,mo_release);
415                 /**
416                         @Begin
417                         @Commit_point_define_check: true
418                         @Label: Publish_Point
419                         @End
420                 */
421         }
422
423         //-----------------------------------------------------
424
425
426 };
427 /**
428         @Begin
429         @Class_end
430         @End
431 */