Fixing header bugs
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
1 #include "encodinggraph.h"
2 #include "iterator.h"
3 #include "element.h"
4 #include "function.h"
5 #include "predicate.h"
6 #include "set.h"
7 #include "csolver.h"
8 #include "tunable.h"
9 #include "boolean.h"
10
11 EncodingGraph::EncodingGraph(CSolver * _solver) :
12         solver(_solver) {
13         
14
15 }
16
17 void EncodingGraph::buildGraph() {
18         ElementIterator it(solver);
19         while(it.hasNext()) {
20                 Element * e = it.next();
21                 switch(e->type) {
22                 case ELEMSET:
23                 case ELEMFUNCRETURN:
24                         processElement(e);
25                         break;
26                 case ELEMCONST:
27                         break;
28                 default:
29                         ASSERT(0);
30                 }
31         }
32 }
33
34 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
35         EncodingSubGraph *graph1=graphMap.get(first);
36         EncodingSubGraph *graph2=graphMap.get(second);
37         if (graph1 == NULL && graph2 == NULL) {
38                 graph1 = new EncodingSubGraph();
39                 graphMap.put(first, graph1);
40                 graph1->addNode(first);
41         }
42         if (graph1 == NULL && graph2 != NULL) {
43                 graph1 = graph2;
44                 graph2 = NULL;
45                 EncodingNode *tmp = second;
46                 second = first;
47                 first = tmp;
48         }
49         if (graph1 != NULL && graph2 != NULL) {
50                 SetIteratorEncodingNode * nodeit=graph2->nodeIterator();
51                 while(nodeit->hasNext()) {
52                         EncodingNode *node=nodeit->next();
53                         graph1->addNode(node);
54                         graphMap.put(node, graph1);
55                 }
56                 delete nodeit;
57                 delete graph2;
58         } else {
59                 ASSERT(graph1 != NULL && graph2 == NULL);
60                 graph1->addNode(second);
61                 graphMap.put(second, graph1);
62         }
63 }
64
65 void EncodingGraph::processElement(Element *e) {
66         uint size=e->parents.getSize();
67         for(uint i=0;i<size;i++) {
68                 ASTNode * n = e->parents.get(i);
69                 switch(n->type) {
70                 case PREDICATEOP:
71                         processPredicate((BooleanPredicate *)n);
72                         break;
73                 case ELEMFUNCRETURN:
74                         processFunction((ElementFunction *)n);
75                         break;
76                 default:
77                         ASSERT(0);
78                 }
79         }
80 }
81
82 void EncodingGraph::processFunction(ElementFunction *ef) {
83         Function *f=ef->getFunction();
84         if (f->type==OPERATORFUNC) {
85                 FunctionOperator *fo=(FunctionOperator*)f;
86                 ASSERT(ef->inputs.getSize() == 2);
87                 EncodingNode *left=createNode(ef->inputs.get(0));
88                 EncodingNode *right=createNode(ef->inputs.get(1));
89                 if (left == NULL && right == NULL)
90                         return;
91                 EncodingNode *dst=createNode(ef);
92                 EncodingEdge *edge=getEdge(left, right, dst);
93                 edge->numArithOps++;
94         }
95 }
96
97 void EncodingGraph::processPredicate(BooleanPredicate *b) {
98         Predicate *p=b->getPredicate();
99         if (p->type==OPERATORPRED) {
100                 PredicateOperator *po=(PredicateOperator *)p;
101                 ASSERT(b->inputs.getSize()==2);
102                 EncodingNode *left=createNode(b->inputs.get(0));
103                 EncodingNode *right=createNode(b->inputs.get(1));
104                 if (left == NULL || right == NULL)
105                         return;
106                 EncodingEdge *edge=getEdge(left, right, NULL);
107                 CompOp op=po->getOp();
108                 switch(op) {
109                 case SATC_EQUALS:
110                         edge->numEquals++;
111                         break;
112                 case SATC_LT:
113                 case SATC_LTE:
114                 case SATC_GT:
115                 case SATC_GTE:
116                         edge->numComparisons++;
117                         break;
118                 default:
119                         ASSERT(0);
120                 }
121         }
122 }
123
124 static TunableDesc EdgeEncodingDesc(EDGE_UNASSIGNED, EDGE_MATCH, EDGE_UNASSIGNED);
125
126 EncodingEdge * EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
127         EncodingEdge e(left, right, dst);
128         EncodingEdge *result = edgeMap.get(&e);
129         if (result == NULL) {
130                 result=new EncodingEdge(left, right, dst);
131                 VarType v1=left->getType();
132                 VarType v2=right->getType();
133                 if (v1 > v2) {
134                         VarType tmp=v2;
135                         v2=v1;
136                         v1=tmp;
137                 }
138                 result->setEncoding((EdgeEncodingType)solver->getTuner()->getVarTunable(v1, v2, EDGEENCODING, &EdgeEncodingDesc));
139                 edgeMap.put(result, result);
140         }
141         return result;
142 }
143
144 EncodingNode::EncodingNode(Set *_s) :
145         s(_s),
146         numElements(0) {
147 }
148
149 uint EncodingNode::getSize() {
150         return s->getSize();
151 }
152
153 VarType EncodingNode::getType() {
154         return s->getType();
155 }
156
157 static TunableDesc NodeEncodingDesc(ELEM_UNASSIGNED, BINARYINDEX, ELEM_UNASSIGNED);
158
159 EncodingNode * EncodingGraph::createNode(Element *e) {
160         if (e->type == ELEMCONST)
161                 return NULL;
162         Set *s = e->getRange();
163         EncodingNode *n = encodingMap.get(s);
164         if (n == NULL) {
165                 n = new EncodingNode(s);
166                 n->setEncoding((ElementEncodingType)solver->getTuner()->getVarTunable(n->getType(), NODEENCODING, &NodeEncodingDesc));
167                 encodingMap.put(s, n);
168         }
169         n->addElement(e);
170         if (discovered.add(e))
171                 n->numElements++;
172         return n;
173 }
174
175 void EncodingNode::addElement(Element *e) {
176         elements.add(e);
177 }
178
179 EncodingEdge::EncodingEdge(EncodingNode *_l, EncodingNode *_r) :
180         left(_l),
181         right(_r),
182         dst(NULL),
183         encoding(EDGE_UNASSIGNED),
184         numArithOps(0),
185         numEquals(0),
186         numComparisons(0)
187 {
188 }
189
190 EncodingEdge::EncodingEdge(EncodingNode *_left, EncodingNode *_right, EncodingNode *_dst) :
191         left(_left),
192         right(_right),
193         dst(_dst),
194         encoding(EDGE_UNASSIGNED),
195         numArithOps(0),
196         numEquals(0),
197         numComparisons(0)
198 {
199 }
200
201 uint hashEncodingEdge(EncodingEdge *edge) {
202         uintptr_t hash=(((uintptr_t) edge->left) >> 2) ^ (((uintptr_t)edge->right) >> 4) ^ (((uintptr_t)edge->dst) >> 6);
203         return (uint) hash;
204 }
205
206 bool equalsEncodingEdge(EncodingEdge *e1, EncodingEdge *e2) {
207         return e1->left == e2->left && e1->right == e2->right && e1->dst == e2->dst;
208 }
209
210 EncodingSubGraph::EncodingSubGraph() {
211 }
212
213 void EncodingSubGraph::addNode(EncodingNode *n) {
214         nodes.add(n);
215         Set *s=n->s;
216         uint size=s->getSize();
217         for(uint i=0; i<size; i++) {
218                 uint64_t val=s->getElement(i);
219                 values.add(val);
220         }
221 }
222
223 SetIteratorEncodingNode * EncodingSubGraph::nodeIterator() {
224         return nodes.iterator();
225 }
226
227 uint EncodingSubGraph::computeIntersection(Set *s) {
228         uint intersect=0;
229         uint size=s->getSize();
230         for(uint i=0; i<size; i++) {
231                 uint64_t val=s->getElement(i);
232                 if (values.contains(val))
233                         intersect++;
234         }
235         return intersect;
236 }
237
238 uint EncodingSubGraph::computeIntersection(EncodingSubGraph *g) {
239         if (g->values.getSize() > values.getSize()) {
240                 //iterator over smaller set
241                 return g->computeIntersection(this);
242         }
243         
244         uint intersect=0;
245         SetIterator64Int * iter=g->values.iterator();
246         while(iter->hasNext()) {
247                 uint64_t val=iter->next();
248                 if (values.contains(val))
249                         intersect++;
250         }
251         delete iter;
252         return intersect;
253 }