Merge branch 'hamed' of ssh://plrg.eecs.uci.edu/home/git/constraint_compiler
[satune.git] / src / Encoders / orderencoder.cc
1 #include "orderencoder.h"
2 #include "structs.h"
3 #include "csolver.h"
4 #include "boolean.h"
5 #include "ordergraph.h"
6 #include "order.h"
7 #include "ordernode.h"
8 #include "rewriter.h"
9 #include "mutableset.h"
10 #include "tunable.h"
11
12 void DFS(OrderGraph *graph, VectorOrderNode *finishNodes) {
13         HSIteratorOrderNode *iterator = iteratorOrderNode(graph->nodes);
14         while (hasNextOrderNode(iterator)) {
15                 OrderNode *node = nextOrderNode(iterator);
16                 if (node->status == NOTVISITED) {
17                         node->status = VISITED;
18                         DFSNodeVisit(node, finishNodes, false, false, 0);
19                         node->status = FINISHED;
20                         pushVectorOrderNode(finishNodes, node);
21                 }
22         }
23         deleteIterOrderNode(iterator);
24 }
25
26 void DFSReverse(OrderGraph *graph, VectorOrderNode *finishNodes) {
27         uint size = getSizeVectorOrderNode(finishNodes);
28         uint sccNum = 1;
29         for (int i = size - 1; i >= 0; i--) {
30                 OrderNode *node = getVectorOrderNode(finishNodes, i);
31                 if (node->status == NOTVISITED) {
32                         node->status = VISITED;
33                         DFSNodeVisit(node, NULL, true, false, sccNum);
34                         node->sccNum = sccNum;
35                         node->status = FINISHED;
36                         sccNum++;
37                 }
38         }
39 }
40
41 void DFSNodeVisit(OrderNode *node, VectorOrderNode *finishNodes, bool isReverse, bool mustvisit, uint sccNum) {
42         HSIteratorOrderEdge *iterator = isReverse ? iteratorOrderEdge(node->inEdges) : iteratorOrderEdge(node->outEdges);
43         while (hasNextOrderEdge(iterator)) {
44                 OrderEdge *edge = nextOrderEdge(iterator);
45                 if (mustvisit) {
46                         if (!edge->mustPos)
47                                 continue;
48                 } else
49                         if (!edge->polPos && !edge->pseudoPos)//Ignore edges that do not have positive polarity
50                                 continue;
51
52                 OrderNode *child = isReverse ? edge->source : edge->sink;
53
54                 if (child->status == NOTVISITED) {
55                         child->status = VISITED;
56                         DFSNodeVisit(child, finishNodes, isReverse, mustvisit, sccNum);
57                         child->status = FINISHED;
58                         if (finishNodes != NULL)
59                                 pushVectorOrderNode(finishNodes, child);
60                         if (isReverse)
61                                 child->sccNum = sccNum;
62                 }
63         }
64         deleteIterOrderEdge(iterator);
65 }
66
67 void resetNodeInfoStatusSCC(OrderGraph *graph) {
68         HSIteratorOrderNode *iterator = iteratorOrderNode(graph->nodes);
69         while (hasNextOrderNode(iterator)) {
70                 nextOrderNode(iterator)->status = NOTVISITED;
71         }
72         deleteIterOrderNode(iterator);
73 }
74
75 void computeStronglyConnectedComponentGraph(OrderGraph *graph) {
76         VectorOrderNode finishNodes;
77         initDefVectorOrderNode(&finishNodes);
78         DFS(graph, &finishNodes);
79         resetNodeInfoStatusSCC(graph);
80         DFSReverse(graph, &finishNodes);
81         resetNodeInfoStatusSCC(graph);
82         deleteVectorArrayOrderNode(&finishNodes);
83 }
84
85 bool isMustBeTrueNode(OrderNode* node){
86         HSIteratorOrderEdge* iterator = iteratorOrderEdge(node->inEdges);
87         while(hasNextOrderEdge(iterator)){
88                 OrderEdge* edge = nextOrderEdge(iterator);
89                 if(!edge->mustPos)
90                         return false;
91         }
92         deleteIterOrderEdge(iterator);
93         iterator = iteratorOrderEdge(node->outEdges);
94         while(hasNextOrderEdge(iterator)){
95                 OrderEdge* edge = nextOrderEdge(iterator);
96                 if(!edge->mustPos)
97                         return false;
98         }
99         deleteIterOrderEdge(iterator);
100         return true;
101 }
102
103 void bypassMustBeTrueNode(CSolver *This, OrderGraph* graph, OrderNode* node){
104         HSIteratorOrderEdge* iterin = iteratorOrderEdge(node->inEdges);
105         while(hasNextOrderEdge(iterin)){
106                 OrderEdge* inEdge = nextOrderEdge(iterin);
107                 OrderNode* srcNode = inEdge->source;
108                 removeHashSetOrderEdge(srcNode->outEdges, inEdge);
109                 HSIteratorOrderEdge* iterout = iteratorOrderEdge(node->outEdges);
110                 while(hasNextOrderEdge(iterout)){
111                         OrderEdge* outEdge = nextOrderEdge(iterout);
112                         OrderNode* sinkNode = outEdge->sink;
113                         removeHashSetOrderEdge(sinkNode->inEdges, outEdge);
114                         //Adding new edge to new sink and src nodes ...
115                         OrderEdge *newEdge =getOrderEdgeFromOrderGraph(graph, srcNode, sinkNode);
116                         newEdge->mustPos = true;
117                         newEdge->polPos = true;
118                         if (newEdge->mustNeg)
119                                 This->unsat = true;
120                         addHashSetOrderEdge(srcNode->outEdges, newEdge);
121                         addHashSetOrderEdge(sinkNode->inEdges, newEdge);
122                 }
123                 deleteIterOrderEdge(iterout);
124         }
125         deleteIterOrderEdge(iterin);
126 }
127
128 void removeMustBeTrueNodes(CSolver *This, OrderGraph *graph) {
129         HSIteratorOrderNode* iterator = iteratorOrderNode(graph->nodes);
130         while(hasNextOrderNode(iterator)){
131                 OrderNode* node = nextOrderNode(iterator);
132                 if(isMustBeTrueNode(node)){
133                         bypassMustBeTrueNode(This,graph, node);
134                 }
135         }
136         deleteIterOrderNode(iterator);
137 }
138
139 /** This function computes a source set for every nodes, the set of
140                 nodes that can reach that node via pospolarity edges.  It then
141                 looks for negative polarity edges from nodes in the the source set
142                 to determine whether we need to generate pseudoPos edges. */
143
144 void completePartialOrderGraph(OrderGraph *graph) {
145         VectorOrderNode finishNodes;
146         initDefVectorOrderNode(&finishNodes);
147         DFS(graph, &finishNodes);
148         resetNodeInfoStatusSCC(graph);
149         HashTableNodeToNodeSet *table = allocHashTableNodeToNodeSet(128, 0.25);
150
151         VectorOrderNode sccNodes;
152         initDefVectorOrderNode(&sccNodes);
153         
154         uint size = getSizeVectorOrderNode(&finishNodes);
155         uint sccNum = 1;
156         for (int i = size - 1; i >= 0; i--) {
157                 OrderNode *node = getVectorOrderNode(&finishNodes, i);
158                 HashSetOrderNode *sources = allocHashSetOrderNode(4, 0.25);
159                 putNodeToNodeSet(table, node, sources);
160                 
161                 if (node->status == NOTVISITED) {
162                         //Need to do reverse traversal here...
163                         node->status = VISITED;
164                         DFSNodeVisit(node, &sccNodes, true, false, sccNum);
165                         node->status = FINISHED;
166                         node->sccNum = sccNum;
167                         sccNum++;
168                         pushVectorOrderNode(&sccNodes, node);
169
170                         //Compute in set for entire SCC
171                         uint rSize = getSizeVectorOrderNode(&sccNodes);
172                         for (uint j = 0; j < rSize; j++) {
173                                 OrderNode *rnode = getVectorOrderNode(&sccNodes, j);
174                                 //Compute source sets
175                                 HSIteratorOrderEdge *iterator = iteratorOrderEdge(rnode->inEdges);
176                                 while (hasNextOrderEdge(iterator)) {
177                                         OrderEdge *edge = nextOrderEdge(iterator);
178                                         OrderNode *parent = edge->source;
179                                         if (edge->polPos) {
180                                                 addHashSetOrderNode(sources, parent);
181                                                 HashSetOrderNode *parent_srcs = (HashSetOrderNode *)getNodeToNodeSet(table, parent);
182                                                 addAllHashSetOrderNode(sources, parent_srcs);
183                                         }
184                                 }
185                                 deleteIterOrderEdge(iterator);
186                         }
187                         for (uint j=0; j < rSize; j++) {
188                                 //Copy in set of entire SCC
189                                 OrderNode *rnode = getVectorOrderNode(&sccNodes, j);
190                                 HashSetOrderNode * set = (j==0) ? sources : copyHashSetOrderNode(sources);
191                                 putNodeToNodeSet(table, rnode, set);
192
193                                 //Use source sets to compute pseudoPos edges
194                                 HSIteratorOrderEdge *iterator = iteratorOrderEdge(rnode->inEdges);
195                                 while (hasNextOrderEdge(iterator)) {
196                                         OrderEdge *edge = nextOrderEdge(iterator);
197                                         OrderNode *parent = edge->source;
198                                         ASSERT(parent != rnode);
199                                         if (edge->polNeg && parent->sccNum != rnode->sccNum &&
200                                                         containsHashSetOrderNode(sources, parent)) {
201                                                 OrderEdge *newedge = getOrderEdgeFromOrderGraph(graph, rnode, parent);
202                                                 newedge->pseudoPos = true;
203                                         }
204                                 }
205                                 deleteIterOrderEdge(iterator);
206                         }
207                         
208                         clearVectorOrderNode(&sccNodes);
209                 }
210         }
211
212         resetAndDeleteHashTableNodeToNodeSet(table);
213         deleteHashTableNodeToNodeSet(table);
214         resetNodeInfoStatusSCC(graph);
215         deleteVectorArrayOrderNode(&sccNodes);
216         deleteVectorArrayOrderNode(&finishNodes);
217 }
218
219 void DFSMust(OrderGraph *graph, VectorOrderNode *finishNodes) {
220         HSIteratorOrderNode *iterator = iteratorOrderNode(graph->nodes);
221         while (hasNextOrderNode(iterator)) {
222                 OrderNode *node = nextOrderNode(iterator);
223                 if (node->status == NOTVISITED) {
224                         node->status = VISITED;
225                         DFSNodeVisit(node, finishNodes, false, true, 0);
226                         node->status = FINISHED;
227                         pushVectorOrderNode(finishNodes, node);
228                 }
229         }
230         deleteIterOrderNode(iterator);
231 }
232
233 void DFSClearContradictions(CSolver *solver, OrderGraph *graph, VectorOrderNode *finishNodes, bool computeTransitiveClosure) {
234         uint size = getSizeVectorOrderNode(finishNodes);
235         HashTableNodeToNodeSet *table = allocHashTableNodeToNodeSet(128, 0.25);
236
237         for (int i = size - 1; i >= 0; i--) {
238                 OrderNode *node = getVectorOrderNode(finishNodes, i);
239                 HashSetOrderNode *sources = allocHashSetOrderNode(4, 0.25);
240                 putNodeToNodeSet(table, node, sources);
241
242                 {
243                         //Compute source sets
244                         HSIteratorOrderEdge *iterator = iteratorOrderEdge(node->inEdges);
245                         while (hasNextOrderEdge(iterator)) {
246                                 OrderEdge *edge = nextOrderEdge(iterator);
247                                 OrderNode *parent = edge->source;
248                                 if (edge->mustPos) {
249                                         addHashSetOrderNode(sources, parent);
250                                         HashSetOrderNode *parent_srcs = (HashSetOrderNode *)getNodeToNodeSet(table, parent);
251                                         addAllHashSetOrderNode(sources, parent_srcs);
252                                 }
253                         }
254                         deleteIterOrderEdge(iterator);
255                 }
256                 if (computeTransitiveClosure) {
257                         //Compute full transitive closure for nodes
258                         HSIteratorOrderNode *srciterator = iteratorOrderNode(sources);
259                         while (hasNextOrderNode(srciterator)) {
260                                 OrderNode *srcnode = nextOrderNode(srciterator);
261                                 OrderEdge *newedge = getOrderEdgeFromOrderGraph(graph, srcnode, node);
262                                 newedge->mustPos = true;
263                                 newedge->polPos = true;
264                                 if (newedge->mustNeg)
265                                         solver->unsat = true;
266                                 addHashSetOrderEdge(srcnode->outEdges,newedge);
267                                 addHashSetOrderEdge(node->inEdges,newedge);
268                         }
269                         deleteIterOrderNode(srciterator);
270                 }
271                 {
272                         //Use source sets to compute mustPos edges
273                         HSIteratorOrderEdge *iterator = iteratorOrderEdge(node->inEdges);
274                         while (hasNextOrderEdge(iterator)) {
275                                 OrderEdge *edge = nextOrderEdge(iterator);
276                                 OrderNode *parent = edge->source;
277                                 if (!edge->mustPos && containsHashSetOrderNode(sources, parent)) {
278                                         edge->mustPos = true;
279                                         edge->polPos = true;
280                                         if (edge->mustNeg)
281                                                 solver->unsat = true;
282                                 }
283                         }
284                         deleteIterOrderEdge(iterator);
285                 }
286                 {
287                         //Use source sets to compute mustNeg for edges that would introduce cycle if true
288                         HSIteratorOrderEdge *iterator = iteratorOrderEdge(node->outEdges);
289                         while (hasNextOrderEdge(iterator)) {
290                                 OrderEdge *edge = nextOrderEdge(iterator);
291                                 OrderNode *child = edge->sink;
292                                 if (!edge->mustNeg && containsHashSetOrderNode(sources, child)) {
293                                         edge->mustNeg = true;
294                                         edge->polNeg = true;
295                                         if (edge->mustPos)
296                                                 solver->unsat = true;
297                                 }
298                         }
299                         deleteIterOrderEdge(iterator);
300                 }
301         }
302
303         resetAndDeleteHashTableNodeToNodeSet(table);
304         deleteHashTableNodeToNodeSet(table);
305 }
306
307 /* This function finds edges that would form a cycle with must edges
308    and forces them to be mustNeg.  It also decides whether an edge
309    must be true because of transitivity from other must be true
310    edges. */
311
312 void reachMustAnalysis(CSolver * solver, OrderGraph *graph, bool computeTransitiveClosure) {
313         VectorOrderNode finishNodes;
314         initDefVectorOrderNode(&finishNodes);
315         //Topologically sort the mustPos edge graph
316         DFSMust(graph, &finishNodes);
317         resetNodeInfoStatusSCC(graph);
318
319         //Find any backwards edges that complete cycles and force them to be mustNeg
320         DFSClearContradictions(solver, graph, &finishNodes, computeTransitiveClosure);
321         deleteVectorArrayOrderNode(&finishNodes);
322 }
323
324 /* This function finds edges that must be positive and forces the
325    inverse edge to be negative (and clears its positive polarity if it
326    had one). */
327
328 void localMustAnalysisTotal(CSolver *solver, OrderGraph *graph) {
329         HSIteratorOrderEdge *iterator = iteratorOrderEdge(graph->edges);
330         while (hasNextOrderEdge(iterator)) {
331                 OrderEdge *edge = nextOrderEdge(iterator);
332                 if (edge->mustPos) {
333                         OrderEdge *invEdge = getInverseOrderEdge(graph, edge);
334                         if (invEdge != NULL) {
335                                 if (!invEdge->mustPos) {
336                                         invEdge->polPos = false;
337                                 } else {
338                                         solver->unsat = true;
339                                 }
340                                 invEdge->mustNeg = true;
341                                 invEdge->polNeg = true;
342                         }
343                 }
344         }
345         deleteIterOrderEdge(iterator);
346 }
347
348 /** This finds edges that must be positive and forces the inverse edge
349     to be negative.  It also clears the negative flag of this edge.
350     It also finds edges that must be negative and clears the positive
351     polarity. */
352
353 void localMustAnalysisPartial(CSolver *solver, OrderGraph *graph) {
354         HSIteratorOrderEdge *iterator = iteratorOrderEdge(graph->edges);
355         while (hasNextOrderEdge(iterator)) {
356                 OrderEdge *edge = nextOrderEdge(iterator);
357                 if (edge->mustPos) {
358                         if (!edge->mustNeg) {
359                                 edge->polNeg = false;
360                         } else
361                                 solver->unsat = true;
362
363                         OrderEdge *invEdge = getInverseOrderEdge(graph, edge);
364                         if (invEdge != NULL) {
365                                 if (!invEdge->mustPos)
366                                         invEdge->polPos = false;
367                                 else
368                                         solver->unsat = true;
369                                 invEdge->mustNeg = true;
370                                 invEdge->polNeg = true;
371                         }
372                 }
373                 if (edge->mustNeg && !edge->mustPos) {
374                         edge->polPos = false;
375                 }
376         }
377         deleteIterOrderEdge(iterator);
378 }
379
380 void decomposeOrder(CSolver *This, Order *order, OrderGraph *graph) {
381         VectorOrder ordervec;
382         VectorOrder partialcandidatevec;
383         initDefVectorOrder(&ordervec);
384         initDefVectorOrder(&partialcandidatevec);
385         uint size = getSizeVectorBooleanOrder(&order->constraints);
386         for (uint i = 0; i < size; i++) {
387                 BooleanOrder *orderconstraint = getVectorBooleanOrder(&order->constraints, i);
388                 OrderNode *from = getOrderNodeFromOrderGraph(graph, orderconstraint->first);
389                 OrderNode *to = getOrderNodeFromOrderGraph(graph, orderconstraint->second);
390                 model_print("from->sccNum:%u\tto->sccNum:%u\n", from->sccNum, to->sccNum);
391                 if (from->sccNum != to->sccNum) {
392                         OrderEdge *edge = getOrderEdgeFromOrderGraph(graph, from, to);                  
393                         if (edge->polPos) {
394                                 replaceBooleanWithTrue(This, (Boolean *)orderconstraint);
395                         } else if (edge->polNeg) {
396                                 replaceBooleanWithFalse(This, (Boolean *)orderconstraint);
397                         } else {
398                                 //This case should only be possible if constraint isn't in AST
399                                 ASSERT(0);
400                         }
401                 } else {
402                         //Build new order and change constraint's order
403                         Order *neworder = NULL;
404                         if (getSizeVectorOrder(&ordervec) > from->sccNum)
405                                 neworder = getVectorOrder(&ordervec, from->sccNum);
406                         if (neworder == NULL) {
407                                 Set *set = (Set *) allocMutableSet(order->set->type);
408                                 neworder = allocOrder(order->type, set);
409                                 pushVectorOrder(This->allOrders, neworder);
410                                 setExpandVectorOrder(&ordervec, from->sccNum, neworder);
411                                 if (order->type == PARTIAL)
412                                         setExpandVectorOrder(&partialcandidatevec, from->sccNum, neworder);
413                                 else
414                                         setExpandVectorOrder(&partialcandidatevec, from->sccNum, NULL);
415                         }
416                         if (from->status != ADDEDTOSET) {
417                                 from->status = ADDEDTOSET;
418                                 addElementMSet((MutableSet *)neworder->set, from->id);
419                         }
420                         if (to->status != ADDEDTOSET) {
421                                 to->status = ADDEDTOSET;
422                                 addElementMSet((MutableSet *)neworder->set, to->id);
423                         }
424                         if (order->type == PARTIAL) {
425                                 OrderEdge *edge = getOrderEdgeFromOrderGraph(graph, from, to);
426                                 if (edge->polNeg)
427                                         setExpandVectorOrder(&partialcandidatevec, from->sccNum, NULL);
428                         }
429                         orderconstraint->order = neworder;
430                         addOrderConstraint(neworder, orderconstraint);
431                 }
432         }
433
434         uint pcvsize=getSizeVectorOrder(&partialcandidatevec);
435         for(uint i=0;i<pcvsize;i++) {
436                 Order * neworder=getVectorOrder(&partialcandidatevec, i);
437                 if (neworder != NULL){
438                         neworder->type = TOTAL;
439                         model_print("i=%u\t", i);
440                 }
441         }
442         
443         deleteVectorArrayOrder(&ordervec);
444         deleteVectorArrayOrder(&partialcandidatevec);
445 }
446
447 void orderAnalysis(CSolver *This) {
448         uint size = getSizeVectorOrder(This->allOrders);
449         for (uint i = 0; i < size; i++) {
450                 Order *order = getVectorOrder(This->allOrders, i);
451                 bool doDecompose=GETVARTUNABLE(This->tuner, order->type, DECOMPOSEORDER, &onoff);
452                 if (!doDecompose)
453                         continue;
454                 
455                 OrderGraph *graph = buildOrderGraph(order);
456                 if (order->type == PARTIAL) {
457                         //Required to do SCC analysis for partial order graphs.  It
458                         //makes sure we don't incorrectly optimize graphs with negative
459                         //polarity edges
460                         completePartialOrderGraph(graph);
461                 }
462
463
464                 bool mustReachGlobal=GETVARTUNABLE(This->tuner, order->type, MUSTREACHGLOBAL, &onoff);
465
466                 if (mustReachGlobal)
467                         reachMustAnalysis(This, graph, false);
468
469                 bool mustReachLocal=GETVARTUNABLE(This->tuner, order->type, MUSTREACHLOCAL, &onoff);
470                 
471                 if (mustReachLocal) {
472                         //This pair of analysis is also optional
473                         if (order->type == PARTIAL) {
474                                 localMustAnalysisPartial(This, graph);
475                         } else {
476                                 localMustAnalysisTotal(This, graph);
477                         }
478                 }
479
480                 bool mustReachPrune=GETVARTUNABLE(This->tuner, order->type, MUSTREACHPRUNE, &onoff);
481                 
482                 if (mustReachPrune)
483                         removeMustBeTrueNodes(This, graph);
484                 
485                 //This is needed for splitorder
486                 computeStronglyConnectedComponentGraph(graph);
487                 
488                 decomposeOrder(This, order, graph);
489                 
490                 deleteOrderGraph(graph);
491         }
492 }