Merge branch 'master' of ssh://plrg.eecs.uci.edu/home/git/constraint_compiler
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
1 #include "encodinggraph.h"
2 #include "iterator.h"
3 #include "element.h"
4 #include "function.h"
5 #include "predicate.h"
6 #include "set.h"
7 #include "csolver.h"
8 #include "tunable.h"
9 #include "qsort.h"
10 #include "subgraph.h"
11 #include "elementencoding.h"
12
13 EncodingGraph::EncodingGraph(CSolver * _solver) :
14         solver(_solver) {
15 }
16
17 int sortEncodingEdge(const void * p1, const void *p2) {
18         const EncodingEdge * e1 = * (const EncodingEdge **) p1;
19         const EncodingEdge * e2 = * (const EncodingEdge **) p2;
20         uint64_t v1 = e1->getValue();
21         uint64_t v2 = e2->getValue();
22         if (v1 < v2)
23                 return 1;
24         else if (v1 == v2)
25                 return 0;
26         else
27                 return -1;
28 }
29
30 void EncodingGraph::buildGraph() {
31         ElementIterator it(solver);
32         while(it.hasNext()) {
33                 Element * e = it.next();
34                 switch(e->type) {
35                 case ELEMSET:
36                 case ELEMFUNCRETURN:
37                         processElement(e);
38                         break;
39                 case ELEMCONST:
40                         break;
41                 default:
42                         ASSERT(0);
43                 }
44         }
45         bsdqsort(edgeVector.expose(), edgeVector.getSize(), sizeof(EncodingEdge *), sortEncodingEdge);
46         decideEdges();
47 }
48
49 void EncodingGraph::encode() {
50         SetIteratorEncodingSubGraph * itesg=subgraphs.iterator();
51         while(itesg->hasNext()) {
52                 EncodingSubGraph *sg=itesg->next();
53                 sg->encode();
54         }
55         delete itesg;
56
57         ElementIterator it(solver);
58         while(it.hasNext()) {
59                 Element * e = it.next();
60                 switch(e->type) {
61                 case ELEMSET:
62                 case ELEMFUNCRETURN: {
63                         ElementEncoding *encoding=e->getElementEncoding();
64                         if (encoding->getElementEncodingType() == ELEM_UNASSIGNED) {
65                                 EncodingNode *n = getNode(e);
66                                 if (n == NULL)
67                                         continue;
68                                 ElementEncodingType encodetype=n->getEncoding();
69                                 encoding->setElementEncodingType(encodetype);
70                                 if (encodetype == UNARY || encodetype == ONEHOT) {
71                                         encoding->encodingArrayInitialization();
72                                 } else if (encodetype == BINARYINDEX) {
73                                         EncodingSubGraph * subgraph = graphMap.get(n);
74                                         if (subgraph == NULL)
75                                                 continue;
76                                         uint encodingSize = subgraph->getEncodingSize(n);
77                                         uint paddedSize = encoding->getSizeEncodingArray(encodingSize);
78                                         encoding->allocInUseArrayElement(paddedSize);
79                                         encoding->allocEncodingArrayElement(paddedSize);
80                                         Set * s=e->getRange();
81                                         for(uint i=0;i<s->getSize();i++) {
82                                                 uint64_t value=s->getElement(i);
83                                                 uint encodingIndex=subgraph->getEncoding(n, value);
84                                                 encoding->setInUseElement(encodingIndex);
85                                                 encoding->encodingArray[encodingIndex] = value;
86                                         }
87                                 }
88                         }
89                         break;
90                 }
91                 default:
92                         break;
93                 }
94                 encodeParent(e);
95         }
96 }
97
98 void EncodingGraph::encodeParent(Element *e) {
99         uint size=e->parents.getSize();
100         for(uint i=0;i<size;i++) {
101                 ASTNode * n = e->parents.get(i);
102                 if (n->type==PREDICATEOP) {
103                         BooleanPredicate *b=(BooleanPredicate *)n;
104                         FunctionEncoding *fenc=b->getFunctionEncoding();
105                         if (fenc->getFunctionEncodingType() != FUNC_UNASSIGNED)
106                                 continue;
107                         Predicate *p=b->getPredicate();
108                         if (p->type==OPERATORPRED) {
109                                 PredicateOperator *po=(PredicateOperator *)p;
110                                 ASSERT(b->inputs.getSize()==2);
111                                 EncodingNode *left=createNode(b->inputs.get(0));
112                                 EncodingNode *right=createNode(b->inputs.get(1));
113                                 if (left == NULL || right == NULL)
114                                         return;
115                                 EncodingEdge *edge=getEdge(left, right, NULL);
116                                 if (edge != NULL && edge->getEncoding() == EDGE_MATCH) {
117                                         fenc->setFunctionEncodingType(CIRCUIT);
118                                 }
119                         }
120                 }
121         }
122 }
123
124 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
125         EncodingSubGraph *graph1=graphMap.get(first);
126         EncodingSubGraph *graph2=graphMap.get(second);
127         if (graph1 == NULL)
128                 first->setEncoding(BINARYINDEX);
129         if (graph2 == NULL)
130                 second->setEncoding(BINARYINDEX);
131
132         if (graph1 == NULL && graph2 == NULL) {
133                 graph1 = new EncodingSubGraph();
134                 subgraphs.add(graph1);
135                 graphMap.put(first, graph1);
136                 graph1->addNode(first);
137         }
138         if (graph1 == NULL && graph2 != NULL) {
139                 graph1 = graph2;
140                 graph2 = NULL;
141                 EncodingNode *tmp = second;
142                 second = first;
143                 first = tmp;
144         }
145         if (graph1 != NULL && graph2 != NULL) {
146                 SetIteratorEncodingNode * nodeit=graph2->nodeIterator();
147                 while(nodeit->hasNext()) {
148                         EncodingNode *node=nodeit->next();
149                         graph1->addNode(node);
150                         graphMap.put(node, graph1);
151                 }
152                 subgraphs.remove(graph2);
153                 delete nodeit;
154                 delete graph2;
155         } else {
156                 ASSERT(graph1 != NULL && graph2 == NULL);
157                 graph1->addNode(second);
158                 graphMap.put(second, graph1);
159         }
160 }
161
162 void EncodingGraph::processElement(Element *e) {
163         uint size=e->parents.getSize();
164         for(uint i=0;i<size;i++) {
165                 ASTNode * n = e->parents.get(i);
166                 switch(n->type) {
167                 case PREDICATEOP:
168                         processPredicate((BooleanPredicate *)n);
169                         break;
170                 case ELEMFUNCRETURN:
171                         processFunction((ElementFunction *)n);
172                         break;
173                 default:
174                         ASSERT(0);
175                 }
176         }
177 }
178
179 void EncodingGraph::processFunction(ElementFunction *ef) {
180         Function *f=ef->getFunction();
181         if (f->type==OPERATORFUNC) {
182                 FunctionOperator *fo=(FunctionOperator*)f;
183                 ASSERT(ef->inputs.getSize() == 2);
184                 EncodingNode *left=createNode(ef->inputs.get(0));
185                 EncodingNode *right=createNode(ef->inputs.get(1));
186                 if (left == NULL && right == NULL)
187                         return;
188                 EncodingNode *dst=createNode(ef);
189                 EncodingEdge *edge=createEdge(left, right, dst);
190                 edge->numArithOps++;
191         }
192 }
193
194 void EncodingGraph::processPredicate(BooleanPredicate *b) {
195         Predicate *p=b->getPredicate();
196         if (p->type==OPERATORPRED) {
197                 PredicateOperator *po=(PredicateOperator *)p;
198                 ASSERT(b->inputs.getSize()==2);
199                 EncodingNode *left=createNode(b->inputs.get(0));
200                 EncodingNode *right=createNode(b->inputs.get(1));
201                 if (left == NULL || right == NULL)
202                         return;
203                 EncodingEdge *edge=createEdge(left, right, NULL);
204                 CompOp op=po->getOp();
205                 switch(op) {
206                 case SATC_EQUALS:
207                         edge->numEquals++;
208                         break;
209                 case SATC_LT:
210                 case SATC_LTE:
211                 case SATC_GT:
212                 case SATC_GTE:
213                         edge->numComparisons++;
214                         break;
215                 default:
216                         ASSERT(0);
217                 }
218         }
219 }
220
221 uint convertSize(uint cost) {
222         cost = 1.2 * cost; // fudge factor
223         return NEXTPOW2(cost);
224 }
225
226 void EncodingGraph::decideEdges() {
227         uint size=edgeVector.getSize();
228         for(uint i=0; i<size; i++) {
229                 EncodingEdge *ee = edgeVector.get(i);
230                 EncodingNode *left = ee->left;
231                 EncodingNode *right = ee->right;
232                 
233                 if (ee->encoding != EDGE_UNASSIGNED ||
234                                 !left->couldBeBinaryIndex() ||
235                                 !right->couldBeBinaryIndex())
236                         continue;
237                 
238                 uint64_t eeValue = ee->getValue();
239                 if (eeValue == 0)
240                         return;
241
242                 EncodingSubGraph *leftGraph = graphMap.get(left);
243                 EncodingSubGraph *rightGraph = graphMap.get(right);
244                 if (leftGraph == NULL && rightGraph !=NULL) {
245                         EncodingNode *tmp = left; left=right; right=tmp;
246                         EncodingSubGraph *tmpsg = leftGraph; leftGraph = rightGraph; rightGraph = tmpsg;
247                 }
248
249                 uint leftSize=0, rightSize=0, newSize=0;
250                 uint64_t totalCost=0;
251                 if (leftGraph == NULL && rightGraph == NULL) {
252                         leftSize=convertSize(left->getSize());
253                         rightSize=convertSize(right->getSize());
254                         newSize=convertSize(left->s->getUnionSize(right->s));
255                         newSize=(leftSize > newSize) ? leftSize: newSize;
256                         newSize=(rightSize > newSize) ? rightSize: newSize;
257                         totalCost = (newSize - leftSize) * left->elements.getSize() +
258                                 (newSize - rightSize) * right->elements.getSize();
259                 } else if (leftGraph != NULL && rightGraph == NULL) {
260                         leftSize=convertSize(leftGraph->encodingSize);
261                         rightSize=convertSize(right->getSize());
262                         newSize=convertSize(leftGraph->estimateNewSize(right));
263                         newSize=(leftSize > newSize) ? leftSize: newSize;
264                         newSize=(rightSize > newSize) ? rightSize: newSize;
265                         totalCost = (newSize - leftSize) * leftGraph->numElements +
266                                 (newSize - rightSize) * right->elements.getSize();
267                 } else {
268                         //Neither are null
269                         leftSize=convertSize(leftGraph->encodingSize);
270                         rightSize=convertSize(rightGraph->encodingSize);
271                         newSize=convertSize(leftGraph->estimateNewSize(rightGraph));
272                         newSize=(leftSize > newSize) ? leftSize: newSize;
273                         newSize=(rightSize > newSize) ? rightSize: newSize;
274                         totalCost = (newSize - leftSize) * leftGraph->numElements +
275                                 (newSize - rightSize) * rightGraph->numElements;
276                 }
277                 double conversionfactor = 0.5;
278                 if ((totalCost * conversionfactor) < eeValue) {
279                         //add the edge
280                         mergeNodes(left, right);
281                 }
282         }
283 }
284
285 static TunableDesc EdgeEncodingDesc(EDGE_UNASSIGNED, EDGE_MATCH, EDGE_UNASSIGNED);
286
287 EncodingEdge * EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
288         EncodingEdge e(left, right, dst);
289         EncodingEdge *result = edgeMap.get(&e);
290         return result;
291 }
292
293 EncodingEdge * EncodingGraph::createEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
294         EncodingEdge e(left, right, dst);
295         EncodingEdge *result = edgeMap.get(&e);
296         if (result == NULL) {
297                 result=new EncodingEdge(left, right, dst);
298                 VarType v1=left->getType();
299                 VarType v2=right->getType();
300                 if (v1 > v2) {
301                         VarType tmp=v2;
302                         v2=v1;
303                         v1=tmp;
304                 }
305
306                 if ((left != NULL && left->couldBeBinaryIndex()) &&
307                                 (right != NULL) && right->couldBeBinaryIndex()) {
308                         EdgeEncodingType type=(EdgeEncodingType)solver->getTuner()->getVarTunable(v1, v2, EDGEENCODING, &EdgeEncodingDesc);
309                         result->setEncoding(type);
310                         if (type == EDGE_MATCH) {
311                                 mergeNodes(left, right);
312                         }
313                 }
314                 edgeMap.put(result, result);
315                 edgeVector.push(result);
316                 if (left != NULL)
317                         left->edges.add(result);
318                 if (right != NULL)
319                         right->edges.add(result);
320                 if (dst != NULL)
321                         dst->edges.add(result);
322         }
323         return result;
324 }
325
326 EncodingNode::EncodingNode(Set *_s) :
327         s(_s) {
328 }
329
330 uint EncodingNode::getSize() const {
331         return s->getSize();
332 }
333
334 VarType EncodingNode::getType() const {
335         return s->getType();
336 }
337
338 static TunableDesc NodeEncodingDesc(ELEM_UNASSIGNED, BINARYINDEX, ELEM_UNASSIGNED);
339
340 EncodingNode * EncodingGraph::createNode(Element *e) {
341         if (e->type == ELEMCONST)
342                 return NULL;
343         Set *s = e->getRange();
344         EncodingNode *n = encodingMap.get(s);
345         if (n == NULL) {
346                 n = new EncodingNode(s);
347                 n->setEncoding((ElementEncodingType)solver->getTuner()->getVarTunable(n->getType(), NODEENCODING, &NodeEncodingDesc));
348                 
349                 encodingMap.put(s, n);
350         }
351         n->addElement(e);
352         return n;
353 }
354
355 EncodingNode * EncodingGraph::getNode(Element *e) {
356         if (e->type == ELEMCONST)
357                 return NULL;
358         Set *s = e->getRange();
359         EncodingNode *n = encodingMap.get(s);
360         return n;
361 }
362
363 void EncodingNode::addElement(Element *e) {
364         elements.add(e);
365 }
366
367 EncodingEdge::EncodingEdge(EncodingNode *_l, EncodingNode *_r) :
368         left(_l),
369         right(_r),
370         dst(NULL),
371         encoding(EDGE_UNASSIGNED),
372         numArithOps(0),
373         numEquals(0),
374         numComparisons(0)
375 {
376 }
377
378 EncodingEdge::EncodingEdge(EncodingNode *_left, EncodingNode *_right, EncodingNode *_dst) :
379         left(_left),
380         right(_right),
381         dst(_dst),
382         encoding(EDGE_UNASSIGNED),
383         numArithOps(0),
384         numEquals(0),
385         numComparisons(0)
386 {
387 }
388
389 uint hashEncodingEdge(EncodingEdge *edge) {
390         uintptr_t hash=(((uintptr_t) edge->left) >> 2) ^ (((uintptr_t)edge->right) >> 4) ^ (((uintptr_t)edge->dst) >> 6);
391         return (uint) hash;
392 }
393
394 bool equalsEncodingEdge(EncodingEdge *e1, EncodingEdge *e2) {
395         return e1->left == e2->left && e1->right == e2->right && e1->dst == e2->dst;
396 }
397
398 uint64_t EncodingEdge::getValue() const {
399         uint lSize = (left != NULL) ? left->getSize() : 1;
400         uint rSize = (right != NULL) ? right->getSize() : 1;
401         uint min = (lSize < rSize) ? lSize : rSize;
402         return numEquals * min + numComparisons * lSize * rSize;
403 }
404
405