e78cb3f634764c3d65fd68b971b313f087745c61
[satune.git] / src / Backend / satencoder.c
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "csolver.h"
4 #include "boolean.h"
5 #include "constraint.h"
6 #include "common.h"
7 #include "element.h"
8 #include "function.h"
9 #include "tableentry.h"
10 #include "table.h"
11 #include "order.h"
12 #include "predicate.h"
13 #include "orderpair.h"
14 #include "set.h"
15
16 SATEncoder * allocSATEncoder() {
17         SATEncoder *This=ourmalloc(sizeof (SATEncoder));
18         This->varcount=1;
19         return This;
20 }
21
22 void deleteSATEncoder(SATEncoder *This) {
23         ourfree(This);
24 }
25
26 Constraint * getElementValueConstraint(SATEncoder* encoder,Element* This, uint64_t value) {
27         generateElementEncodingVariables(encoder, getElementEncoding(This));
28         switch(getElementEncoding(This)->type){
29                 case ONEHOT:
30                         ASSERT(0);
31                         break;
32                 case UNARY:
33                         ASSERT(0);
34                         break;
35                 case BINARYINDEX:
36                         return getElementValueBinaryIndexConstraint(This, value);
37                         break;
38                 case ONEHOTBINARY:
39                         ASSERT(0);
40                         break;
41                 case BINARYVAL:
42                         ASSERT(0);
43                         break;
44                 default:
45                         ASSERT(0);
46                         break;
47         }
48         return NULL;
49 }
50 Constraint * getElementValueBinaryIndexConstraint(Element* This, uint64_t value) {
51         ASTNodeType type = GETELEMENTTYPE(This);
52         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
53         ElementEncoding* elemEnc = getElementEncoding(This);
54         for(uint i=0; i<elemEnc->encArraySize; i++){
55                 if( isinUseElement(elemEnc, i) && elemEnc->encodingArray[i]==value){
56                         return generateBinaryConstraint(elemEnc->numVars,
57                                 elemEnc->variables, i);
58                 }
59         }
60         return NULL;
61 }
62
63 void encodeAllSATEncoder(CSolver *csolver, SATEncoder * This) {
64         VectorBoolean *constraints=csolver->constraints;
65         uint size=getSizeVectorBoolean(constraints);
66         for(uint i=0;i<size;i++) {
67                 Boolean *constraint=getVectorBoolean(constraints, i);
68                 Constraint* c= encodeConstraintSATEncoder(This, constraint);
69                 printConstraint(c);
70                 model_print("\n\n");
71         }
72 }
73
74 Constraint * encodeConstraintSATEncoder(SATEncoder *This, Boolean *constraint) {
75         switch(GETBOOLEANTYPE(constraint)) {
76         case ORDERCONST:
77                 return encodeOrderSATEncoder(This, (BooleanOrder *) constraint);
78         case BOOLEANVAR:
79                 return encodeVarSATEncoder(This, (BooleanVar *) constraint);
80         case LOGICOP:
81                 return encodeLogicSATEncoder(This, (BooleanLogic *) constraint);
82         case PREDICATEOP:
83                 return encodePredicateSATEncoder(This, (BooleanPredicate *) constraint);
84         default:
85                 model_print("Unhandled case in encodeConstraintSATEncoder %u", GETBOOLEANTYPE(constraint));
86                 exit(-1);
87         }
88 }
89
90 void getArrayNewVarsSATEncoder(SATEncoder* encoder, uint num, Constraint **carray) {
91         for(uint i=0;i<num;i++)
92                 carray[i]=getNewVarSATEncoder(encoder);
93 }
94
95 Constraint * getNewVarSATEncoder(SATEncoder *This) {
96         Constraint * var=allocVarConstraint(VAR, This->varcount);
97         Constraint * varneg=allocVarConstraint(NOTVAR, This->varcount++);
98         setNegConstraint(var, varneg);
99         setNegConstraint(varneg, var);
100         return var;
101 }
102
103 Constraint * encodeVarSATEncoder(SATEncoder *This, BooleanVar * constraint) {
104         if (constraint->var == NULL) {
105                 constraint->var=getNewVarSATEncoder(This);
106         }
107         return constraint->var;
108 }
109
110 Constraint * encodeLogicSATEncoder(SATEncoder *This, BooleanLogic * constraint) {
111         Constraint * array[getSizeArrayBoolean(&constraint->inputs)];
112         for(uint i=0;i<getSizeArrayBoolean(&constraint->inputs);i++)
113                 array[i]=encodeConstraintSATEncoder(This, getArrayBoolean(&constraint->inputs, i));
114
115         switch(constraint->op) {
116         case L_AND:
117                 return allocArrayConstraint(AND, getSizeArrayBoolean(&constraint->inputs), array);
118         case L_OR:
119                 return allocArrayConstraint(OR, getSizeArrayBoolean(&constraint->inputs), array);
120         case L_NOT:
121                 ASSERT( getSizeArrayBoolean(&constraint->inputs)==1);
122                 return negateConstraint(array[0]);
123         case L_XOR: {
124                 ASSERT( getSizeArrayBoolean(&constraint->inputs)==2);
125                 Constraint * nleft=negateConstraint(cloneConstraint(array[0]));
126                 Constraint * nright=negateConstraint(cloneConstraint(array[1]));
127                 return allocConstraint(OR,
128                                                                                                          allocConstraint(AND, array[0], nright),
129                                                                                                          allocConstraint(AND, nleft, array[1]));
130         }
131         case L_IMPLIES:
132                 ASSERT( getSizeArrayBoolean( &constraint->inputs)==2);
133                 return allocConstraint(IMPLIES, array[0], array[1]);
134         default:
135                 model_print("Unhandled case in encodeLogicSATEncoder %u", constraint->op);
136                 exit(-1);
137         }
138 }
139
140
141 Constraint * encodeOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint) {
142         switch( constraint->order->type){
143                 case PARTIAL:
144                         return encodePartialOrderSATEncoder(This, constraint);
145                 case TOTAL:
146                         return encodeTotalOrderSATEncoder(This, constraint);
147                 default:
148                         ASSERT(0);
149         }
150         return NULL;
151 }
152
153 Constraint * getPairConstraint(SATEncoder *This, HashTableBoolConst * table, OrderPair * pair) {
154         bool negate = false;
155         OrderPair flipped;
156         if (pair->first > pair->second) {
157                 negate=true;
158                 flipped.first=pair->second;
159                 flipped.second=pair->first;
160                 pair = &flipped;        //FIXME: accessing a local variable from outside of the function?
161         }
162         Constraint * constraint;
163         if (!containsBoolConst(table, pair)) {
164                 constraint = getNewVarSATEncoder(This);
165                 OrderPair * paircopy = allocOrderPair(pair->first, pair->second);
166                 putBoolConst(table, paircopy, constraint);
167         } else
168                 constraint = getBoolConst(table, pair);
169         if (negate)
170                 return negateConstraint(constraint);
171         else
172                 return constraint;
173         
174 }
175
176 Constraint * encodeTotalOrderSATEncoder(SATEncoder *This, BooleanOrder * boolOrder){
177         ASSERT(boolOrder->order->type == TOTAL);
178         if(boolOrder->order->boolsToConstraints == NULL){
179                 initializeOrderHashTable(boolOrder->order);
180                 return createAllTotalOrderConstraintsSATEncoder(This, boolOrder->order);
181         }
182         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
183         OrderPair pair={boolOrder->first, boolOrder->second};
184         Constraint* constraint = getPairConstraint(This, boolToConsts, & pair);
185         ASSERT(constraint != NULL);
186         return constraint;
187 }
188
189 Constraint* createAllTotalOrderConstraintsSATEncoder(SATEncoder* This, Order* order){
190         ASSERT(order->type == TOTAL);
191         VectorInt* mems = order->set->members;
192         HashTableBoolConst* table = order->boolsToConstraints;
193         uint size = getSizeVectorInt(mems);
194         Constraint* constraints [size*size];
195         uint csize =0;
196         for(uint i=0; i<size; i++){
197                 uint64_t valueI = getVectorInt(mems, i);
198                 for(uint j=i+1; j<size;j++){
199                         uint64_t valueJ = getVectorInt(mems, j);
200                         OrderPair pairIJ = {valueI, valueJ};
201                         Constraint* constIJ=getPairConstraint(This, table, & pairIJ);
202                         for(uint k=j+1; k<size; k++){
203                                 uint64_t valueK = getVectorInt(mems, k);
204                                 OrderPair pairJK = {valueJ, valueK};
205                                 OrderPair pairIK = {valueI, valueK};
206                                 Constraint* constIK = getPairConstraint(This, table, & pairIK);
207                                 Constraint* constJK = getPairConstraint(This, table, & pairJK);
208                                 constraints[csize++] = generateTransOrderConstraintSATEncoder(This, constIJ, constJK, constIK); 
209                                 ASSERT(csize < size*size);
210                         }
211                 }
212         }
213         return allocArrayConstraint(AND, csize, constraints);
214 }
215
216 Constraint* getOrderConstraint(HashTableBoolConst *table, OrderPair *pair){
217         ASSERT(pair->first!= pair->second);
218         Constraint* constraint= getBoolConst(table, pair);
219         ASSERT(constraint!= NULL);
220         if(pair->first > pair->second)
221                 return constraint;
222         else
223                 return negateConstraint(constraint);
224 }
225
226 Constraint * generateTransOrderConstraintSATEncoder(SATEncoder *This, Constraint *constIJ,Constraint *constJK,Constraint *constIK){
227         //FIXME: first we should add the the constraint to the satsolver!
228         ASSERT(constIJ!= NULL && constJK != NULL && constIK != NULL);
229         Constraint *carray[] = {constIJ, constJK, negateConstraint(constIK)};
230         Constraint * loop1= allocArrayConstraint(OR, 3, carray);
231         Constraint * carray2[] = {negateConstraint(constIJ), negateConstraint(constJK), constIK};
232         Constraint * loop2= allocArrayConstraint(OR, 3,carray2 );
233         return allocConstraint(AND, loop1, loop2);
234 }
235
236 Constraint * encodePartialOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint){
237         // FIXME: we can have this implementation for partial order. Basically,
238         // we compute the transitivity between two order constraints specified by the client! (also can be used
239         // when client specify sparse constraints for the total order!)
240         ASSERT(constraint->order->type == PARTIAL);
241 /*
242         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
243         if( containsBoolConst(boolToConsts, boolOrder) ){
244                 return getBoolConst(boolToConsts, boolOrder);
245         } else {
246                 Constraint* constraint = getNewVarSATEncoder(This); 
247                 putBoolConst(boolToConsts,boolOrder, constraint);
248                 VectorBoolean* orderConstrs = &boolOrder->order->constraints;
249                 uint size= getSizeVectorBoolean(orderConstrs);
250                 for(uint i=0; i<size; i++){
251                         ASSERT(GETBOOLEANTYPE( getVectorBoolean(orderConstrs, i)) == ORDERCONST );
252                         BooleanOrder* tmp = (BooleanOrder*)getVectorBoolean(orderConstrs, i);
253                         BooleanOrder* newBool;
254                         Constraint* first, *second;
255                         if(tmp->second==boolOrder->first){
256                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,tmp->first,boolOrder->second);
257                                 first = encodeTotalOrderSATEncoder(This, tmp);
258                                 second = constraint;
259                                 
260                         }else if (boolOrder->second == tmp->first){
261                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,boolOrder->first,tmp->second);
262                                 first = constraint;
263                                 second = encodeTotalOrderSATEncoder(This, tmp);
264                         }else
265                                 continue;
266                         Constraint* transConstr= encodeTotalOrderSATEncoder(This, newBool);
267                         generateTransOrderConstraintSATEncoder(This, first, second, transConstr );
268                 }
269                 return constraint;
270         }
271 */      
272         return NULL;
273 }
274
275 Constraint * encodePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint) {
276         switch(GETPREDICATETYPE(constraint->predicate) ){
277                 case TABLEPRED:
278                         return encodeTablePredicateSATEncoder(This, constraint);
279                 case OPERATORPRED:
280                         return encodeOperatorPredicateSATEncoder(This, constraint);
281                 default:
282                         ASSERT(0);
283         }
284         return NULL;
285 }
286
287 Constraint * encodeTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
288         switch(constraint->encoding.type){
289                 case ENUMERATEIMPLICATIONS:
290                 case ENUMERATEIMPLICATIONSNEGATE:
291                         return encodeEnumTablePredicateSATEncoder(This, constraint);
292                 case CIRCUIT:
293                         ASSERT(0);
294                         break;
295                 default:
296                         ASSERT(0);
297         }
298         return NULL;
299 }
300
301 Constraint * encodeEnumTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
302         VectorTableEntry* entries = &(((PredicateTable*)constraint->predicate)->table->entries);
303         FunctionEncodingType encType = constraint->encoding.type;
304         uint size = getSizeVectorTableEntry(entries);
305         Constraint* constraints[size];
306         for(uint i=0; i<size; i++){
307                 TableEntry* entry = getVectorTableEntry(entries, i);
308                 if(encType==ENUMERATEIMPLICATIONS && entry->output!= true)
309                         continue;
310                 else if(encType==ENUMERATEIMPLICATIONSNEGATE && entry->output !=false)
311                         continue;
312                 ArrayElement* inputs = &constraint->inputs;
313                 uint inputNum =getSizeArrayElement(inputs);
314                 Constraint* carray[inputNum];
315                 for(uint j=0; j<inputNum; j++){
316                         Element* el = getArrayElement(inputs, j);
317                         Constraint* tmpc = getElementValueConstraint(This,el, entry->inputs[j]);
318                         ASSERT(tmpc!= NULL);
319                         if( GETELEMENTTYPE(el) == ELEMFUNCRETURN){
320                                 Constraint* func =encodeFunctionElementSATEncoder(This, (ElementFunction*) el);
321                                 ASSERT(func!=NULL);
322                                 carray[j] = allocConstraint(AND, func, tmpc);
323                         } else {
324                                 carray[j] = tmpc;
325                         }
326                         ASSERT(carray[j]!= NULL);
327                 }
328                 constraints[i]=allocArrayConstraint(AND, inputNum, carray);
329         }
330         Constraint* result= allocArrayConstraint(OR, size, constraints);
331         //FIXME: if it didn't match with any entry
332         return encType==ENUMERATEIMPLICATIONS? result: negateConstraint(result);
333 }
334
335 Constraint * encodeOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
336         switch(constraint->encoding.type){
337                 case ENUMERATEIMPLICATIONS:
338                         return encodeEnumOperatorPredicateSATEncoder(This, constraint);
339                 case CIRCUIT:
340                         ASSERT(0);
341                         break;
342                 default:
343                         ASSERT(0);
344         }
345         return NULL;
346 }
347
348 Constraint * encodeEnumOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
349         ASSERT(GETPREDICATETYPE(constraint->predicate)==OPERATORPRED);
350         PredicateOperator* predicate = (PredicateOperator*)constraint->predicate;
351         ASSERT(predicate->op == EQUALS); //For now, we just only support equals
352         //getting maximum size of in common elements between two sets!
353         uint size=getSizeVectorInt( getArraySet( &predicate->domains, 0)->members);
354         uint64_t commonElements [size];
355         getEqualitySetIntersection(predicate, &size, commonElements);
356         Constraint*  carray[size];
357         Element* elem1 = getArrayElement( &constraint->inputs, 0);
358         Constraint *elemc1 = NULL, *elemc2 = NULL;
359         if( GETELEMENTTYPE(elem1) == ELEMFUNCRETURN)
360                 elemc1 = encodeFunctionElementSATEncoder(This, (ElementFunction*) elem1);
361         Element* elem2 = getArrayElement( &constraint->inputs, 1);
362         if( GETELEMENTTYPE(elem2) == ELEMFUNCRETURN)
363                 elemc2 = encodeFunctionElementSATEncoder(This, (ElementFunction*) elem2);
364         for(uint i=0; i<size; i++){
365                 Constraint* arg1 = getElementValueConstraint(This, elem1, commonElements[i]);
366                 ASSERT(arg1!=NULL);
367                 Constraint* arg2 = getElementValueConstraint(This, elem2, commonElements[i]);
368                 ASSERT(arg2 != NULL);
369                 carray[i] =  allocConstraint(AND, arg1, arg2);
370         }
371         //FIXME: the case when there is no intersection ....
372         Constraint* result = allocArrayConstraint(OR, size, carray);
373         ASSERT(result!= NULL);
374         if(elemc1!= NULL)
375                 result = allocConstraint(AND, result, elemc1);
376         if(elemc2 != NULL)
377                 result = allocConstraint (AND, result, elemc2);
378         return result;
379 }
380
381 Constraint* encodeFunctionElementSATEncoder(SATEncoder* encoder, ElementFunction *This){
382         switch(GETFUNCTIONTYPE(This->function)){
383                 case TABLEFUNC:
384                         return encodeTableElementFunctionSATEncoder(encoder, This);
385                 case OPERATORFUNC:
386                         return encodeOperatorElementFunctionSATEncoder(encoder, This);
387                 default:
388                         ASSERT(0);
389         }
390         return NULL;
391 }
392
393 Constraint* encodeTableElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
394         switch(getElementFunctionEncoding(This)->type){
395                 case ENUMERATEIMPLICATIONS:
396                         return encodeEnumTableElemFunctionSATEncoder(encoder, This);
397                         break;
398                 case CIRCUIT:
399                         ASSERT(0);
400                         break;
401                 default:
402                         ASSERT(0);
403         }
404         return NULL;
405 }
406
407 Constraint* encodeOperatorElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
408         ASSERT(GETFUNCTIONTYPE(This->function) == OPERATORFUNC);
409         ASSERT(getSizeArrayElement(&This->inputs)==2 );
410         ElementEncoding* elem1 = getElementEncoding( getArrayElement(&This->inputs,0) );
411         ElementEncoding* elem2 = getElementEncoding( getArrayElement(&This->inputs,1) );
412         Constraint* carray[elem1->encArraySize*elem2->encArraySize];
413         uint size=0;
414         Constraint* overFlowConstraint = ((BooleanVar*) This->overflowstatus)->var;
415         for(uint i=0; i<elem1->encArraySize; i++){
416                 if(isinUseElement(elem1, i)){
417                         for( uint j=0; j<elem2->encArraySize; j++){
418                                 if(isinUseElement(elem2, j)){
419                                         bool isInRange = false;
420                                         uint64_t result= applyFunctionOperator((FunctionOperator*)This->function,elem1->encodingArray[i],
421                                                 elem2->encodingArray[j], &isInRange);
422                                         //FIXME: instead of getElementValueConstraint, it might be useful to have another function
423                                         // that doesn't iterate over encodingArray and treats more efficient ...
424                                         Constraint* valConstrIn1 = getElementValueConstraint(encoder, elem1->element, elem1->encodingArray[i]);
425                                         ASSERT(valConstrIn1 != NULL);
426                                         Constraint* valConstrIn2 = getElementValueConstraint(encoder, elem2->element, elem2->encodingArray[j]);
427                                         ASSERT(valConstrIn2 != NULL);
428                                         Constraint* valConstrOut = getElementValueConstraint(encoder, (Element*) This, result);
429                                         if(valConstrOut == NULL)
430                                                 continue; //FIXME:Should talk to brian about it!
431                                         Constraint* OpConstraint = allocConstraint(IMPLIES, 
432                                                 allocConstraint(AND, valConstrIn1, valConstrIn2) , valConstrOut);
433                                         switch( ((FunctionOperator*)This->function)->overflowbehavior ){
434                                                 case IGNORE:
435                                                         if(isInRange){
436                                                                 carray[size++] = OpConstraint;
437                                                         }
438                                                         break;
439                                                 case WRAPAROUND:
440                                                         carray[size++] = OpConstraint;
441                                                         break;
442                                                 case FLAGFORCESOVERFLOW:
443                                                         if(isInRange){
444                                                                 Constraint* const1 = allocConstraint(IMPLIES,
445                                                                         allocConstraint(AND, valConstrIn1, valConstrIn2), 
446                                                                         negateConstraint(overFlowConstraint));
447                                                                 carray[size++] = allocConstraint(AND, const1, OpConstraint);
448                                                         }
449                                                         break;
450                                                 case OVERFLOWSETSFLAG:
451                                                         if(isInRange){
452                                                                 carray[size++] = OpConstraint;
453                                                         } else{
454                                                                 carray[size++] = allocConstraint(IMPLIES,
455                                                                         allocConstraint(AND, valConstrIn1, valConstrIn2),
456                                                                         overFlowConstraint);
457                                                         }
458                                                         break;
459                                                 case FLAGIFFOVERFLOW:
460                                                         if(isInRange){
461                                                                 Constraint* const1 = allocConstraint(IMPLIES,
462                                                                         allocConstraint(AND, valConstrIn1, valConstrIn2), 
463                                                                         negateConstraint(overFlowConstraint));
464                                                                 carray[size++] = allocConstraint(AND, const1, OpConstraint);
465                                                         }else{
466                                                                 carray[size++] = allocConstraint(IMPLIES,
467                                                                         allocConstraint(AND, valConstrIn1, valConstrIn2),
468                                                                         overFlowConstraint);
469                                                         }
470                                                         break;
471                                                 case NOOVERFLOW:
472                                                         if(!isInRange){
473                                                                 ASSERT(0);
474                                                         }
475                                                         carray[size++] = OpConstraint;
476                                                         break;
477                                                 default:
478                                                         ASSERT(0);
479                                         }
480                                         
481                                 }
482                         }
483                 }
484         }
485         return allocArrayConstraint(AND, size, carray);
486 }
487
488 Constraint* encodeEnumTableElemFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
489         ASSERT(GETFUNCTIONTYPE(This->function)==TABLEFUNC);
490         ArrayElement* elements= &This->inputs;
491         Table* table = ((FunctionTable*) (This->function))->table;
492         uint size = getSizeVectorTableEntry(&table->entries);
493         Constraint* constraints[size]; //FIXME: should add a space for the case that didn't match any entries
494         for(uint i=0; i<size; i++){
495                 TableEntry* entry = getVectorTableEntry(&table->entries, i);
496                 uint inputNum =getSizeArrayElement(elements);
497                 Constraint* carray[inputNum];
498                 for(uint j=0; j<inputNum; j++){
499                         Element* el= getArrayElement(elements, j);
500                         carray[j] = getElementValueConstraint(encoder, el, entry->inputs[j]);
501                         ASSERT(carray[j]!= NULL);
502                 }
503                 Constraint* output = getElementValueConstraint(encoder, (Element*)This, entry->output);
504                 ASSERT(output!= NULL);
505                 Constraint* row= allocConstraint(IMPLIES, allocArrayConstraint(AND, inputNum, carray), output);
506                 constraints[i]=row;
507         }
508         Constraint* result = allocArrayConstraint(OR, size, constraints);
509         return result;
510 }