Merge
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
1 #include "encodinggraph.h"
2 #include "iterator.h"
3 #include "element.h"
4 #include "function.h"
5 #include "predicate.h"
6 #include "set.h"
7 #include "csolver.h"
8 #include "tunable.h"
9 #include "qsort.h"
10 #include "subgraph.h"
11
12 EncodingGraph::EncodingGraph(CSolver * _solver) :
13         solver(_solver) {
14 }
15
16 int sortEncodingEdge(const void * p1, const void *p2) {
17         const EncodingEdge * e1 = * (const EncodingEdge **) p1;
18         const EncodingEdge * e2 = * (const EncodingEdge **) p2;
19         uint64_t v1 = e1->getValue();
20         uint64_t v2 = e2->getValue();
21         if (v1 < v2)
22                 return 1;
23         else if (v1 == v2)
24                 return 0;
25         else
26                 return -1;
27 }
28
29 void EncodingGraph::buildGraph() {
30         ElementIterator it(solver);
31         while(it.hasNext()) {
32                 Element * e = it.next();
33                 switch(e->type) {
34                 case ELEMSET:
35                 case ELEMFUNCRETURN:
36                         processElement(e);
37                         break;
38                 case ELEMCONST:
39                         break;
40                 default:
41                         ASSERT(0);
42                 }
43         }
44         bsdqsort(edgeVector.expose(), edgeVector.getSize(), sizeof(EncodingEdge *), sortEncodingEdge);
45         decideEdges();
46 }
47
48 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
49         EncodingSubGraph *graph1=graphMap.get(first);
50         EncodingSubGraph *graph2=graphMap.get(second);
51         if (graph1 == NULL && graph2 == NULL) {
52                 graph1 = new EncodingSubGraph();
53                 graphMap.put(first, graph1);
54                 graph1->addNode(first);
55         }
56         if (graph1 == NULL && graph2 != NULL) {
57                 graph1 = graph2;
58                 graph2 = NULL;
59                 EncodingNode *tmp = second;
60                 second = first;
61                 first = tmp;
62         }
63         if (graph1 != NULL && graph2 != NULL) {
64                 SetIteratorEncodingNode * nodeit=graph2->nodeIterator();
65                 while(nodeit->hasNext()) {
66                         EncodingNode *node=nodeit->next();
67                         graph1->addNode(node);
68                         graphMap.put(node, graph1);
69                 }
70                 delete nodeit;
71                 delete graph2;
72         } else {
73                 ASSERT(graph1 != NULL && graph2 == NULL);
74                 graph1->addNode(second);
75                 graphMap.put(second, graph1);
76         }
77 }
78
79 void EncodingGraph::processElement(Element *e) {
80         uint size=e->parents.getSize();
81         for(uint i=0;i<size;i++) {
82                 ASTNode * n = e->parents.get(i);
83                 switch(n->type) {
84                 case PREDICATEOP:
85                         processPredicate((BooleanPredicate *)n);
86                         break;
87                 case ELEMFUNCRETURN:
88                         processFunction((ElementFunction *)n);
89                         break;
90                 default:
91                         ASSERT(0);
92                 }
93         }
94 }
95
96 void EncodingGraph::processFunction(ElementFunction *ef) {
97         Function *f=ef->getFunction();
98         if (f->type==OPERATORFUNC) {
99                 FunctionOperator *fo=(FunctionOperator*)f;
100                 ASSERT(ef->inputs.getSize() == 2);
101                 EncodingNode *left=createNode(ef->inputs.get(0));
102                 EncodingNode *right=createNode(ef->inputs.get(1));
103                 if (left == NULL && right == NULL)
104                         return;
105                 EncodingNode *dst=createNode(ef);
106                 EncodingEdge *edge=getEdge(left, right, dst);
107                 edge->numArithOps++;
108         }
109 }
110
111 void EncodingGraph::processPredicate(BooleanPredicate *b) {
112         Predicate *p=b->getPredicate();
113         if (p->type==OPERATORPRED) {
114                 PredicateOperator *po=(PredicateOperator *)p;
115                 ASSERT(b->inputs.getSize()==2);
116                 EncodingNode *left=createNode(b->inputs.get(0));
117                 EncodingNode *right=createNode(b->inputs.get(1));
118                 if (left == NULL || right == NULL)
119                         return;
120                 EncodingEdge *edge=getEdge(left, right, NULL);
121                 CompOp op=po->getOp();
122                 switch(op) {
123                 case SATC_EQUALS:
124                         edge->numEquals++;
125                         break;
126                 case SATC_LT:
127                 case SATC_LTE:
128                 case SATC_GT:
129                 case SATC_GTE:
130                         edge->numComparisons++;
131                         break;
132                 default:
133                         ASSERT(0);
134                 }
135         }
136 }
137
138 uint convertSize(uint cost) {
139         cost = 1.2 * cost; // fudge factor
140         return NEXTPOW2(cost);
141 }
142
143 void EncodingGraph::decideEdges() {
144         uint size=edgeVector.getSize();
145         for(uint i=0; i<size; i++) {
146                 EncodingEdge *ee = edgeVector.get(i);
147                 EncodingNode *left = ee->left;
148                 EncodingNode *right = ee->right;
149                 
150                 if (ee->encoding != EDGE_UNASSIGNED ||
151                                 left->encoding != BINARYINDEX ||
152                                 right->encoding != BINARYINDEX)
153                         continue;
154                 
155                 uint64_t eeValue = ee->getValue();
156                 if (eeValue == 0)
157                         return;
158
159                 EncodingSubGraph *leftGraph = graphMap.get(left);
160                 EncodingSubGraph *rightGraph = graphMap.get(right);
161                 if (leftGraph == NULL && rightGraph !=NULL) {
162                         EncodingNode *tmp = left; left=right; right=tmp;
163                         EncodingSubGraph *tmpsg = leftGraph; leftGraph = rightGraph; rightGraph = tmpsg;
164                 }
165
166                 uint leftSize=0, rightSize=0, newSize=0;
167                 uint64_t totalCost=0;
168                 if (leftGraph == NULL && rightGraph == NULL) {
169                         leftSize=convertSize(left->getSize());
170                         rightSize=convertSize(right->getSize());
171                         newSize=convertSize(left->s->getUnionSize(right->s));
172                         newSize=(leftSize > newSize) ? leftSize: newSize;
173                         newSize=(rightSize > newSize) ? rightSize: newSize;
174                         totalCost = (newSize - leftSize) * left->elements.getSize() +
175                                 (newSize - rightSize) * right->elements.getSize();
176                 } else if (leftGraph != NULL && rightGraph == NULL) {
177                         leftSize=convertSize(leftGraph->encodingSize);
178                         rightSize=convertSize(right->getSize());
179                         newSize=convertSize(leftGraph->estimateNewSize(right));
180                         newSize=(leftSize > newSize) ? leftSize: newSize;
181                         newSize=(rightSize > newSize) ? rightSize: newSize;
182                         totalCost = (newSize - leftSize) * leftGraph->numElements +
183                                 (newSize - rightSize) * right->elements.getSize();
184                 } else {
185                         //Neither are null
186                         leftSize=convertSize(leftGraph->encodingSize);
187                         rightSize=convertSize(rightGraph->encodingSize);
188                         newSize=convertSize(leftGraph->estimateNewSize(rightGraph));
189                         newSize=(leftSize > newSize) ? leftSize: newSize;
190                         newSize=(rightSize > newSize) ? rightSize: newSize;
191                         totalCost = (newSize - leftSize) * leftGraph->numElements +
192                                 (newSize - rightSize) * rightGraph->numElements;
193                 }
194                 double conversionfactor = 0.5;
195                 if ((totalCost * conversionfactor) < eeValue) {
196                         //add the edge
197                         mergeNodes(left, right);
198                 }
199         }
200 }
201
202 static TunableDesc EdgeEncodingDesc(EDGE_UNASSIGNED, EDGE_MATCH, EDGE_UNASSIGNED);
203
204 EncodingEdge * EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
205         EncodingEdge e(left, right, dst);
206         EncodingEdge *result = edgeMap.get(&e);
207         if (result == NULL) {
208                 result=new EncodingEdge(left, right, dst);
209                 VarType v1=left->getType();
210                 VarType v2=right->getType();
211                 if (v1 > v2) {
212                         VarType tmp=v2;
213                         v2=v1;
214                         v1=tmp;
215                 }
216
217                 if ((left != NULL && left->encoding==BINARYINDEX) &&
218                                 (right != NULL) && right->encoding==BINARYINDEX) {
219                         EdgeEncodingType type=(EdgeEncodingType)solver->getTuner()->getVarTunable(v1, v2, EDGEENCODING, &EdgeEncodingDesc);
220                         result->setEncoding(type);
221                         if (type == EDGE_MATCH) {
222                                 mergeNodes(left, right);
223                         }
224                 }
225                 edgeMap.put(result, result);
226                 edgeVector.push(result);
227                 if (left != NULL)
228                         left->edges.add(result);
229                 if (right != NULL)
230                         right->edges.add(result);
231                 if (dst != NULL)
232                         dst->edges.add(result);
233         }
234         return result;
235 }
236
237 EncodingNode::EncodingNode(Set *_s) :
238         s(_s) {
239 }
240
241 uint EncodingNode::getSize() const {
242         return s->getSize();
243 }
244
245 VarType EncodingNode::getType() const {
246         return s->getType();
247 }
248
249 static TunableDesc NodeEncodingDesc(ELEM_UNASSIGNED, BINARYINDEX, ELEM_UNASSIGNED);
250
251 EncodingNode * EncodingGraph::createNode(Element *e) {
252         if (e->type == ELEMCONST)
253                 return NULL;
254         Set *s = e->getRange();
255         EncodingNode *n = encodingMap.get(s);
256         if (n == NULL) {
257                 n = new EncodingNode(s);
258                 n->setEncoding((ElementEncodingType)solver->getTuner()->getVarTunable(n->getType(), NODEENCODING, &NodeEncodingDesc));
259                 encodingMap.put(s, n);
260         }
261         n->addElement(e);
262         return n;
263 }
264
265 void EncodingNode::addElement(Element *e) {
266         elements.add(e);
267 }
268
269 EncodingEdge::EncodingEdge(EncodingNode *_l, EncodingNode *_r) :
270         left(_l),
271         right(_r),
272         dst(NULL),
273         encoding(EDGE_UNASSIGNED),
274         numArithOps(0),
275         numEquals(0),
276         numComparisons(0)
277 {
278 }
279
280 EncodingEdge::EncodingEdge(EncodingNode *_left, EncodingNode *_right, EncodingNode *_dst) :
281         left(_left),
282         right(_right),
283         dst(_dst),
284         encoding(EDGE_UNASSIGNED),
285         numArithOps(0),
286         numEquals(0),
287         numComparisons(0)
288 {
289 }
290
291 uint hashEncodingEdge(EncodingEdge *edge) {
292         uintptr_t hash=(((uintptr_t) edge->left) >> 2) ^ (((uintptr_t)edge->right) >> 4) ^ (((uintptr_t)edge->dst) >> 6);
293         return (uint) hash;
294 }
295
296 bool equalsEncodingEdge(EncodingEdge *e1, EncodingEdge *e2) {
297         return e1->left == e2->left && e1->right == e2->right && e1->dst == e2->dst;
298 }
299
300 uint64_t EncodingEdge::getValue() const {
301         uint lSize = (left != NULL) ? left->getSize() : 1;
302         uint rSize = (right != NULL) ? right->getSize() : 1;
303         uint min = (lSize < rSize) ? lSize : rSize;
304         return numEquals * min + numComparisons * lSize * rSize;
305 }
306
307