8e0bc544cdbdd3a113f588c9543cf93ebf32701d
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
1 #include "encodinggraph.h"
2 #include "iterator.h"
3 #include "element.h"
4 #include "function.h"
5 #include "predicate.h"
6 #include "set.h"
7 #include "csolver.h"
8 #include "tunable.h"
9 #include "qsort.h"
10 #include "subgraph.h"
11 #include "elementencoding.h"
12
13 EncodingGraph::EncodingGraph(CSolver *_solver) :
14         solver(_solver) {
15 }
16
17 EncodingGraph::~EncodingGraph() {
18         subgraphs.resetAndDelete();
19         encodingMap.resetAndDeleteVals();
20         edgeMap.resetAndDeleteVals();
21 }
22
23 int sortEncodingEdge(const void *p1, const void *p2) {
24         const EncodingEdge *e1 = *(const EncodingEdge **) p1;
25         const EncodingEdge *e2 = *(const EncodingEdge **) p2;
26         uint64_t v1 = e1->getValue();
27         uint64_t v2 = e2->getValue();
28         if (v1 < v2)
29                 return 1;
30         else if (v1 == v2)
31                 return 0;
32         else
33                 return -1;
34 }
35
36 void EncodingGraph::buildGraph() {
37         ElementIterator it(solver);
38         while (it.hasNext()) {
39                 Element *e = it.next();
40                 switch (e->type) {
41                 case ELEMSET:
42                 case ELEMFUNCRETURN:
43                         processElement(e);
44                         break;
45                 case ELEMCONST:
46                         break;
47                 default:
48                         ASSERT(0);
49                 }
50         }
51         bsdqsort(edgeVector.expose(), edgeVector.getSize(), sizeof(EncodingEdge *), sortEncodingEdge);
52         decideEdges();
53 }
54
55
56 void EncodingGraph::validate() {
57         SetIteratorBooleanEdge* it= solver->getConstraints();
58         while(it->hasNext()){
59                 BooleanEdge be = it->next();
60                 if(be->type == PREDICATEOP){
61                         BooleanPredicate *b = (BooleanPredicate *)be.getBoolean();
62                         if(b->predicate->type == OPERATORPRED){
63                                 PredicateOperator* predicate = (PredicateOperator*) b->predicate;
64                                 if(predicate->getOp() == SATC_EQUALS){
65                                         ASSERT(b->inputs.getSize() == 2);
66                                         Element* e1= b->inputs.get(0);
67                                         Element* e2= b->inputs.get(1);
68                                         if(e1->type == ELEMCONST || e1->type == ELEMCONST)
69                                                 continue;
70                                         ElementEncoding *enc1 = e1->getElementEncoding();
71                                         ElementEncoding *enc2 = e2->getElementEncoding();
72                                         ASSERT(enc1->getElementEncodingType() != ELEM_UNASSIGNED);
73                                         ASSERT(enc2->getElementEncodingType() != ELEM_UNASSIGNED);
74                                         if(enc1->getElementEncodingType() == enc2->getElementEncodingType() && enc1->getElementEncodingType() == BINARYINDEX && b->getFunctionEncoding()->type == CIRCUIT){
75                                                 for(uint i=0; i<enc1->encArraySize; i++){
76                                                         if(enc1->isinUseElement(i)){
77                                                                 uint64_t val1 = enc1->encodingArray[i];
78                                                                 if(enc2->isinUseElement(i)){
79                                                                         ASSERT(val1 == enc2->encodingArray[i]);
80                                                                 }else{
81                                                                         for(uint j=0; j< enc2->encArraySize; j++){
82                                                                                 if(enc2->isinUseElement(j)){
83                                                                                         ASSERT(val1 != enc2->encodingArray[j]);
84                                                                                 }
85                                                                         }
86                                                                 }
87                                                         }
88                                                 }
89                                         }
90                                         //Now make sure that all the elements in the set are appeared in the encoding array!
91                                         for(uint k=0; k< b->inputs.getSize(); k++){
92                                                 Element *e = b->inputs.get(k);
93                                                 ElementEncoding *enc = e->getElementEncoding();
94                                                 Set *s = e->getRange();
95                                                 for (uint i = 0; i < s->getSize(); i++) {
96                                                         uint64_t value = s->getElement(i);
97                                                         bool exist=false;
98                                                         for(uint j=0; j< enc->encArraySize; j++){
99                                                                 if(enc->isinUseElement(j) && enc->encodingArray[j] == value){
100                                                                         exist = true;
101                                                                         break;
102                                                                 }
103                                                         }
104                                                         ASSERT(exist);
105                                                 }
106                                         }
107                                 }
108                         }
109                 }
110         }
111         delete it;
112 }
113
114
115 void EncodingGraph::encode() {
116         SetIteratorEncodingSubGraph *itesg = subgraphs.iterator();
117         model_print("#SubGraph = %u", subgraphs.getSize());
118         while (itesg->hasNext()) {
119                 EncodingSubGraph *sg = itesg->next();
120                 sg->encode();
121         }
122         delete itesg;
123
124         ElementIterator it(solver);
125         while (it.hasNext()) {
126                 Element *e = it.next();
127                 switch (e->type) {
128                 case ELEMSET:
129                 case ELEMFUNCRETURN: {
130                         ElementEncoding *encoding = e->getElementEncoding();
131                         if (encoding->getElementEncodingType() == ELEM_UNASSIGNED) {
132                                 EncodingNode *n = getNode(e);
133                                 if (n == NULL)
134                                         continue;
135                                 ElementEncodingType encodetype = n->getEncoding();
136                                 encoding->setElementEncodingType(encodetype);
137                                 if (encodetype == UNARY || encodetype == ONEHOT) {
138                                         encoding->encodingArrayInitialization();
139                                 } else if (encodetype == BINARYINDEX) {
140                                         EncodingSubGraph *subgraph = graphMap.get(n);
141                                         DEBUG("graphMap.get(subgraph=%p, n=%p)\n", subgraph, n);
142                                         if (subgraph == NULL) {
143                                                 encoding->encodingArrayInitialization();
144                                                 continue;
145                                         }
146                                         uint encodingSize = subgraph->getEncodingMaxVal(n) + 1;
147                                         uint paddedSize = encoding->getSizeEncodingArray(encodingSize);
148                                         encoding->allocInUseArrayElement(paddedSize);
149                                         encoding->allocEncodingArrayElement(paddedSize);
150                                         Set *s = e->getRange();
151                                         for (uint i = 0; i < s->getSize(); i++) {
152                                                 uint64_t value = s->getElement(i);
153                                                 uint encodingIndex = subgraph->getEncoding(n, value);
154                                                 encoding->setInUseElement(encodingIndex);
155                                                 ASSERT(encoding->isinUseElement(encodingIndex));
156                                                 encoding->encodingArray[encodingIndex] = value;
157                                         }
158                                 }
159                         }
160                         break;
161                 }
162                 default:
163                         break;
164                 }
165                 encodeParent(e);
166         }
167 }
168
169 void EncodingGraph::encodeParent(Element *e) {
170         uint size = e->parents.getSize();
171         for (uint i = 0; i < size; i++) {
172                 ASTNode *n = e->parents.get(i);
173                 if (n->type == PREDICATEOP) {
174                         BooleanPredicate *b = (BooleanPredicate *)n;
175                         FunctionEncoding *fenc = b->getFunctionEncoding();
176                         if (fenc->getFunctionEncodingType() != FUNC_UNASSIGNED)
177                                 continue;
178                         Predicate *p = b->getPredicate();
179                         if (p->type == OPERATORPRED) {
180                                 PredicateOperator *po = (PredicateOperator *)p;
181                                 ASSERT(b->inputs.getSize() == 2);
182                                 EncodingNode *left = createNode(b->inputs.get(0));
183                                 EncodingNode *right = createNode(b->inputs.get(1));
184                                 if (left == NULL || right == NULL)
185                                         return;
186                                 EncodingEdge *edge = getEdge(left, right, NULL);
187                                 if (edge != NULL) {
188                                         EncodingSubGraph *leftGraph = graphMap.get(left);
189                                         if (leftGraph != NULL && leftGraph == graphMap.get(right)) {
190                                                 fenc->setFunctionEncodingType(CIRCUIT);
191                                         }
192                                 }
193                         }
194                 }
195         }
196 }
197
198 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
199         EncodingSubGraph *graph1 = graphMap.get(first);
200         DEBUG("graphMap.get(first=%p, graph1=%p)\n", first, graph1);
201         EncodingSubGraph *graph2 = graphMap.get(second);
202         DEBUG("graphMap.get(second=%p, graph2=%p)\n", second, graph2);
203         if (graph1 == NULL)
204                 first->setEncoding(BINARYINDEX);
205         if (graph2 == NULL)
206                 second->setEncoding(BINARYINDEX);
207         
208         if (graph1 == NULL && graph2 == NULL) {
209                 graph1 = new EncodingSubGraph();
210                 subgraphs.add(graph1);
211                 DEBUG("graphMap.put(first=%p, graph1=%p)\n", first, graph1);
212                 graphMap.put(first, graph1);
213                 graph1->addNode(first);
214         }
215         if (graph1 == NULL && graph2 != NULL) {
216                 graph1 = graph2;
217                 graph2 = NULL;
218                 EncodingNode *tmp = second;
219                 second = first;
220                 first = tmp;
221         }
222         if (graph1 != NULL && graph2 != NULL) {
223                 if (graph1 == graph2)
224                         return;
225
226                 SetIteratorEncodingNode *nodeit = graph2->nodeIterator();
227                 while (nodeit->hasNext()) {
228                         EncodingNode *node = nodeit->next();
229                         graph1->addNode(node);
230                         DEBUG("graphMap.put(node=%p, graph1=%p)\n", node, graph1);
231                         graphMap.put(node, graph1);
232                 }
233                 subgraphs.remove(graph2);
234                 delete nodeit;
235                 DEBUG("Deleting graph2 =%p \n", graph2);
236                 delete graph2;
237         } else {
238                 ASSERT(graph1 != NULL && graph2 == NULL);
239                 graph1->addNode(second);
240                 DEBUG("graphMap.put(first=%p, graph1=%p)\n", first, graph1);
241                 graphMap.put(second, graph1);
242         }
243 }
244
245 void EncodingGraph::processElement(Element *e) {
246         uint size = e->parents.getSize();
247         for (uint i = 0; i < size; i++) {
248                 ASTNode *n = e->parents.get(i);
249                 switch (n->type) {
250                 case PREDICATEOP:
251                         processPredicate((BooleanPredicate *)n);
252                         break;
253                 case ELEMFUNCRETURN:
254                         processFunction((ElementFunction *)n);
255                         break;
256                 default:
257                         ASSERT(0);
258                 }
259         }
260 }
261
262 void EncodingGraph::processFunction(ElementFunction *ef) {
263         Function *f = ef->getFunction();
264         if (f->type == OPERATORFUNC) {
265                 FunctionOperator *fo = (FunctionOperator *)f;
266                 ASSERT(ef->inputs.getSize() == 2);
267                 EncodingNode *left = createNode(ef->inputs.get(0));
268                 EncodingNode *right = createNode(ef->inputs.get(1));
269                 if (left == NULL && right == NULL)
270                         return;
271                 EncodingNode *dst = createNode(ef);
272                 EncodingEdge *edge = createEdge(left, right, dst);
273                 edge->numArithOps++;
274         }
275 }
276
277 void EncodingGraph::processPredicate(BooleanPredicate *b) {
278         Predicate *p = b->getPredicate();
279         if (p->type == OPERATORPRED) {
280                 PredicateOperator *po = (PredicateOperator *)p;
281                 ASSERT(b->inputs.getSize() == 2);
282                 EncodingNode *left = createNode(b->inputs.get(0));
283                 EncodingNode *right = createNode(b->inputs.get(1));
284                 if (left == NULL || right == NULL)
285                         return;
286                 EncodingEdge *edge = createEdge(left, right, NULL);
287                 CompOp op = po->getOp();
288                 switch (op) {
289                 case SATC_EQUALS:
290                         edge->numEquals++;
291                         break;
292                 case SATC_LT:
293                 case SATC_LTE:
294                 case SATC_GT:
295                 case SATC_GTE:
296                         edge->numComparisons++;
297                         break;
298                 default:
299                         ASSERT(0);
300                 }
301         }
302 }
303
304 uint convertSize(uint cost) {
305         cost = FUDGEFACTOR * cost;// fudge factor
306         return NEXTPOW2(cost);
307 }
308
309 void EncodingGraph::decideEdges() {
310         uint size = edgeVector.getSize();
311         for (uint i = 0; i < size; i++) {
312                 EncodingEdge *ee = edgeVector.get(i);
313                 EncodingNode *left = ee->left;
314                 EncodingNode *right = ee->right;
315
316                 if (ee->encoding != EDGE_UNASSIGNED ||
317                                 !left->couldBeBinaryIndex() ||
318                                 !right->couldBeBinaryIndex())
319                         continue;
320
321                 uint64_t eeValue = ee->getValue();
322                 if (eeValue == 0)
323                         return;
324
325                 EncodingSubGraph *leftGraph = graphMap.get(left);
326                 DEBUG("graphMap.get(left=%p, leftgraph=%p)\n", left, leftGraph);
327                 EncodingSubGraph *rightGraph = graphMap.get(right);
328                 DEBUG("graphMap.get(right=%p, rightgraph=%p)\n", right, rightGraph);
329                 if (leftGraph == NULL && rightGraph != NULL) {
330                         EncodingNode *tmp = left; left = right; right = tmp;
331                         EncodingSubGraph *tmpsg = leftGraph; leftGraph = rightGraph; rightGraph = tmpsg;
332                 }
333
334                 uint leftSize = 0, rightSize = 0, newSize = 0;
335                 uint64_t totalCost = 0;
336                 bool merge = false;
337 //              model_print("**************decideEdge*************\n");
338 //              model_print("LeftNode Size = %u\n", left->getSize());
339 //              model_print("rightNode Size = %u\n", right->getSize());
340 //              model_print("UnionSize = %u\n", left->s->getUnionSize(right->s));
341                         
342                 if (leftGraph == NULL && rightGraph == NULL) {
343                         leftSize = convertSize(left->getSize());
344                         rightSize = convertSize(right->getSize());
345                         newSize = convertSize(left->s->getUnionSize(right->s));
346                         newSize = (leftSize > newSize) ? leftSize : newSize;
347                         newSize = (rightSize > newSize) ? rightSize : newSize;
348                         totalCost = (newSize - leftSize) * left->elements.getSize() +
349                                                                         (newSize - rightSize) * right->elements.getSize();
350                         if(leftSize == newSize && rightSize == newSize){
351                                 merge = true;
352                         }
353                 } else if (leftGraph != NULL && rightGraph == NULL) {
354                         leftSize = convertSize(leftGraph->encodingSize);
355                         rightSize = convertSize(right->getSize());
356                         newSize = convertSize(leftGraph->estimateNewSize(right));
357                         newSize = (leftSize > newSize) ? leftSize : newSize;
358                         newSize = (rightSize > newSize) ? rightSize : newSize;
359                         totalCost = (newSize - leftSize) * leftGraph->numElements +
360                                                                         (newSize - rightSize) * right->elements.getSize();
361                         if(leftSize == newSize && rightSize == newSize){
362                                 merge = true;
363                         }
364                 } else {
365                         //Neither are null
366                         leftSize = convertSize(leftGraph->encodingSize);
367                         rightSize = convertSize(rightGraph->encodingSize);
368                         newSize = convertSize(leftGraph->estimateNewSize(rightGraph));
369                         newSize = (leftSize > newSize) ? leftSize : newSize;
370                         newSize = (rightSize > newSize) ? rightSize : newSize;
371                         totalCost = (newSize - leftSize) * leftGraph->numElements +
372                                                                         (newSize - rightSize) * rightGraph->numElements;
373 //                      model_print("LeftGraph size=%u\n", leftGraph->encodingSize);
374 //                      model_print("RightGraph size=%u\n", rightGraph->encodingSize);
375 //                      model_print("UnionGraph size = %u\n", leftGraph->estimateNewSize(rightGraph));
376                         if(rightSize < 64 && leftSize < 64){
377                                 merge = true;
378                         }
379                 }
380 //              model_print("******************************\n");
381                 if (merge) {
382                         //add the edge
383                         mergeNodes(left, right);
384                 }
385         }
386 }
387
388 static TunableDesc EdgeEncodingDesc(EDGE_UNASSIGNED, EDGE_MATCH, EDGE_UNASSIGNED);
389
390 EncodingEdge *EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
391         EncodingEdge e(left, right, dst);
392         EncodingEdge *result = edgeMap.get(&e);
393         return result;
394 }
395
396 EncodingEdge *EncodingGraph::createEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
397         EncodingEdge e(left, right, dst);
398         EncodingEdge *result = edgeMap.get(&e);
399         if (result == NULL) {
400                 result = new EncodingEdge(left, right, dst);
401                 VarType v1 = left->getType();
402                 VarType v2 = right->getType();
403                 if (v1 > v2) {
404                         VarType tmp = v2;
405                         v2 = v1;
406                         v1 = tmp;
407                 }
408
409                 if ((left != NULL && left->couldBeBinaryIndex()) &&
410                                 (right != NULL) && right->couldBeBinaryIndex()) {
411                         EdgeEncodingType type = (EdgeEncodingType)solver->getTuner()->getVarTunable(v1, v2, EDGEENCODING, &EdgeEncodingDesc);
412                         result->setEncoding(type);
413                         if (type == EDGE_MATCH) {
414                                 mergeNodes(left, right);
415                         }
416                 }
417                 edgeMap.put(result, result);
418                 edgeVector.push(result);
419                 if (left != NULL)
420                         left->edges.add(result);
421                 if (right != NULL)
422                         right->edges.add(result);
423                 if (dst != NULL)
424                         dst->edges.add(result);
425         }
426         return result;
427 }
428
429 EncodingNode::EncodingNode(Set *_s) :
430         s(_s) {
431 }
432
433 uint EncodingNode::getSize() const {
434         return s->getSize();
435 }
436
437 VarType EncodingNode::getType() const {
438         return s->getType();
439 }
440
441 static TunableDesc NodeEncodingDesc(ELEM_UNASSIGNED, BINARYINDEX, ELEM_UNASSIGNED);
442
443 EncodingNode *EncodingGraph::createNode(Element *e) {
444         if (e->type == ELEMCONST)
445                 return NULL;
446         Set *s = e->getRange();
447         EncodingNode *n = encodingMap.get(s);
448         if (n == NULL) {
449                 n = new EncodingNode(s);
450                 n->setEncoding((ElementEncodingType)solver->getTuner()->getVarTunable(n->getType(), NODEENCODING, &NodeEncodingDesc));
451
452                 encodingMap.put(s, n);
453         }
454         n->addElement(e);
455         return n;
456 }
457
458 EncodingNode *EncodingGraph::getNode(Element *e) {
459         if (e->type == ELEMCONST)
460                 return NULL;
461         Set *s = e->getRange();
462         EncodingNode *n = encodingMap.get(s);
463         return n;
464 }
465
466 void EncodingNode::addElement(Element *e) {
467         elements.add(e);
468 }
469
470 EncodingEdge::EncodingEdge(EncodingNode *_l, EncodingNode *_r) :
471         left(_l),
472         right(_r),
473         dst(NULL),
474         encoding(EDGE_UNASSIGNED),
475         numArithOps(0),
476         numEquals(0),
477         numComparisons(0)
478 {
479 }
480
481 EncodingEdge::EncodingEdge(EncodingNode *_left, EncodingNode *_right, EncodingNode *_dst) :
482         left(_left),
483         right(_right),
484         dst(_dst),
485         encoding(EDGE_UNASSIGNED),
486         numArithOps(0),
487         numEquals(0),
488         numComparisons(0)
489 {
490 }
491
492 uint hashEncodingEdge(EncodingEdge *edge) {
493         uintptr_t hash = (((uintptr_t) edge->left) >> 2) ^ (((uintptr_t)edge->right) >> 4) ^ (((uintptr_t)edge->dst) >> 6);
494         return (uint) hash;
495 }
496
497 bool equalsEncodingEdge(EncodingEdge *e1, EncodingEdge *e2) {
498         return e1->left == e2->left && e1->right == e2->right && e1->dst == e2->dst;
499 }
500
501 uint64_t EncodingEdge::getValue() const {
502         uint lSize = (left != NULL) ? left->getSize() : 1;
503         uint rSize = (right != NULL) ? right->getSize() : 1;
504         uint min = (lSize < rSize) ? lSize : rSize;
505         return numEquals * min + numComparisons * lSize * rSize;
506 }
507
508