Off by 1 bug
[satune.git] / src / ASTAnalyses / Encoding / subgraph.cc
1 #include "subgraph.h"
2 #include "encodinggraph.h"
3 #include "set.h"
4 #include "qsort.h"
5
6 EncodingSubGraph::EncodingSubGraph() :
7         encodingSize(0),
8         numElements(0),
9         maxEncodingVal(0) {
10 }
11
12 uint hashNodeValuePair(NodeValuePair *nvp) {
13         return (uint) (nvp->value ^ ((uintptr_t)nvp->node));
14 }
15
16 bool equalsNodeValuePair(NodeValuePair *nvp1, NodeValuePair *nvp2) {
17         return nvp1->value == nvp2->value && nvp1->node == nvp2->node;
18 }
19
20 int sortEncodingValue(const void *p1, const void *p2) {
21         const EncodingValue *e1 = *(const EncodingValue **) p1;
22         const EncodingValue *e2 = *(const EncodingValue **) p2;
23         uint se1 = e1->notequals.getSize();
24         uint se2 = e2->notequals.getSize();
25         if (se1 > se2)
26                 return -1;
27         else if (se2 == se1)
28                 return 0;
29         else
30                 return 1;
31 }
32
33 uint EncodingSubGraph::getEncoding(EncodingNode *n, uint64_t val) {
34         NodeValuePair nvp(n, val);
35         EncodingValue *ev = map.get(&nvp);
36         return ev->encoding;
37 }
38
39 void EncodingSubGraph::solveEquals() {
40         Vector<EncodingValue *> toEncode;
41         Vector<bool> encodingArray;
42         SetIteratorEncodingValue *valIt = values.iterator();
43         while (valIt->hasNext()) {
44                 EncodingValue *ev = valIt->next();
45                 if (!ev->inComparison)
46                         toEncode.push(ev);
47                 else
48                         ev->assigned = true;
49         }
50         delete valIt;
51         bsdqsort(toEncode.expose(), toEncode.getSize(), sizeof(EncodingValue *), sortEncodingValue);
52         uint toEncodeSize = toEncode.getSize();
53         for (uint i = 0; i < toEncodeSize; i++) {
54                 EncodingValue *ev = toEncode.get(i);
55                 encodingArray.clear();
56                 SetIteratorEncodingValue *conflictIt = ev->notequals.iterator();
57                 while (conflictIt->hasNext()) {
58                         EncodingValue *conflict = conflictIt->next();
59                         if (conflict->assigned) {
60                                 encodingArray.setExpand(conflict->encoding, true);
61                         }
62                 }
63                 delete conflictIt;
64                 uint encoding = 0;
65                 for (; encoding < encodingArray.getSize(); encoding++) {
66                         //See if this is unassigned
67                         if (!encodingArray.get(encoding))
68                                 break;
69                 }
70                 if (encoding > maxEncodingVal)
71                         maxEncodingVal = encoding;
72                 ev->encoding = encoding;
73                 ev->assigned = true;
74         }
75 }
76
77 void EncodingSubGraph::solveComparisons() {
78         HashsetEncodingValue discovered;
79         Vector<EncodingValue *> tovisit;
80         SetIteratorEncodingValue *valIt = values.iterator();
81         while (valIt->hasNext()) {
82                 EncodingValue *ev = valIt->next();
83                 if (discovered.add(ev)) {
84                         tovisit.push(ev);
85                         while (tovisit.getSize() != 0) {
86                                 EncodingValue *val = tovisit.last(); tovisit.pop();
87                                 SetIteratorEncodingValue *nextIt = val->larger.iterator();
88                                 uint minVal = val->encoding + 1;
89                                 while (nextIt->hasNext()) {
90                                         EncodingValue *nextVal = nextIt->next();
91                                         if (nextVal->encoding < minVal) {
92                                                 if (minVal > maxEncodingVal)
93                                                         maxEncodingVal = minVal;
94                                                 nextVal->encoding = minVal;
95                                                 discovered.add(nextVal);
96                                                 tovisit.push(nextVal);
97                                         }
98                                 }
99                                 delete nextIt;
100                         }
101                 }
102         }
103         delete valIt;
104 }
105
106 uint EncodingSubGraph::estimateNewSize(EncodingSubGraph *sg) {
107         uint newSize = 0;
108         SetIteratorEncodingNode *nit = sg->nodes.iterator();
109         while (nit->hasNext()) {
110                 EncodingNode *en = nit->next();
111                 uint size = estimateNewSize(en);
112                 if (size > newSize)
113                         newSize = size;
114         }
115         delete nit;
116         return newSize;
117 }
118
119 uint EncodingSubGraph::estimateNewSize(EncodingNode *n) {
120         SetIteratorEncodingEdge *eeit = n->edges.iterator();
121         uint newsize = n->getSize();
122         while (eeit->hasNext()) {
123                 EncodingEdge *ee = eeit->next();
124                 if (ee->left != NULL && ee->left != n && nodes.contains(ee->left)) {
125                         uint intersectSize = n->s->getUnionSize(ee->left->s);
126                         if (intersectSize > newsize)
127                                 newsize = intersectSize;
128                 }
129                 if (ee->right != NULL && ee->right != n && nodes.contains(ee->right)) {
130                         uint intersectSize = n->s->getUnionSize(ee->right->s);
131                         if (intersectSize > newsize)
132                                 newsize = intersectSize;
133                 }
134                 if (ee->dst != NULL && ee->dst != n && nodes.contains(ee->dst)) {
135                         uint intersectSize = n->s->getUnionSize(ee->dst->s);
136                         if (intersectSize > newsize)
137                                 newsize = intersectSize;
138                 }
139         }
140         delete eeit;
141         return newsize;
142 }
143
144 void EncodingSubGraph::addNode(EncodingNode *n) {
145         nodes.add(n);
146         uint newSize = estimateNewSize(n);
147         numElements += n->elements.getSize();
148         if (newSize > encodingSize)
149                 encodingSize = newSize;
150 }
151
152 SetIteratorEncodingNode *EncodingSubGraph::nodeIterator() {
153         return nodes.iterator();
154 }
155
156 void EncodingSubGraph::encode() {
157         computeEncodingValue();
158         computeComparisons();
159         computeEqualities();
160         solveComparisons();
161         solveEquals();
162 }
163
164 void EncodingSubGraph::computeEqualities() {
165         SetIteratorEncodingNode *nodeit = nodes.iterator();
166         while (nodeit->hasNext()) {
167                 EncodingNode *node = nodeit->next();
168                 generateEquals(node, node);
169
170                 SetIteratorEncodingEdge *edgeit = node->edges.iterator();
171                 while (edgeit->hasNext()) {
172                         EncodingEdge *edge = edgeit->next();
173                         //skip over comparisons as we have already handled them
174                         if (edge->numComparisons != 0)
175                                 continue;
176                         if (edge->numEquals == 0)
177                                 continue;
178                         if (edge->left == NULL || !nodes.contains(edge->left))
179                                 continue;
180                         if (edge->right == NULL || !nodes.contains(edge->right))
181                                 continue;
182                         //examine only once
183                         if (edge->left != node)
184                                 continue;
185                         //We have a comparison edge between two nodes in the subgraph
186                         //For now we don't support multiple encoding values with the same encoding....
187                         //So we enforce != constraints for every Set...
188                         if (edge->left != edge->right)
189                                 generateEquals(edge->left, edge->right);
190                 }
191                 delete edgeit;
192         }
193         delete nodeit;
194 }
195
196 void EncodingSubGraph::computeComparisons() {
197         SetIteratorEncodingNode *nodeit = nodes.iterator();
198         while (nodeit->hasNext()) {
199                 EncodingNode *node = nodeit->next();
200                 SetIteratorEncodingEdge *edgeit = node->edges.iterator();
201                 while (edgeit->hasNext()) {
202                         EncodingEdge *edge = edgeit->next();
203                         if (edge->numComparisons == 0)
204                                 continue;
205                         if (edge->left == NULL || !nodes.contains(edge->left))
206                                 continue;
207                         if (edge->right == NULL || !nodes.contains(edge->right))
208                                 continue;
209                         //examine only once
210                         if (edge->left != node)
211                                 continue;
212                         //We have a comparison edge between two nodes in the subgraph
213                         generateComparison(edge->left, edge->right);
214                 }
215                 delete edgeit;
216         }
217         delete nodeit;
218 }
219
220 void EncodingSubGraph::orderEV(EncodingValue *earlier, EncodingValue *later) {
221         earlier->larger.add(later);
222 }
223
224 void EncodingSubGraph::generateEquals(EncodingNode *left, EncodingNode *right) {
225         Set *lset = left->s;
226         Set *rset = right->s;
227         uint lSize = lset->getSize(), rSize = rset->getSize();
228         for (uint lindex = 0; lindex < lSize; lindex++) {
229                 for (uint rindex = 0; rindex < rSize; rindex++) {
230                         uint64_t lVal = lset->getElement(lindex);
231                         NodeValuePair nvp1(left, lVal);
232                         EncodingValue *lev = map.get(&nvp1);
233                         uint64_t rVal = rset->getElement(rindex);
234                         NodeValuePair nvp2(right, rVal);
235                         EncodingValue *rev = map.get(&nvp2);
236                         if (lev != rev) {
237                                 if (lev->inComparison && rev->inComparison) {
238                                         //Need to assign during comparison stage...
239                                         //Thus promote to comparison
240                                         if (lVal < rVal) {
241                                                 orderEV(lev, rev);
242                                         } else {
243                                                 orderEV(rev, lev);
244                                         }
245                                 } else {
246                                         lev->notequals.add(rev);
247                                         rev->notequals.add(lev);
248                                 }
249                         }
250                 }
251         }
252 }
253
254 void EncodingSubGraph::generateComparison(EncodingNode *left, EncodingNode *right) {
255         Set *lset = left->s;
256         Set *rset = right->s;
257         uint lindex = 0, rindex = 0;
258         uint lSize = lset->getSize(), rSize = rset->getSize();
259         uint64_t lVal = lset->getElement(lindex);
260         NodeValuePair nvp1(left, lVal);
261         EncodingValue *lev = map.get(&nvp1);
262         lev->inComparison = true;
263         uint64_t rVal = rset->getElement(rindex);
264         NodeValuePair nvp2(right, rVal);
265         EncodingValue *rev = map.get(&nvp2);
266         rev->inComparison = true;
267         EncodingValue *last = NULL;
268
269         while (lindex < lSize || rindex < rSize) {
270                 if (last != NULL) {
271                         if (lev != NULL)
272                                 orderEV(last, lev);
273                         if (rev != NULL && lev != rev)
274                                 orderEV(last, rev);
275                 }
276                 if (lev != rev) {
277                         if (rev == NULL ||
278                                         (lev != NULL && lVal < rVal)) {
279                                 if (rev != NULL)
280                                         orderEV(lev, rev);
281                                 last = lev;
282                                 if (++lindex < lSize) {
283                                         lVal = lset->getElement(lindex);
284                                         NodeValuePair nvpl(left, lVal);
285                                         lev = map.get(&nvpl);
286                                         lev->inComparison = true;
287                                 } else
288                                         lev = NULL;
289                         } else {
290                                 if (lev != NULL)
291                                         orderEV(rev, lev);
292                                 last = rev;
293                                 if (++rindex < rSize) {
294                                         rVal = rset->getElement(rindex);
295                                         NodeValuePair nvpr(right, rVal);
296                                         rev = map.get(&nvpr);
297                                         rev->inComparison = true;
298                                 } else
299                                         rev = NULL;
300                         }
301                 } else {
302                         last = lev;
303                         if (++lindex < lSize) {
304                                 lVal = lset->getElement(lindex);
305                                 NodeValuePair nvpl(left, lVal);
306                                 lev = map.get(&nvpl);
307                                 lev->inComparison = true;
308                         } else
309                                 lev = NULL;
310
311                         if (++rindex < rSize) {
312                                 rVal = rset->getElement(rindex);
313                                 NodeValuePair nvpr(right, rVal);
314                                 rev = map.get(&nvpr);
315                                 rev->inComparison = true;
316                         } else
317                                 rev = NULL;
318                 }
319         }
320 }
321
322 void EncodingSubGraph::computeEncodingValue() {
323         SetIteratorEncodingNode *nodeit = nodes.iterator();
324         while (nodeit->hasNext()) {
325                 EncodingNode *node = nodeit->next();
326                 Set *set = node->s;
327                 uint setSize = set->getSize();
328                 for (uint i = 0; i < setSize; i++) {
329                         uint64_t val = set->getElement(i);
330                         NodeValuePair nvp(node, val);
331                         if (!map.contains(&nvp)) {
332                                 traverseValue(node, val);
333                         }
334                 }
335         }
336         delete nodeit;
337 }
338
339 void EncodingSubGraph::traverseValue(EncodingNode *node, uint64_t value) {
340         EncodingValue *ecv = new EncodingValue(value);
341         values.add(ecv);
342         HashsetEncodingNode discovered;
343         Vector<EncodingNode *> tovisit;
344         tovisit.push(node);
345         discovered.add(node);
346         while (tovisit.getSize() != 0) {
347                 EncodingNode *n = tovisit.last();tovisit.pop();
348                 //Add encoding node to structures
349                 ecv->nodes.add(n);
350                 NodeValuePair *nvp = new NodeValuePair(n, value);
351                 map.put(nvp, ecv);
352                 SetIteratorEncodingEdge *edgeit = node->edges.iterator();
353                 while (edgeit->hasNext()) {
354                         EncodingEdge *ee = edgeit->next();
355                         if (!discovered.contains(ee->left) && nodes.contains(ee->left) && ee->left->s->exists(value)) {
356                                 tovisit.push(ee->left);
357                                 discovered.add(ee->left);
358                         }
359                         if (!discovered.contains(ee->right) && nodes.contains(ee->right) && ee->right->s->exists(value)) {
360                                 tovisit.push(ee->right);
361                                 discovered.add(ee->right);
362                         }
363                 }
364                 delete edgeit;
365         }
366 }
367