adc7ad78c31534ef0b1094935e40414c915e38de
[satune.git] / src / ASTAnalyses / Encoding / subgraph.cc
1 #include "subgraph.h"
2 #include "encodinggraph.h"
3 #include "set.h"
4 #include "qsort.h"
5
6 EncodingSubGraph::EncodingSubGraph() :
7         encodingSize(0),
8         numElements(0),
9         maxEncodingVal(0) {
10 }
11
12 EncodingSubGraph::~EncodingSubGraph() {
13         map.resetAndDeleteKeys();
14         values.resetAndDelete();
15 }
16
17 uint hashNodeValuePair(NodeValuePair *nvp) {
18         return (uint) (nvp->value ^ ((uintptr_t)nvp->node));
19 }
20
21 bool equalsNodeValuePair(NodeValuePair *nvp1, NodeValuePair *nvp2) {
22         return nvp1->value == nvp2->value && nvp1->node == nvp2->node;
23 }
24
25 int sortEncodingValue(const void *p1, const void *p2) {
26         const EncodingValue *e1 = *(const EncodingValue **) p1;
27         const EncodingValue *e2 = *(const EncodingValue **) p2;
28         uint se1 = e1->notequals.getSize();
29         uint se2 = e2->notequals.getSize();
30         if (se1 > se2)
31                 return -1;
32         else if (se2 == se1)
33                 return 0;
34         else
35                 return 1;
36 }
37
38 uint EncodingSubGraph::getEncoding(EncodingNode *n, uint64_t val) {
39         NodeValuePair nvp(n, val);
40         EncodingValue *ev = map.get(&nvp);
41         return ev->encoding;
42 }
43
44 void EncodingSubGraph::solveEquals() {
45         Vector<EncodingValue *> toEncode;
46         Vector<bool> encodingArray;
47         SetIteratorEncodingValue *valIt = values.iterator();
48         while (valIt->hasNext()) {
49                 EncodingValue *ev = valIt->next();
50                 if (!ev->inComparison)
51                         toEncode.push(ev);
52                 else
53                         ev->assigned = true;
54         }
55         delete valIt;
56         bsdqsort(toEncode.expose(), toEncode.getSize(), sizeof(EncodingValue *), sortEncodingValue);
57         uint toEncodeSize = toEncode.getSize();
58         for (uint i = 0; i < toEncodeSize; i++) {
59                 EncodingValue *ev = toEncode.get(i);
60                 encodingArray.clear();
61                 SetIteratorEncodingValue *conflictIt = ev->notequals.iterator();
62                 while (conflictIt->hasNext()) {
63                         EncodingValue *conflict = conflictIt->next();
64                         ASSERT(conflict->value != ev->value);
65                         if (conflict->assigned) {
66                                 encodingArray.setExpand(conflict->encoding, true);
67                         }
68                 }
69                 delete conflictIt;
70                 uint encoding = 0;
71                 for (; encoding < encodingArray.getSize(); encoding++) {
72                         //See if this is unassigned
73                         if (!encodingArray.get(encoding))
74                                 break;
75                 }
76                 if (encoding > maxEncodingVal)
77                         maxEncodingVal = encoding;
78                 ev->encoding = encoding;
79                 ev->assigned = true;
80         }
81 }
82
83 void EncodingSubGraph::solveComparisons() {
84         HashsetEncodingValue discovered;
85         Vector<EncodingValue *> tovisit;
86         SetIteratorEncodingValue *valIt = values.iterator();
87         while (valIt->hasNext()) {
88                 EncodingValue *ev = valIt->next();
89                 if (discovered.add(ev)) {
90                         tovisit.push(ev);
91                         while (tovisit.getSize() != 0) {
92                                 EncodingValue *val = tovisit.last(); tovisit.pop();
93                                 SetIteratorEncodingValue *nextIt = val->larger.iterator();
94                                 uint minVal = val->encoding + 1;
95                                 while (nextIt->hasNext()) {
96                                         EncodingValue *nextVal = nextIt->next();
97                                         if (nextVal->encoding < minVal) {
98                                                 if (minVal > maxEncodingVal)
99                                                         maxEncodingVal = minVal;
100                                                 nextVal->encoding = minVal;
101                                                 discovered.add(nextVal);
102                                                 tovisit.push(nextVal);
103                                         }
104                                 }
105                                 delete nextIt;
106                         }
107                 }
108         }
109         delete valIt;
110 }
111
112 uint EncodingSubGraph::estimateNewSize(EncodingSubGraph *sg) {
113         uint newSize = 0;
114         SetIteratorEncodingNode *nit = sg->nodes.iterator();
115         while (nit->hasNext()) {
116                 EncodingNode *en = nit->next();
117                 uint size = estimateNewSize(en);
118                 if (size > newSize)
119                         newSize = size;
120         }
121         delete nit;
122         return newSize;
123 }
124
125 uint EncodingSubGraph::estimateNewSize(EncodingNode *n) {
126         SetIteratorEncodingEdge *eeit = n->edges.iterator();
127         uint newsize = n->getSize();
128         while (eeit->hasNext()) {
129                 EncodingEdge *ee = eeit->next();
130                 if (ee->left != NULL && ee->left != n && nodes.contains(ee->left)) {
131                         uint intersectSize = n->s->getUnionSize(ee->left->s);
132                         if (intersectSize > newsize)
133                                 newsize = intersectSize;
134                 }
135                 if (ee->right != NULL && ee->right != n && nodes.contains(ee->right)) {
136                         uint intersectSize = n->s->getUnionSize(ee->right->s);
137                         if (intersectSize > newsize)
138                                 newsize = intersectSize;
139                 }
140                 if (ee->dst != NULL && ee->dst != n && nodes.contains(ee->dst)) {
141                         uint intersectSize = n->s->getUnionSize(ee->dst->s);
142                         if (intersectSize > newsize)
143                                 newsize = intersectSize;
144                 }
145         }
146         delete eeit;
147         return newsize;
148 }
149
150 void EncodingSubGraph::addNode(EncodingNode *n) {
151         nodes.add(n);
152         uint newSize = estimateNewSize(n);
153         numElements += n->elements.getSize();
154         if (newSize > encodingSize)
155                 encodingSize = newSize;
156 }
157
158 SetIteratorEncodingNode *EncodingSubGraph::nodeIterator() {
159         return nodes.iterator();
160 }
161
162 void EncodingSubGraph::encode() {
163         computeEncodingValue();
164         computeComparisons();
165         computeEqualities();
166         solveComparisons();
167         solveEquals();
168 }
169
170 void EncodingSubGraph::computeEqualities() {
171         SetIteratorEncodingNode *nodeit = nodes.iterator();
172         while (nodeit->hasNext()) {
173                 EncodingNode *node = nodeit->next();
174                 generateEquals(node, node);
175
176                 SetIteratorEncodingEdge *edgeit = node->edges.iterator();
177                 while (edgeit->hasNext()) {
178                         EncodingEdge *edge = edgeit->next();
179                         //skip over comparisons as we have already handled them
180                         if (edge->numComparisons != 0)
181                                 continue;
182                         if (edge->numEquals == 0)
183                                 continue;
184                         if (edge->left == NULL || !nodes.contains(edge->left))
185                                 continue;
186                         if (edge->right == NULL || !nodes.contains(edge->right))
187                                 continue;
188                         //examine only once
189                         if (edge->left != node)
190                                 continue;
191                         //We have a comparison edge between two nodes in the subgraph
192                         //For now we don't support multiple encoding values with the same encoding....
193                         //So we enforce != constraints for every Set...
194                         if (edge->left != edge->right)
195                                 generateEquals(edge->left, edge->right);
196                 }
197                 delete edgeit;
198         }
199         delete nodeit;
200 }
201
202 void EncodingSubGraph::computeComparisons() {
203         SetIteratorEncodingNode *nodeit = nodes.iterator();
204         while (nodeit->hasNext()) {
205                 EncodingNode *node = nodeit->next();
206                 SetIteratorEncodingEdge *edgeit = node->edges.iterator();
207                 while (edgeit->hasNext()) {
208                         EncodingEdge *edge = edgeit->next();
209                         if (edge->numComparisons == 0)
210                                 continue;
211                         if (edge->left == NULL || !nodes.contains(edge->left))
212                                 continue;
213                         if (edge->right == NULL || !nodes.contains(edge->right))
214                                 continue;
215                         //examine only once
216                         if (edge->left != node)
217                                 continue;
218                         //We have a comparison edge between two nodes in the subgraph
219                         generateComparison(edge->left, edge->right);
220                 }
221                 delete edgeit;
222         }
223         delete nodeit;
224 }
225
226 void EncodingSubGraph::orderEV(EncodingValue *earlier, EncodingValue *later) {
227         earlier->larger.add(later);
228 }
229
230 void EncodingSubGraph::generateEquals(EncodingNode *left, EncodingNode *right) {
231         Set *lset = left->s;
232         Set *rset = right->s;
233         uint lSize = lset->getSize(), rSize = rset->getSize();
234         for (uint lindex = 0; lindex < lSize; lindex++) {
235                 for (uint rindex = 0; rindex < rSize; rindex++) {
236                         uint64_t lVal = lset->getElement(lindex);
237                         NodeValuePair nvp1(left, lVal);
238                         EncodingValue *lev = map.get(&nvp1);
239                         uint64_t rVal = rset->getElement(rindex);
240                         NodeValuePair nvp2(right, rVal);
241                         EncodingValue *rev = map.get(&nvp2);
242                         if (lev != rev) {
243                                 ASSERT(lVal != rVal);
244                                 if (lev->inComparison && rev->inComparison) {
245                                         //Need to assign during comparison stage...
246                                         //Thus promote to comparison
247                                         if (lVal < rVal) {
248                                                 orderEV(lev, rev);
249                                         } else {
250                                                 orderEV(rev, lev);
251                                         }
252                                 } else {
253                                         lev->notequals.add(rev);
254                                         rev->notequals.add(lev);
255                                 }
256                         }
257                 }
258         }
259 }
260
261 void EncodingSubGraph::generateComparison(EncodingNode *left, EncodingNode *right) {
262         Set *lset = left->s;
263         Set *rset = right->s;
264         uint lindex = 0, rindex = 0;
265         uint lSize = lset->getSize(), rSize = rset->getSize();
266         uint64_t lVal = lset->getElement(lindex);
267         NodeValuePair nvp1(left, lVal);
268         EncodingValue *lev = map.get(&nvp1);
269         lev->inComparison = true;
270         uint64_t rVal = rset->getElement(rindex);
271         NodeValuePair nvp2(right, rVal);
272         EncodingValue *rev = map.get(&nvp2);
273         rev->inComparison = true;
274         EncodingValue *last = NULL;
275
276         while (lindex < lSize || rindex < rSize) {
277                 if (last != NULL) {
278                         if (lev != NULL)
279                                 orderEV(last, lev);
280                         if (rev != NULL && lev != rev)
281                                 orderEV(last, rev);
282                 }
283                 if (lev != rev) {
284                         if (rev == NULL ||
285                                         (lev != NULL && lVal < rVal)) {
286                                 if (rev != NULL)
287                                         orderEV(lev, rev);
288                                 last = lev;
289                                 if (++lindex < lSize) {
290                                         lVal = lset->getElement(lindex);
291                                         NodeValuePair nvpl(left, lVal);
292                                         lev = map.get(&nvpl);
293                                         lev->inComparison = true;
294                                 } else
295                                         lev = NULL;
296                         } else {
297                                 if (lev != NULL)
298                                         orderEV(rev, lev);
299                                 last = rev;
300                                 if (++rindex < rSize) {
301                                         rVal = rset->getElement(rindex);
302                                         NodeValuePair nvpr(right, rVal);
303                                         rev = map.get(&nvpr);
304                                         rev->inComparison = true;
305                                 } else
306                                         rev = NULL;
307                         }
308                 } else {
309                         last = lev;
310                         if (++lindex < lSize) {
311                                 lVal = lset->getElement(lindex);
312                                 NodeValuePair nvpl(left, lVal);
313                                 lev = map.get(&nvpl);
314                                 lev->inComparison = true;
315                         } else
316                                 lev = NULL;
317
318                         if (++rindex < rSize) {
319                                 rVal = rset->getElement(rindex);
320                                 NodeValuePair nvpr(right, rVal);
321                                 rev = map.get(&nvpr);
322                                 rev->inComparison = true;
323                         } else
324                                 rev = NULL;
325                 }
326         }
327 }
328
329 void EncodingSubGraph::computeEncodingValue() {
330         SetIteratorEncodingNode *nodeit = nodes.iterator();
331         while (nodeit->hasNext()) {
332                 EncodingNode *node = nodeit->next();
333                 Set *set = node->s;
334                 uint setSize = set->getSize();
335                 for (uint i = 0; i < setSize; i++) {
336                         uint64_t val = set->getElement(i);
337                         NodeValuePair nvp(node, val);
338                         if (!map.contains(&nvp)) {
339                                 traverseValue(node, val);
340                         }
341                 }
342         }
343         delete nodeit;
344 }
345
346 void EncodingSubGraph::traverseValue(EncodingNode *node, uint64_t value) {
347         EncodingValue *ecv = new EncodingValue(value);
348         values.add(ecv);
349         HashsetEncodingNode discovered;
350         Vector<EncodingNode *> tovisit;
351         tovisit.push(node);
352         discovered.add(node);
353         while (tovisit.getSize() != 0) {
354                 EncodingNode *n = tovisit.last();tovisit.pop();
355                 //Add encoding node to structures
356                 NodeValuePair *nvp = new NodeValuePair(n, value);
357                 if(map.contains(nvp))
358                         continue;
359                 ecv->nodes.add(n);
360                 map.put(nvp, ecv);
361                 ASSERT(node != NULL);
362                 SetIteratorEncodingEdge *edgeit = n->edges.iterator();
363                 while (edgeit->hasNext()) {
364                         EncodingEdge *ee = edgeit->next();
365                         if (!discovered.contains(ee->left) && nodes.contains(ee->left) && ee->left->s->exists(value)) {
366                                 tovisit.push(ee->left);
367                                 discovered.add(ee->left);
368                                 ASSERT(discovered.contains(ee->left));
369                         }
370                         if (!discovered.contains(ee->right) && nodes.contains(ee->right) && ee->right->s->exists(value)) {
371                                 tovisit.push(ee->right);
372                                 discovered.add(ee->right);
373                                 ASSERT(discovered.contains(ee->right));
374                         }
375                 }
376                 delete edgeit;
377         }
378 }
379