5809a599aeeb29603a13a32cf6af7810024a3835
[satune.git] / src / Backend / satencoder.c
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "csolver.h"
4 #include "boolean.h"
5 #include "constraint.h"
6 #include "common.h"
7 #include "element.h"
8 #include "function.h"
9 #include "tableentry.h"
10 #include "table.h"
11 #include "order.h"
12 #include "predicate.h"
13 #include "set.h"
14
15 SATEncoder * allocSATEncoder() {
16         SATEncoder *This=ourmalloc(sizeof (SATEncoder));
17         This->varcount=1;
18         This->cnf=createCNF();
19         return This;
20 }
21
22 void deleteSATEncoder(SATEncoder *This) {
23         deleteCNF(This->cnf);
24         ourfree(This);
25 }
26
27 void encodeAllSATEncoder(CSolver *csolver, SATEncoder * This) {
28         VectorBoolean *constraints=csolver->constraints;
29         uint size=getSizeVectorBoolean(constraints);
30         for(uint i=0;i<size;i++) {
31                 Boolean *constraint=getVectorBoolean(constraints, i);
32                 Edge c= encodeConstraintSATEncoder(This, constraint);
33                 printCNF(c);
34                 printf("\n\n");
35                 addConstraintCNF(This->cnf, c);
36         }
37 }
38
39 Edge encodeConstraintSATEncoder(SATEncoder *This, Boolean *constraint) {
40         switch(GETBOOLEANTYPE(constraint)) {
41         case ORDERCONST:
42                 return encodeOrderSATEncoder(This, (BooleanOrder *) constraint);
43         case BOOLEANVAR:
44                 return encodeVarSATEncoder(This, (BooleanVar *) constraint);
45         case LOGICOP:
46                 return encodeLogicSATEncoder(This, (BooleanLogic *) constraint);
47         case PREDICATEOP:
48                 return encodePredicateSATEncoder(This, (BooleanPredicate *) constraint);
49         default:
50                 model_print("Unhandled case in encodeConstraintSATEncoder %u", GETBOOLEANTYPE(constraint));
51                 exit(-1);
52         }
53 }
54
55 void getArrayNewVarsSATEncoder(SATEncoder* encoder, uint num, Edge * carray) {
56         for(uint i=0;i<num;i++)
57                 carray[i]=getNewVarSATEncoder(encoder);
58 }
59
60 Edge getNewVarSATEncoder(SATEncoder *This) {
61         return constraintNewVar(This->cnf);
62 }
63
64 Edge encodeVarSATEncoder(SATEncoder *This, BooleanVar * constraint) {
65         if (edgeIsNull(constraint->var)) {
66                 constraint->var=getNewVarSATEncoder(This);
67         }
68         return constraint->var;
69 }
70
71 Edge encodeLogicSATEncoder(SATEncoder *This, BooleanLogic * constraint) {
72         Edge array[getSizeArrayBoolean(&constraint->inputs)];
73         for(uint i=0;i<getSizeArrayBoolean(&constraint->inputs);i++)
74                 array[i]=encodeConstraintSATEncoder(This, getArrayBoolean(&constraint->inputs, i));
75
76         switch(constraint->op) {
77         case L_AND:
78                 return constraintAND(This->cnf, getSizeArrayBoolean(&constraint->inputs), array);
79         case L_OR:
80                 return constraintOR(This->cnf, getSizeArrayBoolean(&constraint->inputs), array);
81         case L_NOT:
82                 ASSERT( getSizeArrayBoolean(&constraint->inputs)==1);
83                 return constraintNegate(array[0]);
84         case L_XOR:
85                 ASSERT( getSizeArrayBoolean(&constraint->inputs)==2);
86                 return constraintXOR(This->cnf, array[0], array[1]);
87         case L_IMPLIES:
88                 ASSERT( getSizeArrayBoolean( &constraint->inputs)==2);
89                 return constraintIMPLIES(This->cnf, array[0], array[1]);
90         default:
91                 model_print("Unhandled case in encodeLogicSATEncoder %u", constraint->op);
92                 exit(-1);
93         }
94 }
95
96 Edge encodePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint) {
97         switch(GETPREDICATETYPE(constraint->predicate) ){
98                 case TABLEPRED:
99                         return encodeTablePredicateSATEncoder(This, constraint);
100                 case OPERATORPRED:
101                         return encodeOperatorPredicateSATEncoder(This, constraint);
102                 default:
103                         ASSERT(0);
104         }
105         return E_BOGUS;
106 }
107
108 Edge encodeTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
109         switch(constraint->encoding.type){
110                 case ENUMERATEIMPLICATIONS:
111                 case ENUMERATEIMPLICATIONSNEGATE:
112                         return encodeEnumTablePredicateSATEncoder(This, constraint);
113                 case CIRCUIT:
114                         ASSERT(0);
115                         break;
116                 default:
117                         ASSERT(0);
118         }
119         return E_BOGUS;
120 }
121
122 Edge encodeEnumTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
123         VectorTableEntry* entries = &(((PredicateTable*)constraint->predicate)->table->entries);
124         FunctionEncodingType encType = constraint->encoding.type;
125         ArrayElement* inputs = &constraint->inputs;
126         uint inputNum =getSizeArrayElement(inputs);
127         //Encode all the inputs first ...
128         for(uint i=0; i<inputNum; i++){
129                 encodeElementSATEncoder(This, getArrayElement(inputs, i));
130         }
131         
132         //WARNING: THIS ASSUMES PREDICATE TABLE IS COMPLETE...SEEMS UNLIKELY TO BE SAFE IN MANY CASES...
133         //WONDER WHAT BEST WAY TO HANDLE THIS IS...
134         
135         uint size = getSizeVectorTableEntry(entries);
136         bool generateNegation = encType == ENUMERATEIMPLICATIONSNEGATE;
137         Edge constraints[size];
138         for(uint i=0; i<size; i++){
139                 TableEntry* entry = getVectorTableEntry(entries, i);
140                 if(generateNegation == entry->output) {
141                         //Skip the irrelevant entries
142                         continue;
143                 }
144                 Edge carray[inputNum];
145                 for(uint j=0; j<inputNum; j++){
146                         Element* el = getArrayElement(inputs, j);
147                         carray[j] = getElementValueConstraint(This, el, entry->inputs[j]);
148                 }
149                 constraints[i]=constraintAND(This->cnf, inputNum, carray);
150         }
151         Edge result=constraintOR(This->cnf, size, constraints);
152
153         return generateNegation ? result: constraintNegate(result);
154 }
155
156 Edge encodeOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint) {
157         switch(constraint->encoding.type){
158                 case ENUMERATEIMPLICATIONS:
159                         return encodeEnumOperatorPredicateSATEncoder(This, constraint);
160                 case CIRCUIT:
161                         ASSERT(0);
162                         break;
163                 default:
164                         ASSERT(0);
165         }
166         return E_BOGUS;
167 }
168
169 Edge encodeEnumOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint) {
170         PredicateOperator* predicate = (PredicateOperator*)constraint->predicate;
171         uint numDomains=getSizeArraySet(&predicate->domains);
172
173         FunctionEncodingType encType = constraint->encoding.type;
174         bool generateNegation = encType == ENUMERATEIMPLICATIONSNEGATE;
175
176         /* Call base encoders for children */
177         for(uint i=0;i<numDomains;i++) {
178                 Element *elem = getArrayElement( &constraint->inputs, i);
179                 encodeElementSATEncoder(This, elem);
180         }
181         VectorEdge * clauses=allocDefVectorEdge(); // Setup array of clauses
182         
183         uint indices[numDomains]; //setup indices
184         bzero(indices, sizeof(uint)*numDomains);
185         
186         uint64_t vals[numDomains]; //setup value array
187         for(uint i=0;i<numDomains; i++) {
188                 Set * set=getArraySet(&predicate->domains, i);
189                 vals[i]=getSetElement(set, indices[i]);
190         }
191         
192         bool notfinished=true;
193         while(notfinished) {
194                 Edge carray[numDomains];
195
196                 if (evalPredicateOperator(predicate, vals) ^ generateNegation) {
197                         //Include this in the set of terms
198                         for(uint i=0;i<numDomains;i++) {
199                                 Element * elem = getArrayElement(&constraint->inputs, i);
200                                 carray[i] = getElementValueConstraint(This, elem, vals[i]);
201                         }
202                         pushVectorEdge(clauses, constraintAND(This->cnf, numDomains, carray));
203                 }
204                 
205                 notfinished=false;
206                 for(uint i=0;i<numDomains; i++) {
207                         uint index=++indices[i];
208                         Set * set=getArraySet(&predicate->domains, i);
209
210                         if (index < getSetSize(set)) {
211                                 vals[i]=getSetElement(set, index);
212                                 notfinished=true;
213                                 break;
214                         } else {
215                                 indices[i]=0;
216                                 vals[i]=getSetElement(set, 0);
217                         }
218                 }
219         }
220
221         Edge cor=constraintOR(This->cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
222         deleteVectorEdge(clauses);
223         return generateNegation ? cor : constraintNegate(cor);
224 }
225
226 void encodeElementSATEncoder(SATEncoder* encoder, Element *This){
227         switch( GETELEMENTTYPE(This) ){
228                 case ELEMFUNCRETURN:
229                         addConstraintCNF(encoder->cnf, encodeElementFunctionSATEncoder(encoder, (ElementFunction*) This));
230                         break;
231                 case ELEMSET:
232                         return;
233                 default:
234                         ASSERT(0);
235         }
236 }
237
238 Edge encodeElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction *This){
239         switch(GETFUNCTIONTYPE(This->function)){
240                 case TABLEFUNC:
241                         return encodeTableElementFunctionSATEncoder(encoder, This);
242                 case OPERATORFUNC:
243                         return encodeOperatorElementFunctionSATEncoder(encoder, This);
244                 default:
245                         ASSERT(0);
246         }
247         return E_BOGUS;
248 }
249
250 Edge encodeTableElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
251         switch(getElementFunctionEncoding(This)->type){
252                 case ENUMERATEIMPLICATIONS:
253                         return encodeEnumTableElemFunctionSATEncoder(encoder, This);
254                         break;
255                 case CIRCUIT:
256                         ASSERT(0);
257                         break;
258                 default:
259                         ASSERT(0);
260         }
261         return E_BOGUS;
262 }
263
264 Edge encodeOperatorElementFunctionSATEncoder(SATEncoder* This, ElementFunction* func) {
265         FunctionOperator * function = (FunctionOperator *) func->function;
266         uint numDomains=getSizeArrayElement(&func->inputs);
267
268         /* Call base encoders for children */
269         for(uint i=0;i<numDomains;i++) {
270                 Element *elem = getArrayElement( &func->inputs, i);
271                 encodeElementSATEncoder(This, elem);
272         }
273
274         VectorEdge * clauses=allocDefVectorEdge(); // Setup array of clauses
275         
276         uint indices[numDomains]; //setup indices
277         bzero(indices, sizeof(uint)*numDomains);
278         
279         uint64_t vals[numDomains]; //setup value array
280         for(uint i=0;i<numDomains; i++) {
281                 Set * set=getElementSet(getArrayElement(&func->inputs, i));
282                 vals[i]=getSetElement(set, indices[i]);
283         }
284
285         Edge overFlowConstraint = ((BooleanVar*) func->overflowstatus)->var;
286         
287         bool notfinished=true;
288         while(notfinished) {
289                 Edge carray[numDomains+2];
290
291                 uint64_t result=applyFunctionOperator(function, numDomains, vals);
292                 bool isInRange = isInRangeFunction((FunctionOperator*)func->function, result);
293                 bool needClause = isInRange;
294                 if (function->overflowbehavior == OVERFLOWSETSFLAG || function->overflowbehavior == FLAGIFFOVERFLOW) {
295                         needClause=true;
296                 }
297                 
298                 if (needClause) {
299                         //Include this in the set of terms
300                         for(uint i=0;i<numDomains;i++) {
301                                 Element * elem = getArrayElement(&func->inputs, i);
302                                 carray[i] = getElementValueConstraint(This, elem, vals[i]);
303                         }
304                         if (isInRange) {
305                                 carray[numDomains] = getElementValueConstraint(This, &func->base, result);
306                         }
307
308                         Edge clause;
309                         switch(function->overflowbehavior) {
310                         case IGNORE:
311                         case NOOVERFLOW:
312                         case WRAPAROUND: {
313                                 clause=constraintAND(This->cnf, numDomains+1, carray);
314                                 break;
315                         }
316                         case FLAGFORCESOVERFLOW: {
317                                 carray[numDomains+1]=constraintNegate(overFlowConstraint);
318                                 clause=constraintAND(This->cnf, numDomains+2, carray);
319                                 break;
320                         }
321                         case OVERFLOWSETSFLAG: {
322                                 if (isInRange) {
323                                         clause=constraintAND(This->cnf, numDomains+1, carray);
324                                 } else {
325                                         carray[numDomains+1]=overFlowConstraint;
326                                         clause=constraintAND(This->cnf, numDomains+1, carray);
327                                 }
328                                 break;
329                         }
330                         case FLAGIFFOVERFLOW: {
331                                 if (isInRange) {
332                                 carray[numDomains+1]=constraintNegate(overFlowConstraint);
333                                         clause=constraintAND(This->cnf, numDomains+2, carray);
334                                 } else {
335                                         carray[numDomains+1]=overFlowConstraint;
336                                         clause=constraintAND(This->cnf, numDomains+1, carray);
337                                 }
338                                 break;
339                         }
340                         default:
341                                 ASSERT(0);
342                         }
343                         pushVectorEdge(clauses, clause);
344                 }
345                 
346                 notfinished=false;
347                 for(uint i=0;i<numDomains; i++) {
348                         uint index=++indices[i];
349                         Set * set=getElementSet(getArrayElement(&func->inputs, i));
350
351                         if (index < getSetSize(set)) {
352                                 vals[i]=getSetElement(set, index);
353                                 notfinished=true;
354                                 break;
355                         } else {
356                                 indices[i]=0;
357                                 vals[i]=getSetElement(set, 0);
358                         }
359                 }
360         }
361
362         Edge cor=constraintOR(This->cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
363         deleteVectorEdge(clauses);
364         return cor;
365 }
366
367 Edge encodeEnumTableElemFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
368         //FIXME: HANDLE UNDEFINED BEHAVIORS
369         ASSERT(GETFUNCTIONTYPE(This->function)==TABLEFUNC);
370         ArrayElement* elements= &This->inputs;
371         Table* table = ((FunctionTable*) (This->function))->table;
372         uint size = getSizeVectorTableEntry(&table->entries);
373         Edge constraints[size]; //FIXME: should add a space for the case that didn't match any entries
374         for(uint i=0; i<size; i++) {
375                 TableEntry* entry = getVectorTableEntry(&table->entries, i);
376                 uint inputNum = getSizeArrayElement(elements);
377                 Edge carray[inputNum];
378                 for(uint j=0; j<inputNum; j++){
379                         Element* el= getArrayElement(elements, j);
380                         carray[j] = getElementValueConstraint(encoder, el, entry->inputs[j]);
381                 }
382                 Edge output = getElementValueConstraint(encoder, (Element*)This, entry->output);
383                 Edge row= constraintIMPLIES(encoder->cnf, constraintAND(encoder->cnf, inputNum, carray), output);
384                 constraints[i]=row;
385         }
386         return constraintOR(encoder->cnf, size, constraints);
387 }