Fixing ASSERT statement bugs
[satune.git] / src / Backend / satencoder.c
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "csolver.h"
4 #include "boolean.h"
5 #include "constraint.h"
6 #include "common.h"
7 #include "element.h"
8 #include "function.h"
9 #include "tableentry.h"
10 #include "table.h"
11 #include "order.h"
12 #include "predicate.h"
13 #include "orderpair.h"
14 #include "set.h"
15
16 SATEncoder * allocSATEncoder() {
17         SATEncoder *This=ourmalloc(sizeof (SATEncoder));
18         This->varcount=1;
19         return This;
20 }
21
22 void deleteSATEncoder(SATEncoder *This) {
23         ourfree(This);
24 }
25
26 Constraint * getElementValueConstraint(Element* This, uint64_t value) {
27         switch(getElementEncoding(This)->type){
28                 case ONEHOT:
29                         ASSERT(0);
30                         break;
31                 case UNARY:
32                         ASSERT(0);
33                         break;
34                 case BINARYINDEX:
35                         ASSERT(0);
36                         break;
37                 case ONEHOTBINARY:
38                         return getElementValueBinaryIndexConstraint(This, value);
39                         break;
40                 case BINARYVAL:
41                         ASSERT(0);
42                         break;
43                 default:
44                         ASSERT(0);
45                         break;
46         }
47         return NULL;
48 }
49 Constraint * getElementValueBinaryIndexConstraint(Element* This, uint64_t value) {
50         ASTNodeType type = GETELEMENTTYPE(This);
51         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
52         ElementEncoding* elemEnc = getElementEncoding(This);
53         for(uint i=0; i<elemEnc->encArraySize; i++){
54                 if( isinUseElement(elemEnc, i) && elemEnc->encodingArray[i]==value){
55                         return generateBinaryConstraint(elemEnc->numVars,
56                                 elemEnc->variables, i);
57                 }
58         }
59         return NULL;
60 }
61
62 void encodeAllSATEncoder(CSolver *csolver, SATEncoder * This) {
63         VectorBoolean *constraints=csolver->constraints;
64         uint size=getSizeVectorBoolean(constraints);
65         for(uint i=0;i<size;i++) {
66                 Boolean *constraint=getVectorBoolean(constraints, i);
67                 Constraint* c= encodeConstraintSATEncoder(This, constraint);
68                 printConstraint(c);
69         }
70 }
71
72 Constraint * encodeConstraintSATEncoder(SATEncoder *This, Boolean *constraint) {
73         switch(GETBOOLEANTYPE(constraint)) {
74         case ORDERCONST:
75                 return encodeOrderSATEncoder(This, (BooleanOrder *) constraint);
76         case BOOLEANVAR:
77                 return encodeVarSATEncoder(This, (BooleanVar *) constraint);
78         case LOGICOP:
79                 return encodeLogicSATEncoder(This, (BooleanLogic *) constraint);
80         case PREDICATEOP:
81                 return encodePredicateSATEncoder(This, (BooleanPredicate *) constraint);
82         default:
83                 model_print("Unhandled case in encodeConstraintSATEncoder %u", GETBOOLEANTYPE(constraint));
84                 exit(-1);
85         }
86 }
87
88 void getArrayNewVarsSATEncoder(SATEncoder* encoder, uint num, Constraint **carray) {
89         for(uint i=0;i<num;i++)
90                 carray[i]=getNewVarSATEncoder(encoder);
91 }
92
93 Constraint * getNewVarSATEncoder(SATEncoder *This) {
94         Constraint * var=allocVarConstraint(VAR, This->varcount);
95         Constraint * varneg=allocVarConstraint(NOTVAR, This->varcount++);
96         setNegConstraint(var, varneg);
97         setNegConstraint(varneg, var);
98         return var;
99 }
100
101 Constraint * encodeVarSATEncoder(SATEncoder *This, BooleanVar * constraint) {
102         if (constraint->var == NULL) {
103                 constraint->var=getNewVarSATEncoder(This);
104         }
105         return constraint->var;
106 }
107
108 Constraint * encodeLogicSATEncoder(SATEncoder *This, BooleanLogic * constraint) {
109         Constraint * array[getSizeArrayBoolean(&constraint->inputs)];
110         for(uint i=0;i<getSizeArrayBoolean(&constraint->inputs);i++)
111                 array[i]=encodeConstraintSATEncoder(This, getArrayBoolean(&constraint->inputs, i));
112
113         switch(constraint->op) {
114         case L_AND:
115                 return allocArrayConstraint(AND, getSizeArrayBoolean(&constraint->inputs), array);
116         case L_OR:
117                 return allocArrayConstraint(OR, getSizeArrayBoolean(&constraint->inputs), array);
118         case L_NOT:
119                 ASSERT( getSizeArrayBoolean(&constraint->inputs)==1);
120                 return negateConstraint(array[0]);
121         case L_XOR: {
122                 ASSERT( getSizeArrayBoolean(&constraint->inputs)==2);
123                 Constraint * nleft=negateConstraint(cloneConstraint(array[0]));
124                 Constraint * nright=negateConstraint(cloneConstraint(array[1]));
125                 return allocConstraint(OR,
126                                                                                                          allocConstraint(AND, array[0], nright),
127                                                                                                          allocConstraint(AND, nleft, array[1]));
128         }
129         case L_IMPLIES:
130                 ASSERT( getSizeArrayBoolean( &constraint->inputs)==2);
131                 return allocConstraint(IMPLIES, array[0], array[1]);
132         default:
133                 model_print("Unhandled case in encodeLogicSATEncoder %u", constraint->op);
134                 exit(-1);
135         }
136 }
137
138
139 Constraint * encodeOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint) {
140         switch( constraint->order->type){
141                 case PARTIAL:
142                         return encodePartialOrderSATEncoder(This, constraint);
143                 case TOTAL:
144                         return encodeTotalOrderSATEncoder(This, constraint);
145                 default:
146                         ASSERT(0);
147         }
148         return NULL;
149 }
150
151 Constraint * getPairConstraint(SATEncoder *This, HashTableBoolConst * table, OrderPair * pair) {
152         bool negate = false;
153         OrderPair flipped;
154         if (pair->first > pair->second) {
155                 negate=true;
156                 flipped.first=pair->second;
157                 flipped.second=pair->first;
158                 pair = &flipped;
159         }
160         Constraint * constraint;
161         if (!containsBoolConst(table, pair)) {
162                 constraint = getNewVarSATEncoder(This);
163                 OrderPair * paircopy = allocOrderPair(pair->first, pair->second);
164                 putBoolConst(table, paircopy, constraint);
165         } else
166                 constraint = getBoolConst(table, pair);
167         if (negate)
168                 return negateConstraint(constraint);
169         else
170                 return constraint;
171         
172 }
173
174 Constraint * encodeTotalOrderSATEncoder(SATEncoder *This, BooleanOrder * boolOrder){
175         ASSERT(boolOrder->order->type == TOTAL);
176         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
177         OrderPair pair={boolOrder->first, boolOrder->second};
178         Constraint* constraint = getPairConstraint(This, boolToConsts, & pair);
179         ASSERT(constraint != NULL);
180         return constraint;
181 }
182
183 void createAllTotalOrderConstraintsSATEncoder(SATEncoder* This, Order* order){
184         ASSERT(order->type == TOTAL);
185         VectorInt* mems = order->set->members;
186         HashTableBoolConst* table = order->boolsToConstraints;
187         uint size = getSizeVectorInt(mems);
188         for(uint i=0; i<size; i++){
189                 uint64_t valueI = getVectorInt(mems, i);
190                 for(uint j=i+1; j<size;j++){
191                         uint64_t valueJ = getVectorInt(mems, j);
192                         OrderPair pairIJ = {valueI, valueJ};
193                         Constraint* constIJ=getPairConstraint(This, table, & pairIJ);
194                         for(uint k=j+1; k<size; k++){
195                                 uint64_t valueK = getVectorInt(mems, k);
196                                 OrderPair pairJK = {valueJ, valueK};
197                                 OrderPair pairIK = {valueI, valueK};
198                                 Constraint* constIK = getPairConstraint(This, table, & pairIK);
199                                 Constraint* constJK = getPairConstraint(This, table, & pairJK);
200                                 generateTransOrderConstraintSATEncoder(This, constIJ, constJK, constIK); 
201                         }
202                 }
203         }
204 }
205
206 Constraint* getOrderConstraint(HashTableBoolConst *table, OrderPair *pair){
207         ASSERT(pair->first!= pair->second);
208         Constraint* constraint= getBoolConst(table, pair);
209         ASSERT(constraint!= NULL);
210         if(pair->first > pair->second)
211                 return constraint;
212         else
213                 return negateConstraint(constraint);
214 }
215
216 Constraint * generateTransOrderConstraintSATEncoder(SATEncoder *This, Constraint *constIJ,Constraint *constJK,Constraint *constIK){
217         //FIXME: first we should add the the constraint to the satsolver!
218         ASSERT(constIJ!= NULL && constJK != NULL && constIK != NULL);
219         Constraint *carray[] = {constIJ, constJK, negateConstraint(constIK)};
220         Constraint * loop1= allocArrayConstraint(OR, 3, carray);
221         Constraint * carray2[] = {negateConstraint(constIJ), negateConstraint(constJK), constIK};
222         Constraint * loop2= allocArrayConstraint(OR, 3,carray2 );
223         return allocConstraint(AND, loop1, loop2);
224 }
225
226 Constraint * encodePartialOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint){
227         // FIXME: we can have this implementation for partial order. Basically,
228         // we compute the transitivity between two order constraints specified by the client! (also can be used
229         // when client specify sparse constraints for the total order!)
230         ASSERT(constraint->order->type == PARTIAL);
231 /*
232         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
233         if( containsBoolConst(boolToConsts, boolOrder) ){
234                 return getBoolConst(boolToConsts, boolOrder);
235         } else {
236                 Constraint* constraint = getNewVarSATEncoder(This); 
237                 putBoolConst(boolToConsts,boolOrder, constraint);
238                 VectorBoolean* orderConstrs = &boolOrder->order->constraints;
239                 uint size= getSizeVectorBoolean(orderConstrs);
240                 for(uint i=0; i<size; i++){
241                         ASSERT(GETBOOLEANTYPE( getVectorBoolean(orderConstrs, i)) == ORDERCONST );
242                         BooleanOrder* tmp = (BooleanOrder*)getVectorBoolean(orderConstrs, i);
243                         BooleanOrder* newBool;
244                         Constraint* first, *second;
245                         if(tmp->second==boolOrder->first){
246                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,tmp->first,boolOrder->second);
247                                 first = encodeTotalOrderSATEncoder(This, tmp);
248                                 second = constraint;
249                                 
250                         }else if (boolOrder->second == tmp->first){
251                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,boolOrder->first,tmp->second);
252                                 first = constraint;
253                                 second = encodeTotalOrderSATEncoder(This, tmp);
254                         }else
255                                 continue;
256                         Constraint* transConstr= encodeTotalOrderSATEncoder(This, newBool);
257                         generateTransOrderConstraintSATEncoder(This, first, second, transConstr );
258                 }
259                 return constraint;
260         }
261 */      
262         return NULL;
263 }
264
265 Constraint * encodePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint) {
266         switch(GETPREDICATETYPE(constraint->predicate) ){
267                 case TABLEPRED:
268                         return encodeTablePredicateSATEncoder(This, constraint);
269                 case OPERATORPRED:
270                         return encodeOperatorPredicateSATEncoder(This, constraint);
271                 default:
272                         ASSERT(0);
273         }
274         return NULL;
275 }
276
277 Constraint * encodeTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
278         switch(constraint->encoding.type){
279                 case ENUMERATEIMPLICATIONS:
280                 case ENUMERATEIMPLICATIONSNEGATE:
281                         return encodeEnumTablePredicateSATEncoder(This, constraint);
282                 case CIRCUIT:
283                         ASSERT(0);
284                         break;
285                 default:
286                         ASSERT(0);
287         }
288         return NULL;
289 }
290
291 Constraint * encodeEnumTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
292         VectorTableEntry* entries = &(((PredicateTable*)constraint->predicate)->table->entries);
293         FunctionEncodingType encType = constraint->encoding.type;
294         uint size = getSizeVectorTableEntry(entries);
295         Constraint* constraints[size];
296         for(uint i=0; i<size; i++){
297                 TableEntry* entry = getVectorTableEntry(entries, i);
298                 if(encType==ENUMERATEIMPLICATIONS && entry->output!= true)
299                         continue;
300                 else if(encType==ENUMERATEIMPLICATIONSNEGATE && entry->output !=false)
301                         continue;
302                 ArrayElement* inputs = &constraint->inputs;
303                 uint inputNum =getSizeArrayElement(inputs);
304                 Constraint* carray[inputNum];
305                 for(uint j=0; j<inputNum; j++){
306                         Element* el = getArrayElement(inputs, j);
307                         if( GETELEMENTTYPE(el) == ELEMFUNCRETURN)
308                                 encodeFunctionElementSATEncoder(This, (ElementFunction*) el);
309                         carray[j] = getElementValueConstraint(el, entry->inputs[j]);
310                         ASSERT(carray[j]!= NULL);
311                 }
312                 constraints[i]=allocArrayConstraint(AND, inputNum, carray);
313         }
314         Constraint* result= allocArrayConstraint(OR, size, constraints);
315         //FIXME: if it didn't match with any entry
316         return encType==ENUMERATEIMPLICATIONS? result: negateConstraint(result);
317 }
318
319 Constraint * encodeOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
320         switch(constraint->encoding.type){
321                 case ENUMERATEIMPLICATIONS:
322                         return encodeEnumOperatorPredicateSATEncoder(This, constraint);
323                 case CIRCUIT:
324                         ASSERT(0);
325                         break;
326                 default:
327                         ASSERT(0);
328         }
329         return NULL;
330 }
331
332 Constraint * encodeEnumOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
333         ASSERT(GETPREDICATETYPE(constraint)==OPERATORPRED);
334         PredicateOperator* predicate = (PredicateOperator*)constraint->predicate;
335         ASSERT(predicate->op == EQUALS); //For now, we just only support equals
336         //getting maximum size of in common elements between two sets!
337         uint size=getSizeVectorInt( getArraySet( &predicate->domains, 0)->members);
338         uint64_t commonElements [size];
339         getEqualitySetIntersection(predicate, &size, commonElements);
340         Constraint*  carray[size];
341         Element* elem1 = getArrayElement( &constraint->inputs, 0);
342         if( GETELEMENTTYPE(elem1) == ELEMFUNCRETURN)
343                 encodeFunctionElementSATEncoder(This, (ElementFunction*) elem1);
344         Element* elem2 = getArrayElement( &constraint->inputs, 1);
345         if( GETELEMENTTYPE(elem2) == ELEMFUNCRETURN)
346                 encodeFunctionElementSATEncoder(This, (ElementFunction*) elem2);
347         for(uint i=0; i<size; i++){
348                 Constraint* arg1 = getElementValueConstraint(elem1, commonElements[i]);
349                 ASSERT(arg1!=NULL);
350                 Constraint* arg2 = getElementValueConstraint(elem2, commonElements[i]);
351                 ASSERT(arg2 != NULL);
352                 carray[i] =  allocConstraint(AND, arg1, arg2);
353         }
354         //FIXME: the case when there is no intersection ....
355         return allocArrayConstraint(OR, size, carray);
356 }
357
358 Constraint* encodeFunctionElementSATEncoder(SATEncoder* encoder, ElementFunction *This){
359         switch(GETFUNCTIONTYPE(This->function)){
360                 case TABLEFUNC:
361                         return encodeTableElementFunctionSATEncoder(encoder, This);
362                 case OPERATORFUNC:
363                         return encodeOperatorElementFunctionSATEncoder(encoder, This);
364                 default:
365                         ASSERT(0);
366         }
367         return NULL;
368 }
369
370 Constraint* encodeTableElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
371         switch(getElementFunctionEncoding(This)->type){
372                 case ENUMERATEIMPLICATIONS:
373                         return encodeEnumTableElemFunctionSATEncoder(encoder, This);
374                         break;
375                 case CIRCUIT:
376                         ASSERT(0);
377                         break;
378                 default:
379                         ASSERT(0);
380         }
381         return NULL;
382 }
383
384 Constraint* encodeOperatorElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
385         ASSERT(GETFUNCTIONTYPE(This->function) == OPERATORFUNC);
386         ASSERT(getSizeArrayElement(&This->inputs)==2 );
387         ElementEncoding* elem1 = getElementEncoding( getArrayElement(&This->inputs,0) );
388         ElementEncoding* elem2 = getElementEncoding( getArrayElement(&This->inputs,1) );
389         Constraint* carray[elem1->encArraySize*elem2->encArraySize];
390         uint size=0;
391         Constraint* overFlowConstraint = ((BooleanVar*) This->overflowstatus)->var;
392         for(uint i=0; i<elem1->encArraySize; i++){
393                 if(isinUseElement(elem1, i)){
394                         for(uint j=0; j<elem2->encArraySize; j++){
395                                 if(isinUseElement(elem2, j)){
396                                         bool isInRange = false, hasOverFlow=false;
397                                         uint64_t result= applyFunctionOperator((FunctionOperator*)This->function,elem1->encodingArray[i],
398                                                 elem2->encodingArray[j], &isInRange, &hasOverFlow);
399                                         if(!isInRange)
400                                                 break; // Ignoring the cases that result of operation doesn't exist in the code.
401                                         //FIXME: instead of getElementValueConstraint, it might be useful to have another function
402                                         // that doesn't iterate over encodingArray and treats more efficient ...
403                                         Constraint* and1 = getElementValueConstraint(elem1->element, elem1->encodingArray[i]);
404                                         ASSERT(and1 != NULL);
405                                         Constraint* and2 = getElementValueConstraint(elem2->element, elem2->encodingArray[j]);
406                                         ASSERT(and2 != NULL);
407                                         Constraint* imply2 = getElementValueConstraint((Element*) This, result);
408                                         ASSERT(imply2 != NULL);
409                                         Constraint* constraint = allocConstraint(IMPLIES, 
410                                                 allocConstraint(AND, and1, and2) , imply2);
411                                         switch( ((FunctionOperator*)This->function)->overflowbehavior ){
412                                                 case IGNORE:
413                                                         if(!hasOverFlow){
414                                                                 carray[size++] = constraint;
415                                                         }
416                                                 case WRAPAROUND:
417                                                         carray[size++] = constraint;
418                                                 case FLAGFORCESOVERFLOW:
419                                                         if(hasOverFlow){
420                                                                 Constraint* const1 = allocConstraint(IMPLIES, overFlowConstraint, 
421                                                                                 allocConstraint(AND, getElementValueConstraint(elem1->element, elem1->encodingArray[i]),
422                                                                                 getElementValueConstraint(elem2->element, elem2->encodingArray[j])));
423                                                                 carray[size++] = allocConstraint(AND, const1, constraint);
424                                                         }
425                                                 case OVERFLOWSETSFLAG:
426                                                         if(hasOverFlow){
427                                                                 Constraint* const1 = allocConstraint(IMPLIES, allocConstraint(AND, getElementValueConstraint(elem1->element, elem1->encodingArray[i]),
428                                                                                 getElementValueConstraint(elem2->element, elem2->encodingArray[j])), overFlowConstraint);
429                                                                 carray[size++] = allocConstraint(AND, const1, constraint);
430                                                         } else
431                                                                 carray[size++] = constraint;
432                                                 case FLAGIFFOVERFLOW:
433                                                         if(!hasOverFlow){
434                                                                 carray[size++] = constraint;
435                                                         }else{
436                                                                 Constraint* const1 = allocConstraint(IMPLIES, overFlowConstraint, 
437                                                                                 allocConstraint(AND, getElementValueConstraint(elem1->element, elem1->encodingArray[i]),
438                                                                                 getElementValueConstraint(elem2->element, elem2->encodingArray[j])));
439                                                                 Constraint* const2 = allocConstraint(IMPLIES, allocConstraint(AND, getElementValueConstraint(elem1->element, elem1->encodingArray[i]),
440                                                                                 getElementValueConstraint(elem2->element, elem2->encodingArray[j])), overFlowConstraint);
441                                                                 Constraint* res [] = {const1, const2, constraint};
442                                                                 carray[size++] = allocArrayConstraint(AND, 3, res);
443                                                         }
444                                                 case NOOVERFLOW:
445                                                         if(hasOverFlow)
446                                                                 ASSERT(0);
447                                                                 carray[size++] = constraint;
448                                                 default:
449                                                         ASSERT(0);
450                                         }
451                                         
452                                 }
453                         }
454                 }
455         }
456         return allocArrayConstraint(AND, size, carray);
457 }
458
459 Constraint* encodeEnumTableElemFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
460         ASSERT(GETFUNCTIONTYPE(This->function)==TABLEFUNC);
461         ArrayElement* elements= &This->inputs;
462         Table* table = ((FunctionTable*) (This->function))->table;
463         uint size = getSizeVectorTableEntry(&table->entries);
464         Constraint* constraints[size]; //FIXME: should add a space for the case that didn't match any entries
465         for(uint i=0; i<size; i++){
466                 TableEntry* entry = getVectorTableEntry(&table->entries, i);
467                 uint inputNum =getSizeArrayElement(elements);
468                 Constraint* carray[inputNum];
469                 for(uint j=0; j<inputNum; j++){
470                         Element* el= getArrayElement(elements, j);
471                         carray[j] = getElementValueConstraint(el, entry->inputs[j]);
472                         ASSERT(carray[j]!= NULL);
473                 }
474                 Constraint* output = getElementValueConstraint((Element*)This, entry->output);
475                 ASSERT(output!= NULL);
476                 Constraint* row= allocConstraint(IMPLIES, allocArrayConstraint(AND, inputNum, carray), output);
477                 constraints[i]=row;
478         }
479         Constraint* result = allocArrayConstraint(OR, size, constraints);
480         return result;
481 }