Unary encoding of predicates
[satune.git] / src / Backend / satfuncopencoder.cc
1 #include "satencoder.h"
2 #include "common.h"
3 #include "function.h"
4 #include "ops.h"
5 #include "predicate.h"
6 #include "boolean.h"
7 #include "table.h"
8 #include "tableentry.h"
9 #include "set.h"
10 #include "element.h"
11 #include "common.h"
12
13 Edge SATEncoder::encodeOperatorPredicateSATEncoder(BooleanPredicate *constraint) {
14         switch (constraint->encoding.type) {
15         case ENUMERATEIMPLICATIONS:
16                 return encodeEnumOperatorPredicateSATEncoder(constraint);
17         case CIRCUIT:
18                 return encodeCircuitOperatorPredicateEncoder(constraint);
19         default:
20                 ASSERT(0);
21         }
22         exit(-1);
23 }
24
25 Edge SATEncoder::encodeEnumEqualsPredicateSATEncoder(BooleanPredicate *constraint) {
26         Polarity polarity = constraint->polarity;
27
28         /* Call base encoders for children */
29         for (uint i = 0; i < 2; i++) {
30                 Element *elem = constraint->inputs.get(i);
31                 encodeElementSATEncoder(elem);
32         }
33         VectorEdge *clauses = vector;
34
35         Set *set0 = constraint->inputs.get(0)->getRange();
36         uint size0 = set0->getSize();
37
38         Set *set1 = constraint->inputs.get(1)->getRange();
39         uint size1 = set1->getSize();
40
41         uint64_t val0 = set0->getElement(0);
42         uint64_t val1 = set1->getElement(0);
43         if (size0 != 0 && size1 != 0)
44                 for (uint i = 0, j = 0; true; ) {
45                         if (val0 == val1) {
46                                 Edge carray[2];
47                                 carray[0] = getElementValueConstraint(constraint->inputs.get(0), polarity, val0);
48                                 carray[1] = getElementValueConstraint(constraint->inputs.get(1), polarity, val0);
49                                 Edge term = constraintAND(cnf, 2, carray);
50                                 pushVectorEdge(clauses, term);
51                                 i++; j++;
52                                 if (i < size0)
53                                         val0 = set0->getElement(i);
54                                 else
55                                         break;
56                                 if (j < size1)
57                                         val1 = set1->getElement(j);
58                                 else
59                                         break;
60                         } else if (val0 < val1) {
61                                 i++;
62                                 if (i < size0)
63                                         val0 = set0->getElement(i);
64                                 else
65                                         break;
66                         } else {
67                                 j++;
68                                 if (j < size1)
69                                         val1 = set1->getElement(j);
70                                 else
71                                         break;
72                         }
73                 }
74         if (getSizeVectorEdge(clauses) == 0) {
75                 return E_False;
76         }
77         Edge cor = constraintOR(cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
78         clearVectorEdge(clauses);
79         return cor;
80 }
81
82 Edge SATEncoder::encodeUnaryPredicateSATEncoder(BooleanPredicate *constraint) {
83         Polarity polarity = constraint->polarity;
84         PredicateOperator *predicate = (PredicateOperator *)constraint->predicate;
85         CompOp op = predicate->getOp();
86
87         /* Call base encoders for children */
88         for (uint i = 0; i < 2; i++) {
89                 Element *elem = constraint->inputs.get(i);
90                 encodeElementSATEncoder(elem);
91         }
92         VectorEdge *clauses = vector;
93
94         Element *elem0 = constraint->inputs.get(0);
95         Element *elem1 = constraint->inputs.get(1);
96
97         //Eliminate symmetric cases
98         if (op == SATC_GT) {
99                 op = SATC_LT;
100                 Element *tmp = elem0;
101                 elem0 = elem1;
102                 elem1 = elem0;
103         } else if (op == SATC_GTE) {
104                 op = SATC_LTE;
105                 Element *tmp = elem0;
106                 elem0 = elem1;
107                 elem1 = elem0;
108         }
109
110         Set *set0 = elem0->getRange();
111         uint size0 = set0->getSize();
112         Edge *vars0 = elem0->getElementEncoding()->variables;
113
114         Set *set1 = elem1->getRange();
115         uint size1 = set1->getSize();
116         Edge *vars1 = elem1->getElementEncoding()->variables;
117
118
119         uint64_t val0 = set0->getElement(0);
120         uint64_t val1 = set1->getElement(0);
121         if (size0 != 0 && size1 != 0) {
122                 for (uint i = 0, j = 0; true; ) {
123                         if (val0 > val1 || (op == SATC_LT && val0 == val1)) {
124                                 j++;
125                                 if (j == size1) {
126                                         //need to assert val0 isn't this big
127                                         if (i == 0)
128                                                 return E_False;//Can't satisfy this constraint
129                                         pushVectorEdge(clauses, constraintNegate(vars0[i - 1]));
130                                         break;
131                                 }
132                                 val1 = set1->getElement(j);
133                         } else {
134                                 if (i == 0) {
135                                         if (j != 0) {
136                                                 pushVectorEdge(clauses, vars1[j - 1]);
137                                         }
138                                 } else {
139                                         if (j != 0) {
140                                                 Edge term = constraintIMPLIES(cnf, vars0[i - 1], vars1[j - 1]);
141                                                 pushVectorEdge(clauses, term);
142                                         }
143                                 }
144                                 i++;
145                                 if (i == size0)
146                                         break;
147                                 val0 = set0->getElement(i);
148                         }
149                 }
150         }
151         //Trivially true constraint
152         if (getSizeVectorEdge(clauses) == 0)
153                 return E_True;
154
155         Edge cand = constraintAND(cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
156         clearVectorEdge(clauses);
157         return cand;
158 }
159
160 Edge SATEncoder::encodeEnumOperatorPredicateSATEncoder(BooleanPredicate *constraint) {
161         PredicateOperator *predicate = (PredicateOperator *)constraint->predicate;
162         uint numDomains = constraint->inputs.getSize();
163         Polarity polarity = constraint->polarity;
164         FunctionEncodingType encType = constraint->encoding.type;
165         bool generateNegation = encType == ENUMERATEIMPLICATIONSNEGATE;
166         if (generateNegation)
167                 polarity = negatePolarity(polarity);
168
169         CompOp op = predicate->getOp();
170         if (!generateNegation && op == SATC_EQUALS)
171                 return encodeEnumEqualsPredicateSATEncoder(constraint);
172
173         if (!generateNegation && numDomains == 2 &&
174                         (op == SATC_LT || op == SATC_GT || op == SATC_LTE || op == SATC_GTE) &&
175                         constraint->inputs.get(0)->encoding.type == UNARY &&
176                         constraint->inputs.get(1)->encoding.type == UNARY) {
177                 return encodeUnaryPredicateSATEncoder(constraint);
178         }
179
180         /* Call base encoders for children */
181         for (uint i = 0; i < numDomains; i++) {
182                 Element *elem = constraint->inputs.get(i);
183                 encodeElementSATEncoder(elem);
184         }
185         VectorEdge *clauses = vector;
186
187         uint indices[numDomains];       //setup indices
188         bzero(indices, sizeof(uint) * numDomains);
189
190         uint64_t vals[numDomains];//setup value array
191         for (uint i = 0; i < numDomains; i++) {
192                 Set *set = constraint->inputs.get(i)->getRange();
193                 vals[i] = set->getElement(indices[i]);
194         }
195
196         bool notfinished = true;
197         Edge carray[numDomains];
198         while (notfinished) {
199                 if (predicate->evalPredicateOperator(vals) != generateNegation) {
200                         //Include this in the set of terms
201                         for (uint i = 0; i < numDomains; i++) {
202                                 Element *elem = constraint->inputs.get(i);
203                                 carray[i] = getElementValueConstraint(elem, polarity, vals[i]);
204                         }
205                         Edge term = constraintAND(cnf, numDomains, carray);
206                         pushVectorEdge(clauses, term);
207                         ASSERT(getSizeVectorEdge(clauses) > 0);
208                 }
209
210                 notfinished = false;
211                 for (uint i = 0; i < numDomains; i++) {
212                         uint index = ++indices[i];
213                         Set *set = constraint->inputs.get(i)->getRange();
214
215                         if (index < set->getSize()) {
216                                 vals[i] = set->getElement(index);
217                                 notfinished = true;
218                                 break;
219                         } else {
220                                 indices[i] = 0;
221                                 vals[i] = set->getElement(0);
222                         }
223                 }
224         }
225         if (getSizeVectorEdge(clauses) == 0) {
226                 return E_False;
227         }
228         Edge cor = constraintOR(cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
229         clearVectorEdge(clauses);
230         return generateNegation ? constraintNegate(cor) : cor;
231 }
232
233
234 void SATEncoder::encodeOperatorElementFunctionSATEncoder(ElementFunction *func) {
235 #ifdef TRACE_DEBUG
236         model_print("Operator Function ...\n");
237 #endif
238         FunctionOperator *function = (FunctionOperator *) func->getFunction();
239         uint numDomains = func->inputs.getSize();
240
241         /* Call base encoders for children */
242         for (uint i = 0; i < numDomains; i++) {
243                 Element *elem = func->inputs.get(i);
244                 encodeElementSATEncoder(elem);
245         }
246
247         VectorEdge *clauses = allocDefVectorEdge();     // Setup array of clauses
248
249         uint indices[numDomains];       //setup indices
250         bzero(indices, sizeof(uint) * numDomains);
251
252         uint64_t vals[numDomains];//setup value array
253         for (uint i = 0; i < numDomains; i++) {
254                 Set *set = func->inputs.get(i)->getRange();
255                 vals[i] = set->getElement(indices[i]);
256         }
257
258         bool notfinished = true;
259         Edge carray[numDomains + 1];
260         while (notfinished) {
261                 uint64_t result = function->applyFunctionOperator(numDomains, vals);
262                 bool isInRange = ((FunctionOperator *)func->getFunction())->isInRangeFunction(result);
263                 bool needClause = isInRange;
264                 if (function->overflowbehavior == SATC_OVERFLOWSETSFLAG || function->overflowbehavior == SATC_FLAGIFFOVERFLOW) {
265                         needClause = true;
266                 }
267
268                 if (needClause) {
269                         //Include this in the set of terms
270                         for (uint i = 0; i < numDomains; i++) {
271                                 Element *elem = func->inputs.get(i);
272                                 carray[i] = getElementValueConstraint(elem, P_FALSE, vals[i]);
273                         }
274                         if (isInRange) {
275                                 carray[numDomains] = getElementValueConstraint(func, P_TRUE, result);
276                         }
277
278                         Edge clause;
279                         switch (function->overflowbehavior) {
280                         case SATC_IGNORE:
281                         case SATC_NOOVERFLOW:
282                         case SATC_WRAPAROUND: {
283                                 clause = constraintIMPLIES(cnf, constraintAND(cnf, numDomains, carray), carray[numDomains]);
284                                 break;
285                         }
286                         case SATC_FLAGFORCESOVERFLOW: {
287                                 Edge overFlowConstraint = encodeConstraintSATEncoder(func->overflowstatus);
288                                 clause = constraintIMPLIES(cnf,constraintAND(cnf, numDomains, carray), constraintAND2(cnf, carray[numDomains], constraintNegate(overFlowConstraint)));
289                                 break;
290                         }
291                         case SATC_OVERFLOWSETSFLAG: {
292                                 if (isInRange) {
293                                         clause = constraintIMPLIES(cnf, constraintAND(cnf, numDomains, carray), carray[numDomains]);
294                                 } else {
295                                         Edge overFlowConstraint = encodeConstraintSATEncoder(func->overflowstatus);
296                                         clause = constraintIMPLIES(cnf,constraintAND(cnf, numDomains, carray), overFlowConstraint);
297                                 }
298                                 break;
299                         }
300                         case SATC_FLAGIFFOVERFLOW: {
301                                 Edge overFlowConstraint = encodeConstraintSATEncoder(func->overflowstatus);
302                                 if (isInRange) {
303                                         clause = constraintIMPLIES(cnf, constraintAND(cnf, numDomains, carray), constraintAND2(cnf, carray[numDomains], constraintNegate(overFlowConstraint)));
304                                 } else {
305                                         clause = constraintIMPLIES(cnf, constraintAND(cnf, numDomains, carray), overFlowConstraint);
306                                 }
307                                 break;
308                         }
309                         default:
310                                 ASSERT(0);
311                         }
312 #ifdef TRACE_DEBUG
313                         model_print("added clause in operator function\n");
314                         printCNF(clause);
315                         model_print("\n");
316 #endif
317                         pushVectorEdge(clauses, clause);
318                 }
319
320                 notfinished = false;
321                 for (uint i = 0; i < numDomains; i++) {
322                         uint index = ++indices[i];
323                         Set *set = func->inputs.get(i)->getRange();
324
325                         if (index < set->getSize()) {
326                                 vals[i] = set->getElement(index);
327                                 notfinished = true;
328                                 break;
329                         } else {
330                                 indices[i] = 0;
331                                 vals[i] = set->getElement(0);
332                         }
333                 }
334         }
335         if (getSizeVectorEdge(clauses) == 0) {
336                 deleteVectorEdge(clauses);
337                 return;
338         }
339         Edge cand = constraintAND(cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
340         addConstraintCNF(cnf, cand);
341         deleteVectorEdge(clauses);
342 }
343
344 Edge SATEncoder::encodeCircuitOperatorPredicateEncoder(BooleanPredicate *constraint) {
345         PredicateOperator *predicate = (PredicateOperator *) constraint->predicate;
346         Element *elem0 = constraint->inputs.get(0);
347         encodeElementSATEncoder(elem0);
348         Element *elem1 = constraint->inputs.get(1);
349         encodeElementSATEncoder(elem1);
350         ElementEncoding *ee0 = elem0->getElementEncoding();
351         ElementEncoding *ee1 = elem1->getElementEncoding();
352         ASSERT(ee0->numVars == ee1->numVars);
353         uint numVars = ee0->numVars;
354         switch (predicate->getOp()) {
355         case SATC_EQUALS:
356                 return generateEquivNVConstraint(cnf, numVars, ee0->variables, ee1->variables);
357         case SATC_LT:
358                 return generateLTConstraint(cnf, numVars, ee0->variables, ee1->variables);
359         case SATC_GT:
360                 return generateLTConstraint(cnf, numVars, ee1->variables, ee0->variables);
361         default:
362                 ASSERT(0);
363         }
364         exit(-1);
365 }
366