revert Hamed's changes to encoding graph
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
1 #include "encodinggraph.h"
2 #include "iterator.h"
3 #include "element.h"
4 #include "function.h"
5 #include "predicate.h"
6 #include "set.h"
7 #include "csolver.h"
8 #include "tunable.h"
9 #include "qsort.h"
10 #include "subgraph.h"
11 #include "elementencoding.h"
12
13 EncodingGraph::EncodingGraph(CSolver *_solver) :
14         solver(_solver) {
15 }
16
17 EncodingGraph::~EncodingGraph() {
18         subgraphs.resetAndDelete();
19         encodingMap.resetAndDeleteVals();
20         edgeMap.resetAndDeleteVals();
21 }
22
23 int sortEncodingEdge(const void *p1, const void *p2) {
24         const EncodingEdge *e1 = *(const EncodingEdge **) p1;
25         const EncodingEdge *e2 = *(const EncodingEdge **) p2;
26         uint64_t v1 = e1->getValue();
27         uint64_t v2 = e2->getValue();
28         if (v1 < v2)
29                 return 1;
30         else if (v1 == v2)
31                 return 0;
32         else
33                 return -1;
34 }
35
36 void EncodingGraph::buildGraph() {
37         ElementIterator it(solver);
38         while (it.hasNext()) {
39                 Element *e = it.next();
40                 switch (e->type) {
41                 case ELEMSET:
42                 case ELEMFUNCRETURN:
43                         processElement(e);
44                         break;
45                 case ELEMCONST:
46                         break;
47                 default:
48                         ASSERT(0);
49                 }
50         }
51         bsdqsort(edgeVector.expose(), edgeVector.getSize(), sizeof(EncodingEdge *), sortEncodingEdge);
52         decideEdges();
53 }
54
55 void EncodingGraph::encode() {
56         SetIteratorEncodingSubGraph *itesg = subgraphs.iterator();
57         while (itesg->hasNext()) {
58                 EncodingSubGraph *sg = itesg->next();
59                 sg->encode();
60         }
61         delete itesg;
62
63         ElementIterator it(solver);
64         while (it.hasNext()) {
65                 Element *e = it.next();
66                 switch (e->type) {
67                 case ELEMSET:
68                 case ELEMFUNCRETURN: {
69                         ElementEncoding *encoding = e->getElementEncoding();
70                         if (encoding->getElementEncodingType() == ELEM_UNASSIGNED) {
71                                 EncodingNode *n = getNode(e);
72                                 if (n == NULL)
73                                         continue;
74                                 ElementEncodingType encodetype = n->getEncoding();
75                                 encoding->setElementEncodingType(encodetype);
76                                 if (encodetype == UNARY || encodetype == ONEHOT) {
77                                         encoding->encodingArrayInitialization();
78                                 } else if (encodetype == BINARYINDEX) {
79                                         EncodingSubGraph *subgraph = graphMap.get(n);
80                                         DEBUG("graphMap.get(subgraph=%p, n=%p)\n", subgraph, n);
81                                         if (subgraph == NULL) {
82                                                 continue;
83                                         }
84                                         uint encodingSize = subgraph->getEncodingMaxVal(n) + 1;
85                                         uint paddedSize = encoding->getSizeEncodingArray(encodingSize);
86                                         encoding->allocInUseArrayElement(paddedSize);
87                                         encoding->allocEncodingArrayElement(paddedSize);
88                                         Set *s = e->getRange();
89                                         for (uint i = 0; i < s->getSize(); i++) {
90                                                 uint64_t value = s->getElement(i);
91                                                 uint encodingIndex = subgraph->getEncoding(n, value);
92                                                 encoding->setInUseElement(encodingIndex);
93                                                 encoding->encodingArray[encodingIndex] = value;
94                                         }
95                                 }
96                         }
97                         break;
98                 }
99                 default:
100                         break;
101                 }
102                 encodeParent(e);
103         }
104 }
105
106 void EncodingGraph::encodeParent(Element *e) {
107         uint size = e->parents.getSize();
108         for (uint i = 0; i < size; i++) {
109                 ASTNode *n = e->parents.get(i);
110                 if (n->type == PREDICATEOP) {
111                         BooleanPredicate *b = (BooleanPredicate *)n;
112                         FunctionEncoding *fenc = b->getFunctionEncoding();
113                         if (fenc->getFunctionEncodingType() != FUNC_UNASSIGNED)
114                                 continue;
115                         Predicate *p = b->getPredicate();
116                         if (p->type == OPERATORPRED) {
117                                 PredicateOperator *po = (PredicateOperator *)p;
118                                 ASSERT(b->inputs.getSize() == 2);
119                                 EncodingNode *left = createNode(b->inputs.get(0));
120                                 EncodingNode *right = createNode(b->inputs.get(1));
121                                 if (left == NULL || right == NULL)
122                                         return;
123                                 EncodingEdge *edge = getEdge(left, right, NULL);
124                                 if (edge != NULL && edge->getEncoding() == EDGE_MATCH) {
125                                         fenc->setFunctionEncodingType(CIRCUIT);
126                                 }
127                         }
128                 }
129         }
130 }
131
132 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
133         EncodingSubGraph *graph1 = graphMap.get(first);
134         DEBUG("graphMap.get(first=%p, graph1=%p)\n", first, graph1);
135         EncodingSubGraph *graph2 = graphMap.get(second);
136         DEBUG("graphMap.get(second=%p, graph2=%p)\n", second, graph2);
137         if (graph1 == NULL)
138                 first->setEncoding(BINARYINDEX);
139         if (graph2 == NULL)
140                 second->setEncoding(BINARYINDEX);
141
142         if (graph1 == NULL && graph2 == NULL) {
143                 graph1 = new EncodingSubGraph();
144                 subgraphs.add(graph1);
145                 DEBUG("graphMap.put(first=%p, graph1=%p)\n", first, graph1);
146                 graphMap.put(first, graph1);
147                 graph1->addNode(first);
148         }
149         if (graph1 == NULL && graph2 != NULL) {
150                 graph1 = graph2;
151                 graph2 = NULL;
152                 EncodingNode *tmp = second;
153                 second = first;
154                 first = tmp;
155         }
156         if (graph1 != NULL && graph2 != NULL) {
157                 SetIteratorEncodingNode *nodeit = graph2->nodeIterator();
158                 while (nodeit->hasNext()) {
159                         EncodingNode *node = nodeit->next();
160                         graph1->addNode(node);
161                         DEBUG("graphMap.put(node=%p, graph1=%p)\n", node, graph1);
162                         graphMap.put(node, graph1);
163                 }
164                 subgraphs.remove(graph2);
165                 delete nodeit;
166                 DEBUG("Deleting graph2 =%p \n", graph2);
167                 delete graph2;
168         } else {
169                 ASSERT(graph1 != NULL && graph2 == NULL);
170                 graph1->addNode(second);
171                 DEBUG("graphMap.put(first=%p, graph1=%p)\n", first, graph1);
172                 graphMap.put(second, graph1);
173         }
174 }
175
176 void EncodingGraph::processElement(Element *e) {
177         uint size = e->parents.getSize();
178         for (uint i = 0; i < size; i++) {
179                 ASTNode *n = e->parents.get(i);
180                 switch (n->type) {
181                 case PREDICATEOP:
182                         processPredicate((BooleanPredicate *)n);
183                         break;
184                 case ELEMFUNCRETURN:
185                         processFunction((ElementFunction *)n);
186                         break;
187                 default:
188                         ASSERT(0);
189                 }
190         }
191 }
192
193 void EncodingGraph::processFunction(ElementFunction *ef) {
194         Function *f = ef->getFunction();
195         if (f->type == OPERATORFUNC) {
196                 FunctionOperator *fo = (FunctionOperator *)f;
197                 ASSERT(ef->inputs.getSize() == 2);
198                 EncodingNode *left = createNode(ef->inputs.get(0));
199                 EncodingNode *right = createNode(ef->inputs.get(1));
200                 if (left == NULL && right == NULL)
201                         return;
202                 EncodingNode *dst = createNode(ef);
203                 EncodingEdge *edge = createEdge(left, right, dst);
204                 edge->numArithOps++;
205         }
206 }
207
208 void EncodingGraph::processPredicate(BooleanPredicate *b) {
209         Predicate *p = b->getPredicate();
210         if (p->type == OPERATORPRED) {
211                 PredicateOperator *po = (PredicateOperator *)p;
212                 ASSERT(b->inputs.getSize() == 2);
213                 EncodingNode *left = createNode(b->inputs.get(0));
214                 EncodingNode *right = createNode(b->inputs.get(1));
215                 if (left == NULL || right == NULL)
216                         return;
217                 EncodingEdge *edge = createEdge(left, right, NULL);
218                 CompOp op = po->getOp();
219                 switch (op) {
220                 case SATC_EQUALS:
221                         edge->numEquals++;
222                         break;
223                 case SATC_LT:
224                 case SATC_LTE:
225                 case SATC_GT:
226                 case SATC_GTE:
227                         edge->numComparisons++;
228                         break;
229                 default:
230                         ASSERT(0);
231                 }
232         }
233 }
234
235 uint convertSize(uint cost) {
236         cost = 1.2 * cost;// fudge factor
237         return NEXTPOW2(cost);
238 }
239
240 void EncodingGraph::decideEdges() {
241         uint size = edgeVector.getSize();
242         for (uint i = 0; i < size; i++) {
243                 EncodingEdge *ee = edgeVector.get(i);
244                 EncodingNode *left = ee->left;
245                 EncodingNode *right = ee->right;
246
247                 if (ee->encoding != EDGE_UNASSIGNED ||
248                                 !left->couldBeBinaryIndex() ||
249                                 !right->couldBeBinaryIndex())
250                         continue;
251
252                 uint64_t eeValue = ee->getValue();
253                 if (eeValue == 0)
254                         return;
255
256                 EncodingSubGraph *leftGraph = graphMap.get(left);
257                 DEBUG("graphMap.get(left=%p, leftgraph=%p)\n", left, leftGraph);
258                 EncodingSubGraph *rightGraph = graphMap.get(right);
259                 DEBUG("graphMap.get(right=%p, rightgraph=%p)\n", right, rightGraph);
260                 if (leftGraph == NULL && rightGraph != NULL) {
261                         EncodingNode *tmp = left; left = right; right = tmp;
262                         EncodingSubGraph *tmpsg = leftGraph; leftGraph = rightGraph; rightGraph = tmpsg;
263                 }
264
265                 uint leftSize = 0, rightSize = 0, newSize = 0;
266                 uint64_t totalCost = 0;
267                 if (leftGraph == NULL && rightGraph == NULL) {
268                         leftSize = convertSize(left->getSize());
269                         rightSize = convertSize(right->getSize());
270                         newSize = convertSize(left->s->getUnionSize(right->s));
271                         newSize = (leftSize > newSize) ? leftSize : newSize;
272                         newSize = (rightSize > newSize) ? rightSize : newSize;
273                         totalCost = (newSize - leftSize) * left->elements.getSize() +
274                                                                         (newSize - rightSize) * right->elements.getSize();
275                 } else if (leftGraph != NULL && rightGraph == NULL) {
276                         leftSize = convertSize(leftGraph->encodingSize);
277                         rightSize = convertSize(right->getSize());
278                         newSize = convertSize(leftGraph->estimateNewSize(right));
279                         newSize = (leftSize > newSize) ? leftSize : newSize;
280                         newSize = (rightSize > newSize) ? rightSize : newSize;
281                         totalCost = (newSize - leftSize) * leftGraph->numElements +
282                                                                         (newSize - rightSize) * right->elements.getSize();
283                 } else {
284                         //Neither are null
285                         leftSize = convertSize(leftGraph->encodingSize);
286                         rightSize = convertSize(rightGraph->encodingSize);
287                         newSize = convertSize(leftGraph->estimateNewSize(rightGraph));
288                         newSize = (leftSize > newSize) ? leftSize : newSize;
289                         newSize = (rightSize > newSize) ? rightSize : newSize;
290                         totalCost = (newSize - leftSize) * leftGraph->numElements +
291                                                                         (newSize - rightSize) * rightGraph->numElements;
292                 }
293                 double conversionfactor = 0.5;
294                 if ((totalCost * conversionfactor) < eeValue) {
295                         //add the edge
296                         mergeNodes(left, right);
297                 }
298         }
299 }
300
301 static TunableDesc EdgeEncodingDesc(EDGE_UNASSIGNED, EDGE_MATCH, EDGE_UNASSIGNED);
302
303 EncodingEdge *EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
304         EncodingEdge e(left, right, dst);
305         EncodingEdge *result = edgeMap.get(&e);
306         return result;
307 }
308
309 EncodingEdge *EncodingGraph::createEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
310         EncodingEdge e(left, right, dst);
311         EncodingEdge *result = edgeMap.get(&e);
312         if (result == NULL) {
313                 result = new EncodingEdge(left, right, dst);
314                 VarType v1 = left->getType();
315                 VarType v2 = right->getType();
316                 if (v1 > v2) {
317                         VarType tmp = v2;
318                         v2 = v1;
319                         v1 = tmp;
320                 }
321
322                 if ((left != NULL && left->couldBeBinaryIndex()) &&
323                                 (right != NULL) && right->couldBeBinaryIndex()) {
324                         EdgeEncodingType type = (EdgeEncodingType)solver->getTuner()->getVarTunable(v1, v2, EDGEENCODING, &EdgeEncodingDesc);
325                         result->setEncoding(type);
326                         if (type == EDGE_MATCH) {
327                                 mergeNodes(left, right);
328                         }
329                 }
330                 edgeMap.put(result, result);
331                 edgeVector.push(result);
332                 if (left != NULL)
333                         left->edges.add(result);
334                 if (right != NULL)
335                         right->edges.add(result);
336                 if (dst != NULL)
337                         dst->edges.add(result);
338         }
339         return result;
340 }
341
342 EncodingNode::EncodingNode(Set *_s) :
343         s(_s) {
344 }
345
346 uint EncodingNode::getSize() const {
347         return s->getSize();
348 }
349
350 VarType EncodingNode::getType() const {
351         return s->getType();
352 }
353
354 static TunableDesc NodeEncodingDesc(ELEM_UNASSIGNED, BINARYINDEX, ELEM_UNASSIGNED);
355
356 EncodingNode *EncodingGraph::createNode(Element *e) {
357         if (e->type == ELEMCONST)
358                 return NULL;
359         Set *s = e->getRange();
360         EncodingNode *n = encodingMap.get(s);
361         if (n == NULL) {
362                 n = new EncodingNode(s);
363                 n->setEncoding((ElementEncodingType)solver->getTuner()->getVarTunable(n->getType(), NODEENCODING, &NodeEncodingDesc));
364
365                 encodingMap.put(s, n);
366         }
367         n->addElement(e);
368         return n;
369 }
370
371 EncodingNode *EncodingGraph::getNode(Element *e) {
372         if (e->type == ELEMCONST)
373                 return NULL;
374         Set *s = e->getRange();
375         EncodingNode *n = encodingMap.get(s);
376         return n;
377 }
378
379 void EncodingNode::addElement(Element *e) {
380         elements.add(e);
381 }
382
383 EncodingEdge::EncodingEdge(EncodingNode *_l, EncodingNode *_r) :
384         left(_l),
385         right(_r),
386         dst(NULL),
387         encoding(EDGE_UNASSIGNED),
388         numArithOps(0),
389         numEquals(0),
390         numComparisons(0)
391 {
392 }
393
394 EncodingEdge::EncodingEdge(EncodingNode *_left, EncodingNode *_right, EncodingNode *_dst) :
395         left(_left),
396         right(_right),
397         dst(_dst),
398         encoding(EDGE_UNASSIGNED),
399         numArithOps(0),
400         numEquals(0),
401         numComparisons(0)
402 {
403 }
404
405 uint hashEncodingEdge(EncodingEdge *edge) {
406         uintptr_t hash = (((uintptr_t) edge->left) >> 2) ^ (((uintptr_t)edge->right) >> 4) ^ (((uintptr_t)edge->dst) >> 6);
407         return (uint) hash;
408 }
409
410 bool equalsEncodingEdge(EncodingEdge *e1, EncodingEdge *e2) {
411         return e1->left == e2->left && e1->right == e2->right && e1->dst == e2->dst;
412 }
413
414 uint64_t EncodingEdge::getValue() const {
415         uint lSize = (left != NULL) ? left->getSize() : 1;
416         uint rSize = (right != NULL) ? right->getSize() : 1;
417         uint min = (lSize < rSize) ? lSize : rSize;
418         return numEquals * min + numComparisons * lSize * rSize;
419 }
420
421