5416ed0e2d7778444bff22640de8f127dd6ef187
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
1 #include "encodinggraph.h"
2 #include "iterator.h"
3 #include "element.h"
4 #include "function.h"
5 #include "predicate.h"
6 #include "set.h"
7 #include "csolver.h"
8 #include "tunable.h"
9 #include "qsort.h"
10 #include "subgraph.h"
11 #include "elementencoding.h"
12
13 EncodingGraph::EncodingGraph(CSolver *_solver) :
14         solver(_solver) {
15 }
16
17 int sortEncodingEdge(const void *p1, const void *p2) {
18         const EncodingEdge *e1 = *(const EncodingEdge **) p1;
19         const EncodingEdge *e2 = *(const EncodingEdge **) p2;
20         uint64_t v1 = e1->getValue();
21         uint64_t v2 = e2->getValue();
22         if (v1 < v2)
23                 return 1;
24         else if (v1 == v2)
25                 return 0;
26         else
27                 return -1;
28 }
29
30 void EncodingGraph::buildGraph() {
31         ElementIterator it(solver);
32         while (it.hasNext()) {
33                 Element *e = it.next();
34                 switch (e->type) {
35                 case ELEMSET:
36                 case ELEMFUNCRETURN:
37                         processElement(e);
38                         break;
39                 case ELEMCONST:
40                         break;
41                 default:
42                         ASSERT(0);
43                 }
44         }
45         bsdqsort(edgeVector.expose(), edgeVector.getSize(), sizeof(EncodingEdge *), sortEncodingEdge);
46         decideEdges();
47 }
48
49 void EncodingGraph::encode() {
50         SetIteratorEncodingSubGraph *itesg = subgraphs.iterator();
51         while (itesg->hasNext()) {
52                 EncodingSubGraph *sg = itesg->next();
53                 sg->encode();
54         }
55         delete itesg;
56
57         ElementIterator it(solver);
58         while (it.hasNext()) {
59                 Element *e = it.next();
60                 switch (e->type) {
61                 case ELEMSET:
62                 case ELEMFUNCRETURN: {
63                         ElementEncoding *encoding = e->getElementEncoding();
64                         if (encoding->getElementEncodingType() == ELEM_UNASSIGNED) {
65                                 EncodingNode *n = getNode(e);
66                                 if (n == NULL)
67                                         continue;
68                                 ElementEncodingType encodetype = n->getEncoding();
69                                 encoding->setElementEncodingType(encodetype);
70                                 if (encodetype == UNARY || encodetype == ONEHOT) {
71                                         encoding->encodingArrayInitialization();
72                                 } else if (encodetype == BINARYINDEX) {
73                                         EncodingSubGraph *subgraph = graphMap.get(n);
74                                         if (subgraph == NULL)
75                                                 continue;
76                                         uint encodingSize = subgraph->getEncodingMaxVal(n)+1;
77                                         uint paddedSize = encoding->getSizeEncodingArray(encodingSize);
78                                         model_print("encoding size=%u\n", encodingSize);
79                                         model_print("padded=%u\n", paddedSize);
80                                         encoding->allocInUseArrayElement(paddedSize);
81                                         encoding->allocEncodingArrayElement(paddedSize);
82                                         Set *s = e->getRange();
83                                         for (uint i = 0; i < s->getSize(); i++) {
84                                                 model_print("index=%u\n", i);
85                                                 uint64_t value = s->getElement(i);
86                                                 uint encodingIndex = subgraph->getEncoding(n, value);
87                                                 encoding->setInUseElement(encodingIndex);
88                                                 encoding->encodingArray[encodingIndex] = value;
89                                         }
90                                 }
91                         }
92                         break;
93                 }
94                 default:
95                         break;
96                 }
97                 encodeParent(e);
98         }
99 }
100
101 void EncodingGraph::encodeParent(Element *e) {
102         uint size = e->parents.getSize();
103         for (uint i = 0; i < size; i++) {
104                 ASTNode *n = e->parents.get(i);
105                 if (n->type == PREDICATEOP) {
106                         BooleanPredicate *b = (BooleanPredicate *)n;
107                         FunctionEncoding *fenc = b->getFunctionEncoding();
108                         if (fenc->getFunctionEncodingType() != FUNC_UNASSIGNED)
109                                 continue;
110                         Predicate *p = b->getPredicate();
111                         if (p->type == OPERATORPRED) {
112                                 PredicateOperator *po = (PredicateOperator *)p;
113                                 ASSERT(b->inputs.getSize() == 2);
114                                 EncodingNode *left = createNode(b->inputs.get(0));
115                                 EncodingNode *right = createNode(b->inputs.get(1));
116                                 if (left == NULL || right == NULL)
117                                         return;
118                                 EncodingEdge *edge = getEdge(left, right, NULL);
119                                 if (edge != NULL && edge->getEncoding() == EDGE_MATCH) {
120                                         fenc->setFunctionEncodingType(CIRCUIT);
121                                 }
122                         }
123                 }
124         }
125 }
126
127 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
128         EncodingSubGraph *graph1 = graphMap.get(first);
129         EncodingSubGraph *graph2 = graphMap.get(second);
130         if (graph1 == NULL)
131                 first->setEncoding(BINARYINDEX);
132         if (graph2 == NULL)
133                 second->setEncoding(BINARYINDEX);
134
135         if (graph1 == NULL && graph2 == NULL) {
136                 graph1 = new EncodingSubGraph();
137                 subgraphs.add(graph1);
138                 graphMap.put(first, graph1);
139                 graph1->addNode(first);
140         }
141         if (graph1 == NULL && graph2 != NULL) {
142                 graph1 = graph2;
143                 graph2 = NULL;
144                 EncodingNode *tmp = second;
145                 second = first;
146                 first = tmp;
147         }
148         if (graph1 != NULL && graph2 != NULL) {
149                 SetIteratorEncodingNode *nodeit = graph2->nodeIterator();
150                 while (nodeit->hasNext()) {
151                         EncodingNode *node = nodeit->next();
152                         graph1->addNode(node);
153                         graphMap.put(node, graph1);
154                 }
155                 subgraphs.remove(graph2);
156                 delete nodeit;
157                 delete graph2;
158         } else {
159                 ASSERT(graph1 != NULL && graph2 == NULL);
160                 graph1->addNode(second);
161                 graphMap.put(second, graph1);
162         }
163 }
164
165 void EncodingGraph::processElement(Element *e) {
166         uint size = e->parents.getSize();
167         for (uint i = 0; i < size; i++) {
168                 ASTNode *n = e->parents.get(i);
169                 switch (n->type) {
170                 case PREDICATEOP:
171                         processPredicate((BooleanPredicate *)n);
172                         break;
173                 case ELEMFUNCRETURN:
174                         processFunction((ElementFunction *)n);
175                         break;
176                 default:
177                         ASSERT(0);
178                 }
179         }
180 }
181
182 void EncodingGraph::processFunction(ElementFunction *ef) {
183         Function *f = ef->getFunction();
184         if (f->type == OPERATORFUNC) {
185                 FunctionOperator *fo = (FunctionOperator *)f;
186                 ASSERT(ef->inputs.getSize() == 2);
187                 EncodingNode *left = createNode(ef->inputs.get(0));
188                 EncodingNode *right = createNode(ef->inputs.get(1));
189                 if (left == NULL && right == NULL)
190                         return;
191                 EncodingNode *dst = createNode(ef);
192                 EncodingEdge *edge = createEdge(left, right, dst);
193                 edge->numArithOps++;
194         }
195 }
196
197 void EncodingGraph::processPredicate(BooleanPredicate *b) {
198         Predicate *p = b->getPredicate();
199         if (p->type == OPERATORPRED) {
200                 PredicateOperator *po = (PredicateOperator *)p;
201                 ASSERT(b->inputs.getSize() == 2);
202                 EncodingNode *left = createNode(b->inputs.get(0));
203                 EncodingNode *right = createNode(b->inputs.get(1));
204                 if (left == NULL || right == NULL)
205                         return;
206                 EncodingEdge *edge = createEdge(left, right, NULL);
207                 CompOp op = po->getOp();
208                 switch (op) {
209                 case SATC_EQUALS:
210                         edge->numEquals++;
211                         break;
212                 case SATC_LT:
213                 case SATC_LTE:
214                 case SATC_GT:
215                 case SATC_GTE:
216                         edge->numComparisons++;
217                         break;
218                 default:
219                         ASSERT(0);
220                 }
221         }
222 }
223
224 uint convertSize(uint cost) {
225         cost = 1.2 * cost;// fudge factor
226         return NEXTPOW2(cost);
227 }
228
229 void EncodingGraph::decideEdges() {
230         uint size = edgeVector.getSize();
231         for (uint i = 0; i < size; i++) {
232                 EncodingEdge *ee = edgeVector.get(i);
233                 EncodingNode *left = ee->left;
234                 EncodingNode *right = ee->right;
235
236                 if (ee->encoding != EDGE_UNASSIGNED ||
237                                 !left->couldBeBinaryIndex() ||
238                                 !right->couldBeBinaryIndex())
239                         continue;
240
241                 uint64_t eeValue = ee->getValue();
242                 if (eeValue == 0)
243                         return;
244
245                 EncodingSubGraph *leftGraph = graphMap.get(left);
246                 EncodingSubGraph *rightGraph = graphMap.get(right);
247                 if (leftGraph == NULL && rightGraph != NULL) {
248                         EncodingNode *tmp = left; left = right; right = tmp;
249                         EncodingSubGraph *tmpsg = leftGraph; leftGraph = rightGraph; rightGraph = tmpsg;
250                 }
251
252                 uint leftSize = 0, rightSize = 0, newSize = 0;
253                 uint64_t totalCost = 0;
254                 if (leftGraph == NULL && rightGraph == NULL) {
255                         leftSize = convertSize(left->getSize());
256                         rightSize = convertSize(right->getSize());
257                         newSize = convertSize(left->s->getUnionSize(right->s));
258                         newSize = (leftSize > newSize) ? leftSize : newSize;
259                         newSize = (rightSize > newSize) ? rightSize : newSize;
260                         totalCost = (newSize - leftSize) * left->elements.getSize() +
261                                                                         (newSize - rightSize) * right->elements.getSize();
262                 } else if (leftGraph != NULL && rightGraph == NULL) {
263                         leftSize = convertSize(leftGraph->encodingSize);
264                         rightSize = convertSize(right->getSize());
265                         newSize = convertSize(leftGraph->estimateNewSize(right));
266                         newSize = (leftSize > newSize) ? leftSize : newSize;
267                         newSize = (rightSize > newSize) ? rightSize : newSize;
268                         totalCost = (newSize - leftSize) * leftGraph->numElements +
269                                                                         (newSize - rightSize) * right->elements.getSize();
270                 } else {
271                         //Neither are null
272                         leftSize = convertSize(leftGraph->encodingSize);
273                         rightSize = convertSize(rightGraph->encodingSize);
274                         newSize = convertSize(leftGraph->estimateNewSize(rightGraph));
275                         newSize = (leftSize > newSize) ? leftSize : newSize;
276                         newSize = (rightSize > newSize) ? rightSize : newSize;
277                         totalCost = (newSize - leftSize) * leftGraph->numElements +
278                                                                         (newSize - rightSize) * rightGraph->numElements;
279                 }
280                 double conversionfactor = 0.5;
281                 if ((totalCost * conversionfactor) < eeValue) {
282                         //add the edge
283                         mergeNodes(left, right);
284                 }
285         }
286 }
287
288 static TunableDesc EdgeEncodingDesc(EDGE_UNASSIGNED, EDGE_MATCH, EDGE_UNASSIGNED);
289
290 EncodingEdge *EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
291         EncodingEdge e(left, right, dst);
292         EncodingEdge *result = edgeMap.get(&e);
293         return result;
294 }
295
296 EncodingEdge *EncodingGraph::createEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
297         EncodingEdge e(left, right, dst);
298         EncodingEdge *result = edgeMap.get(&e);
299         if (result == NULL) {
300                 result = new EncodingEdge(left, right, dst);
301                 VarType v1 = left->getType();
302                 VarType v2 = right->getType();
303                 if (v1 > v2) {
304                         VarType tmp = v2;
305                         v2 = v1;
306                         v1 = tmp;
307                 }
308
309                 if ((left != NULL && left->couldBeBinaryIndex()) &&
310                                 (right != NULL) && right->couldBeBinaryIndex()) {
311                         EdgeEncodingType type = (EdgeEncodingType)solver->getTuner()->getVarTunable(v1, v2, EDGEENCODING, &EdgeEncodingDesc);
312                         result->setEncoding(type);
313                         if (type == EDGE_MATCH) {
314                                 mergeNodes(left, right);
315                         }
316                 }
317                 edgeMap.put(result, result);
318                 edgeVector.push(result);
319                 if (left != NULL)
320                         left->edges.add(result);
321                 if (right != NULL)
322                         right->edges.add(result);
323                 if (dst != NULL)
324                         dst->edges.add(result);
325         }
326         return result;
327 }
328
329 EncodingNode::EncodingNode(Set *_s) :
330         s(_s) {
331 }
332
333 uint EncodingNode::getSize() const {
334         return s->getSize();
335 }
336
337 VarType EncodingNode::getType() const {
338         return s->getType();
339 }
340
341 static TunableDesc NodeEncodingDesc(ELEM_UNASSIGNED, BINARYINDEX, ELEM_UNASSIGNED);
342
343 EncodingNode *EncodingGraph::createNode(Element *e) {
344         if (e->type == ELEMCONST)
345                 return NULL;
346         Set *s = e->getRange();
347         EncodingNode *n = encodingMap.get(s);
348         if (n == NULL) {
349                 n = new EncodingNode(s);
350                 n->setEncoding((ElementEncodingType)solver->getTuner()->getVarTunable(n->getType(), NODEENCODING, &NodeEncodingDesc));
351
352                 encodingMap.put(s, n);
353         }
354         n->addElement(e);
355         return n;
356 }
357
358 EncodingNode *EncodingGraph::getNode(Element *e) {
359         if (e->type == ELEMCONST)
360                 return NULL;
361         Set *s = e->getRange();
362         EncodingNode *n = encodingMap.get(s);
363         return n;
364 }
365
366 void EncodingNode::addElement(Element *e) {
367         elements.add(e);
368 }
369
370 EncodingEdge::EncodingEdge(EncodingNode *_l, EncodingNode *_r) :
371         left(_l),
372         right(_r),
373         dst(NULL),
374         encoding(EDGE_UNASSIGNED),
375         numArithOps(0),
376         numEquals(0),
377         numComparisons(0)
378 {
379 }
380
381 EncodingEdge::EncodingEdge(EncodingNode *_left, EncodingNode *_right, EncodingNode *_dst) :
382         left(_left),
383         right(_right),
384         dst(_dst),
385         encoding(EDGE_UNASSIGNED),
386         numArithOps(0),
387         numEquals(0),
388         numComparisons(0)
389 {
390 }
391
392 uint hashEncodingEdge(EncodingEdge *edge) {
393         uintptr_t hash = (((uintptr_t) edge->left) >> 2) ^ (((uintptr_t)edge->right) >> 4) ^ (((uintptr_t)edge->dst) >> 6);
394         return (uint) hash;
395 }
396
397 bool equalsEncodingEdge(EncodingEdge *e1, EncodingEdge *e2) {
398         return e1->left == e2->left && e1->right == e2->right && e1->dst == e2->dst;
399 }
400
401 uint64_t EncodingEdge::getValue() const {
402         uint lSize = (left != NULL) ? left->getSize() : 1;
403         uint rSize = (right != NULL) ? right->getSize() : 1;
404         uint min = (lSize < rSize) ? lSize : rSize;
405         return numEquals * min + numComparisons * lSize * rSize;
406 }
407
408