Merge branch 'master' of ssh://plrg.eecs.uci.edu/home/git/constraint_compiler
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
1 #include "encodinggraph.h"
2 #include "iterator.h"
3 #include "element.h"
4 #include "function.h"
5 #include "predicate.h"
6 #include "set.h"
7 #include "csolver.h"
8 #include "tunable.h"
9 #include "qsort.h"
10 #include "subgraph.h"
11 #include "elementencoding.h"
12 #include "boolean.h"
13
14 EncodingGraph::EncodingGraph(CSolver * _solver) :
15         solver(_solver) {
16 }
17
18 int sortEncodingEdge(const void * p1, const void *p2) {
19         const EncodingEdge * e1 = * (const EncodingEdge **) p1;
20         const EncodingEdge * e2 = * (const EncodingEdge **) p2;
21         uint64_t v1 = e1->getValue();
22         uint64_t v2 = e2->getValue();
23         if (v1 < v2)
24                 return 1;
25         else if (v1 == v2)
26                 return 0;
27         else
28                 return -1;
29 }
30
31 void EncodingGraph::buildGraph() {
32         ElementIterator it(solver);
33         while(it.hasNext()) {
34                 Element * e = it.next();
35                 switch(e->type) {
36                 case ELEMSET:
37                 case ELEMFUNCRETURN:
38                         processElement(e);
39                         break;
40                 case ELEMCONST:
41                         break;
42                 default:
43                         ASSERT(0);
44                 }
45         }
46         bsdqsort(edgeVector.expose(), edgeVector.getSize(), sizeof(EncodingEdge *), sortEncodingEdge);
47         decideEdges();
48 }
49
50 void EncodingGraph::encode() {
51         SetIteratorEncodingSubGraph * itesg=subgraphs.iterator();
52         while(itesg->hasNext()) {
53                 EncodingSubGraph *sg=itesg->next();
54                 sg->encode();
55         }
56         delete itesg;
57
58         ElementIterator it(solver);
59         while(it.hasNext()) {
60                 Element * e = it.next();
61                 switch(e->type) {
62                 case ELEMSET:
63                 case ELEMFUNCRETURN: {
64                         ElementEncoding *encoding=e->getElementEncoding();
65                         if (encoding->getElementEncodingType() == ELEM_UNASSIGNED) {
66                                 EncodingNode *n = getNode(e);
67                                 ASSERT(n != NULL);
68                                 ElementEncodingType encodetype=n->getEncoding();
69                                 encoding->setElementEncodingType(encodetype);
70                                 if (encodetype == UNARY || encodetype == ONEHOT) {
71                                         encoding->encodingArrayInitialization();
72                                 } else if (encodetype == BINARYINDEX) {
73                                         EncodingSubGraph * subgraph = graphMap.get(n);
74                                         uint encodingSize = subgraph->getEncodingSize(n);
75                                         uint paddedSize = encoding->getSizeEncodingArray(encodingSize);
76                                         encoding->allocInUseArrayElement(paddedSize);
77                                         encoding->allocEncodingArrayElement(paddedSize);
78                                         Set * s=e->getRange();
79                                         for(uint i=0;i<s->getSize();i++) {
80                                                 uint64_t value=s->getElement(i);
81                                                 uint encodingIndex=subgraph->getEncoding(n, value);
82                                                 encoding->setInUseElement(encodingIndex);
83                                                 encoding->encodingArray[encodingIndex] = value;
84                                         }
85                                 }
86                         }
87                         break;
88                 }
89                 default:
90                         break;
91                 }
92                 encodeParent(e);
93         }
94 }
95
96 void EncodingGraph::encodeParent(Element *e) {
97         uint size=e->parents.getSize();
98         for(uint i=0;i<size;i++) {
99                 ASTNode * n = e->parents.get(i);
100                 if (n->type==PREDICATEOP) {
101                         BooleanPredicate *b=(BooleanPredicate *)n;
102                         FunctionEncoding *fenc=b->getFunctionEncoding();
103                         if (fenc->getFunctionEncodingType() != FUNC_UNASSIGNED)
104                                 continue;
105                         Predicate *p=b->getPredicate();
106                         if (p->type==OPERATORPRED) {
107                                 PredicateOperator *po=(PredicateOperator *)p;
108                                 ASSERT(b->inputs.getSize()==2);
109                                 EncodingNode *left=createNode(b->inputs.get(0));
110                                 EncodingNode *right=createNode(b->inputs.get(1));
111                                 if (left == NULL || right == NULL)
112                                         return;
113                                 EncodingEdge *edge=getEdge(left, right, NULL);
114                                 if (edge != NULL && edge->getEncoding() == EDGE_MATCH) {
115                                         fenc->setFunctionEncodingType(CIRCUIT);
116                                 }
117                         }
118                 }
119         }
120 }
121
122 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
123         EncodingSubGraph *graph1=graphMap.get(first);
124         EncodingSubGraph *graph2=graphMap.get(second);
125         if (graph1 == NULL && graph2 == NULL) {
126                 graph1 = new EncodingSubGraph();
127                 subgraphs.add(graph1);
128                 graphMap.put(first, graph1);
129                 graph1->addNode(first);
130         }
131         if (graph1 == NULL && graph2 != NULL) {
132                 graph1 = graph2;
133                 graph2 = NULL;
134                 EncodingNode *tmp = second;
135                 second = first;
136                 first = tmp;
137         }
138         if (graph1 != NULL && graph2 != NULL) {
139                 SetIteratorEncodingNode * nodeit=graph2->nodeIterator();
140                 while(nodeit->hasNext()) {
141                         EncodingNode *node=nodeit->next();
142                         graph1->addNode(node);
143                         graphMap.put(node, graph1);
144                 }
145                 subgraphs.remove(graph2);
146                 delete nodeit;
147                 delete graph2;
148         } else {
149                 ASSERT(graph1 != NULL && graph2 == NULL);
150                 graph1->addNode(second);
151                 graphMap.put(second, graph1);
152         }
153 }
154
155 void EncodingGraph::processElement(Element *e) {
156         uint size=e->parents.getSize();
157         for(uint i=0;i<size;i++) {
158                 ASTNode * n = e->parents.get(i);
159                 switch(n->type) {
160                 case PREDICATEOP:
161                         processPredicate((BooleanPredicate *)n);
162                         break;
163                 case ELEMFUNCRETURN:
164                         processFunction((ElementFunction *)n);
165                         break;
166                 default:
167                         ASSERT(0);
168                 }
169         }
170 }
171
172 void EncodingGraph::processFunction(ElementFunction *ef) {
173         Function *f=ef->getFunction();
174         if (f->type==OPERATORFUNC) {
175                 FunctionOperator *fo=(FunctionOperator*)f;
176                 ASSERT(ef->inputs.getSize() == 2);
177                 EncodingNode *left=createNode(ef->inputs.get(0));
178                 EncodingNode *right=createNode(ef->inputs.get(1));
179                 if (left == NULL && right == NULL)
180                         return;
181                 EncodingNode *dst=createNode(ef);
182                 EncodingEdge *edge=createEdge(left, right, dst);
183                 edge->numArithOps++;
184         }
185 }
186
187 void EncodingGraph::processPredicate(BooleanPredicate *b) {
188         Predicate *p=b->getPredicate();
189         if (p->type==OPERATORPRED) {
190                 PredicateOperator *po=(PredicateOperator *)p;
191                 ASSERT(b->inputs.getSize()==2);
192                 EncodingNode *left=createNode(b->inputs.get(0));
193                 EncodingNode *right=createNode(b->inputs.get(1));
194                 if (left == NULL || right == NULL)
195                         return;
196                 EncodingEdge *edge=createEdge(left, right, NULL);
197                 CompOp op=po->getOp();
198                 switch(op) {
199                 case SATC_EQUALS:
200                         edge->numEquals++;
201                         break;
202                 case SATC_LT:
203                 case SATC_LTE:
204                 case SATC_GT:
205                 case SATC_GTE:
206                         edge->numComparisons++;
207                         break;
208                 default:
209                         ASSERT(0);
210                 }
211         }
212 }
213
214 uint convertSize(uint cost) {
215         cost = 1.2 * cost; // fudge factor
216         return NEXTPOW2(cost);
217 }
218
219 void EncodingGraph::decideEdges() {
220         uint size=edgeVector.getSize();
221         for(uint i=0; i<size; i++) {
222                 EncodingEdge *ee = edgeVector.get(i);
223                 EncodingNode *left = ee->left;
224                 EncodingNode *right = ee->right;
225                 
226                 if (ee->encoding != EDGE_UNASSIGNED ||
227                                 left->encoding != BINARYINDEX ||
228                                 right->encoding != BINARYINDEX)
229                         continue;
230                 
231                 uint64_t eeValue = ee->getValue();
232                 if (eeValue == 0)
233                         return;
234
235                 EncodingSubGraph *leftGraph = graphMap.get(left);
236                 EncodingSubGraph *rightGraph = graphMap.get(right);
237                 if (leftGraph == NULL && rightGraph !=NULL) {
238                         EncodingNode *tmp = left; left=right; right=tmp;
239                         EncodingSubGraph *tmpsg = leftGraph; leftGraph = rightGraph; rightGraph = tmpsg;
240                 }
241
242                 uint leftSize=0, rightSize=0, newSize=0;
243                 uint64_t totalCost=0;
244                 if (leftGraph == NULL && rightGraph == NULL) {
245                         leftSize=convertSize(left->getSize());
246                         rightSize=convertSize(right->getSize());
247                         newSize=convertSize(left->s->getUnionSize(right->s));
248                         newSize=(leftSize > newSize) ? leftSize: newSize;
249                         newSize=(rightSize > newSize) ? rightSize: newSize;
250                         totalCost = (newSize - leftSize) * left->elements.getSize() +
251                                 (newSize - rightSize) * right->elements.getSize();
252                 } else if (leftGraph != NULL && rightGraph == NULL) {
253                         leftSize=convertSize(leftGraph->encodingSize);
254                         rightSize=convertSize(right->getSize());
255                         newSize=convertSize(leftGraph->estimateNewSize(right));
256                         newSize=(leftSize > newSize) ? leftSize: newSize;
257                         newSize=(rightSize > newSize) ? rightSize: newSize;
258                         totalCost = (newSize - leftSize) * leftGraph->numElements +
259                                 (newSize - rightSize) * right->elements.getSize();
260                 } else {
261                         //Neither are null
262                         leftSize=convertSize(leftGraph->encodingSize);
263                         rightSize=convertSize(rightGraph->encodingSize);
264                         newSize=convertSize(leftGraph->estimateNewSize(rightGraph));
265                         newSize=(leftSize > newSize) ? leftSize: newSize;
266                         newSize=(rightSize > newSize) ? rightSize: newSize;
267                         totalCost = (newSize - leftSize) * leftGraph->numElements +
268                                 (newSize - rightSize) * rightGraph->numElements;
269                 }
270                 double conversionfactor = 0.5;
271                 if ((totalCost * conversionfactor) < eeValue) {
272                         //add the edge
273                         mergeNodes(left, right);
274                 }
275         }
276 }
277
278 static TunableDesc EdgeEncodingDesc(EDGE_UNASSIGNED, EDGE_MATCH, EDGE_UNASSIGNED);
279
280 EncodingEdge * EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
281         EncodingEdge e(left, right, dst);
282         EncodingEdge *result = edgeMap.get(&e);
283         return result;
284 }
285
286 EncodingEdge * EncodingGraph::createEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
287         EncodingEdge e(left, right, dst);
288         EncodingEdge *result = edgeMap.get(&e);
289         if (result == NULL) {
290                 result=new EncodingEdge(left, right, dst);
291                 VarType v1=left->getType();
292                 VarType v2=right->getType();
293                 if (v1 > v2) {
294                         VarType tmp=v2;
295                         v2=v1;
296                         v1=tmp;
297                 }
298
299                 if ((left != NULL && left->encoding==BINARYINDEX) &&
300                                 (right != NULL) && right->encoding==BINARYINDEX) {
301                         EdgeEncodingType type=(EdgeEncodingType)solver->getTuner()->getVarTunable(v1, v2, EDGEENCODING, &EdgeEncodingDesc);
302                         result->setEncoding(type);
303                         if (type == EDGE_MATCH) {
304                                 mergeNodes(left, right);
305                         }
306                 }
307                 edgeMap.put(result, result);
308                 edgeVector.push(result);
309                 if (left != NULL)
310                         left->edges.add(result);
311                 if (right != NULL)
312                         right->edges.add(result);
313                 if (dst != NULL)
314                         dst->edges.add(result);
315         }
316         return result;
317 }
318
319 EncodingNode::EncodingNode(Set *_s) :
320         s(_s) {
321 }
322
323 uint EncodingNode::getSize() const {
324         return s->getSize();
325 }
326
327 VarType EncodingNode::getType() const {
328         return s->getType();
329 }
330
331 static TunableDesc NodeEncodingDesc(ELEM_UNASSIGNED, BINARYINDEX, ELEM_UNASSIGNED);
332
333 EncodingNode * EncodingGraph::createNode(Element *e) {
334         if (e->type == ELEMCONST)
335                 return NULL;
336         Set *s = e->getRange();
337         EncodingNode *n = encodingMap.get(s);
338         if (n == NULL) {
339                 n = new EncodingNode(s);
340                 n->setEncoding((ElementEncodingType)solver->getTuner()->getVarTunable(n->getType(), NODEENCODING, &NodeEncodingDesc));
341                 encodingMap.put(s, n);
342         }
343         n->addElement(e);
344         return n;
345 }
346
347 EncodingNode * EncodingGraph::getNode(Element *e) {
348         if (e->type == ELEMCONST)
349                 return NULL;
350         Set *s = e->getRange();
351         EncodingNode *n = encodingMap.get(s);
352         return n;
353 }
354
355 void EncodingNode::addElement(Element *e) {
356         elements.add(e);
357 }
358
359 EncodingEdge::EncodingEdge(EncodingNode *_l, EncodingNode *_r) :
360         left(_l),
361         right(_r),
362         dst(NULL),
363         encoding(EDGE_UNASSIGNED),
364         numArithOps(0),
365         numEquals(0),
366         numComparisons(0)
367 {
368 }
369
370 EncodingEdge::EncodingEdge(EncodingNode *_left, EncodingNode *_right, EncodingNode *_dst) :
371         left(_left),
372         right(_right),
373         dst(_dst),
374         encoding(EDGE_UNASSIGNED),
375         numArithOps(0),
376         numEquals(0),
377         numComparisons(0)
378 {
379 }
380
381 uint hashEncodingEdge(EncodingEdge *edge) {
382         uintptr_t hash=(((uintptr_t) edge->left) >> 2) ^ (((uintptr_t)edge->right) >> 4) ^ (((uintptr_t)edge->dst) >> 6);
383         return (uint) hash;
384 }
385
386 bool equalsEncodingEdge(EncodingEdge *e1, EncodingEdge *e2) {
387         return e1->left == e2->left && e1->right == e2->right && e1->dst == e2->dst;
388 }
389
390 uint64_t EncodingEdge::getValue() const {
391         uint lSize = (left != NULL) ? left->getSize() : 1;
392         uint rSize = (right != NULL) ? right->getSize() : 1;
393         uint min = (lSize < rSize) ? lSize : rSize;
394         return numEquals * min + numComparisons * lSize * rSize;
395 }
396
397