6ebea56771771644a8bbe33782adde4d49ebbf3c
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
1 #include "encodinggraph.h"
2 #include "iterator.h"
3 #include "element.h"
4 #include "function.h"
5 #include "predicate.h"
6 #include "set.h"
7 #include "csolver.h"
8 #include "tunable.h"
9 #include "qsort.h"
10 #include "subgraph.h"
11 #include "elementencoding.h"
12
13 EncodingGraph::EncodingGraph(CSolver *_solver) :
14         solver(_solver) {
15 }
16
17 EncodingGraph::~EncodingGraph() {
18         subgraphs.resetAndDelete();
19         encodingMap.resetAndDeleteVals();
20         edgeMap.resetAndDeleteVals();
21 }
22
23 int sortEncodingEdge(const void *p1, const void *p2) {
24         const EncodingEdge *e1 = *(const EncodingEdge **) p1;
25         const EncodingEdge *e2 = *(const EncodingEdge **) p2;
26         uint64_t v1 = e1->getValue();
27         uint64_t v2 = e2->getValue();
28         if (v1 < v2)
29                 return 1;
30         else if (v1 == v2)
31                 return 0;
32         else
33                 return -1;
34 }
35
36 void EncodingGraph::buildGraph() {
37         ElementIterator it(solver);
38         while (it.hasNext()) {
39                 Element *e = it.next();
40                 switch (e->type) {
41                 case ELEMSET:
42                 case ELEMFUNCRETURN:
43                         processElement(e);
44                         break;
45                 case ELEMCONST:
46                         break;
47                 default:
48                         ASSERT(0);
49                 }
50         }
51         bsdqsort(edgeVector.expose(), edgeVector.getSize(), sizeof(EncodingEdge *), sortEncodingEdge);
52         decideEdges();
53 }
54
55
56 void EncodingGraph::validate() {
57         SetIteratorBooleanEdge *it = solver->getConstraints();
58         while (it->hasNext()) {
59                 BooleanEdge be = it->next();
60                 if (be->type == PREDICATEOP) {
61                         BooleanPredicate *b = (BooleanPredicate *)be.getBoolean();
62                         if (b->predicate->type == OPERATORPRED) {
63                                 PredicateOperator *predicate = (PredicateOperator *) b->predicate;
64                                 if (predicate->getOp() == SATC_EQUALS) {
65                                         ASSERT(b->inputs.getSize() == 2);
66                                         Element *e1 = b->inputs.get(0);
67                                         Element *e2 = b->inputs.get(1);
68                                         if (e1->type == ELEMCONST || e1->type == ELEMCONST)
69                                                 continue;
70                                         ElementEncoding *enc1 = e1->getElementEncoding();
71                                         ElementEncoding *enc2 = e2->getElementEncoding();
72                                         ASSERT(enc1->getElementEncodingType() != ELEM_UNASSIGNED);
73                                         ASSERT(enc2->getElementEncodingType() != ELEM_UNASSIGNED);
74                                         if (enc1->getElementEncodingType() == enc2->getElementEncodingType() && enc1->getElementEncodingType() == BINARYINDEX && b->getFunctionEncoding()->type == CIRCUIT) {
75                                                 for (uint i = 0; i < enc1->encArraySize; i++) {
76                                                         if (enc1->isinUseElement(i)) {
77                                                                 uint64_t val1 = enc1->encodingArray[i];
78                                                                 if (enc2->isinUseElement(i)) {
79                                                                         ASSERT(val1 == enc2->encodingArray[i]);
80                                                                 } else {
81                                                                         for (uint j = 0; j < enc2->encArraySize; j++) {
82                                                                                 if (enc2->isinUseElement(j)) {
83                                                                                         ASSERT(val1 != enc2->encodingArray[j]);
84                                                                                 }
85                                                                         }
86                                                                 }
87                                                         }
88                                                 }
89                                         }
90                                         //Now make sure that all the elements in the set are appeared in the encoding array!
91                                         for (uint k = 0; k < b->inputs.getSize(); k++) {
92                                                 Element *e = b->inputs.get(k);
93                                                 ElementEncoding *enc = e->getElementEncoding();
94                                                 Set *s = e->getRange();
95                                                 for (uint i = 0; i < s->getSize(); i++) {
96                                                         uint64_t value = s->getElement(i);
97                                                         bool exist = false;
98                                                         for (uint j = 0; j < enc->encArraySize; j++) {
99                                                                 if (enc->isinUseElement(j) && enc->encodingArray[j] == value) {
100                                                                         exist = true;
101                                                                         break;
102                                                                 }
103                                                         }
104                                                         ASSERT(exist);
105                                                 }
106                                         }
107                                 }
108                         }
109                 }
110         }
111         delete it;
112 }
113
114
115 void EncodingGraph::encode() {
116         if (solver->getTuner()->getTunable(ENCODINGGRAPHOPT, &offon) == 0)
117                 return;
118         buildGraph();
119         SetIteratorEncodingSubGraph *itesg = subgraphs.iterator();
120         model_print("#SubGraph = %u\n", subgraphs.getSize());
121         while (itesg->hasNext()) {
122                 EncodingSubGraph *sg = itesg->next();
123                 sg->encode();
124         }
125         delete itesg;
126
127         ElementIterator it(solver);
128         while (it.hasNext()) {
129                 Element *e = it.next();
130                 switch (e->type) {
131                 case ELEMSET:
132                 case ELEMFUNCRETURN: {
133                         ElementEncoding *encoding = e->getElementEncoding();
134                         if (encoding->getElementEncodingType() == ELEM_UNASSIGNED) {
135                                 EncodingNode *n = getNode(e);
136                                 if (n == NULL)
137                                         continue;
138                                 ElementEncodingType encodetype = n->getEncoding();
139                                 encoding->setElementEncodingType(encodetype);
140                                 if (encodetype == UNARY || encodetype == ONEHOT) {
141                                         encoding->encodingArrayInitialization();
142                                 } else if (encodetype == BINARYINDEX) {
143                                         EncodingSubGraph *subgraph = graphMap.get(n);
144                                         DEBUG("graphMap.get(subgraph=%p, n=%p)\n", subgraph, n);
145                                         if (subgraph == NULL) {
146                                                 encoding->encodingArrayInitialization();
147                                                 continue;
148                                         }
149                                         uint encodingSize = subgraph->getEncodingMaxVal(n) + 1;
150                                         uint paddedSize = encoding->getSizeEncodingArray(encodingSize);
151                                         encoding->allocInUseArrayElement(paddedSize);
152                                         encoding->allocEncodingArrayElement(paddedSize);
153                                         Set *s = e->getRange();
154                                         for (uint i = 0; i < s->getSize(); i++) {
155                                                 uint64_t value = s->getElement(i);
156                                                 uint encodingIndex = subgraph->getEncoding(n, value);
157                                                 encoding->setInUseElement(encodingIndex);
158                                                 ASSERT(encoding->isinUseElement(encodingIndex));
159                                                 encoding->encodingArray[encodingIndex] = value;
160                                         }
161                                 }
162                         }
163                         break;
164                 }
165                 default:
166                         break;
167                 }
168                 encodeParent(e);
169         }
170 }
171
172 void EncodingGraph::encodeParent(Element *e) {
173         uint size = e->parents.getSize();
174         for (uint i = 0; i < size; i++) {
175                 ASTNode *n = e->parents.get(i);
176                 if (n->type == PREDICATEOP) {
177                         BooleanPredicate *b = (BooleanPredicate *)n;
178                         FunctionEncoding *fenc = b->getFunctionEncoding();
179                         if (fenc->getFunctionEncodingType() != FUNC_UNASSIGNED)
180                                 continue;
181                         Predicate *p = b->getPredicate();
182                         if (p->type == OPERATORPRED) {
183                                 PredicateOperator *po = (PredicateOperator *)p;
184                                 ASSERT(b->inputs.getSize() == 2);
185                                 EncodingNode *left = createNode(b->inputs.get(0));
186                                 EncodingNode *right = createNode(b->inputs.get(1));
187                                 if (left == NULL || right == NULL)
188                                         return;
189                                 EncodingEdge *edge = getEdge(left, right, NULL);
190                                 if (edge != NULL) {
191                                         EncodingSubGraph *leftGraph = graphMap.get(left);
192                                         if (leftGraph != NULL && leftGraph == graphMap.get(right)) {
193                                                 fenc->setFunctionEncodingType(CIRCUIT);
194                                         }
195                                 }
196                         }
197                 }
198         }
199 }
200
201 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
202         EncodingSubGraph *graph1 = graphMap.get(first);
203         DEBUG("graphMap.get(first=%p, graph1=%p)\n", first, graph1);
204         EncodingSubGraph *graph2 = graphMap.get(second);
205         DEBUG("graphMap.get(second=%p, graph2=%p)\n", second, graph2);
206         if (graph1 == NULL)
207                 first->setEncoding(BINARYINDEX);
208         if (graph2 == NULL)
209                 second->setEncoding(BINARYINDEX);
210
211         if (graph1 == NULL && graph2 == NULL) {
212                 graph1 = new EncodingSubGraph();
213                 subgraphs.add(graph1);
214                 DEBUG("graphMap.put(first=%p, graph1=%p)\n", first, graph1);
215                 graphMap.put(first, graph1);
216                 graph1->addNode(first);
217         }
218         if (graph1 == NULL && graph2 != NULL) {
219                 graph1 = graph2;
220                 graph2 = NULL;
221                 EncodingNode *tmp = second;
222                 second = first;
223                 first = tmp;
224         }
225         if (graph1 != NULL && graph2 != NULL) {
226                 if (graph1 == graph2)
227                         return;
228
229                 SetIteratorEncodingNode *nodeit = graph2->nodeIterator();
230                 while (nodeit->hasNext()) {
231                         EncodingNode *node = nodeit->next();
232                         graph1->addNode(node);
233                         DEBUG("graphMap.put(node=%p, graph1=%p)\n", node, graph1);
234                         graphMap.put(node, graph1);
235                 }
236                 subgraphs.remove(graph2);
237                 delete nodeit;
238                 DEBUG("Deleting graph2 =%p \n", graph2);
239                 delete graph2;
240         } else {
241                 ASSERT(graph1 != NULL && graph2 == NULL);
242                 graph1->addNode(second);
243                 DEBUG("graphMap.put(first=%p, graph1=%p)\n", first, graph1);
244                 graphMap.put(second, graph1);
245         }
246 }
247
248 void EncodingGraph::processElement(Element *e) {
249         uint size = e->parents.getSize();
250         for (uint i = 0; i < size; i++) {
251                 ASTNode *n = e->parents.get(i);
252                 switch (n->type) {
253                 case PREDICATEOP:
254                         processPredicate((BooleanPredicate *)n);
255                         break;
256                 case ELEMFUNCRETURN:
257                         processFunction((ElementFunction *)n);
258                         break;
259                 default:
260                         ASSERT(0);
261                 }
262         }
263 }
264
265 void EncodingGraph::processFunction(ElementFunction *ef) {
266         Function *f = ef->getFunction();
267         if (f->type == OPERATORFUNC) {
268                 FunctionOperator *fo = (FunctionOperator *)f;
269                 ASSERT(ef->inputs.getSize() == 2);
270                 EncodingNode *left = createNode(ef->inputs.get(0));
271                 EncodingNode *right = createNode(ef->inputs.get(1));
272                 if (left == NULL && right == NULL)
273                         return;
274                 EncodingNode *dst = createNode(ef);
275                 EncodingEdge *edge = createEdge(left, right, dst);
276                 edge->numArithOps++;
277         }
278 }
279
280 void EncodingGraph::processPredicate(BooleanPredicate *b) {
281         Predicate *p = b->getPredicate();
282         if (p->type == OPERATORPRED) {
283                 PredicateOperator *po = (PredicateOperator *)p;
284                 ASSERT(b->inputs.getSize() == 2);
285                 EncodingNode *left = createNode(b->inputs.get(0));
286                 EncodingNode *right = createNode(b->inputs.get(1));
287                 if (left == NULL || right == NULL)
288                         return;
289                 EncodingEdge *edge = createEdge(left, right, NULL);
290                 CompOp op = po->getOp();
291                 switch (op) {
292                 case SATC_EQUALS:
293                         edge->numEquals++;
294                         break;
295                 case SATC_LT:
296                 case SATC_LTE:
297                 case SATC_GT:
298                 case SATC_GTE:
299                         edge->numComparisons++;
300                         break;
301                 default:
302                         ASSERT(0);
303                 }
304         }
305 }
306
307 uint convertSize(uint cost) {
308         cost = FUDGEFACTOR * cost;// fudge factor
309         return NEXTPOW2(cost);
310 }
311
312 void EncodingGraph::decideEdges() {
313         uint size = edgeVector.getSize();
314         for (uint i = 0; i < size; i++) {
315                 EncodingEdge *ee = edgeVector.get(i);
316                 EncodingNode *left = ee->left;
317                 EncodingNode *right = ee->right;
318
319                 if (ee->encoding != EDGE_UNASSIGNED ||
320                                 !left->couldBeBinaryIndex() ||
321                                 !right->couldBeBinaryIndex())
322                         continue;
323
324                 uint64_t eeValue = ee->getValue();
325                 if (eeValue == 0)
326                         return;
327
328                 EncodingSubGraph *leftGraph = graphMap.get(left);
329                 DEBUG("graphMap.get(left=%p, leftgraph=%p)\n", left, leftGraph);
330                 EncodingSubGraph *rightGraph = graphMap.get(right);
331                 DEBUG("graphMap.get(right=%p, rightgraph=%p)\n", right, rightGraph);
332                 if (leftGraph == NULL && rightGraph != NULL) {
333                         EncodingNode *tmp = left; left = right; right = tmp;
334                         EncodingSubGraph *tmpsg = leftGraph; leftGraph = rightGraph; rightGraph = tmpsg;
335                 }
336
337                 uint leftSize = 0, rightSize = 0, newSize = 0, max = 0;
338                 bool merge = false;
339                 if (leftGraph == NULL && rightGraph == NULL) {
340                         leftSize = convertSize(left->getSize());
341                         rightSize = convertSize(right->getSize());
342                         newSize = convertSize(left->s->getUnionSize(right->s));
343                         newSize = (leftSize > newSize) ? leftSize : newSize;
344                         newSize = (rightSize > newSize) ? rightSize : newSize;
345                         max = rightSize > leftSize ? rightSize : leftSize;
346                         merge = left->measureSimilarity(right) > 1.5 || max == newSize;
347                 } else if (leftGraph != NULL && rightGraph == NULL) {
348                         leftSize = convertSize(leftGraph->encodingSize);
349                         rightSize = convertSize(right->getSize());
350                         newSize = convertSize(leftGraph->estimateNewSize(right));
351                         newSize = (leftSize > newSize) ? leftSize : newSize;
352                         newSize = (rightSize > newSize) ? rightSize : newSize;
353                         max = rightSize > leftSize ? rightSize : leftSize;
354 //                      model_print("Merge=%s\tsimilarity=%f\n", max==newSize?"TRUE":"FALSE", left->measureSimilarity(right));
355                         merge = left->measureSimilarity(right) > 1.5 || max == newSize;
356                 } else {
357                         //Neither are null
358                         leftSize = convertSize(leftGraph->encodingSize);
359                         rightSize = convertSize(rightGraph->encodingSize);
360                         newSize = convertSize(leftGraph->estimateNewSize(rightGraph));
361 //                      model_print("MergingSubGraphs: left=%u\tright=%u\tnewSize=%u\n", leftSize, rightSize, newSize);
362                         newSize = (leftSize > newSize) ? leftSize : newSize;
363                         newSize = (rightSize > newSize) ? rightSize : newSize;
364                         max = rightSize > leftSize ? rightSize : leftSize;
365 //                      model_print("Merge=%s\tsimilarity=%f\n", max==newSize?"TRUE":"FALSE", leftGraph->measureSimilarity(right));
366                         merge = leftGraph->measureSimilarity(right) > 1.5 || max == newSize;
367                 }
368                 if (merge) {
369                         //add the edge
370                         mergeNodes(left, right);
371                 }
372         }
373 }
374
375 static TunableDesc EdgeEncodingDesc(EDGE_UNASSIGNED, EDGE_MATCH, EDGE_UNASSIGNED);
376
377 EncodingEdge *EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
378         EncodingEdge e(left, right, dst);
379         EncodingEdge *result = edgeMap.get(&e);
380         return result;
381 }
382
383 EncodingEdge *EncodingGraph::createEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
384         EncodingEdge e(left, right, dst);
385         EncodingEdge *result = edgeMap.get(&e);
386         if (result == NULL) {
387                 result = new EncodingEdge(left, right, dst);
388                 VarType v1 = left->getType();
389                 VarType v2 = right->getType();
390                 if (v1 > v2) {
391                         VarType tmp = v2;
392                         v2 = v1;
393                         v1 = tmp;
394                 }
395
396                 if ((left != NULL && left->couldBeBinaryIndex()) &&
397                                 (right != NULL) && right->couldBeBinaryIndex()) {
398                         EdgeEncodingType type = (EdgeEncodingType)solver->getTuner()->getVarTunable(v1, v2, EDGEENCODING, &EdgeEncodingDesc);
399                         result->setEncoding(type);
400                         if (type == EDGE_MATCH) {
401                                 mergeNodes(left, right);
402                         }
403                 }
404                 edgeMap.put(result, result);
405                 edgeVector.push(result);
406                 if (left != NULL)
407                         left->edges.add(result);
408                 if (right != NULL)
409                         right->edges.add(result);
410                 if (dst != NULL)
411                         dst->edges.add(result);
412         }
413         return result;
414 }
415
416 EncodingNode::EncodingNode(Set *_s) :
417         s(_s) {
418 }
419
420 uint EncodingNode::getSize() const {
421         return s->getSize();
422 }
423
424 uint64_t EncodingNode::getIndex(uint index) {
425         return s->getElement(index);
426 }
427
428 VarType EncodingNode::getType() const {
429         return s->getType();
430 }
431
432 bool EncodingNode::itemExists(uint64_t item) {
433         for (uint i = 0; i < s->getSize(); i++) {
434                 if (item == s->getElement(i))
435                         return true;
436         }
437         return false;
438 }
439
440 double EncodingNode::measureSimilarity(EncodingNode *node) {
441         uint common = 0;
442         for (uint i = 0; i < s->getSize(); i++) {
443                 uint64_t item = s->getElement(i);
444                 if (node->itemExists(item)) {
445                         common++;
446                 }
447         }
448 //      model_print("common=%u\tsize1=%u\tsize2=%u\tsim1=%f\tsim2=%f\n", common, s->getSize(), node->getSize(), 1.0*common/s->getSize(), 1.0*common/node->getSize());
449         return common * 1.0 / s->getSize() + common * 1.0 / node->getSize();
450 }
451
452 EncodingNode *EncodingGraph::createNode(Element *e) {
453         if (e->type == ELEMCONST)
454                 return NULL;
455         Set *s = e->getRange();
456         EncodingNode *n = encodingMap.get(s);
457         if (n == NULL) {
458                 n = new EncodingNode(s);
459                 n->setEncoding((ElementEncodingType)solver->getTuner()->getVarTunable(n->getType(), NODEENCODING, &NodeEncodingDesc));
460
461                 encodingMap.put(s, n);
462         }
463         n->addElement(e);
464         return n;
465 }
466
467 EncodingNode *EncodingGraph::getNode(Element *e) {
468         if (e->type == ELEMCONST)
469                 return NULL;
470         Set *s = e->getRange();
471         EncodingNode *n = encodingMap.get(s);
472         return n;
473 }
474
475 void EncodingNode::addElement(Element *e) {
476         elements.add(e);
477 }
478
479 EncodingEdge::EncodingEdge(EncodingNode *_l, EncodingNode *_r) :
480         left(_l),
481         right(_r),
482         dst(NULL),
483         encoding(EDGE_UNASSIGNED),
484         numArithOps(0),
485         numEquals(0),
486         numComparisons(0)
487 {
488 }
489
490 EncodingEdge::EncodingEdge(EncodingNode *_left, EncodingNode *_right, EncodingNode *_dst) :
491         left(_left),
492         right(_right),
493         dst(_dst),
494         encoding(EDGE_UNASSIGNED),
495         numArithOps(0),
496         numEquals(0),
497         numComparisons(0)
498 {
499 }
500
501 uint hashEncodingEdge(EncodingEdge *edge) {
502         uintptr_t hash = (((uintptr_t) edge->left) >> 2) ^ (((uintptr_t)edge->right) >> 4) ^ (((uintptr_t)edge->dst) >> 6);
503         return (uint) hash;
504 }
505
506 bool equalsEncodingEdge(EncodingEdge *e1, EncodingEdge *e2) {
507         return e1->left == e2->left && e1->right == e2->right && e1->dst == e2->dst;
508 }
509
510 uint64_t EncodingEdge::getValue() const {
511         uint lSize = (left != NULL) ? left->getSize() : 1;
512         uint rSize = (right != NULL) ? right->getSize() : 1;
513         uint min = (lSize < rSize) ? lSize : rSize;
514         return numEquals * min + numComparisons * lSize * rSize;
515 }
516
517