hashtable still buggy
[cdsspec-compiler.git] / benchmark / mpmc-queue / mpmc-queue.h
1 #include <stdatomic.h>
2 #include <unrelacy.h>
3 #include <common.h>
4
5 /**
6         @Begin
7         @Class_begin
8         @End
9 */
10 template <typename t_element, size_t t_size>
11 struct mpmc_boundq_1_alt
12 {
13 private:
14
15         // elements should generally be cache-line-size padded :
16         t_element               m_array[t_size];
17
18         // rdwr counts the reads & writes that have started
19         atomic<unsigned int>    m_rdwr;
20         // "read" and "written" count the number completed
21         atomic<unsigned int>    m_read;
22         atomic<unsigned int>    m_written;
23
24 public:
25
26         mpmc_boundq_1_alt()
27         {
28         /**
29                 @Begin
30                         @Entry_point
31                         @End
32                 */
33                 m_rdwr = 0;
34                 m_read = 0;
35                 m_written = 0;
36         }
37         
38
39         /**
40                 @Begin
41                 @Options:
42                         LANG = CPP;
43                         CLASS = mpmc_boundq_1_alt;
44                 @Global_define:
45                 @DeclareStruct:
46                         typedef struct elem {
47                                 t_element *pos;
48                                 bool written;
49                                 thread_id_t tid;
50                                 thread_id_t fetch_tid;
51                                 call_id_t id;
52                         } elem;
53                 @DeclareVar:
54                         spec_list *list;
55                         id_tag_t *tag;
56                 @InitVar:
57                         list = new_spec_list();
58                         tag = new_id_tag();
59                 @DefineFunc:
60                         elem* new_elem(t_element *pos, call_id_t id, thread_id_t tid) {
61                                 elem *e = (elem*) MODEL_MALLOC(sizeof(elem));
62                                 e->pos = pos;
63                                 e->written = false;
64                                 e->id = id;
65                                 e->tid = tid;
66                                 e->fetch_tid = -1;
67                         }
68                 @DefineFunc:
69                         elem* get_elem_by_pos(t_element *pos) {
70                                 for (int i = 0; i < size(list); i++) {
71                                         elem *e = (elem*) elem_at_index(list, i);
72                                         if (e->pos == pos) {
73                                                 return e;
74                                         }
75                                 }
76                                 return NULL;
77                         }
78                 @DefineFunc:
79                         void show_list() {
80                                 //model_print("Status:\n");
81                                 for (int i = 0; i < size(list); i++) {
82                                         elem *e = (elem*) elem_at_index(list, i);
83                                         //model_print("%d: pos %d, written %d, tid %d, fetch_tid %d, call_id %d\n", i, e->pos, e->written, e->tid, e->fetch_tid, e->id); 
84                                 }
85                         }
86                 @DefineFunc:
87                         elem* get_elem_by_tid(thread_id_t tid) {
88                                 for (int i = 0; i < size(list); i++) {
89                                         elem *e = (elem*) elem_at_index(list, i);
90                                         if (e->tid== tid) {
91                                                 return e;
92                                         }
93                                 }
94                                 return NULL;
95                         }
96                 @DefineFunc:
97                         elem* get_elem_by_fetch_tid(thread_id_t fetch_tid) {
98                                 for (int i = 0; i < size(list); i++) {
99                                         elem *e = (elem*) elem_at_index(list, i);
100                                         if (e->fetch_tid== fetch_tid) {
101                                                 return e;
102                                         }
103                                 }
104                                 return NULL;
105                         }
106                 @DefineFunc:
107                         int elem_idx_by_pos(t_element *pos) {
108                                 for (int i = 0; i < size(list); i++) {
109                                         elem *existing = (elem*) elem_at_index(list, i);
110                                         if (pos == existing->pos) {
111                                                 return i;
112                                         }
113                                 }
114                                 return -1;
115                         }
116                 @DefineFunc:
117                         int elem_idx_by_tid(thread_id_t tid) {
118                                 for (int i = 0; i < size(list); i++) {
119                                         elem *existing = (elem*) elem_at_index(list, i);
120                                         if (tid == existing->tid) {
121                                                 return i;
122                                         }
123                                 }
124                                 return -1;
125                         }
126                 @DefineFunc:
127                         int elem_idx_by_fetch_tid(thread_id_t fetch_tid) {
128                                 for (int i = 0; i < size(list); i++) {
129                                         elem *existing = (elem*) elem_at_index(list, i);
130                                         if (fetch_tid == existing->fetch_tid) {
131                                                 return i;
132                                         }
133                                 }
134                                 return -1;
135                         }
136                 @DefineFunc:
137                         int elem_num(t_element *pos) {
138                                 int cnt = 0;
139                                 for (int i = 0; i < size(list); i++) {
140                                         elem *existing = (elem*) elem_at_index(list, i);
141                                         if (pos == existing->pos) {
142                                                 cnt++;
143                                         }
144                                 }
145                                 return cnt;
146                         }
147                 @DefineFunc:
148                         call_id_t prepare_id() {
149                                 call_id_t res = get_and_inc(tag);
150                                 //model_print("prepare_id: %d\n", res);
151                                 return res;
152                         }
153                 @DefineFunc:
154                         bool prepare_check(t_element *pos, thread_id_t tid) {
155                                 show_list();
156                                 elem *e = get_elem_by_pos(pos);
157                                 //model_print("prepare_check: e %d\n", e);
158                                 return NULL == e;
159                         }
160                 @DefineFunc:
161                         void prepare(call_id_t id, t_element *pos, thread_id_t tid) {
162                                 //model_print("prepare: id %d, pos %d, tid %d\n", id, pos, tid);
163                                 elem *e = new_elem(pos, id, tid);
164                                 push_back(list, e);
165                         }
166                 @DefineFunc:
167                         call_id_t publish_id(thread_id_t tid) {
168                                 elem *e = get_elem_by_tid(tid);
169                                 //model_print("publish_id: id %d\n", e == NULL ? 0 : e->id);
170                                 if (NULL == e)
171                                         return DEFAULT_CALL_ID;
172                                 return e->id;
173                         }
174                 @DefineFunc:
175                         bool publish_check(thread_id_t tid) {
176                                 show_list();
177                                 elem *e = get_elem_by_tid(tid);
178                                 //model_print("publish_check: tid %d\n", tid);
179                                 if (NULL == e)
180                                         return false;
181                                 if (elem_num(e->pos) > 1)
182                                         return false;
183                                 return !e->written;
184                         }
185                 @DefineFunc:
186                         void publish(thread_id_t tid) {
187                                 //model_print("publish: tid %d\n", tid);
188                                 elem *e = get_elem_by_tid(tid);
189                                 e->written = true;
190                         }
191                 @DefineFunc:
192                         call_id_t fetch_id(t_element *pos) {
193                                 elem *e = get_elem_by_pos(pos);
194                                 //model_print("fetch_id: id %d\n", e == NULL ? 0 : e->id);
195                                 if (NULL == e)
196                                         return DEFAULT_CALL_ID;
197                                 return e->id;
198                         }
199                 @DefineFunc:
200                         bool fetch_check(t_element *pos) {
201                                 show_list();
202                                 if (pos == NULL) return true;
203                                 elem *e = get_elem_by_pos(pos);
204                                 //model_print("fetch_check: pos %d, e %d\n", pos, e);
205                                 if (e == NULL) return false;
206                                 if (elem_num(e->pos) > 1)
207                                         return false;
208                                 return true;
209                         }
210                 @DefineFunc:
211                         void fetch(t_element *pos, thread_id_t tid) {
212                                 if (pos == NULL) return;
213                                 elem *e = (elem*) get_elem_by_pos(pos);
214                                 //model_print("fetch: pos %d, tid %d\n", pos, tid);
215                                 // Remember the thread that fetches the position
216                                 e->fetch_tid = tid;
217                         }
218                 @DefineFunc:
219                         bool consume_check(thread_id_t tid) {
220                                 show_list();
221                                 elem *e = get_elem_by_fetch_tid(tid);
222                                 //model_print("consume_check: tid %d, e %d\n", tid, e);
223                                 if (NULL == e)
224                                         return false;
225                                 if (elem_num(e->pos) > 1)
226                                         return false;
227                                 return e->written;
228                         }
229                 @DefineFunc:
230                         call_id_t consume_id(thread_id_t tid) {
231                                 elem *e = get_elem_by_fetch_tid(tid);
232                                 //model_print("consume_id: id %d\n", e == NULL ? 0 : e->id);
233                                 if (NULL == e)
234                                         return DEFAULT_CALL_ID;
235                                 return e->id;
236                         }
237                 @DefineFunc:
238                         void consume(thread_id_t tid) {
239                                 //model_print("consume: tid %d\n", tid);
240                                 int idx = elem_idx_by_fetch_tid(tid);
241                                 if (idx == -1)
242                                         return;
243                                 remove_at_index(list, idx);
244                         }
245         @Happens_before:
246                 Prepare -> Fetch
247                 Publish -> Consume
248         @End
249         */
250
251         //-----------------------------------------------------
252
253         /**
254                 @Begin
255                 @Interface: Fetch
256                 @Commit_point_set: Fetch_Empty_Point | Fetch_Succ_Point
257                 @ID: fetch_id(__RET__)
258                 @Check:
259                         fetch_check(__RET__)
260                 @Action:
261                         fetch(__RET__, __TID__);
262                 @End
263         */
264         t_element * read_fetch() {
265                 unsigned int rdwr = m_rdwr.load(mo_acquire);
266                 /**
267                         @Begin
268                         @Potential_commit_point_define: true
269                         @Label: Fetch_Potential_Point
270                         @End
271                 */
272                 unsigned int rd,wr;
273                 for(;;) {
274                         rd = (rdwr>>16) & 0xFFFF;
275                         wr = rdwr & 0xFFFF;
276
277                         if ( wr == rd ) { // empty
278                                 /**
279                                         @Begin
280                                         @Commit_point_define: true
281                                         @Potential_commit_point_label: Fetch_Potential_Point 
282                                         @Label: Fetch_Empty_Point
283                                         @End
284                                 */
285                                 return false;
286                         }
287                         
288                         bool succ = m_rdwr.compare_exchange_weak(rdwr,rdwr+(1<<16),mo_acq_rel);
289                         /**
290                                 @Begin
291                                 @Commit_point_define_check: succ == true
292                                 @Label: Fetch_Succ_Point
293                                 @End
294                         */
295                         if (succ)
296                                 break;
297                         else
298                                 thrd_yield();
299                 }
300
301                 // (*1)
302                 rl::backoff bo;
303                 while ( (m_written.load(mo_acquire) & 0xFFFF) != wr ) {
304                         thrd_yield();
305                 }
306
307                 t_element * p = & ( m_array[ rd % t_size ] );
308                 
309                 return p;
310         }
311
312         /**
313                 @Begin
314                 @Interface: Consume
315                 @Commit_point_set: Consume_Point
316                 @ID: consume_id(__TID__)
317                 @Check:
318                         consume_check(__TID__)
319                 @Action:
320                         consume(__TID__);
321                 @End
322         */
323         void read_consume() {
324                 m_read.fetch_add(1,mo_release);
325                 /**
326                         @Begin
327                         @Commit_point_define_check: true
328                         @Label: Consume_Point
329                         @End
330                 */
331         }
332
333         //-----------------------------------------------------
334
335         /**
336                 @Begin
337                 @Interface: Prepare 
338                 @Commit_point_set: Prepare_Full_Point | Prepare_Succ_Point
339                 @ID: prepare_id()
340                 @Check:
341                         prepare_check(__RET__, __TID__)
342                 @Action:
343                         prepare(__ID__, __RET__, __TID__);
344                 @End
345         */
346         t_element * write_prepare() {
347                 unsigned int rdwr = m_rdwr.load(mo_acquire);
348                 /**
349                         @Begin
350                         @Potential_commit_point_define: true
351                         @Label: Prepare_Potential_Point
352                         @End
353                 */
354                 unsigned int rd,wr;
355                 for(;;) {
356                         rd = (rdwr>>16) & 0xFFFF;
357                         wr = rdwr & 0xFFFF;
358
359                         if ( wr == ((rd + t_size)&0xFFFF) ) { // full
360                                 /**
361                                         @Begin
362                                         @Commit_point_define: true
363                                         @Potential_commit_point_label: Prepare_Potential_Point 
364                                         @Label: Prepare_Full_Point
365                                         @End
366                                 */
367                                 return NULL;
368                         }
369                         
370                         bool succ = m_rdwr.compare_exchange_weak(rdwr,(rd<<16) |
371                                 ((wr+1)&0xFFFF),mo_acq_rel);
372                         /**
373                                 @Begin
374                                 @Commit_point_define_check: succ == true
375                                 @Label: Prepare_Succ_Point
376                                 @End
377                         */
378                         if (succ)
379                                 break;
380                         else
381                                 thrd_yield();
382                 }
383
384                 // (*1)
385                 rl::backoff bo;
386                 while ( (m_read.load(mo_acquire) & 0xFFFF) != rd ) {
387                         thrd_yield();
388                 }
389
390                 t_element * p = & ( m_array[ wr % t_size ] );
391
392                 return p;
393         }
394
395         /**
396                 @Begin
397                 @Interface: Publish 
398                 @Commit_point_set: Publish_Point
399                 @ID: publish_id(__TID__)
400                 @Check:
401                         publish_check(__TID__)
402                 @Action:
403                         publish(__TID__);
404                 @End
405         */
406         void write_publish()
407         {
408                 m_written.fetch_add(1,mo_release);
409                 /**
410                         @Begin
411                         @Commit_point_define_check: true
412                         @Label: Publish_Point
413                         @End
414                 */
415         }
416
417         //-----------------------------------------------------
418
419
420 };
421 /**
422         @Begin
423         @Class_end
424         @End
425 */