86937549fe578070a78914c0a27f9e9307e08b16
[satune.git] / src / ASTAnalyses / Encoding / subgraph.cc
1 #include "subgraph.h"
2 #include "encodinggraph.h"
3 #include "set.h"
4 #include "qsort.h"
5
6 EncodingSubGraph::EncodingSubGraph() :
7         encodingSize(0),
8         numElements(0),
9         maxEncodingVal(0) {
10 }
11
12 EncodingSubGraph::~EncodingSubGraph() {
13         map.resetAndDeleteKeys();
14         values.resetAndDelete();
15 }
16
17 uint hashNodeValuePair(NodeValuePair *nvp) {
18         return (uint) (nvp->value ^ ((uintptr_t)nvp->node));
19 }
20
21 bool equalsNodeValuePair(NodeValuePair *nvp1, NodeValuePair *nvp2) {
22         return nvp1->value == nvp2->value && nvp1->node == nvp2->node;
23 }
24
25 int sortEncodingValue(const void *p1, const void *p2) {
26         const EncodingValue *e1 = *(const EncodingValue **) p1;
27         const EncodingValue *e2 = *(const EncodingValue **) p2;
28         uint se1 = e1->notequals.getSize();
29         uint se2 = e2->notequals.getSize();
30         if (se1 > se2)
31                 return -1;
32         else if (se2 == se1)
33                 return 0;
34         else
35                 return 1;
36 }
37
38 uint EncodingSubGraph::getEncoding(EncodingNode *n, uint64_t val) {
39         NodeValuePair nvp(n, val);
40         EncodingValue *ev = map.get(&nvp);
41         return ev->encoding;
42 }
43
44 void EncodingSubGraph::solveEquals() {
45         Vector<EncodingValue *> toEncode;
46         Vector<bool> encodingArray;
47         SetIteratorEncodingValue *valIt = values.iterator();
48         while (valIt->hasNext()) {
49                 EncodingValue *ev = valIt->next();
50                 if (!ev->inComparison)
51                         toEncode.push(ev);
52                 else
53                         ev->assigned = true;
54         }
55         delete valIt;
56         bsdqsort(toEncode.expose(), toEncode.getSize(), sizeof(EncodingValue *), sortEncodingValue);
57         uint toEncodeSize = toEncode.getSize();
58         for (uint i = 0; i < toEncodeSize; i++) {
59                 EncodingValue *ev = toEncode.get(i);
60                 encodingArray.clear();
61                 SetIteratorEncodingValue *conflictIt = ev->notequals.iterator();
62                 while (conflictIt->hasNext()) {
63                         EncodingValue *conflict = conflictIt->next();
64                         if (conflict->assigned) {
65                                 encodingArray.setExpand(conflict->encoding, true);
66                         }
67                 }
68                 delete conflictIt;
69                 uint encoding = 0;
70                 for (; encoding < encodingArray.getSize(); encoding++) {
71                         //See if this is unassigned
72                         if (!encodingArray.get(encoding))
73                                 break;
74                 }
75                 if (encoding > maxEncodingVal)
76                         maxEncodingVal = encoding;
77                 ev->encoding = encoding;
78                 ev->assigned = true;
79         }
80 }
81
82 void EncodingSubGraph::solveComparisons() {
83         HashsetEncodingValue discovered;
84         Vector<EncodingValue *> tovisit;
85         SetIteratorEncodingValue *valIt = values.iterator();
86         while (valIt->hasNext()) {
87                 EncodingValue *ev = valIt->next();
88                 if (discovered.add(ev)) {
89                         tovisit.push(ev);
90                         while (tovisit.getSize() != 0) {
91                                 EncodingValue *val = tovisit.last(); tovisit.pop();
92                                 SetIteratorEncodingValue *nextIt = val->larger.iterator();
93                                 uint minVal = val->encoding + 1;
94                                 while (nextIt->hasNext()) {
95                                         EncodingValue *nextVal = nextIt->next();
96                                         if (nextVal->encoding < minVal) {
97                                                 if (minVal > maxEncodingVal)
98                                                         maxEncodingVal = minVal;
99                                                 nextVal->encoding = minVal;
100                                                 discovered.add(nextVal);
101                                                 tovisit.push(nextVal);
102                                         }
103                                 }
104                                 delete nextIt;
105                         }
106                 }
107         }
108         delete valIt;
109 }
110
111 uint EncodingSubGraph::estimateNewSize(EncodingSubGraph *sg) {
112         uint newSize = 0;
113         SetIteratorEncodingNode *nit = sg->nodes.iterator();
114         while (nit->hasNext()) {
115                 EncodingNode *en = nit->next();
116                 uint size = estimateNewSize(en);
117                 if (size > newSize)
118                         newSize = size;
119         }
120         delete nit;
121         return newSize;
122 }
123
124 uint EncodingSubGraph::estimateNewSize(EncodingNode *n) {
125         SetIteratorEncodingEdge *eeit = n->edges.iterator();
126         uint newsize = n->getSize();
127         while (eeit->hasNext()) {
128                 EncodingEdge *ee = eeit->next();
129                 if (ee->left != NULL && ee->left != n && nodes.contains(ee->left)) {
130                         uint intersectSize = n->s->getUnionSize(ee->left->s);
131                         if (intersectSize > newsize)
132                                 newsize = intersectSize;
133                 }
134                 if (ee->right != NULL && ee->right != n && nodes.contains(ee->right)) {
135                         uint intersectSize = n->s->getUnionSize(ee->right->s);
136                         if (intersectSize > newsize)
137                                 newsize = intersectSize;
138                 }
139                 if (ee->dst != NULL && ee->dst != n && nodes.contains(ee->dst)) {
140                         uint intersectSize = n->s->getUnionSize(ee->dst->s);
141                         if (intersectSize > newsize)
142                                 newsize = intersectSize;
143                 }
144         }
145         delete eeit;
146         return newsize;
147 }
148
149 void EncodingSubGraph::addNode(EncodingNode *n) {
150         nodes.add(n);
151         uint newSize = estimateNewSize(n);
152         numElements += n->elements.getSize();
153         if (newSize > encodingSize)
154                 encodingSize = newSize;
155 }
156
157 SetIteratorEncodingNode *EncodingSubGraph::nodeIterator() {
158         return nodes.iterator();
159 }
160
161 void EncodingSubGraph::encode() {
162         computeEncodingValue();
163         computeComparisons();
164         computeEqualities();
165         solveComparisons();
166         solveEquals();
167 }
168
169 void EncodingSubGraph::computeEqualities() {
170         SetIteratorEncodingNode *nodeit = nodes.iterator();
171         while (nodeit->hasNext()) {
172                 EncodingNode *node = nodeit->next();
173                 generateEquals(node, node);
174
175                 SetIteratorEncodingEdge *edgeit = node->edges.iterator();
176                 while (edgeit->hasNext()) {
177                         EncodingEdge *edge = edgeit->next();
178                         //skip over comparisons as we have already handled them
179                         if (edge->numComparisons != 0)
180                                 continue;
181                         if (edge->numEquals == 0)
182                                 continue;
183                         if (edge->left == NULL || !nodes.contains(edge->left))
184                                 continue;
185                         if (edge->right == NULL || !nodes.contains(edge->right))
186                                 continue;
187                         //examine only once
188                         if (edge->left != node)
189                                 continue;
190                         //We have a comparison edge between two nodes in the subgraph
191                         //For now we don't support multiple encoding values with the same encoding....
192                         //So we enforce != constraints for every Set...
193                         if (edge->left != edge->right)
194                                 generateEquals(edge->left, edge->right);
195                 }
196                 delete edgeit;
197         }
198         delete nodeit;
199 }
200
201 void EncodingSubGraph::computeComparisons() {
202         SetIteratorEncodingNode *nodeit = nodes.iterator();
203         while (nodeit->hasNext()) {
204                 EncodingNode *node = nodeit->next();
205                 SetIteratorEncodingEdge *edgeit = node->edges.iterator();
206                 while (edgeit->hasNext()) {
207                         EncodingEdge *edge = edgeit->next();
208                         if (edge->numComparisons == 0)
209                                 continue;
210                         if (edge->left == NULL || !nodes.contains(edge->left))
211                                 continue;
212                         if (edge->right == NULL || !nodes.contains(edge->right))
213                                 continue;
214                         //examine only once
215                         if (edge->left != node)
216                                 continue;
217                         //We have a comparison edge between two nodes in the subgraph
218                         generateComparison(edge->left, edge->right);
219                 }
220                 delete edgeit;
221         }
222         delete nodeit;
223 }
224
225 void EncodingSubGraph::orderEV(EncodingValue *earlier, EncodingValue *later) {
226         earlier->larger.add(later);
227 }
228
229 void EncodingSubGraph::generateEquals(EncodingNode *left, EncodingNode *right) {
230         Set *lset = left->s;
231         Set *rset = right->s;
232         uint lSize = lset->getSize(), rSize = rset->getSize();
233         for (uint lindex = 0; lindex < lSize; lindex++) {
234                 for (uint rindex = 0; rindex < rSize; rindex++) {
235                         uint64_t lVal = lset->getElement(lindex);
236                         NodeValuePair nvp1(left, lVal);
237                         EncodingValue *lev = map.get(&nvp1);
238                         uint64_t rVal = rset->getElement(rindex);
239                         NodeValuePair nvp2(right, rVal);
240                         EncodingValue *rev = map.get(&nvp2);
241                         if (lev != rev) {
242                                 if (lev->inComparison && rev->inComparison) {
243                                         //Need to assign during comparison stage...
244                                         //Thus promote to comparison
245                                         if (lVal < rVal) {
246                                                 orderEV(lev, rev);
247                                         } else {
248                                                 orderEV(rev, lev);
249                                         }
250                                 } else {
251                                         lev->notequals.add(rev);
252                                         rev->notequals.add(lev);
253                                 }
254                         }
255                 }
256         }
257 }
258
259 void EncodingSubGraph::generateComparison(EncodingNode *left, EncodingNode *right) {
260         Set *lset = left->s;
261         Set *rset = right->s;
262         uint lindex = 0, rindex = 0;
263         uint lSize = lset->getSize(), rSize = rset->getSize();
264         uint64_t lVal = lset->getElement(lindex);
265         NodeValuePair nvp1(left, lVal);
266         EncodingValue *lev = map.get(&nvp1);
267         lev->inComparison = true;
268         uint64_t rVal = rset->getElement(rindex);
269         NodeValuePair nvp2(right, rVal);
270         EncodingValue *rev = map.get(&nvp2);
271         rev->inComparison = true;
272         EncodingValue *last = NULL;
273
274         while (lindex < lSize || rindex < rSize) {
275                 if (last != NULL) {
276                         if (lev != NULL)
277                                 orderEV(last, lev);
278                         if (rev != NULL && lev != rev)
279                                 orderEV(last, rev);
280                 }
281                 if (lev != rev) {
282                         if (rev == NULL ||
283                                         (lev != NULL && lVal < rVal)) {
284                                 if (rev != NULL)
285                                         orderEV(lev, rev);
286                                 last = lev;
287                                 if (++lindex < lSize) {
288                                         lVal = lset->getElement(lindex);
289                                         NodeValuePair nvpl(left, lVal);
290                                         lev = map.get(&nvpl);
291                                         lev->inComparison = true;
292                                 } else
293                                         lev = NULL;
294                         } else {
295                                 if (lev != NULL)
296                                         orderEV(rev, lev);
297                                 last = rev;
298                                 if (++rindex < rSize) {
299                                         rVal = rset->getElement(rindex);
300                                         NodeValuePair nvpr(right, rVal);
301                                         rev = map.get(&nvpr);
302                                         rev->inComparison = true;
303                                 } else
304                                         rev = NULL;
305                         }
306                 } else {
307                         last = lev;
308                         if (++lindex < lSize) {
309                                 lVal = lset->getElement(lindex);
310                                 NodeValuePair nvpl(left, lVal);
311                                 lev = map.get(&nvpl);
312                                 lev->inComparison = true;
313                         } else
314                                 lev = NULL;
315
316                         if (++rindex < rSize) {
317                                 rVal = rset->getElement(rindex);
318                                 NodeValuePair nvpr(right, rVal);
319                                 rev = map.get(&nvpr);
320                                 rev->inComparison = true;
321                         } else
322                                 rev = NULL;
323                 }
324         }
325 }
326
327 void EncodingSubGraph::computeEncodingValue() {
328         SetIteratorEncodingNode *nodeit = nodes.iterator();
329         while (nodeit->hasNext()) {
330                 EncodingNode *node = nodeit->next();
331                 Set *set = node->s;
332                 uint setSize = set->getSize();
333                 for (uint i = 0; i < setSize; i++) {
334                         uint64_t val = set->getElement(i);
335                         NodeValuePair nvp(node, val);
336                         if (!map.contains(&nvp)) {
337                                 traverseValue(node, val);
338                         }
339                 }
340         }
341         delete nodeit;
342 }
343
344 void EncodingSubGraph::traverseValue(EncodingNode *node, uint64_t value) {
345         EncodingValue *ecv = new EncodingValue(value);
346         values.add(ecv);
347         HashsetEncodingNode discovered;
348         Vector<EncodingNode *> tovisit;
349         tovisit.push(node);
350         discovered.add(node);
351         while (tovisit.getSize() != 0) {
352                 EncodingNode *n = tovisit.last();tovisit.pop();
353                 //Add encoding node to structures
354                 ecv->nodes.add(n);
355                 NodeValuePair *nvp = new NodeValuePair(n, value);
356                 map.put(nvp, ecv);
357                 SetIteratorEncodingEdge *edgeit = node->edges.iterator();
358                 while (edgeit->hasNext()) {
359                         EncodingEdge *ee = edgeit->next();
360                         if (!discovered.contains(ee->left) && nodes.contains(ee->left) && ee->left->s->exists(value)) {
361                                 tovisit.push(ee->left);
362                                 discovered.add(ee->left);
363                         }
364                         if (!discovered.contains(ee->right) && nodes.contains(ee->right) && ee->right->s->exists(value)) {
365                                 tovisit.push(ee->right);
366                                 discovered.add(ee->right);
367                         }
368                 }
369                 delete edgeit;
370         }
371 }
372