Bug fixes
[satune.git] / src / ASTAnalyses / orderanalysis.cc
1 #include "orderanalysis.h"
2 #include "structs.h"
3 #include "csolver.h"
4 #include "boolean.h"
5 #include "ordergraph.h"
6 #include "order.h"
7 #include "ordernode.h"
8 #include "rewriter.h"
9 #include "mutableset.h"
10 #include "tunable.h"
11
12 void DFS(OrderGraph *graph, Vector<OrderNode *> *finishNodes) {
13         HSIteratorOrderNode *iterator = graph->getNodes();
14         while (iterator->hasNext()) {
15                 OrderNode *node = iterator->next();
16                 if (node->status == NOTVISITED) {
17                         node->status = VISITED;
18                         DFSNodeVisit(node, finishNodes, false, false, 0);
19                         node->status = FINISHED;
20                         finishNodes->push(node);
21                 }
22         }
23         delete iterator;
24 }
25
26 void DFSReverse(OrderGraph *graph, Vector<OrderNode *> *finishNodes) {
27         uint size = finishNodes->getSize();
28         uint sccNum = 1;
29         for (int i = size - 1; i >= 0; i--) {
30                 OrderNode *node = finishNodes->get(i);
31                 if (node->status == NOTVISITED) {
32                         node->status = VISITED;
33                         DFSNodeVisit(node, NULL, true, false, sccNum);
34                         node->sccNum = sccNum;
35                         node->status = FINISHED;
36                         sccNum++;
37                 }
38         }
39 }
40
41 void DFSNodeVisit(OrderNode *node, Vector<OrderNode *> *finishNodes, bool isReverse, bool mustvisit, uint sccNum) {
42         HSIteratorOrderEdge *iterator = isReverse ? node->inEdges.iterator() : node->outEdges.iterator();
43         while (iterator->hasNext()) {
44                 OrderEdge *edge = iterator->next();
45                 if (mustvisit) {
46                         if (!edge->mustPos)
47                                 continue;
48                 } else
49                 if (!edge->polPos && !edge->pseudoPos)  //Ignore edges that do not have positive polarity
50                         continue;
51
52                 OrderNode *child = isReverse ? edge->source : edge->sink;
53
54                 if (child->status == NOTVISITED) {
55                         child->status = VISITED;
56                         DFSNodeVisit(child, finishNodes, isReverse, mustvisit, sccNum);
57                         child->status = FINISHED;
58                         if (finishNodes != NULL)
59                                 finishNodes->push(child);
60                         if (isReverse)
61                                 child->sccNum = sccNum;
62                 }
63         }
64         delete iterator;
65 }
66
67 void resetNodeInfoStatusSCC(OrderGraph *graph) {
68         HSIteratorOrderNode *iterator = graph->getNodes();
69         while (iterator->hasNext()) {
70                 iterator->next()->status = NOTVISITED;
71         }
72         delete iterator;
73 }
74
75 void computeStronglyConnectedComponentGraph(OrderGraph *graph) {
76         Vector<OrderNode *> finishNodes;
77         DFS(graph, &finishNodes);
78         resetNodeInfoStatusSCC(graph);
79         DFSReverse(graph, &finishNodes);
80         resetNodeInfoStatusSCC(graph);
81 }
82
83 bool isMustBeTrueNode(OrderNode *node) {
84         HSIteratorOrderEdge *iterator = node->inEdges.iterator();
85         while (iterator->hasNext()) {
86                 OrderEdge *edge = iterator->next();
87                 if (!edge->mustPos) {
88                         delete iterator;
89                         return false;
90                 }
91         }
92         delete iterator;
93         iterator = node->outEdges.iterator();
94         while (iterator->hasNext()) {
95                 OrderEdge *edge = iterator->next();
96                 if (!edge->mustPos) {
97                         delete iterator;
98                         return false;
99                 }
100         }
101         delete iterator;
102         return true;
103 }
104
105 void bypassMustBeTrueNode(CSolver *This, OrderGraph *graph, OrderNode *node) {
106         HSIteratorOrderEdge *iterin = node->inEdges.iterator();
107         while (iterin->hasNext()) {
108                 OrderEdge *inEdge = iterin->next();
109                 OrderNode *srcNode = inEdge->source;
110                 srcNode->outEdges.remove(inEdge);
111                 HSIteratorOrderEdge *iterout = node->outEdges.iterator();
112                 while (iterout->hasNext()) {
113                         OrderEdge *outEdge = iterout->next();
114                         OrderNode *sinkNode = outEdge->sink;
115                         sinkNode->inEdges.remove(outEdge);
116                         //Adding new edge to new sink and src nodes ...
117                         OrderEdge *newEdge = graph->getOrderEdgeFromOrderGraph(srcNode, sinkNode);
118                         newEdge->mustPos = true;
119                         newEdge->polPos = true;
120                         if (newEdge->mustNeg)
121                                 This->setUnSAT();
122                         srcNode->outEdges.add(newEdge);
123                         sinkNode->inEdges.add(newEdge);
124                 }
125                 delete iterout;
126         }
127         delete iterin;
128 }
129
130 void removeMustBeTrueNodes(CSolver *This, OrderGraph *graph) {
131         HSIteratorOrderNode *iterator = graph->getNodes();
132         while (iterator->hasNext()) {
133                 OrderNode *node = iterator->next();
134                 if (isMustBeTrueNode(node)) {
135                         bypassMustBeTrueNode(This, graph, node);
136                 }
137         }
138         delete iterator;
139 }
140
141 /** This function computes a source set for every nodes, the set of
142     nodes that can reach that node via pospolarity edges.  It then
143     looks for negative polarity edges from nodes in the the source set
144     to determine whether we need to generate pseudoPos edges. */
145
146 void completePartialOrderGraph(OrderGraph *graph) {
147         Vector<OrderNode *> finishNodes;
148         DFS(graph, &finishNodes);
149         resetNodeInfoStatusSCC(graph);
150         HashTableNodeToNodeSet *table = new HashTableNodeToNodeSet(128, 0.25);
151
152         Vector<OrderNode *> sccNodes;
153
154         uint size = finishNodes.getSize();
155         uint sccNum = 1;
156         for (int i = size - 1; i >= 0; i--) {
157                 OrderNode *node = finishNodes.get(i);
158                 HashSetOrderNode *sources = new HashSetOrderNode(4, 0.25);
159                 table->put(node, sources);
160
161                 if (node->status == NOTVISITED) {
162                         //Need to do reverse traversal here...
163                         node->status = VISITED;
164                         DFSNodeVisit(node, &sccNodes, true, false, sccNum);
165                         node->status = FINISHED;
166                         node->sccNum = sccNum;
167                         sccNum++;
168                         sccNodes.push(node);
169
170                         //Compute in set for entire SCC
171                         uint rSize = sccNodes.getSize();
172                         for (uint j = 0; j < rSize; j++) {
173                                 OrderNode *rnode = sccNodes.get(j);
174                                 //Compute source sets
175                                 HSIteratorOrderEdge *iterator = rnode->inEdges.iterator();
176                                 while (iterator->hasNext()) {
177                                         OrderEdge *edge = iterator->next();
178                                         OrderNode *parent = edge->source;
179                                         if (edge->polPos) {
180                                                 sources->add(parent);
181                                                 HashSetOrderNode *parent_srcs = (HashSetOrderNode *)table->get(parent);
182                                                 sources->addAll(parent_srcs);
183                                         }
184                                 }
185                                 delete iterator;
186                         }
187                         for (uint j = 0; j < rSize; j++) {
188                                 //Copy in set of entire SCC
189                                 OrderNode *rnode = sccNodes.get(j);
190                                 HashSetOrderNode *set = (j == 0) ? sources : sources->copy();
191                                 table->put(rnode, set);
192
193                                 //Use source sets to compute pseudoPos edges
194                                 HSIteratorOrderEdge *iterator = node->inEdges.iterator();
195                                 while (iterator->hasNext()) {
196                                         OrderEdge *edge = iterator->next();
197                                         OrderNode *parent = edge->source;
198                                         ASSERT(parent != rnode);
199                                         if (edge->polNeg && parent->sccNum != rnode->sccNum &&
200                                                         sources->contains(parent)) {
201                                                 OrderEdge *newedge = graph->getOrderEdgeFromOrderGraph(rnode, parent);
202                                                 newedge->pseudoPos = true;
203                                         }
204                                 }
205                                 delete iterator;
206                         }
207
208                         sccNodes.clear();
209                 }
210         }
211
212         table->resetanddelete();
213         delete table;
214         resetNodeInfoStatusSCC(graph);
215 }
216
217 void DFSMust(OrderGraph *graph, Vector<OrderNode *> *finishNodes) {
218         HSIteratorOrderNode *iterator = graph->getNodes();
219         while (iterator->hasNext()) {
220                 OrderNode *node = iterator->next();
221                 if (node->status == NOTVISITED) {
222                         node->status = VISITED;
223                         DFSNodeVisit(node, finishNodes, false, true, 0);
224                         node->status = FINISHED;
225                         finishNodes->push(node);
226                 }
227         }
228         delete iterator;
229 }
230
231 void DFSClearContradictions(CSolver *solver, OrderGraph *graph, Vector<OrderNode *> *finishNodes, bool computeTransitiveClosure) {
232         uint size = finishNodes->getSize();
233         HashTableNodeToNodeSet *table = new HashTableNodeToNodeSet(128, 0.25);
234
235         for (int i = size - 1; i >= 0; i--) {
236                 OrderNode *node = finishNodes->get(i);
237                 HashSetOrderNode *sources = new HashSetOrderNode(4, 0.25);
238                 table->put(node, sources);
239
240                 {
241                         //Compute source sets
242                         HSIteratorOrderEdge *iterator = node->inEdges.iterator();
243                         while (iterator->hasNext()) {
244                                 OrderEdge *edge = iterator->next();
245                                 OrderNode *parent = edge->source;
246                                 if (edge->mustPos) {
247                                         sources->add(parent);
248                                         HashSetOrderNode *parent_srcs = (HashSetOrderNode *) table->get(parent);
249                                         sources->addAll(parent_srcs);
250                                 }
251                         }
252                         delete iterator;
253                 }
254                 if (computeTransitiveClosure) {
255                         //Compute full transitive closure for nodes
256                         HSIteratorOrderNode *srciterator = sources->iterator();
257                         while (srciterator->hasNext()) {
258                                 OrderNode *srcnode = srciterator->next();
259                                 OrderEdge *newedge = graph->getOrderEdgeFromOrderGraph(srcnode, node);
260                                 newedge->mustPos = true;
261                                 newedge->polPos = true;
262                                 if (newedge->mustNeg)
263                                         solver->setUnSAT();
264                                 srcnode->outEdges.add(newedge);
265                                 node->inEdges.add(newedge);
266                         }
267                         delete srciterator;
268                 }
269                 {
270                         //Use source sets to compute mustPos edges
271                         HSIteratorOrderEdge *iterator = node->inEdges.iterator();
272                         while (iterator->hasNext()) {
273                                 OrderEdge *edge = iterator->next();
274                                 OrderNode *parent = edge->source;
275                                 if (!edge->mustPos && sources->contains(parent)) {
276                                         edge->mustPos = true;
277                                         edge->polPos = true;
278                                         if (edge->mustNeg)
279                                                 solver->setUnSAT();
280                                 }
281                         }
282                         delete iterator;
283                 }
284                 {
285                         //Use source sets to compute mustNeg for edges that would introduce cycle if true
286                         HSIteratorOrderEdge *iterator = node->outEdges.iterator();
287                         while (iterator->hasNext()) {
288                                 OrderEdge *edge = iterator->next();
289                                 OrderNode *child = edge->sink;
290                                 if (!edge->mustNeg && sources->contains(child)) {
291                                         edge->mustNeg = true;
292                                         edge->polNeg = true;
293                                         if (edge->mustPos)
294                                                 solver->setUnSAT();
295                                 }
296                         }
297                         delete iterator;
298                 }
299         }
300
301         table->resetanddelete();
302         delete table;
303 }
304
305 /* This function finds edges that would form a cycle with must edges
306    and forces them to be mustNeg.  It also decides whether an edge
307    must be true because of transitivity from other must be true
308    edges. */
309
310 void reachMustAnalysis(CSolver *solver, OrderGraph *graph, bool computeTransitiveClosure) {
311         Vector<OrderNode *> finishNodes;
312         //Topologically sort the mustPos edge graph
313         DFSMust(graph, &finishNodes);
314         resetNodeInfoStatusSCC(graph);
315
316         //Find any backwards edges that complete cycles and force them to be mustNeg
317         DFSClearContradictions(solver, graph, &finishNodes, computeTransitiveClosure);
318 }
319
320 /* This function finds edges that must be positive and forces the
321    inverse edge to be negative (and clears its positive polarity if it
322    had one). */
323
324 void localMustAnalysisTotal(CSolver *solver, OrderGraph *graph) {
325         HSIteratorOrderEdge *iterator = graph->getEdges();
326         while (iterator->hasNext()) {
327                 OrderEdge *edge = iterator->next();
328                 if (edge->mustPos) {
329                         OrderEdge *invEdge = graph->getInverseOrderEdge(edge);
330                         if (invEdge != NULL) {
331                                 if (!invEdge->mustPos) {
332                                         invEdge->polPos = false;
333                                 } else {
334                                         solver->setUnSAT();
335                                 }
336                                 invEdge->mustNeg = true;
337                                 invEdge->polNeg = true;
338                         }
339                 }
340         }
341         delete iterator;
342 }
343
344 /** This finds edges that must be positive and forces the inverse edge
345     to be negative.  It also clears the negative flag of this edge.
346     It also finds edges that must be negative and clears the positive
347     polarity. */
348
349 void localMustAnalysisPartial(CSolver *solver, OrderGraph *graph) {
350         HSIteratorOrderEdge *iterator = graph->getEdges();
351         while (iterator->hasNext()) {
352                 OrderEdge *edge = iterator->next();
353                 if (edge->mustPos) {
354                         if (!edge->mustNeg) {
355                                 edge->polNeg = false;
356                         } else
357                                 solver->setUnSAT();
358
359                         OrderEdge *invEdge = graph->getInverseOrderEdge(edge);
360                         if (invEdge != NULL) {
361                                 if (!invEdge->mustPos)
362                                         invEdge->polPos = false;
363                                 else
364                                         solver->setUnSAT();
365
366                                 invEdge->mustNeg = true;
367                                 invEdge->polNeg = true;
368                         }
369                 }
370                 if (edge->mustNeg && !edge->mustPos) {
371                         edge->polPos = false;
372                 }
373         }
374         delete iterator;
375 }