Add LTE function for completeness and fix bug in LT
[satune.git] / src / Backend / satfuncopencoder.c
1 #include "satencoder.h"
2 #include "common.h"
3 #include "function.h"
4 #include "ops.h"
5 #include "predicate.h"
6 #include "boolean.h"
7 #include "table.h"
8 #include "tableentry.h"
9 #include "set.h"
10 #include "element.h"
11 #include "common.h"
12 #include "satfuncopencoder.h"
13
14 Edge encodeOperatorPredicateSATEncoder(SATEncoder *This, BooleanPredicate *constraint) {
15         switch (constraint->encoding.type) {
16         case ENUMERATEIMPLICATIONS:
17                 return encodeEnumOperatorPredicateSATEncoder(This, constraint);
18         case CIRCUIT:
19                 return encodeCircuitOperatorPredicateEncoder(This, constraint);
20         default:
21                 ASSERT(0);
22         }
23         exit(-1);
24 }
25
26 Edge encodeEnumOperatorPredicateSATEncoder(SATEncoder *This, BooleanPredicate *constraint) {
27         PredicateOperator *predicate = (PredicateOperator *)constraint->predicate;
28         uint numDomains = getSizeArraySet(&predicate->domains);
29
30         FunctionEncodingType encType = constraint->encoding.type;
31         bool generateNegation = encType == ENUMERATEIMPLICATIONSNEGATE;
32
33         /* Call base encoders for children */
34         for (uint i = 0; i < numDomains; i++) {
35                 Element *elem = getArrayElement( &constraint->inputs, i);
36                 encodeElementSATEncoder(This, elem);
37         }
38         VectorEdge *clauses = allocDefVectorEdge();     // Setup array of clauses
39
40         uint indices[numDomains];       //setup indices
41         bzero(indices, sizeof(uint) * numDomains);
42
43         uint64_t vals[numDomains];//setup value array
44         for (uint i = 0; i < numDomains; i++) {
45                 Set *set = getArraySet(&predicate->domains, i);
46                 vals[i] = getSetElement(set, indices[i]);
47         }
48
49         bool notfinished = true;
50         while (notfinished) {
51                 Edge carray[numDomains];
52
53                 if (evalPredicateOperator(predicate, vals) ^ generateNegation) {
54                         //Include this in the set of terms
55                         for (uint i = 0; i < numDomains; i++) {
56                                 Element *elem = getArrayElement(&constraint->inputs, i);
57                                 carray[i] = getElementValueConstraint(This, elem, vals[i]);
58                         }
59                         Edge term = constraintAND(This->cnf, numDomains, carray);
60                         pushVectorEdge(clauses, term);
61                 }
62
63                 notfinished = false;
64                 for (uint i = 0; i < numDomains; i++) {
65                         uint index = ++indices[i];
66                         Set *set = getArraySet(&predicate->domains, i);
67
68                         if (index < getSetSize(set)) {
69                                 vals[i] = getSetElement(set, index);
70                                 notfinished = true;
71                                 break;
72                         } else {
73                                 indices[i] = 0;
74                                 vals[i] = getSetElement(set, 0);
75                         }
76                 }
77         }
78         if (getSizeVectorEdge(clauses) == 0) {
79                 deleteVectorEdge(clauses);
80                 return E_False;
81         }
82         Edge cor = constraintOR(This->cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
83         deleteVectorEdge(clauses);
84         return generateNegation ? constraintNegate(cor) : cor;
85 }
86
87
88 void encodeOperatorElementFunctionSATEncoder(SATEncoder *This, ElementFunction *func) {
89 #ifdef TRACE_DEBUG
90         model_print("Operator Function ...\n");
91 #endif
92         FunctionOperator *function = (FunctionOperator *) func->function;
93         uint numDomains = getSizeArrayElement(&func->inputs);
94
95         /* Call base encoders for children */
96         for (uint i = 0; i < numDomains; i++) {
97                 Element *elem = getArrayElement( &func->inputs, i);
98                 encodeElementSATEncoder(This, elem);
99         }
100
101         VectorEdge *clauses = allocDefVectorEdge();     // Setup array of clauses
102
103         uint indices[numDomains];       //setup indices
104         bzero(indices, sizeof(uint) * numDomains);
105
106         uint64_t vals[numDomains];//setup value array
107         for (uint i = 0; i < numDomains; i++) {
108                 Set *set = getElementSet(getArrayElement(&func->inputs, i));
109                 vals[i] = getSetElement(set, indices[i]);
110         }
111
112         Edge overFlowConstraint = ((BooleanVar *) func->overflowstatus)->var;
113
114         bool notfinished = true;
115         while (notfinished) {
116                 Edge carray[numDomains + 1];
117
118                 uint64_t result = applyFunctionOperator(function, numDomains, vals);
119                 bool isInRange = isInRangeFunction((FunctionOperator *)func->function, result);
120                 bool needClause = isInRange;
121                 if (function->overflowbehavior == OVERFLOWSETSFLAG || function->overflowbehavior == FLAGIFFOVERFLOW) {
122                         needClause = true;
123                 }
124
125                 if (needClause) {
126                         //Include this in the set of terms
127                         for (uint i = 0; i < numDomains; i++) {
128                                 Element *elem = getArrayElement(&func->inputs, i);
129                                 carray[i] = getElementValueConstraint(This, elem, vals[i]);
130                         }
131                         if (isInRange) {
132                                 carray[numDomains] = getElementValueConstraint(This, &func->base, result);
133                         }
134
135                         Edge clause;
136                         switch (function->overflowbehavior) {
137                         case IGNORE:
138                         case NOOVERFLOW:
139                         case WRAPAROUND: {
140                                 clause = constraintIMPLIES(This->cnf,constraintAND(This->cnf, numDomains, carray), carray[numDomains]);
141                                 break;
142                         }
143                         case FLAGFORCESOVERFLOW: {
144                                 clause = constraintIMPLIES(This->cnf,constraintAND(This->cnf, numDomains, carray), constraintAND2(This->cnf, carray[numDomains], constraintNegate(overFlowConstraint)));
145                                 break;
146                         }
147                         case OVERFLOWSETSFLAG: {
148                                 if (isInRange) {
149                                         clause = constraintIMPLIES(This->cnf,constraintAND(This->cnf, numDomains, carray), carray[numDomains]);
150                                 } else {
151                                         clause = constraintIMPLIES(This->cnf,constraintAND(This->cnf, numDomains, carray), overFlowConstraint);
152                                 }
153                                 break;
154                         }
155                         case FLAGIFFOVERFLOW: {
156                                 if (isInRange) {
157                                         clause = constraintIMPLIES(This->cnf,constraintAND(This->cnf, numDomains, carray), constraintAND2(This->cnf, carray[numDomains], constraintNegate(overFlowConstraint)));
158                                 } else {
159                                         clause = constraintIMPLIES(This->cnf,constraintAND(This->cnf, numDomains, carray), overFlowConstraint);
160                                 }
161                                 break;
162                         }
163                         default:
164                                 ASSERT(0);
165                         }
166 #ifdef TRACE_DEBUG
167                         model_print("added clause in operator function\n");
168                         printCNF(clause);
169                         model_print("\n");
170 #endif
171                         pushVectorEdge(clauses, clause);
172                 }
173
174                 notfinished = false;
175                 for (uint i = 0; i < numDomains; i++) {
176                         uint index = ++indices[i];
177                         Set *set = getElementSet(getArrayElement(&func->inputs, i));
178
179                         if (index < getSetSize(set)) {
180                                 vals[i] = getSetElement(set, index);
181                                 notfinished = true;
182                                 break;
183                         } else {
184                                 indices[i] = 0;
185                                 vals[i] = getSetElement(set, 0);
186                         }
187                 }
188         }
189         if (getSizeVectorEdge(clauses) == 0) {
190                 deleteVectorEdge(clauses);
191                 return;
192         }
193         Edge cor = constraintAND(This->cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
194         addConstraintCNF(This->cnf, cor);
195         deleteVectorEdge(clauses);
196 }
197
198 Edge encodeCircuitOperatorPredicateEncoder(SATEncoder *This, BooleanPredicate *constraint) {
199         PredicateOperator *predicate = (PredicateOperator *) constraint->predicate;
200         ASSERT(getSizeArraySet(&predicate->domains) == 2);
201         Element *elem0 = getArrayElement( &constraint->inputs, 0);
202         encodeElementSATEncoder(This, elem0);
203         Element *elem1 = getArrayElement( &constraint->inputs, 1);
204         encodeElementSATEncoder(This, elem1);
205         ElementEncoding *ee0 = getElementEncoding(elem0);
206         ElementEncoding *ee1 = getElementEncoding(elem1);
207         ASSERT(ee0->numVars == ee1->numVars);
208         uint numVars = ee0->numVars;
209         switch (predicate->op) {
210                 case EQUALS:
211                         return generateEquivNVConstraint(This->cnf, numVars, ee0->variables, ee1->variables);
212                 case LT:
213                         return generateLTConstraint(This->cnf, numVars, ee0->variables, ee1->variables);
214                 case GT:
215                         return generateLTConstraint(This->cnf, numVars, ee1->variables, ee0->variables);
216                 default:
217                         ASSERT(0);
218         }
219         exit(-1);
220 }
221