edits
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
1 #include "encodinggraph.h"
2 #include "iterator.h"
3 #include "element.h"
4 #include "function.h"
5 #include "predicate.h"
6 #include "set.h"
7 #include "csolver.h"
8 #include "tunable.h"
9 #include "qsort.h"
10
11 EncodingGraph::EncodingGraph(CSolver * _solver) :
12         solver(_solver) {
13 }
14
15 int sortEncodingEdge(const void * p1, const void *p2) {
16         const EncodingEdge * e1 = * (const EncodingEdge **) p1;
17         const EncodingEdge * e2 = * (const EncodingEdge **) p2;
18         uint64_t v1 = e1->getValue();
19         uint64_t v2 = e2->getValue();
20         if (v1 < v2)
21                 return 1;
22         else if (v1 == v2)
23                 return 0;
24         else
25                 return -1;
26 }
27
28 void EncodingGraph::buildGraph() {
29         ElementIterator it(solver);
30         while(it.hasNext()) {
31                 Element * e = it.next();
32                 switch(e->type) {
33                 case ELEMSET:
34                 case ELEMFUNCRETURN:
35                         processElement(e);
36                         break;
37                 case ELEMCONST:
38                         break;
39                 default:
40                         ASSERT(0);
41                 }
42         }
43         bsdqsort(edgeVector.expose(), edgeVector.getSize(), sizeof(EncodingEdge *), sortEncodingEdge);
44         decideEdges();
45 }
46
47 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
48         EncodingSubGraph *graph1=graphMap.get(first);
49         EncodingSubGraph *graph2=graphMap.get(second);
50         if (graph1 == NULL && graph2 == NULL) {
51                 graph1 = new EncodingSubGraph();
52                 graphMap.put(first, graph1);
53                 graph1->addNode(first);
54         }
55         if (graph1 == NULL && graph2 != NULL) {
56                 graph1 = graph2;
57                 graph2 = NULL;
58                 EncodingNode *tmp = second;
59                 second = first;
60                 first = tmp;
61         }
62         if (graph1 != NULL && graph2 != NULL) {
63                 SetIteratorEncodingNode * nodeit=graph2->nodeIterator();
64                 while(nodeit->hasNext()) {
65                         EncodingNode *node=nodeit->next();
66                         graph1->addNode(node);
67                         graphMap.put(node, graph1);
68                 }
69                 delete nodeit;
70                 delete graph2;
71         } else {
72                 ASSERT(graph1 != NULL && graph2 == NULL);
73                 graph1->addNode(second);
74                 graphMap.put(second, graph1);
75         }
76 }
77
78 void EncodingGraph::processElement(Element *e) {
79         uint size=e->parents.getSize();
80         for(uint i=0;i<size;i++) {
81                 ASTNode * n = e->parents.get(i);
82                 switch(n->type) {
83                 case PREDICATEOP:
84                         processPredicate((BooleanPredicate *)n);
85                         break;
86                 case ELEMFUNCRETURN:
87                         processFunction((ElementFunction *)n);
88                         break;
89                 default:
90                         ASSERT(0);
91                 }
92         }
93 }
94
95 void EncodingGraph::processFunction(ElementFunction *ef) {
96         Function *f=ef->getFunction();
97         if (f->type==OPERATORFUNC) {
98                 FunctionOperator *fo=(FunctionOperator*)f;
99                 ASSERT(ef->inputs.getSize() == 2);
100                 EncodingNode *left=createNode(ef->inputs.get(0));
101                 EncodingNode *right=createNode(ef->inputs.get(1));
102                 if (left == NULL && right == NULL)
103                         return;
104                 EncodingNode *dst=createNode(ef);
105                 EncodingEdge *edge=getEdge(left, right, dst);
106                 edge->numArithOps++;
107         }
108 }
109
110 void EncodingGraph::processPredicate(BooleanPredicate *b) {
111         Predicate *p=b->getPredicate();
112         if (p->type==OPERATORPRED) {
113                 PredicateOperator *po=(PredicateOperator *)p;
114                 ASSERT(b->inputs.getSize()==2);
115                 EncodingNode *left=createNode(b->inputs.get(0));
116                 EncodingNode *right=createNode(b->inputs.get(1));
117                 if (left == NULL || right == NULL)
118                         return;
119                 EncodingEdge *edge=getEdge(left, right, NULL);
120                 CompOp op=po->getOp();
121                 switch(op) {
122                 case SATC_EQUALS:
123                         edge->numEquals++;
124                         break;
125                 case SATC_LT:
126                 case SATC_LTE:
127                 case SATC_GT:
128                 case SATC_GTE:
129                         edge->numComparisons++;
130                         break;
131                 default:
132                         ASSERT(0);
133                 }
134         }
135 }
136
137 uint convertSize(uint cost) {
138         cost = 1.2 * cost; // fudge factor
139         return NEXTPOW2(cost);
140 }
141
142 void EncodingGraph::decideEdges() {
143         uint size=edgeVector.getSize();
144         for(uint i=0; i<size; i++) {
145                 EncodingEdge *ee = edgeVector.get(i);
146                 EncodingNode *left = ee->left;
147                 EncodingNode *right = ee->right;
148                 
149                 if (ee->encoding != EDGE_UNASSIGNED ||
150                                 left->encoding != BINARYINDEX ||
151                                 right->encoding != BINARYINDEX)
152                         continue;
153                 
154                 uint64_t eeValue = ee->getValue();
155                 if (eeValue == 0)
156                         return;
157
158                 EncodingSubGraph *leftGraph = graphMap.get(left);
159                 EncodingSubGraph *rightGraph = graphMap.get(right);
160                 if (leftGraph == NULL && rightGraph !=NULL) {
161                         EncodingNode *tmp = left; left=right; right=tmp;
162                         EncodingSubGraph *tmpsg = leftGraph; leftGraph = rightGraph; rightGraph = tmpsg;
163                 }
164
165                 uint leftSize=0, rightSize=0, newSize=0;
166                 uint64_t totalCost=0;
167                 if (leftGraph == NULL && rightGraph == NULL) {
168                         leftSize=convertSize(left->getSize());
169                         rightSize=convertSize(right->getSize());
170                         newSize=convertSize(left->s->getUnionSize(right->s));
171                         newSize=(leftSize > newSize) ? leftSize: newSize;
172                         newSize=(rightSize > newSize) ? rightSize: newSize;
173                         totalCost = (newSize - leftSize) * left->elements.getSize() +
174                                 (newSize - rightSize) * right->elements.getSize();
175                 } else if (leftGraph != NULL && rightGraph == NULL) {
176                         leftSize=convertSize(leftGraph->encodingSize);
177                         rightSize=convertSize(right->getSize());
178                         newSize=convertSize(leftGraph->estimateNewSize(right));
179                         newSize=(leftSize > newSize) ? leftSize: newSize;
180                         newSize=(rightSize > newSize) ? rightSize: newSize;
181                         totalCost = (newSize - leftSize) * leftGraph->numElements +
182                                 (newSize - rightSize) * right->elements.getSize();
183                 } else {
184                         //Neither are null
185                         leftSize=convertSize(leftGraph->encodingSize);
186                         rightSize=convertSize(rightGraph->encodingSize);
187                         newSize=convertSize(leftGraph->estimateNewSize(rightGraph));
188                         newSize=(leftSize > newSize) ? leftSize: newSize;
189                         newSize=(rightSize > newSize) ? rightSize: newSize;
190                         totalCost = (newSize - leftSize) * leftGraph->numElements +
191                                 (newSize - rightSize) * rightGraph->numElements;
192                 }
193                 double conversionfactor = 0.5;
194                 if ((totalCost * conversionfactor) < eeValue) {
195                         //add the edge
196                         mergeNodes(left, right);
197                 }
198         }
199 }
200
201 static TunableDesc EdgeEncodingDesc(EDGE_UNASSIGNED, EDGE_MATCH, EDGE_UNASSIGNED);
202
203 EncodingEdge * EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
204         EncodingEdge e(left, right, dst);
205         EncodingEdge *result = edgeMap.get(&e);
206         if (result == NULL) {
207                 result=new EncodingEdge(left, right, dst);
208                 VarType v1=left->getType();
209                 VarType v2=right->getType();
210                 if (v1 > v2) {
211                         VarType tmp=v2;
212                         v2=v1;
213                         v1=tmp;
214                 }
215
216                 if ((left != NULL && left->encoding==BINARYINDEX) &&
217                                 (right != NULL) && right->encoding==BINARYINDEX) {
218                         EdgeEncodingType type=(EdgeEncodingType)solver->getTuner()->getVarTunable(v1, v2, EDGEENCODING, &EdgeEncodingDesc);
219                         result->setEncoding(type);
220                         if (type == EDGE_MATCH) {
221                                 mergeNodes(left, right);
222                         }
223                 }
224                 edgeMap.put(result, result);
225                 edgeVector.push(result);
226                 if (left != NULL)
227                         left->edges.add(result);
228                 if (right != NULL)
229                         right->edges.add(result);
230                 if (dst != NULL)
231                         dst->edges.add(result);
232         }
233         return result;
234 }
235
236 EncodingNode::EncodingNode(Set *_s) :
237         s(_s) {
238 }
239
240 uint EncodingNode::getSize() const {
241         return s->getSize();
242 }
243
244 VarType EncodingNode::getType() const {
245         return s->getType();
246 }
247
248 static TunableDesc NodeEncodingDesc(ELEM_UNASSIGNED, BINARYINDEX, ELEM_UNASSIGNED);
249
250 EncodingNode * EncodingGraph::createNode(Element *e) {
251         if (e->type == ELEMCONST)
252                 return NULL;
253         Set *s = e->getRange();
254         EncodingNode *n = encodingMap.get(s);
255         if (n == NULL) {
256                 n = new EncodingNode(s);
257                 n->setEncoding((ElementEncodingType)solver->getTuner()->getVarTunable(n->getType(), NODEENCODING, &NodeEncodingDesc));
258                 encodingMap.put(s, n);
259         }
260         n->addElement(e);
261         return n;
262 }
263
264 void EncodingNode::addElement(Element *e) {
265         elements.add(e);
266 }
267
268 EncodingEdge::EncodingEdge(EncodingNode *_l, EncodingNode *_r) :
269         left(_l),
270         right(_r),
271         dst(NULL),
272         encoding(EDGE_UNASSIGNED),
273         numArithOps(0),
274         numEquals(0),
275         numComparisons(0)
276 {
277 }
278
279 EncodingEdge::EncodingEdge(EncodingNode *_left, EncodingNode *_right, EncodingNode *_dst) :
280         left(_left),
281         right(_right),
282         dst(_dst),
283         encoding(EDGE_UNASSIGNED),
284         numArithOps(0),
285         numEquals(0),
286         numComparisons(0)
287 {
288 }
289
290 uint hashEncodingEdge(EncodingEdge *edge) {
291         uintptr_t hash=(((uintptr_t) edge->left) >> 2) ^ (((uintptr_t)edge->right) >> 4) ^ (((uintptr_t)edge->dst) >> 6);
292         return (uint) hash;
293 }
294
295 bool equalsEncodingEdge(EncodingEdge *e1, EncodingEdge *e2) {
296         return e1->left == e2->left && e1->right == e2->right && e1->dst == e2->dst;
297 }
298
299 uint64_t EncodingEdge::getValue() const {
300         uint lSize = (left != NULL) ? left->getSize() : 1;
301         uint rSize = (right != NULL) ? right->getSize() : 1;
302         uint min = (lSize < rSize) ? lSize : rSize;
303         return numEquals * min + numComparisons * lSize * rSize;
304 }
305
306 EncodingSubGraph::EncodingSubGraph() :
307         encodingSize(0),
308         numElements(0) {
309 }
310
311 uint EncodingSubGraph::estimateNewSize(EncodingSubGraph *sg) {
312         uint newSize=0;
313         SetIteratorEncodingNode * nit = sg->nodes.iterator();
314         while(nit->hasNext()) {
315                 EncodingNode *en = nit->next();
316                 uint size=estimateNewSize(en);
317                 if (size > newSize)
318                         newSize = size;
319         }
320         delete nit;
321         return newSize;
322 }
323
324 uint EncodingSubGraph::estimateNewSize(EncodingNode *n) {
325         SetIteratorEncodingEdge * eeit = n->edges.iterator();
326         uint newsize=n->getSize();
327         while(eeit->hasNext()) {
328                 EncodingEdge * ee = eeit->next();
329                 if (ee->left != NULL && ee->left != n && nodes.contains(ee->left)) {
330                         uint intersectSize = n->s->getUnionSize(ee->left->s);
331                         if (intersectSize > newsize)
332                                 newsize = intersectSize;
333                 }
334                 if (ee->right != NULL && ee->right != n && nodes.contains(ee->right)) {
335                         uint intersectSize = n->s->getUnionSize(ee->right->s);
336                         if (intersectSize > newsize)
337                                 newsize = intersectSize;
338                 }
339                 if (ee->dst != NULL && ee->dst != n && nodes.contains(ee->dst)) {
340                         uint intersectSize = n->s->getUnionSize(ee->dst->s);
341                         if (intersectSize > newsize)
342                                 newsize = intersectSize;
343                 }
344         }
345         delete eeit;
346         return newsize;
347 }
348
349 void EncodingSubGraph::addNode(EncodingNode *n) {
350         nodes.add(n);
351         uint newSize=estimateNewSize(n);
352         numElements += n->elements.getSize();
353         if (newSize > encodingSize)
354                 encodingSize=newSize;
355 }
356
357 SetIteratorEncodingNode * EncodingSubGraph::nodeIterator() {
358         return nodes.iterator();
359 }