Fix bug regarding order translation for removed nodes
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
1 #include "encodinggraph.h"
2 #include "iterator.h"
3 #include "element.h"
4 #include "function.h"
5 #include "predicate.h"
6 #include "set.h"
7 #include "csolver.h"
8 #include "tunable.h"
9 #include "qsort.h"
10 #include "subgraph.h"
11 #include "elementencoding.h"
12
13 EncodingGraph::EncodingGraph(CSolver *_solver) :
14         solver(_solver) {
15 }
16
17 EncodingGraph::~EncodingGraph() {
18         subgraphs.resetAndDelete();
19         encodingMap.resetAndDeleteVals();
20         edgeMap.resetAndDeleteVals();
21 }
22
23 int sortEncodingEdge(const void *p1, const void *p2) {
24         const EncodingEdge *e1 = *(const EncodingEdge **) p1;
25         const EncodingEdge *e2 = *(const EncodingEdge **) p2;
26         uint64_t v1 = e1->getValue();
27         uint64_t v2 = e2->getValue();
28         if (v1 < v2)
29                 return 1;
30         else if (v1 == v2)
31                 return 0;
32         else
33                 return -1;
34 }
35
36 void EncodingGraph::buildGraph() {
37         ElementIterator it(solver);
38         while (it.hasNext()) {
39                 Element *e = it.next();
40                 switch (e->type) {
41                 case ELEMSET:
42                 case ELEMFUNCRETURN:
43                         processElement(e);
44                         break;
45                 case ELEMCONST:
46                         break;
47                 default:
48                         ASSERT(0);
49                 }
50         }
51         bsdqsort(edgeVector.expose(), edgeVector.getSize(), sizeof(EncodingEdge *), sortEncodingEdge);
52         decideEdges();
53 }
54
55 void EncodingGraph::encode() {
56         SetIteratorEncodingSubGraph *itesg = subgraphs.iterator();
57         while (itesg->hasNext()) {
58                 EncodingSubGraph *sg = itesg->next();
59                 sg->encode();
60         }
61         delete itesg;
62
63         ElementIterator it(solver);
64         while (it.hasNext()) {
65                 Element *e = it.next();
66                 switch (e->type) {
67                 case ELEMSET:
68                 case ELEMFUNCRETURN: {
69                         ElementEncoding *encoding = e->getElementEncoding();
70                         if (encoding->getElementEncodingType() == ELEM_UNASSIGNED) {
71                                 EncodingNode *n = getNode(e);
72                                 if (n == NULL)
73                                         continue;
74                                 ElementEncodingType encodetype = n->getEncoding();
75                                 encoding->setElementEncodingType(encodetype);
76                                 if (encodetype == UNARY || encodetype == ONEHOT) {
77                                         encoding->encodingArrayInitialization();
78                                 } else if (encodetype == BINARYINDEX) {
79                                         EncodingSubGraph *subgraph = graphMap.get(n);
80                                         if (subgraph == NULL)
81                                                 continue;
82                                         uint encodingSize = subgraph->getEncodingMaxVal(n)+1;
83                                         uint paddedSize = encoding->getSizeEncodingArray(encodingSize);
84                                         encoding->allocInUseArrayElement(paddedSize);
85                                         encoding->allocEncodingArrayElement(paddedSize);
86                                         Set *s = e->getRange();
87                                         for (uint i = 0; i < s->getSize(); i++) {
88                                                 uint64_t value = s->getElement(i);
89                                                 uint encodingIndex = subgraph->getEncoding(n, value);
90                                                 encoding->setInUseElement(encodingIndex);
91                                                 encoding->encodingArray[encodingIndex] = value;
92                                         }
93                                 }
94                         }
95                         break;
96                 }
97                 default:
98                         break;
99                 }
100                 encodeParent(e);
101         }
102 }
103
104 void EncodingGraph::encodeParent(Element *e) {
105         uint size = e->parents.getSize();
106         for (uint i = 0; i < size; i++) {
107                 ASTNode *n = e->parents.get(i);
108                 if (n->type == PREDICATEOP) {
109                         BooleanPredicate *b = (BooleanPredicate *)n;
110                         FunctionEncoding *fenc = b->getFunctionEncoding();
111                         if (fenc->getFunctionEncodingType() != FUNC_UNASSIGNED)
112                                 continue;
113                         Predicate *p = b->getPredicate();
114                         if (p->type == OPERATORPRED) {
115                                 PredicateOperator *po = (PredicateOperator *)p;
116                                 ASSERT(b->inputs.getSize() == 2);
117                                 EncodingNode *left = createNode(b->inputs.get(0));
118                                 EncodingNode *right = createNode(b->inputs.get(1));
119                                 if (left == NULL || right == NULL)
120                                         return;
121                                 EncodingEdge *edge = getEdge(left, right, NULL);
122                                 if (edge != NULL && edge->getEncoding() == EDGE_MATCH) {
123                                         fenc->setFunctionEncodingType(CIRCUIT);
124                                 }
125                         }
126                 }
127         }
128 }
129
130 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
131         EncodingSubGraph *graph1 = graphMap.get(first);
132         EncodingSubGraph *graph2 = graphMap.get(second);
133         if (graph1 == NULL)
134                 first->setEncoding(BINARYINDEX);
135         if (graph2 == NULL)
136                 second->setEncoding(BINARYINDEX);
137
138         if (graph1 == NULL && graph2 == NULL) {
139                 graph1 = new EncodingSubGraph();
140                 subgraphs.add(graph1);
141                 graphMap.put(first, graph1);
142                 graph1->addNode(first);
143         }
144         if (graph1 == NULL && graph2 != NULL) {
145                 graph1 = graph2;
146                 graph2 = NULL;
147                 EncodingNode *tmp = second;
148                 second = first;
149                 first = tmp;
150         }
151         if (graph1 != NULL && graph2 != NULL) {
152                 SetIteratorEncodingNode *nodeit = graph2->nodeIterator();
153                 while (nodeit->hasNext()) {
154                         EncodingNode *node = nodeit->next();
155                         graph1->addNode(node);
156                         graphMap.put(node, graph1);
157                 }
158                 subgraphs.remove(graph2);
159                 delete nodeit;
160                 delete graph2;
161         } else {
162                 ASSERT(graph1 != NULL && graph2 == NULL);
163                 graph1->addNode(second);
164                 graphMap.put(second, graph1);
165         }
166 }
167
168 void EncodingGraph::processElement(Element *e) {
169         uint size = e->parents.getSize();
170         for (uint i = 0; i < size; i++) {
171                 ASTNode *n = e->parents.get(i);
172                 switch (n->type) {
173                 case PREDICATEOP:
174                         processPredicate((BooleanPredicate *)n);
175                         break;
176                 case ELEMFUNCRETURN:
177                         processFunction((ElementFunction *)n);
178                         break;
179                 default:
180                         ASSERT(0);
181                 }
182         }
183 }
184
185 void EncodingGraph::processFunction(ElementFunction *ef) {
186         Function *f = ef->getFunction();
187         if (f->type == OPERATORFUNC) {
188                 FunctionOperator *fo = (FunctionOperator *)f;
189                 ASSERT(ef->inputs.getSize() == 2);
190                 EncodingNode *left = createNode(ef->inputs.get(0));
191                 EncodingNode *right = createNode(ef->inputs.get(1));
192                 if (left == NULL && right == NULL)
193                         return;
194                 EncodingNode *dst = createNode(ef);
195                 EncodingEdge *edge = createEdge(left, right, dst);
196                 edge->numArithOps++;
197         }
198 }
199
200 void EncodingGraph::processPredicate(BooleanPredicate *b) {
201         Predicate *p = b->getPredicate();
202         if (p->type == OPERATORPRED) {
203                 PredicateOperator *po = (PredicateOperator *)p;
204                 ASSERT(b->inputs.getSize() == 2);
205                 EncodingNode *left = createNode(b->inputs.get(0));
206                 EncodingNode *right = createNode(b->inputs.get(1));
207                 if (left == NULL || right == NULL)
208                         return;
209                 EncodingEdge *edge = createEdge(left, right, NULL);
210                 CompOp op = po->getOp();
211                 switch (op) {
212                 case SATC_EQUALS:
213                         edge->numEquals++;
214                         break;
215                 case SATC_LT:
216                 case SATC_LTE:
217                 case SATC_GT:
218                 case SATC_GTE:
219                         edge->numComparisons++;
220                         break;
221                 default:
222                         ASSERT(0);
223                 }
224         }
225 }
226
227 uint convertSize(uint cost) {
228         cost = 1.2 * cost;// fudge factor
229         return NEXTPOW2(cost);
230 }
231
232 void EncodingGraph::decideEdges() {
233         uint size = edgeVector.getSize();
234         for (uint i = 0; i < size; i++) {
235                 EncodingEdge *ee = edgeVector.get(i);
236                 EncodingNode *left = ee->left;
237                 EncodingNode *right = ee->right;
238
239                 if (ee->encoding != EDGE_UNASSIGNED ||
240                                 !left->couldBeBinaryIndex() ||
241                                 !right->couldBeBinaryIndex())
242                         continue;
243
244                 uint64_t eeValue = ee->getValue();
245                 if (eeValue == 0)
246                         return;
247
248                 EncodingSubGraph *leftGraph = graphMap.get(left);
249                 EncodingSubGraph *rightGraph = graphMap.get(right);
250                 if (leftGraph == NULL && rightGraph != NULL) {
251                         EncodingNode *tmp = left; left = right; right = tmp;
252                         EncodingSubGraph *tmpsg = leftGraph; leftGraph = rightGraph; rightGraph = tmpsg;
253                 }
254
255                 uint leftSize = 0, rightSize = 0, newSize = 0;
256                 uint64_t totalCost = 0;
257                 if (leftGraph == NULL && rightGraph == NULL) {
258                         leftSize = convertSize(left->getSize());
259                         rightSize = convertSize(right->getSize());
260                         newSize = convertSize(left->s->getUnionSize(right->s));
261                         newSize = (leftSize > newSize) ? leftSize : newSize;
262                         newSize = (rightSize > newSize) ? rightSize : newSize;
263                         totalCost = (newSize - leftSize) * left->elements.getSize() +
264                                                                         (newSize - rightSize) * right->elements.getSize();
265                 } else if (leftGraph != NULL && rightGraph == NULL) {
266                         leftSize = convertSize(leftGraph->encodingSize);
267                         rightSize = convertSize(right->getSize());
268                         newSize = convertSize(leftGraph->estimateNewSize(right));
269                         newSize = (leftSize > newSize) ? leftSize : newSize;
270                         newSize = (rightSize > newSize) ? rightSize : newSize;
271                         totalCost = (newSize - leftSize) * leftGraph->numElements +
272                                                                         (newSize - rightSize) * right->elements.getSize();
273                 } else {
274                         //Neither are null
275                         leftSize = convertSize(leftGraph->encodingSize);
276                         rightSize = convertSize(rightGraph->encodingSize);
277                         newSize = convertSize(leftGraph->estimateNewSize(rightGraph));
278                         newSize = (leftSize > newSize) ? leftSize : newSize;
279                         newSize = (rightSize > newSize) ? rightSize : newSize;
280                         totalCost = (newSize - leftSize) * leftGraph->numElements +
281                                                                         (newSize - rightSize) * rightGraph->numElements;
282                 }
283                 double conversionfactor = 0.5;
284                 if ((totalCost * conversionfactor) < eeValue) {
285                         //add the edge
286                         mergeNodes(left, right);
287                 }
288         }
289 }
290
291 static TunableDesc EdgeEncodingDesc(EDGE_UNASSIGNED, EDGE_MATCH, EDGE_UNASSIGNED);
292
293 EncodingEdge *EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
294         EncodingEdge e(left, right, dst);
295         EncodingEdge *result = edgeMap.get(&e);
296         return result;
297 }
298
299 EncodingEdge *EncodingGraph::createEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
300         EncodingEdge e(left, right, dst);
301         EncodingEdge *result = edgeMap.get(&e);
302         if (result == NULL) {
303                 result = new EncodingEdge(left, right, dst);
304                 VarType v1 = left->getType();
305                 VarType v2 = right->getType();
306                 if (v1 > v2) {
307                         VarType tmp = v2;
308                         v2 = v1;
309                         v1 = tmp;
310                 }
311
312                 if ((left != NULL && left->couldBeBinaryIndex()) &&
313                                 (right != NULL) && right->couldBeBinaryIndex()) {
314                         EdgeEncodingType type = (EdgeEncodingType)solver->getTuner()->getVarTunable(v1, v2, EDGEENCODING, &EdgeEncodingDesc);
315                         result->setEncoding(type);
316                         if (type == EDGE_MATCH) {
317                                 mergeNodes(left, right);
318                         }
319                 }
320                 edgeMap.put(result, result);
321                 edgeVector.push(result);
322                 if (left != NULL)
323                         left->edges.add(result);
324                 if (right != NULL)
325                         right->edges.add(result);
326                 if (dst != NULL)
327                         dst->edges.add(result);
328         }
329         return result;
330 }
331
332 EncodingNode::EncodingNode(Set *_s) :
333         s(_s) {
334 }
335
336 uint EncodingNode::getSize() const {
337         return s->getSize();
338 }
339
340 VarType EncodingNode::getType() const {
341         return s->getType();
342 }
343
344 static TunableDesc NodeEncodingDesc(ELEM_UNASSIGNED, BINARYINDEX, ELEM_UNASSIGNED);
345
346 EncodingNode *EncodingGraph::createNode(Element *e) {
347         if (e->type == ELEMCONST)
348                 return NULL;
349         Set *s = e->getRange();
350         EncodingNode *n = encodingMap.get(s);
351         if (n == NULL) {
352                 n = new EncodingNode(s);
353                 n->setEncoding((ElementEncodingType)solver->getTuner()->getVarTunable(n->getType(), NODEENCODING, &NodeEncodingDesc));
354
355                 encodingMap.put(s, n);
356         }
357         n->addElement(e);
358         return n;
359 }
360
361 EncodingNode *EncodingGraph::getNode(Element *e) {
362         if (e->type == ELEMCONST)
363                 return NULL;
364         Set *s = e->getRange();
365         EncodingNode *n = encodingMap.get(s);
366         return n;
367 }
368
369 void EncodingNode::addElement(Element *e) {
370         elements.add(e);
371 }
372
373 EncodingEdge::EncodingEdge(EncodingNode *_l, EncodingNode *_r) :
374         left(_l),
375         right(_r),
376         dst(NULL),
377         encoding(EDGE_UNASSIGNED),
378         numArithOps(0),
379         numEquals(0),
380         numComparisons(0)
381 {
382 }
383
384 EncodingEdge::EncodingEdge(EncodingNode *_left, EncodingNode *_right, EncodingNode *_dst) :
385         left(_left),
386         right(_right),
387         dst(_dst),
388         encoding(EDGE_UNASSIGNED),
389         numArithOps(0),
390         numEquals(0),
391         numComparisons(0)
392 {
393 }
394
395 uint hashEncodingEdge(EncodingEdge *edge) {
396         uintptr_t hash = (((uintptr_t) edge->left) >> 2) ^ (((uintptr_t)edge->right) >> 4) ^ (((uintptr_t)edge->dst) >> 6);
397         return (uint) hash;
398 }
399
400 bool equalsEncodingEdge(EncodingEdge *e1, EncodingEdge *e2) {
401         return e1->left == e2->left && e1->right == e2->right && e1->dst == e2->dst;
402 }
403
404 uint64_t EncodingEdge::getValue() const {
405         uint lSize = (left != NULL) ? left->getSize() : 1;
406         uint rSize = (right != NULL) ? right->getSize() : 1;
407         uint min = (lSize < rSize) ? lSize : rSize;
408         return numEquals * min + numComparisons * lSize * rSize;
409 }
410
411