backout changes
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
1 #include "encodinggraph.h"
2 #include "iterator.h"
3 #include "element.h"
4 #include "function.h"
5 #include "predicate.h"
6 #include "set.h"
7 #include "csolver.h"
8 #include "tunable.h"
9 #include "qsort.h"
10 #include "subgraph.h"
11 #include "elementencoding.h"
12
13 EncodingGraph::EncodingGraph(CSolver * _solver) :
14         solver(_solver) {
15 }
16
17 int sortEncodingEdge(const void * p1, const void *p2) {
18         const EncodingEdge * e1 = * (const EncodingEdge **) p1;
19         const EncodingEdge * e2 = * (const EncodingEdge **) p2;
20         uint64_t v1 = e1->getValue();
21         uint64_t v2 = e2->getValue();
22         if (v1 < v2)
23                 return 1;
24         else if (v1 == v2)
25                 return 0;
26         else
27                 return -1;
28 }
29
30 void EncodingGraph::buildGraph() {
31         ElementIterator it(solver);
32         while(it.hasNext()) {
33                 Element * e = it.next();
34                 switch(e->type) {
35                 case ELEMSET:
36                 case ELEMFUNCRETURN:
37                         processElement(e);
38                         break;
39                 case ELEMCONST:
40                         break;
41                 default:
42                         ASSERT(0);
43                 }
44         }
45         bsdqsort(edgeVector.expose(), edgeVector.getSize(), sizeof(EncodingEdge *), sortEncodingEdge);
46         decideEdges();
47 }
48
49 void EncodingGraph::encode() {
50         SetIteratorEncodingSubGraph * itesg=subgraphs.iterator();
51         while(itesg->hasNext()) {
52                 EncodingSubGraph *sg=itesg->next();
53                 sg->encode();
54         }
55         delete itesg;
56
57         ElementIterator it(solver);
58         while(it.hasNext()) {
59                 Element * e = it.next();
60                 switch(e->type) {
61                 case ELEMSET:
62                 case ELEMFUNCRETURN: {
63                         ElementEncoding *encoding=e->getElementEncoding();
64                         if (encoding->getElementEncodingType() == ELEM_UNASSIGNED) {
65                                 EncodingNode *n = getNode(e);
66                                 ASSERT(n != NULL);
67                                 ElementEncodingType encodetype=n->getEncoding();
68                                 encoding->setElementEncodingType(encodetype);
69                                 if (encodetype == UNARY || encodetype == ONEHOT) {
70                                         encoding->encodingArrayInitialization();
71                                 } else if (encodetype == BINARYINDEX) {
72                                         EncodingSubGraph * subgraph = graphMap.get(n);
73                                         uint encodingSize = subgraph->getEncodingSize(n);
74                                         uint paddedSize = encoding->getSizeEncodingArray(encodingSize);
75                                         encoding->allocInUseArrayElement(paddedSize);
76                                         encoding->allocEncodingArrayElement(paddedSize);
77                                         Set * s=e->getRange();
78                                         for(uint i=0;i<s->getSize();i++) {
79                                                 uint64_t value=s->getElement(i);
80                                                 uint encodingIndex=subgraph->getEncoding(n, value);
81                                                 encoding->setInUseElement(encodingIndex);
82                                                 encoding->encodingArray[encodingIndex] = value;
83                                         }
84                                 }
85                         }
86                         break;
87                 }
88                 default:
89                         break;
90                 }
91                 encodeParent(e);
92         }
93 }
94
95 void EncodingGraph::encodeParent(Element *e) {
96         uint size=e->parents.getSize();
97         for(uint i=0;i<size;i++) {
98                 ASTNode * n = e->parents.get(i);
99                 if (n->type==PREDICATEOP) {
100                         BooleanPredicate *b=(BooleanPredicate *)n;
101                         FunctionEncoding *fenc=b->getFunctionEncoding();
102                         if (fenc->getFunctionEncodingType() != FUNC_UNASSIGNED)
103                                 continue;
104                         Predicate *p=b->getPredicate();
105                         if (p->type==OPERATORPRED) {
106                                 PredicateOperator *po=(PredicateOperator *)p;
107                                 ASSERT(b->inputs.getSize()==2);
108                                 EncodingNode *left=createNode(b->inputs.get(0));
109                                 EncodingNode *right=createNode(b->inputs.get(1));
110                                 if (left == NULL || right == NULL)
111                                         return;
112                                 EncodingEdge *edge=getEdge(left, right, NULL);
113                                 if (edge != NULL && edge->getEncoding() == EDGE_MATCH) {
114                                         fenc->setFunctionEncodingType(CIRCUIT);
115                                 }
116                         }
117                 }
118         }
119 }
120
121 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
122         EncodingSubGraph *graph1=graphMap.get(first);
123         EncodingSubGraph *graph2=graphMap.get(second);
124         if (graph1 == NULL && graph2 == NULL) {
125                 graph1 = new EncodingSubGraph();
126                 subgraphs.add(graph1);
127                 graphMap.put(first, graph1);
128                 graph1->addNode(first);
129         }
130         if (graph1 == NULL && graph2 != NULL) {
131                 graph1 = graph2;
132                 graph2 = NULL;
133                 EncodingNode *tmp = second;
134                 second = first;
135                 first = tmp;
136         }
137         if (graph1 != NULL && graph2 != NULL) {
138                 SetIteratorEncodingNode * nodeit=graph2->nodeIterator();
139                 while(nodeit->hasNext()) {
140                         EncodingNode *node=nodeit->next();
141                         graph1->addNode(node);
142                         graphMap.put(node, graph1);
143                 }
144                 subgraphs.remove(graph2);
145                 delete nodeit;
146                 delete graph2;
147         } else {
148                 ASSERT(graph1 != NULL && graph2 == NULL);
149                 graph1->addNode(second);
150                 graphMap.put(second, graph1);
151         }
152 }
153
154 void EncodingGraph::processElement(Element *e) {
155         uint size=e->parents.getSize();
156         for(uint i=0;i<size;i++) {
157                 ASTNode * n = e->parents.get(i);
158                 switch(n->type) {
159                 case PREDICATEOP:
160                         processPredicate((BooleanPredicate *)n);
161                         break;
162                 case ELEMFUNCRETURN:
163                         processFunction((ElementFunction *)n);
164                         break;
165                 default:
166                         ASSERT(0);
167                 }
168         }
169 }
170
171 void EncodingGraph::processFunction(ElementFunction *ef) {
172         Function *f=ef->getFunction();
173         if (f->type==OPERATORFUNC) {
174                 FunctionOperator *fo=(FunctionOperator*)f;
175                 ASSERT(ef->inputs.getSize() == 2);
176                 EncodingNode *left=createNode(ef->inputs.get(0));
177                 EncodingNode *right=createNode(ef->inputs.get(1));
178                 if (left == NULL && right == NULL)
179                         return;
180                 EncodingNode *dst=createNode(ef);
181                 EncodingEdge *edge=createEdge(left, right, dst);
182                 edge->numArithOps++;
183         }
184 }
185
186 void EncodingGraph::processPredicate(BooleanPredicate *b) {
187         Predicate *p=b->getPredicate();
188         if (p->type==OPERATORPRED) {
189                 PredicateOperator *po=(PredicateOperator *)p;
190                 ASSERT(b->inputs.getSize()==2);
191                 EncodingNode *left=createNode(b->inputs.get(0));
192                 EncodingNode *right=createNode(b->inputs.get(1));
193                 if (left == NULL || right == NULL)
194                         return;
195                 EncodingEdge *edge=createEdge(left, right, NULL);
196                 CompOp op=po->getOp();
197                 switch(op) {
198                 case SATC_EQUALS:
199                         edge->numEquals++;
200                         break;
201                 case SATC_LT:
202                 case SATC_LTE:
203                 case SATC_GT:
204                 case SATC_GTE:
205                         edge->numComparisons++;
206                         break;
207                 default:
208                         ASSERT(0);
209                 }
210         }
211 }
212
213 uint convertSize(uint cost) {
214         cost = 1.2 * cost; // fudge factor
215         return NEXTPOW2(cost);
216 }
217
218 void EncodingGraph::decideEdges() {
219         uint size=edgeVector.getSize();
220         for(uint i=0; i<size; i++) {
221                 EncodingEdge *ee = edgeVector.get(i);
222                 EncodingNode *left = ee->left;
223                 EncodingNode *right = ee->right;
224                 
225                 if (ee->encoding != EDGE_UNASSIGNED ||
226                                 left->encoding != BINARYINDEX ||
227                                 right->encoding != BINARYINDEX)
228                         continue;
229                 
230                 uint64_t eeValue = ee->getValue();
231                 if (eeValue == 0)
232                         return;
233
234                 EncodingSubGraph *leftGraph = graphMap.get(left);
235                 EncodingSubGraph *rightGraph = graphMap.get(right);
236                 if (leftGraph == NULL && rightGraph !=NULL) {
237                         EncodingNode *tmp = left; left=right; right=tmp;
238                         EncodingSubGraph *tmpsg = leftGraph; leftGraph = rightGraph; rightGraph = tmpsg;
239                 }
240
241                 uint leftSize=0, rightSize=0, newSize=0;
242                 uint64_t totalCost=0;
243                 if (leftGraph == NULL && rightGraph == NULL) {
244                         leftSize=convertSize(left->getSize());
245                         rightSize=convertSize(right->getSize());
246                         newSize=convertSize(left->s->getUnionSize(right->s));
247                         newSize=(leftSize > newSize) ? leftSize: newSize;
248                         newSize=(rightSize > newSize) ? rightSize: newSize;
249                         totalCost = (newSize - leftSize) * left->elements.getSize() +
250                                 (newSize - rightSize) * right->elements.getSize();
251                 } else if (leftGraph != NULL && rightGraph == NULL) {
252                         leftSize=convertSize(leftGraph->encodingSize);
253                         rightSize=convertSize(right->getSize());
254                         newSize=convertSize(leftGraph->estimateNewSize(right));
255                         newSize=(leftSize > newSize) ? leftSize: newSize;
256                         newSize=(rightSize > newSize) ? rightSize: newSize;
257                         totalCost = (newSize - leftSize) * leftGraph->numElements +
258                                 (newSize - rightSize) * right->elements.getSize();
259                 } else {
260                         //Neither are null
261                         leftSize=convertSize(leftGraph->encodingSize);
262                         rightSize=convertSize(rightGraph->encodingSize);
263                         newSize=convertSize(leftGraph->estimateNewSize(rightGraph));
264                         newSize=(leftSize > newSize) ? leftSize: newSize;
265                         newSize=(rightSize > newSize) ? rightSize: newSize;
266                         totalCost = (newSize - leftSize) * leftGraph->numElements +
267                                 (newSize - rightSize) * rightGraph->numElements;
268                 }
269                 double conversionfactor = 0.5;
270                 if ((totalCost * conversionfactor) < eeValue) {
271                         //add the edge
272                         mergeNodes(left, right);
273                 }
274         }
275 }
276
277 static TunableDesc EdgeEncodingDesc(EDGE_UNASSIGNED, EDGE_MATCH, EDGE_UNASSIGNED);
278
279 EncodingEdge * EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
280         EncodingEdge e(left, right, dst);
281         EncodingEdge *result = edgeMap.get(&e);
282         return result;
283 }
284
285 EncodingEdge * EncodingGraph::createEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
286         EncodingEdge e(left, right, dst);
287         EncodingEdge *result = edgeMap.get(&e);
288         if (result == NULL) {
289                 result=new EncodingEdge(left, right, dst);
290                 VarType v1=left->getType();
291                 VarType v2=right->getType();
292                 if (v1 > v2) {
293                         VarType tmp=v2;
294                         v2=v1;
295                         v1=tmp;
296                 }
297
298                 if ((left != NULL && left->encoding==BINARYINDEX) &&
299                                 (right != NULL) && right->encoding==BINARYINDEX) {
300                         EdgeEncodingType type=(EdgeEncodingType)solver->getTuner()->getVarTunable(v1, v2, EDGEENCODING, &EdgeEncodingDesc);
301                         result->setEncoding(type);
302                         if (type == EDGE_MATCH) {
303                                 mergeNodes(left, right);
304                         }
305                 }
306                 edgeMap.put(result, result);
307                 edgeVector.push(result);
308                 if (left != NULL)
309                         left->edges.add(result);
310                 if (right != NULL)
311                         right->edges.add(result);
312                 if (dst != NULL)
313                         dst->edges.add(result);
314         }
315         return result;
316 }
317
318 EncodingNode::EncodingNode(Set *_s) :
319         s(_s) {
320 }
321
322 uint EncodingNode::getSize() const {
323         return s->getSize();
324 }
325
326 VarType EncodingNode::getType() const {
327         return s->getType();
328 }
329
330 static TunableDesc NodeEncodingDesc(ELEM_UNASSIGNED, BINARYINDEX, ELEM_UNASSIGNED);
331
332 EncodingNode * EncodingGraph::createNode(Element *e) {
333         if (e->type == ELEMCONST)
334                 return NULL;
335         Set *s = e->getRange();
336         EncodingNode *n = encodingMap.get(s);
337         if (n == NULL) {
338                 n = new EncodingNode(s);
339                 n->setEncoding((ElementEncodingType)solver->getTuner()->getVarTunable(n->getType(), NODEENCODING, &NodeEncodingDesc));
340                 encodingMap.put(s, n);
341         }
342         n->addElement(e);
343         return n;
344 }
345
346 EncodingNode * EncodingGraph::getNode(Element *e) {
347         if (e->type == ELEMCONST)
348                 return NULL;
349         Set *s = e->getRange();
350         EncodingNode *n = encodingMap.get(s);
351         return n;
352 }
353
354 void EncodingNode::addElement(Element *e) {
355         elements.add(e);
356 }
357
358 EncodingEdge::EncodingEdge(EncodingNode *_l, EncodingNode *_r) :
359         left(_l),
360         right(_r),
361         dst(NULL),
362         encoding(EDGE_UNASSIGNED),
363         numArithOps(0),
364         numEquals(0),
365         numComparisons(0)
366 {
367 }
368
369 EncodingEdge::EncodingEdge(EncodingNode *_left, EncodingNode *_right, EncodingNode *_dst) :
370         left(_left),
371         right(_right),
372         dst(_dst),
373         encoding(EDGE_UNASSIGNED),
374         numArithOps(0),
375         numEquals(0),
376         numComparisons(0)
377 {
378 }
379
380 uint hashEncodingEdge(EncodingEdge *edge) {
381         uintptr_t hash=(((uintptr_t) edge->left) >> 2) ^ (((uintptr_t)edge->right) >> 4) ^ (((uintptr_t)edge->dst) >> 6);
382         return (uint) hash;
383 }
384
385 bool equalsEncodingEdge(EncodingEdge *e1, EncodingEdge *e2) {
386         return e1->left == e2->left && e1->right == e2->right && e1->dst == e2->dst;
387 }
388
389 uint64_t EncodingEdge::getValue() const {
390         uint lSize = (left != NULL) ? left->getSize() : 1;
391         uint rSize = (right != NULL) ? right->getSize() : 1;
392         uint min = (lSize < rSize) ? lSize : rSize;
393         return numEquals * min + numComparisons * lSize * rSize;
394 }
395
396