Merge branch 'master' into brian
[satune.git] / src / Backend / satencoder.c
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "csolver.h"
4 #include "boolean.h"
5 #include "constraint.h"
6 #include "common.h"
7 #include "element.h"
8 #include "function.h"
9 #include "tableentry.h"
10 #include "table.h"
11 #include "order.h"
12 #include "predicate.h"
13 #include "orderpair.h"
14 #include "set.h"
15
16 SATEncoder * allocSATEncoder() {
17         SATEncoder *This=ourmalloc(sizeof (SATEncoder));
18         This->varcount=1;
19         This->cnf=createCNF();
20         return This;
21 }
22
23 void deleteSATEncoder(SATEncoder *This) {
24         deleteCNF(This->cnf);
25         ourfree(This);
26 }
27
28 Edge getElementValueConstraint(SATEncoder* This, Element* elem, uint64_t value) {
29         generateElementEncodingVariables(This, getElementEncoding(elem));
30         switch(getElementEncoding(elem)->type){
31                 case ONEHOT:
32                         //FIXME
33                         ASSERT(0);
34                         break;
35                 case UNARY:
36                         ASSERT(0);
37                         break;
38                 case BINARYINDEX:
39                         return getElementValueBinaryIndexConstraint(This, elem, value);
40                         break;
41                 case ONEHOTBINARY:
42                         ASSERT(0);
43                         break;
44                 case BINARYVAL:
45                         ASSERT(0);
46                         break;
47                 default:
48                         ASSERT(0);
49                         break;
50         }
51         return E_BOGUS;
52 }
53
54 Edge getElementValueBinaryIndexConstraint(SATEncoder * This, Element* elem, uint64_t value) {
55         ASTNodeType type = GETELEMENTTYPE(elem);
56         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
57         ElementEncoding* elemEnc = getElementEncoding(elem);
58         for(uint i=0; i<elemEnc->encArraySize; i++){
59                 if( isinUseElement(elemEnc, i) && elemEnc->encodingArray[i]==value){
60                         return generateBinaryConstraint(This->cnf, elemEnc->numVars, elemEnc->variables, i);
61                 }
62         }
63         return E_BOGUS;
64 }
65
66 void encodeAllSATEncoder(CSolver *csolver, SATEncoder * This) {
67         VectorBoolean *constraints=csolver->constraints;
68         uint size=getSizeVectorBoolean(constraints);
69         for(uint i=0;i<size;i++) {
70                 Boolean *constraint=getVectorBoolean(constraints, i);
71                 Edge c= encodeConstraintSATEncoder(This, constraint);
72                 printCNF(c);
73                 printf("\n");
74                 addConstraint(This->cnf, c);
75         }
76 }
77
78 Edge encodeConstraintSATEncoder(SATEncoder *This, Boolean *constraint) {
79         switch(GETBOOLEANTYPE(constraint)) {
80         case ORDERCONST:
81                 return encodeOrderSATEncoder(This, (BooleanOrder *) constraint);
82         case BOOLEANVAR:
83                 return encodeVarSATEncoder(This, (BooleanVar *) constraint);
84         case LOGICOP:
85                 return encodeLogicSATEncoder(This, (BooleanLogic *) constraint);
86         case PREDICATEOP:
87                 return encodePredicateSATEncoder(This, (BooleanPredicate *) constraint);
88         default:
89                 model_print("Unhandled case in encodeConstraintSATEncoder %u", GETBOOLEANTYPE(constraint));
90                 exit(-1);
91         }
92 }
93
94 void getArrayNewVarsSATEncoder(SATEncoder* encoder, uint num, Edge * carray) {
95         for(uint i=0;i<num;i++)
96                 carray[i]=getNewVarSATEncoder(encoder);
97 }
98
99 Edge getNewVarSATEncoder(SATEncoder *This) {
100         return constraintNewVar(This->cnf);
101 }
102
103 Edge encodeVarSATEncoder(SATEncoder *This, BooleanVar * constraint) {
104         if (edgeIsNull(constraint->var)) {
105                 constraint->var=getNewVarSATEncoder(This);
106         }
107         return constraint->var;
108 }
109
110 Edge encodeLogicSATEncoder(SATEncoder *This, BooleanLogic * constraint) {
111         Edge array[getSizeArrayBoolean(&constraint->inputs)];
112         for(uint i=0;i<getSizeArrayBoolean(&constraint->inputs);i++)
113                 array[i]=encodeConstraintSATEncoder(This, getArrayBoolean(&constraint->inputs, i));
114
115         switch(constraint->op) {
116         case L_AND:
117                 return constraintAND(This->cnf, getSizeArrayBoolean(&constraint->inputs), array);
118         case L_OR:
119                 return constraintOR(This->cnf, getSizeArrayBoolean(&constraint->inputs), array);
120         case L_NOT:
121                 ASSERT( getSizeArrayBoolean(&constraint->inputs)==1);
122                 return constraintNegate(array[0]);
123         case L_XOR:
124                 ASSERT( getSizeArrayBoolean(&constraint->inputs)==2);
125                 return constraintXOR(This->cnf, array[0], array[1]);
126         case L_IMPLIES:
127                 ASSERT( getSizeArrayBoolean( &constraint->inputs)==2);
128                 return constraintIMPLIES(This->cnf, array[0], array[1]);
129         default:
130                 model_print("Unhandled case in encodeLogicSATEncoder %u", constraint->op);
131                 exit(-1);
132         }
133 }
134
135
136 Edge encodeOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint) {
137         switch( constraint->order->type){
138                 case PARTIAL:
139                         return encodePartialOrderSATEncoder(This, constraint);
140                 case TOTAL:
141                         return encodeTotalOrderSATEncoder(This, constraint);
142                 default:
143                         ASSERT(0);
144         }
145         return E_BOGUS;
146 }
147
148 Edge getPairConstraint(SATEncoder *This, HashTableBoolConst * table, OrderPair * pair) {
149         bool negate = false;
150         OrderPair flipped;
151         if (pair->first > pair->second) {
152                 negate=true;
153                 flipped.first=pair->second;
154                 flipped.second=pair->first;
155                 pair = &flipped;        //FIXME: accessing a local variable from outside of the function?
156         }
157         Edge constraint;
158         if (!containsBoolConst(table, pair)) {
159                 constraint = getNewVarSATEncoder(This);
160                 OrderPair * paircopy = allocOrderPair(pair->first, pair->second, constraint);
161                 putBoolConst(table, paircopy, paircopy);
162         } else
163                 constraint = getBoolConst(table, pair)->constraint;
164         if (negate)
165                 return constraintNegate(constraint);
166         else
167                 return constraint;
168         
169 }
170
171 Edge encodeTotalOrderSATEncoder(SATEncoder *This, BooleanOrder * boolOrder){
172         ASSERT(boolOrder->order->type == TOTAL);
173         if(boolOrder->order->boolsToConstraints == NULL){
174                 initializeOrderHashTable(boolOrder->order);
175                 createAllTotalOrderConstraintsSATEncoder(This, boolOrder->order);
176         }
177         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
178         OrderPair pair={boolOrder->first, boolOrder->second, E_NULL};
179         Edge constraint = getPairConstraint(This, boolToConsts, & pair);
180         return constraint;
181 }
182
183 void createAllTotalOrderConstraintsSATEncoder(SATEncoder* This, Order* order){
184         ASSERT(order->type == TOTAL);
185         VectorInt* mems = order->set->members;
186         HashTableBoolConst* table = order->boolsToConstraints;
187         uint size = getSizeVectorInt(mems);
188         uint csize =0;
189         for(uint i=0; i<size; i++){
190                 uint64_t valueI = getVectorInt(mems, i);
191                 for(uint j=i+1; j<size;j++){
192                         uint64_t valueJ = getVectorInt(mems, j);
193                         OrderPair pairIJ = {valueI, valueJ};
194                         Edge constIJ=getPairConstraint(This, table, & pairIJ);
195                         for(uint k=j+1; k<size; k++){
196                                 uint64_t valueK = getVectorInt(mems, k);
197                                 OrderPair pairJK = {valueJ, valueK};
198                                 OrderPair pairIK = {valueI, valueK};
199                                 Edge constIK = getPairConstraint(This, table, & pairIK);
200                                 Edge constJK = getPairConstraint(This, table, & pairJK);
201                                 addConstraint(This->cnf, generateTransOrderConstraintSATEncoder(This, constIJ, constJK, constIK)); 
202                         }
203                 }
204         }
205 }
206
207 Edge getOrderConstraint(HashTableBoolConst *table, OrderPair *pair){
208         ASSERT(pair->first!= pair->second);
209         Edge constraint = getBoolConst(table, pair)->constraint;
210         if(pair->first > pair->second)
211                 return constraint;
212         else
213                 return constraintNegate(constraint);
214 }
215
216 Edge generateTransOrderConstraintSATEncoder(SATEncoder *This, Edge constIJ,Edge constJK,Edge constIK){
217         Edge carray[] = {constIJ, constJK, constraintNegate(constIK)};
218         Edge loop1= constraintOR(This->cnf, 3, carray);
219         Edge carray2[] = {constraintNegate(constIJ), constraintNegate(constJK), constIK};
220         Edge loop2= constraintOR(This->cnf, 3, carray2 );
221         return constraintAND2(This->cnf, loop1, loop2);
222 }
223
224 Edge encodePartialOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint){
225         // FIXME: we can have this implementation for partial order. Basically,
226         // we compute the transitivity between two order constraints specified by the client! (also can be used
227         // when client specify sparse constraints for the total order!)
228         ASSERT(constraint->order->type == PARTIAL);
229 /*
230         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
231         if( containsBoolConst(boolToConsts, boolOrder) ){
232                 return getBoolConst(boolToConsts, boolOrder);
233         } else {
234                 Edge constraint = getNewVarSATEncoder(This); 
235                 putBoolConst(boolToConsts,boolOrder, constraint);
236                 VectorBoolean* orderConstrs = &boolOrder->order->constraints;
237                 uint size= getSizeVectorBoolean(orderConstrs);
238                 for(uint i=0; i<size; i++){
239                         ASSERT(GETBOOLEANTYPE( getVectorBoolean(orderConstrs, i)) == ORDERCONST );
240                         BooleanOrder* tmp = (BooleanOrder*)getVectorBoolean(orderConstrs, i);
241                         BooleanOrder* newBool;
242                         Edge first, second;
243                         if(tmp->second==boolOrder->first){
244                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,tmp->first,boolOrder->second);
245                                 first = encodeTotalOrderSATEncoder(This, tmp);
246                                 second = constraint;
247                                 
248                         }else if (boolOrder->second == tmp->first){
249                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,boolOrder->first,tmp->second);
250                                 first = constraint;
251                                 second = encodeTotalOrderSATEncoder(This, tmp);
252                         }else
253                                 continue;
254                         Edge transConstr= encodeTotalOrderSATEncoder(This, newBool);
255                         generateTransOrderConstraintSATEncoder(This, first, second, transConstr );
256                 }
257                 return constraint;
258         }
259 */      
260         return E_BOGUS;
261 }
262
263 Edge encodePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint) {
264         switch(GETPREDICATETYPE(constraint->predicate) ){
265                 case TABLEPRED:
266                         return encodeTablePredicateSATEncoder(This, constraint);
267                 case OPERATORPRED:
268                         return encodeOperatorPredicateSATEncoder(This, constraint);
269                 default:
270                         ASSERT(0);
271         }
272         return E_BOGUS;
273 }
274
275 Edge encodeTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
276         switch(constraint->encoding.type){
277                 case ENUMERATEIMPLICATIONS:
278                 case ENUMERATEIMPLICATIONSNEGATE:
279                         return encodeEnumTablePredicateSATEncoder(This, constraint);
280                 case CIRCUIT:
281                         ASSERT(0);
282                         break;
283                 default:
284                         ASSERT(0);
285         }
286         return E_BOGUS;
287 }
288
289 Edge encodeEnumTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
290         VectorTableEntry* entries = &(((PredicateTable*)constraint->predicate)->table->entries);
291         FunctionEncodingType encType = constraint->encoding.type;
292         uint size = getSizeVectorTableEntry(entries);
293         Edge constraints[size];
294         for(uint i=0; i<size; i++){
295                 TableEntry* entry = getVectorTableEntry(entries, i);
296                 if(encType==ENUMERATEIMPLICATIONS && entry->output!= true)
297                         continue;
298                 else if(encType==ENUMERATEIMPLICATIONSNEGATE && entry->output !=false)
299                         continue;
300                 ArrayElement* inputs = &constraint->inputs;
301                 uint inputNum =getSizeArrayElement(inputs);
302                 Edge carray[inputNum];
303                 for(uint j=0; j<inputNum; j++){
304                         Element* el = getArrayElement(inputs, j);
305                         Edge tmpc = getElementValueConstraint(This, el, entry->inputs[j]);
306                         if( GETELEMENTTYPE(el) == ELEMFUNCRETURN){
307                                 Edge func =encodeFunctionElementSATEncoder(This, (ElementFunction*) el);
308                                 carray[j] = constraintAND2(This->cnf, func, tmpc);
309                         } else {
310                                 carray[j] = tmpc;
311                         }
312                 }
313                 constraints[i]=constraintAND(This->cnf, inputNum, carray);
314         }
315         Edge result=constraintOR(This->cnf, size, constraints);
316         //FIXME: if it didn't match with any entry
317         return encType==ENUMERATEIMPLICATIONS? result: constraintNegate(result);
318 }
319
320 Edge encodeOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
321         switch(constraint->encoding.type){
322                 case ENUMERATEIMPLICATIONS:
323                         return encodeEnumOperatorPredicateSATEncoder(This, constraint);
324                 case CIRCUIT:
325                         ASSERT(0);
326                         break;
327                 default:
328                         ASSERT(0);
329         }
330         return E_BOGUS;
331 }
332
333 Edge encodeEnumOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
334         ASSERT(GETPREDICATETYPE(constraint->predicate)==OPERATORPRED);
335         PredicateOperator* predicate = (PredicateOperator*)constraint->predicate;
336         ASSERT(predicate->op == EQUALS); //For now, we just only support equals
337         //getting maximum size of in common elements between two sets!
338         uint size=getSizeVectorInt( getArraySet( &predicate->domains, 0)->members);
339         uint64_t commonElements [size];
340         getEqualitySetIntersection(predicate, &size, commonElements);
341         Edge  carray[size];
342         Element* elem1 = getArrayElement( &constraint->inputs, 0);
343         Edge elemc1 = E_NULL, elemc2 = E_NULL;
344         if( GETELEMENTTYPE(elem1) == ELEMFUNCRETURN)
345                 elemc1 = encodeFunctionElementSATEncoder(This, (ElementFunction*) elem1);
346         Element* elem2 = getArrayElement( &constraint->inputs, 1);
347         if( GETELEMENTTYPE(elem2) == ELEMFUNCRETURN)
348                 elemc2 = encodeFunctionElementSATEncoder(This, (ElementFunction*) elem2);
349         for(uint i=0; i<size; i++){
350                 Edge arg1 = getElementValueConstraint(This, elem1, commonElements[i]);
351                 Edge arg2 = getElementValueConstraint(This, elem2, commonElements[i]);
352                 carray[i] =  constraintAND2(This->cnf, arg1, arg2);
353         }
354         //FIXME: the case when there is no intersection ....
355         Edge result = constraintOR(This->cnf, size, carray);
356         if (!edgeIsNull(elemc1))
357                 result = constraintAND2(This->cnf, result, elemc1);
358         if (!edgeIsNull(elemc2))
359                 result = constraintAND2(This->cnf, result, elemc2);
360         return result;
361 }
362
363 Edge encodeFunctionElementSATEncoder(SATEncoder* encoder, ElementFunction *This){
364         switch(GETFUNCTIONTYPE(This->function)){
365                 case TABLEFUNC:
366                         return encodeTableElementFunctionSATEncoder(encoder, This);
367                 case OPERATORFUNC:
368                         return encodeOperatorElementFunctionSATEncoder(encoder, This);
369                 default:
370                         ASSERT(0);
371         }
372         return E_BOGUS;
373 }
374
375 Edge encodeTableElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
376         switch(getElementFunctionEncoding(This)->type){
377                 case ENUMERATEIMPLICATIONS:
378                         return encodeEnumTableElemFunctionSATEncoder(encoder, This);
379                         break;
380                 case CIRCUIT:
381                         ASSERT(0);
382                         break;
383                 default:
384                         ASSERT(0);
385         }
386         return E_BOGUS;
387 }
388
389 Edge encodeOperatorElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
390         ASSERT(GETFUNCTIONTYPE(This->function) == OPERATORFUNC);
391         ASSERT(getSizeArrayElement(&This->inputs)==2 );
392         ElementEncoding* elem1 = getElementEncoding( getArrayElement(&This->inputs,0) );
393         ElementEncoding* elem2 = getElementEncoding( getArrayElement(&This->inputs,1) );
394         Edge carray[elem1->encArraySize*elem2->encArraySize];
395         uint size=0;
396         Edge overFlowConstraint = ((BooleanVar*) This->overflowstatus)->var;
397         for(uint i=0; i<elem1->encArraySize; i++){
398                 if(isinUseElement(elem1, i)){
399                         for( uint j=0; j<elem2->encArraySize; j++){
400                                 if(isinUseElement(elem2, j)){
401                                         bool isInRange = false;
402                                         uint64_t result= applyFunctionOperator((FunctionOperator*)This->function,elem1->encodingArray[i],
403                                                 elem2->encodingArray[j], &isInRange);
404                                         //FIXME: instead of getElementValueConstraint, it might be useful to have another function
405                                         // that doesn't iterate over encodingArray and treats more efficient ...
406                                         Edge valConstrIn1 = getElementValueConstraint(encoder, elem1->element, elem1->encodingArray[i]);
407                                         Edge valConstrIn2 = getElementValueConstraint(encoder, elem2->element, elem2->encodingArray[j]);
408                                         Edge valConstrOut = getElementValueConstraint(encoder, (Element*) This, result);
409                                         if(edgeIsNull(valConstrOut))
410                                                 continue; //FIXME:Should talk to brian about it!
411                                         Edge OpConstraint = constraintIMPLIES(encoder->cnf, constraintAND2(encoder->cnf, valConstrIn1, valConstrIn2), valConstrOut);
412                                         switch( ((FunctionOperator*)This->function)->overflowbehavior ){
413                                                 case IGNORE:
414                                                         if(isInRange){
415                                                                 carray[size++] = OpConstraint;
416                                                         }
417                                                         break;
418                                                 case WRAPAROUND:
419                                                         carray[size++] = OpConstraint;
420                                                         break;
421                                                 case FLAGFORCESOVERFLOW:
422                                                         if(isInRange){
423                                                                 Edge const1 = constraintIMPLIES(encoder->cnf, constraintAND2(encoder->cnf, valConstrIn1, valConstrIn2), constraintNegate(overFlowConstraint));
424                                                                 carray[size++] = constraintAND2(encoder->cnf, const1, OpConstraint);
425                                                         }
426                                                         break;
427                                                 case OVERFLOWSETSFLAG:
428                                                         if(isInRange){
429                                                                 carray[size++] = OpConstraint;
430                                                         } else{
431                                                                 carray[size++] = constraintIMPLIES(encoder->cnf, constraintAND2(encoder->cnf, valConstrIn1, valConstrIn2), overFlowConstraint);
432                                                         }
433                                                         break;
434                                                 case FLAGIFFOVERFLOW:
435                                                         if(isInRange){
436                                                                 Edge const1 = constraintIMPLIES(encoder->cnf, constraintAND2(encoder->cnf, valConstrIn1, valConstrIn2), constraintNegate(overFlowConstraint));
437                                                                 carray[size++] = constraintAND2(encoder->cnf, const1, OpConstraint);
438                                                         } else {
439                                                                 carray[size++] = constraintIMPLIES(encoder->cnf, constraintAND2(encoder->cnf, valConstrIn1, valConstrIn2), overFlowConstraint);
440                                                         }
441                                                         break;
442                                                 case NOOVERFLOW:
443                                                         if(!isInRange){
444                                                                 ASSERT(0);
445                                                         }
446                                                         carray[size++] = OpConstraint;
447                                                         break;
448                                                 default:
449                                                         ASSERT(0);
450                                         }
451                                         
452                                 }
453                         }
454                 }
455         }
456         return constraintAND(encoder->cnf, size, carray);
457 }
458
459 Edge encodeEnumTableElemFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
460         ASSERT(GETFUNCTIONTYPE(This->function)==TABLEFUNC);
461         ArrayElement* elements= &This->inputs;
462         Table* table = ((FunctionTable*) (This->function))->table;
463         uint size = getSizeVectorTableEntry(&table->entries);
464         Edge constraints[size]; //FIXME: should add a space for the case that didn't match any entries
465         for(uint i=0; i<size; i++) {
466                 TableEntry* entry = getVectorTableEntry(&table->entries, i);
467                 uint inputNum = getSizeArrayElement(elements);
468                 Edge carray[inputNum];
469                 for(uint j=0; j<inputNum; j++){
470                         Element* el= getArrayElement(elements, j);
471                         carray[j] = getElementValueConstraint(encoder, el, entry->inputs[j]);
472                 }
473                 Edge output = getElementValueConstraint(encoder, (Element*)This, entry->output);
474                 Edge row= constraintIMPLIES(encoder->cnf, constraintAND(encoder->cnf, inputNum, carray), output);
475                 constraints[i]=row;
476         }
477         Edge result = constraintOR(encoder->cnf, size, constraints);
478         return result;
479 }
480