Adding ASTTransform ...
[satune.git] / src / ASTAnalyses / orderencoder.cc
1 #include "orderencoder.h"
2 #include "structs.h"
3 #include "csolver.h"
4 #include "boolean.h"
5 #include "ordergraph.h"
6 #include "order.h"
7 #include "ordernode.h"
8 #include "rewriter.h"
9 #include "mutableset.h"
10 #include "tunable.h"
11
12 void DFS(OrderGraph *graph, Vector<OrderNode *> *finishNodes) {
13         HSIteratorOrderNode *iterator = graph->nodes->iterator();
14         while (iterator->hasNext()) {
15                 OrderNode *node = iterator->next();
16                 if (node->status == NOTVISITED) {
17                         node->status = VISITED;
18                         DFSNodeVisit(node, finishNodes, false, false, 0);
19                         node->status = FINISHED;
20                         finishNodes->push(node);
21                 }
22         }
23         delete iterator;
24 }
25
26 void DFSReverse(OrderGraph *graph, Vector<OrderNode *> *finishNodes) {
27         uint size = finishNodes->getSize();
28         uint sccNum = 1;
29         for (int i = size - 1; i >= 0; i--) {
30                 OrderNode *node = finishNodes->get(i);
31                 if (node->status == NOTVISITED) {
32                         node->status = VISITED;
33                         DFSNodeVisit(node, NULL, true, false, sccNum);
34                         node->sccNum = sccNum;
35                         node->status = FINISHED;
36                         sccNum++;
37                 }
38         }
39 }
40
41 void DFSNodeVisit(OrderNode *node, Vector<OrderNode *> *finishNodes, bool isReverse, bool mustvisit, uint sccNum) {
42         HSIteratorOrderEdge *iterator = isReverse ? node->inEdges->iterator() : node->outEdges->iterator();
43         while (iterator->hasNext()) {
44                 OrderEdge *edge = iterator->next();
45                 if (mustvisit) {
46                         if (!edge->mustPos)
47                                 continue;
48                 } else
49                         if (!edge->polPos && !edge->pseudoPos)//Ignore edges that do not have positive polarity
50                                 continue;
51
52                 OrderNode *child = isReverse ? edge->source : edge->sink;
53
54                 if (child->status == NOTVISITED) {
55                         child->status = VISITED;
56                         DFSNodeVisit(child, finishNodes, isReverse, mustvisit, sccNum);
57                         child->status = FINISHED;
58                         if (finishNodes != NULL)
59                                 finishNodes->push(child);
60                         if (isReverse)
61                                 child->sccNum = sccNum;
62                 }
63         }
64         delete iterator;
65 }
66
67 void resetNodeInfoStatusSCC(OrderGraph *graph) {
68         HSIteratorOrderNode *iterator = graph->nodes->iterator();
69         while (iterator->hasNext()) {
70                 iterator->next()->status = NOTVISITED;
71         }
72         delete iterator;
73 }
74
75 void computeStronglyConnectedComponentGraph(OrderGraph *graph) {
76         Vector<OrderNode *> finishNodes;
77         DFS(graph, &finishNodes);
78         resetNodeInfoStatusSCC(graph);
79         DFSReverse(graph, &finishNodes);
80         resetNodeInfoStatusSCC(graph);
81 }
82
83 bool isMustBeTrueNode(OrderNode* node){
84         HSIteratorOrderEdge* iterator = node->inEdges->iterator();
85         while(iterator->hasNext()){
86                 OrderEdge* edge = iterator->next();
87                 if(!edge->mustPos)
88                         return false;
89         }
90         delete iterator;
91         iterator = node->outEdges->iterator();
92         while(iterator->hasNext()){
93                 OrderEdge* edge = iterator->next();
94                 if(!edge->mustPos)
95                         return false;
96         }
97         delete iterator;
98         return true;
99 }
100
101 void bypassMustBeTrueNode(CSolver *This, OrderGraph* graph, OrderNode* node){
102         HSIteratorOrderEdge* iterin = node->inEdges->iterator();
103         while(iterin->hasNext()){
104                 OrderEdge* inEdge = iterin->next();
105                 OrderNode* srcNode = inEdge->source;
106                 srcNode->outEdges->remove(inEdge);
107                 HSIteratorOrderEdge* iterout = node->outEdges->iterator();
108                 while(iterout->hasNext()){
109                         OrderEdge* outEdge = iterout->next();
110                         OrderNode* sinkNode = outEdge->sink;
111                         sinkNode->inEdges->remove(outEdge);
112                         //Adding new edge to new sink and src nodes ...
113                         OrderEdge *newEdge =getOrderEdgeFromOrderGraph(graph, srcNode, sinkNode);
114                         newEdge->mustPos = true;
115                         newEdge->polPos = true;
116                         if (newEdge->mustNeg)
117                                 This->unsat = true;
118                         srcNode->outEdges->add(newEdge);
119                         sinkNode->inEdges->add(newEdge);
120                 }
121                 delete iterout;
122         }
123         delete iterin;
124 }
125
126 void removeMustBeTrueNodes(CSolver *This, OrderGraph *graph) {
127         HSIteratorOrderNode* iterator = graph->nodes->iterator();
128         while(iterator->hasNext()) {
129                 OrderNode* node = iterator->next();
130                 if(isMustBeTrueNode(node)){
131                         bypassMustBeTrueNode(This,graph, node);
132                 }
133         }
134         delete iterator;
135 }
136
137 /** This function computes a source set for every nodes, the set of
138                 nodes that can reach that node via pospolarity edges.  It then
139                 looks for negative polarity edges from nodes in the the source set
140                 to determine whether we need to generate pseudoPos edges. */
141
142 void completePartialOrderGraph(OrderGraph *graph) {
143         Vector<OrderNode *> finishNodes;
144         DFS(graph, &finishNodes);
145         resetNodeInfoStatusSCC(graph);
146         HashTableNodeToNodeSet *table = new HashTableNodeToNodeSet(128, 0.25);
147
148         Vector<OrderNode *> sccNodes;
149         
150         uint size = finishNodes.getSize();
151         uint sccNum = 1;
152         for (int i = size - 1; i >= 0; i--) {
153                 OrderNode *node = finishNodes.get(i);
154                 HashSetOrderNode *sources = new HashSetOrderNode(4, 0.25);
155                 table->put(node, sources);
156                 
157                 if (node->status == NOTVISITED) {
158                         //Need to do reverse traversal here...
159                         node->status = VISITED;
160                         DFSNodeVisit(node, &sccNodes, true, false, sccNum);
161                         node->status = FINISHED;
162                         node->sccNum = sccNum;
163                         sccNum++;
164                         sccNodes.push(node);
165
166                         //Compute in set for entire SCC
167                         uint rSize = sccNodes.getSize();
168                         for (uint j = 0; j < rSize; j++) {
169                                 OrderNode *rnode = sccNodes.get(j);
170                                 //Compute source sets
171                                 HSIteratorOrderEdge *iterator = rnode->inEdges->iterator();
172                                 while (iterator->hasNext()) {
173                                         OrderEdge *edge = iterator->next();
174                                         OrderNode *parent = edge->source;
175                                         if (edge->polPos) {
176                                                 sources->add(parent);
177                                                 HashSetOrderNode *parent_srcs = (HashSetOrderNode *)table->get(parent);
178                                                 sources->addAll(parent_srcs);
179                                         }
180                                 }
181                                 delete iterator;
182                         }
183                         for (uint j=0; j < rSize; j++) {
184                                 //Copy in set of entire SCC
185                                 OrderNode *rnode = sccNodes.get(j);
186                                 HashSetOrderNode * set = (j==0) ? sources : sources->copy();
187                                 table->put(rnode, set);
188
189                                 //Use source sets to compute pseudoPos edges
190                                 HSIteratorOrderEdge *iterator = node->inEdges->iterator();
191                                 while (iterator->hasNext()) {
192                                         OrderEdge *edge = iterator->next();
193                                         OrderNode *parent = edge->source;
194                                         ASSERT(parent != rnode);
195                                         if (edge->polNeg && parent->sccNum != rnode->sccNum &&
196                                                         sources->contains(parent)) {
197                                                 OrderEdge *newedge = getOrderEdgeFromOrderGraph(graph, rnode, parent);
198                                                 newedge->pseudoPos = true;
199                                         }
200                                 }
201                                 delete iterator;
202                         }
203                         
204                         sccNodes.clear();
205                 }
206         }
207
208         table->resetanddelete();
209         delete table;
210         resetNodeInfoStatusSCC(graph);
211 }
212
213 void DFSMust(OrderGraph *graph, Vector<OrderNode *> *finishNodes) {
214         HSIteratorOrderNode *iterator = graph->nodes->iterator();
215         while (iterator->hasNext()) {
216                 OrderNode *node = iterator->next();
217                 if (node->status == NOTVISITED) {
218                         node->status = VISITED;
219                         DFSNodeVisit(node, finishNodes, false, true, 0);
220                         node->status = FINISHED;
221                         finishNodes->push(node);
222                 }
223         }
224         delete iterator;
225 }
226
227 void DFSClearContradictions(CSolver *solver, OrderGraph *graph, Vector<OrderNode *> *finishNodes, bool computeTransitiveClosure) {
228         uint size = finishNodes->getSize();
229         HashTableNodeToNodeSet *table = new HashTableNodeToNodeSet(128, 0.25);
230
231         for (int i = size - 1; i >= 0; i--) {
232                 OrderNode *node = finishNodes->get(i);
233                 HashSetOrderNode *sources = new HashSetOrderNode(4, 0.25);
234                 table->put(node, sources);
235
236                 {
237                         //Compute source sets
238                         HSIteratorOrderEdge *iterator = node->inEdges->iterator();
239                         while (iterator->hasNext()) {
240                                 OrderEdge *edge = iterator->next();
241                                 OrderNode *parent = edge->source;
242                                 if (edge->mustPos) {
243                                         sources->add(parent);
244                                         HashSetOrderNode *parent_srcs = (HashSetOrderNode *) table->get(parent);
245                                         sources->addAll(parent_srcs);
246                                 }
247                         }
248                         delete iterator;
249                 }
250                 if (computeTransitiveClosure) {
251                         //Compute full transitive closure for nodes
252                         HSIteratorOrderNode *srciterator = sources->iterator();
253                         while (srciterator->hasNext()) {
254                                 OrderNode *srcnode = srciterator->next();
255                                 OrderEdge *newedge = getOrderEdgeFromOrderGraph(graph, srcnode, node);
256                                 newedge->mustPos = true;
257                                 newedge->polPos = true;
258                                 if (newedge->mustNeg)
259                                         solver->unsat = true;
260                                 srcnode->outEdges->add(newedge);
261                                 node->inEdges->add(newedge);
262                         }
263                         delete srciterator;
264                 }
265                 {
266                         //Use source sets to compute mustPos edges
267                         HSIteratorOrderEdge *iterator =node->inEdges->iterator();
268                         while (iterator->hasNext()) {
269                                 OrderEdge *edge = iterator->next();
270                                 OrderNode *parent = edge->source;
271                                 if (!edge->mustPos && sources->contains(parent)) {
272                                         edge->mustPos = true;
273                                         edge->polPos = true;
274                                         if (edge->mustNeg)
275                                                 solver->unsat = true;
276                                 }
277                         }
278                         delete iterator;
279                 }
280                 {
281                         //Use source sets to compute mustNeg for edges that would introduce cycle if true
282                         HSIteratorOrderEdge *iterator = node->outEdges->iterator();
283                         while (iterator->hasNext()) {
284                                 OrderEdge *edge = iterator->next();
285                                 OrderNode *child = edge->sink;
286                                 if (!edge->mustNeg && sources->contains(child)) {
287                                         edge->mustNeg = true;
288                                         edge->polNeg = true;
289                                         if (edge->mustPos)
290                                                 solver->unsat = true;
291                                 }
292                         }
293                         delete iterator;
294                 }
295         }
296
297         table->resetanddelete();
298         delete table;
299 }
300
301 /* This function finds edges that would form a cycle with must edges
302    and forces them to be mustNeg.  It also decides whether an edge
303    must be true because of transitivity from other must be true
304    edges. */
305
306 void reachMustAnalysis(CSolver * solver, OrderGraph *graph, bool computeTransitiveClosure) {
307         Vector<OrderNode *> finishNodes;
308         //Topologically sort the mustPos edge graph
309         DFSMust(graph, &finishNodes);
310         resetNodeInfoStatusSCC(graph);
311
312         //Find any backwards edges that complete cycles and force them to be mustNeg
313         DFSClearContradictions(solver, graph, &finishNodes, computeTransitiveClosure);
314 }
315
316 /* This function finds edges that must be positive and forces the
317    inverse edge to be negative (and clears its positive polarity if it
318    had one). */
319
320 void localMustAnalysisTotal(CSolver *solver, OrderGraph *graph) {
321         HSIteratorOrderEdge *iterator = graph->edges->iterator();
322         while (iterator->hasNext()) {
323                 OrderEdge *edge = iterator->next();
324                 if (edge->mustPos) {
325                         OrderEdge *invEdge = getInverseOrderEdge(graph, edge);
326                         if (invEdge != NULL) {
327                                 if (!invEdge->mustPos) {
328                                         invEdge->polPos = false;
329                                 } else {
330                                         solver->unsat = true;
331                                 }
332                                 invEdge->mustNeg = true;
333                                 invEdge->polNeg = true;
334                         }
335                 }
336         }
337         delete iterator;
338 }
339
340 /** This finds edges that must be positive and forces the inverse edge
341     to be negative.  It also clears the negative flag of this edge.
342     It also finds edges that must be negative and clears the positive
343     polarity. */
344
345 void localMustAnalysisPartial(CSolver *solver, OrderGraph *graph) {
346         HSIteratorOrderEdge *iterator = graph->edges->iterator();
347         while (iterator->hasNext()) {
348                 OrderEdge *edge = iterator->next();
349                 if (edge->mustPos) {
350                         if (!edge->mustNeg) {
351                                 edge->polNeg = false;
352                         } else
353                                 solver->unsat = true;
354
355                         OrderEdge *invEdge = getInverseOrderEdge(graph, edge);
356                         if (invEdge != NULL) {
357                                 if (!invEdge->mustPos)
358                                         invEdge->polPos = false;
359                                 else
360                                         solver->unsat = true;
361                                 invEdge->mustNeg = true;
362                                 invEdge->polNeg = true;
363                         }
364                 }
365                 if (edge->mustNeg && !edge->mustPos) {
366                         edge->polPos = false;
367                 }
368         }
369         delete iterator;
370 }
371
372 void orderAnalysis(CSolver *This) {
373         uint size = This->allOrders.getSize();
374         for (uint i = 0; i < size; i++) {
375                 Order *order = This->allOrders.get(i);
376                 
377                 OrderGraph *graph;
378                 if(order->graph == NULL){
379                         graph= buildOrderGraph(order);
380                         if (order->type == PARTIAL) {
381                                 //Required to do SCC analysis for partial order graphs.  It
382                                 //makes sure we don't incorrectly optimize graphs with negative
383                                 //polarity edges
384                                 completePartialOrderGraph(graph);
385                         }
386                 }else
387                         graph = order->graph;
388                 
389                 bool mustReachGlobal=GETVARTUNABLE(This->tuner, order->type, MUSTREACHGLOBAL, &onoff);
390
391                 if (mustReachGlobal)
392                         reachMustAnalysis(This, graph, false);
393
394                 bool mustReachLocal=GETVARTUNABLE(This->tuner, order->type, MUSTREACHLOCAL, &onoff);
395                 
396                 if (mustReachLocal) {
397                         //This pair of analysis is also optional
398                         if (order->type == PARTIAL) {
399                                 localMustAnalysisPartial(This, graph);
400                         } else {
401                                 localMustAnalysisTotal(This, graph);
402                         }
403                 }
404
405                 bool mustReachPrune=GETVARTUNABLE(This->tuner, order->type, MUSTREACHPRUNE, &onoff);
406                 
407                 if (mustReachPrune)
408                         removeMustBeTrueNodes(This, graph);
409                 
410         }
411 }