Fix a bug
[c11tester.git] / waitobj.cc
1 #include "waitobj.h"
2 #include "threads-model.h"
3
4 WaitObj::WaitObj(thread_id_t tid) :
5         tid(tid),
6         waiting_for(32),
7         waited_by(32),
8         thrd_dist_maps(),
9         thrd_target_nodes()
10 {}
11
12 WaitObj::~WaitObj()
13 {
14         for (uint i = 0; i < thrd_dist_maps.size(); i++)
15                 delete thrd_dist_maps[i];
16
17         for (uint i = 0; i < thrd_target_nodes.size(); i++)
18                 delete thrd_target_nodes[i];
19 }
20
21 void WaitObj::add_waiting_for(thread_id_t other, FuncNode * node, int dist)
22 {
23         waiting_for.add(other);
24
25         dist_map_t * dist_map = getDistMap(other);
26         dist_map->put(node, dist);
27
28         node_set_t * target_nodes = getTargetNodes(other);
29         target_nodes->add(node);
30 }
31
32 void WaitObj::add_waited_by(thread_id_t other)
33 {
34         waited_by.add(other);
35 }
36
37 /**
38  * Stop waiting for the thread to reach the target node
39  *
40  * @param other The thread to be removed
41  * @param node The target node
42  * @return true if "other" is removed from waiting_for set
43  */
44 bool WaitObj::remove_waiting_for(thread_id_t other, FuncNode * node)
45 {
46         dist_map_t * dist_map = getDistMap(other);
47         dist_map->remove(node);
48
49         node_set_t * target_nodes = getTargetNodes(other);
50         target_nodes->remove(node);
51
52         /* The thread has not nodes to reach */
53         if (target_nodes->isEmpty()) {
54                 waiting_for.remove(other);
55                 return true;
56         }
57
58         return false;
59 }
60
61 void WaitObj::remove_waited_by(thread_id_t other)
62 {
63         waited_by.remove(other);
64 }
65
66 int WaitObj::lookup_dist(thread_id_t tid, FuncNode * target)
67 {
68         dist_map_t * map = getDistMap(tid);
69         if (map->contains(target))
70                 return map->get(target);
71
72         return -1;
73 }
74
75 dist_map_t * WaitObj::getDistMap(thread_id_t tid)
76 {
77         int thread_id = id_to_int(tid);
78         int old_size = thrd_dist_maps.size();
79
80         if (old_size <= thread_id) {
81                 thrd_dist_maps.resize(thread_id + 1);
82                 for (int i = old_size; i < thread_id + 1; i++) {
83                         thrd_dist_maps[i] = new dist_map_t(16);
84                 }
85         }
86
87         return thrd_dist_maps[thread_id];
88 }
89
90 node_set_t * WaitObj::getTargetNodes(thread_id_t tid)
91 {
92         int thread_id = id_to_int(tid);
93         int old_size = thrd_target_nodes.size();
94
95         if (old_size <= thread_id) {
96                 thrd_target_nodes.resize(thread_id + 1);
97                 for (int i = old_size; i < thread_id + 1; i++) {
98                         thrd_target_nodes[i] = new node_set_t(16);
99                 }
100         }
101
102         return thrd_target_nodes[thread_id];
103 }
104
105 void WaitObj::reset()
106 {
107         thrd_id_set_iter * iter = waiting_for.iterator();
108         while (iter->hasNext()) {
109                 thread_id_t tid = iter->next();
110                 int index = id_to_int(tid);
111                 thrd_target_nodes[index]->reset();
112                 /* thrd_dist_maps are not reset because distances
113                  * will be overwritten */
114         }
115
116         waiting_for.reset();
117         waited_by.reset();
118 }
119
120 void WaitObj::print_waiting_for()
121 {
122         if (waiting_for.getSize() == 0)
123                 return;
124
125         model_print("thread %d is waiting for: ", tid);
126         thrd_id_set_iter * it = waiting_for.iterator();
127
128         while (it->hasNext()) {
129                 thread_id_t thread_id = it->next();
130                 model_print("%d ", thread_id);
131         }
132         model_print("\n");
133 }
134
135 void WaitObj::print_waited_by()
136 {
137         if (waited_by.getSize() == 0)
138                 return;
139
140         model_print("thread %d is waited by: ", tid);
141         thrd_id_set_iter * it = waited_by.iterator();
142
143         while (it->hasNext()) {
144                 thread_id_t thread_id = it->next();
145                 model_print("%d ", thread_id);
146         }
147         model_print("\n");
148
149 }