Subgraphing code in place
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
1 #include "encodinggraph.h"
2 #include "iterator.h"
3 #include "element.h"
4 #include "function.h"
5 #include "predicate.h"
6 #include "set.h"
7 #include "csolver.h"
8 #include "tunable.h"
9 #include "qsort.h"
10
11 EncodingGraph::EncodingGraph(CSolver * _solver) :
12         solver(_solver) {
13 }
14
15 int sortEncodingEdge(const void * p1, const void *p2) {
16         const EncodingEdge * e1 = * (const EncodingEdge **) p1;
17         const EncodingEdge * e2 = * (const EncodingEdge **) p2;
18         uint64_t v1 = e1->getValue();
19         uint64_t v2 = e2->getValue();
20         if (v1 < v2)
21                 return 1;
22         else if (v1 == v2)
23                 return 0;
24         else
25                 return -1;
26 }
27
28 void EncodingGraph::buildGraph() {
29         ElementIterator it(solver);
30         while(it.hasNext()) {
31                 Element * e = it.next();
32                 switch(e->type) {
33                 case ELEMSET:
34                 case ELEMFUNCRETURN:
35                         processElement(e);
36                         break;
37                 case ELEMCONST:
38                         break;
39                 default:
40                         ASSERT(0);
41                 }
42         }
43         bsdqsort(edgeVector.expose(), edgeVector.getSize(), sizeof(EncodingEdge *), sortEncodingEdge);
44 }
45
46 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
47         EncodingSubGraph *graph1=graphMap.get(first);
48         EncodingSubGraph *graph2=graphMap.get(second);
49         if (graph1 == NULL && graph2 == NULL) {
50                 graph1 = new EncodingSubGraph();
51                 graphMap.put(first, graph1);
52                 graph1->addNode(first);
53         }
54         if (graph1 == NULL && graph2 != NULL) {
55                 graph1 = graph2;
56                 graph2 = NULL;
57                 EncodingNode *tmp = second;
58                 second = first;
59                 first = tmp;
60         }
61         if (graph1 != NULL && graph2 != NULL) {
62                 SetIteratorEncodingNode * nodeit=graph2->nodeIterator();
63                 while(nodeit->hasNext()) {
64                         EncodingNode *node=nodeit->next();
65                         graph1->addNode(node);
66                         graphMap.put(node, graph1);
67                 }
68                 delete nodeit;
69                 delete graph2;
70         } else {
71                 ASSERT(graph1 != NULL && graph2 == NULL);
72                 graph1->addNode(second);
73                 graphMap.put(second, graph1);
74         }
75 }
76
77 void EncodingGraph::processElement(Element *e) {
78         uint size=e->parents.getSize();
79         for(uint i=0;i<size;i++) {
80                 ASTNode * n = e->parents.get(i);
81                 switch(n->type) {
82                 case PREDICATEOP:
83                         processPredicate((BooleanPredicate *)n);
84                         break;
85                 case ELEMFUNCRETURN:
86                         processFunction((ElementFunction *)n);
87                         break;
88                 default:
89                         ASSERT(0);
90                 }
91         }
92 }
93
94 void EncodingGraph::processFunction(ElementFunction *ef) {
95         Function *f=ef->getFunction();
96         if (f->type==OPERATORFUNC) {
97                 FunctionOperator *fo=(FunctionOperator*)f;
98                 ASSERT(ef->inputs.getSize() == 2);
99                 EncodingNode *left=createNode(ef->inputs.get(0));
100                 EncodingNode *right=createNode(ef->inputs.get(1));
101                 if (left == NULL && right == NULL)
102                         return;
103                 EncodingNode *dst=createNode(ef);
104                 EncodingEdge *edge=getEdge(left, right, dst);
105                 edge->numArithOps++;
106         }
107 }
108
109 void EncodingGraph::processPredicate(BooleanPredicate *b) {
110         Predicate *p=b->getPredicate();
111         if (p->type==OPERATORPRED) {
112                 PredicateOperator *po=(PredicateOperator *)p;
113                 ASSERT(b->inputs.getSize()==2);
114                 EncodingNode *left=createNode(b->inputs.get(0));
115                 EncodingNode *right=createNode(b->inputs.get(1));
116                 if (left == NULL || right == NULL)
117                         return;
118                 EncodingEdge *edge=getEdge(left, right, NULL);
119                 CompOp op=po->getOp();
120                 switch(op) {
121                 case SATC_EQUALS:
122                         edge->numEquals++;
123                         break;
124                 case SATC_LT:
125                 case SATC_LTE:
126                 case SATC_GT:
127                 case SATC_GTE:
128                         edge->numComparisons++;
129                         break;
130                 default:
131                         ASSERT(0);
132                 }
133         }
134 }
135
136 uint convertSize(uint cost) {
137         cost = 1.2 * cost; // fudge factor
138         return NEXTPOW2(cost);
139 }
140
141 void EncodingGraph::decideEdges() {
142         uint size=edgeVector.getSize();
143         for(uint i=0; i<size; i++) {
144                 EncodingEdge *ee = edgeVector.get(i);
145                 if (ee->encoding != EDGE_UNASSIGNED)
146                         continue;
147                 
148                 uint64_t eeValue = ee->getValue();
149                 if (eeValue == 0)
150                         return;
151                 EncodingNode *left = ee->left;
152                 EncodingNode *right = ee->right;
153                 EncodingSubGraph *leftGraph = graphMap.get(left);
154                 EncodingSubGraph *rightGraph = graphMap.get(right);
155                 if (leftGraph == NULL && rightGraph !=NULL) {
156                         EncodingNode *tmp = left; left=right; right=tmp;
157                         EncodingSubGraph *tmpsg = leftGraph; leftGraph = rightGraph; rightGraph = tmpsg;
158                 }
159
160                 uint leftSize=0, rightSize=0, newSize=0;
161                 uint64_t totalCost=0;
162                 if (leftGraph == NULL && rightGraph == NULL) {
163                         leftSize=convertSize(left->getSize());
164                         rightSize=convertSize(right->getSize());
165                         newSize=convertSize(left->s->getUnionSize(right->s));
166                         newSize=(leftSize > newSize) ? leftSize: newSize;
167                         newSize=(rightSize > newSize) ? rightSize: newSize;
168                         totalCost = (newSize - leftSize) * left->elements.getSize() +
169                                 (newSize - rightSize) * right->elements.getSize();
170                 } else if (leftGraph != NULL && rightGraph == NULL) {
171                         leftSize=convertSize(leftGraph->encodingSize);
172                         rightSize=convertSize(right->getSize());
173                         newSize=convertSize(leftGraph->estimateNewSize(right));
174                         newSize=(leftSize > newSize) ? leftSize: newSize;
175                         newSize=(rightSize > newSize) ? rightSize: newSize;
176                         totalCost = (newSize - leftSize) * leftGraph->numElements +
177                                 (newSize - rightSize) * right->elements.getSize();
178                 } else {
179                         //Neither are null
180                         leftSize=convertSize(leftGraph->encodingSize);
181                         rightSize=convertSize(rightGraph->encodingSize);
182                         newSize=convertSize(leftGraph->estimateNewSize(rightGraph));
183                         newSize=(leftSize > newSize) ? leftSize: newSize;
184                         newSize=(rightSize > newSize) ? rightSize: newSize;
185                         totalCost = (newSize - leftSize) * leftGraph->numElements +
186                                 (newSize - rightSize) * rightGraph->numElements;
187                 }
188                 double conversionfactor = 0.5;
189                 if ((totalCost * conversionfactor) < eeValue) {
190                         //add the edge
191                         mergeNodes(left, right);
192                 }
193         }
194 }
195
196 static TunableDesc EdgeEncodingDesc(EDGE_UNASSIGNED, EDGE_MATCH, EDGE_UNASSIGNED);
197
198 EncodingEdge * EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
199         EncodingEdge e(left, right, dst);
200         EncodingEdge *result = edgeMap.get(&e);
201         if (result == NULL) {
202                 result=new EncodingEdge(left, right, dst);
203                 VarType v1=left->getType();
204                 VarType v2=right->getType();
205                 if (v1 > v2) {
206                         VarType tmp=v2;
207                         v2=v1;
208                         v1=tmp;
209                 }
210
211                 if ((left != NULL && left->encoding==BINARYINDEX) &&
212                                 (right != NULL) && right->encoding==BINARYINDEX) {
213                         EdgeEncodingType type=(EdgeEncodingType)solver->getTuner()->getVarTunable(v1, v2, EDGEENCODING, &EdgeEncodingDesc);
214                         result->setEncoding(type);
215                         if (type == EDGE_MATCH) {
216                                 mergeNodes(left, right);
217                         }
218                 }
219                 edgeMap.put(result, result);
220                 edgeVector.push(result);
221                 if (left != NULL)
222                         left->edges.add(result);
223                 if (right != NULL)
224                         right->edges.add(result);
225                 if (dst != NULL)
226                         dst->edges.add(result);
227         }
228         return result;
229 }
230
231 EncodingNode::EncodingNode(Set *_s) :
232         s(_s) {
233 }
234
235 uint EncodingNode::getSize() const {
236         return s->getSize();
237 }
238
239 VarType EncodingNode::getType() const {
240         return s->getType();
241 }
242
243 static TunableDesc NodeEncodingDesc(ELEM_UNASSIGNED, BINARYINDEX, ELEM_UNASSIGNED);
244
245 EncodingNode * EncodingGraph::createNode(Element *e) {
246         if (e->type == ELEMCONST)
247                 return NULL;
248         Set *s = e->getRange();
249         EncodingNode *n = encodingMap.get(s);
250         if (n == NULL) {
251                 n = new EncodingNode(s);
252                 n->setEncoding((ElementEncodingType)solver->getTuner()->getVarTunable(n->getType(), NODEENCODING, &NodeEncodingDesc));
253                 encodingMap.put(s, n);
254         }
255         n->addElement(e);
256         return n;
257 }
258
259 void EncodingNode::addElement(Element *e) {
260         elements.add(e);
261 }
262
263 EncodingEdge::EncodingEdge(EncodingNode *_l, EncodingNode *_r) :
264         left(_l),
265         right(_r),
266         dst(NULL),
267         encoding(EDGE_UNASSIGNED),
268         numArithOps(0),
269         numEquals(0),
270         numComparisons(0)
271 {
272 }
273
274 EncodingEdge::EncodingEdge(EncodingNode *_left, EncodingNode *_right, EncodingNode *_dst) :
275         left(_left),
276         right(_right),
277         dst(_dst),
278         encoding(EDGE_UNASSIGNED),
279         numArithOps(0),
280         numEquals(0),
281         numComparisons(0)
282 {
283 }
284
285 uint hashEncodingEdge(EncodingEdge *edge) {
286         uintptr_t hash=(((uintptr_t) edge->left) >> 2) ^ (((uintptr_t)edge->right) >> 4) ^ (((uintptr_t)edge->dst) >> 6);
287         return (uint) hash;
288 }
289
290 bool equalsEncodingEdge(EncodingEdge *e1, EncodingEdge *e2) {
291         return e1->left == e2->left && e1->right == e2->right && e1->dst == e2->dst;
292 }
293
294 uint64_t EncodingEdge::getValue() const {
295         uint lSize = (left != NULL) ? left->getSize() : 1;
296         uint rSize = (right != NULL) ? right->getSize() : 1;
297         uint min = (lSize < rSize) ? lSize : rSize;
298         return numEquals * min + numComparisons * lSize * rSize;
299 }
300
301 EncodingSubGraph::EncodingSubGraph() :
302         encodingSize(0),
303         numElements(0) {
304 }
305
306 uint EncodingSubGraph::estimateNewSize(EncodingSubGraph *sg) {
307         uint newSize=0;
308         SetIteratorEncodingNode * nit = sg->nodes.iterator();
309         while(nit->hasNext()) {
310                 EncodingNode *en = nit->next();
311                 uint size=estimateNewSize(en);
312                 if (size > newSize)
313                         newSize = size;
314         }
315         delete nit;
316         return newSize;
317 }
318
319 uint EncodingSubGraph::estimateNewSize(EncodingNode *n) {
320         SetIteratorEncodingEdge * eeit = n->edges.iterator();
321         uint newsize=n->getSize();
322         while(eeit->hasNext()) {
323                 EncodingEdge * ee = eeit->next();
324                 if (ee->left != NULL && ee->left != n && nodes.contains(ee->left)) {
325                         uint intersectSize = n->s->getUnionSize(ee->left->s);
326                         if (intersectSize > newsize)
327                                 newsize = intersectSize;
328                 }
329                 if (ee->right != NULL && ee->right != n && nodes.contains(ee->right)) {
330                         uint intersectSize = n->s->getUnionSize(ee->right->s);
331                         if (intersectSize > newsize)
332                                 newsize = intersectSize;
333                 }
334                 if (ee->dst != NULL && ee->dst != n && nodes.contains(ee->dst)) {
335                         uint intersectSize = n->s->getUnionSize(ee->dst->s);
336                         if (intersectSize > newsize)
337                                 newsize = intersectSize;
338                 }
339         }
340         delete eeit;
341         return newsize;
342 }
343
344 void EncodingSubGraph::addNode(EncodingNode *n) {
345         nodes.add(n);
346         uint newSize=estimateNewSize(n);
347         numElements += n->elements.getSize();
348         if (newSize > encodingSize)
349                 encodingSize=newSize;
350 }
351
352 SetIteratorEncodingNode * EncodingSubGraph::nodeIterator() {
353         return nodes.iterator();
354 }