Merge branch 'hamed' into brian
[satune.git] / src / Backend / satencoder.c
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "csolver.h"
4 #include "boolean.h"
5 #include "constraint.h"
6 #include "common.h"
7 #include "element.h"
8 #include "function.h"
9 #include "tableentry.h"
10 #include "table.h"
11 #include "order.h"
12 #include "predicate.h"
13 #include "orderpair.h"
14 #include "set.h"
15
16 SATEncoder * allocSATEncoder() {
17         SATEncoder *This=ourmalloc(sizeof (SATEncoder));
18         This->varcount=1;
19         This->satSolver = allocIncrementalSolver();
20         return This;
21 }
22
23 void deleteSATEncoder(SATEncoder *This) {
24         deleteIncrementalSolver(This->satSolver);
25         ourfree(This);
26 }
27
28 Constraint * getElementValueConstraint(SATEncoder* encoder,Element* This, uint64_t value) {
29         generateElementEncodingVariables(encoder, getElementEncoding(This));
30         switch(getElementEncoding(This)->type){
31                 case ONEHOT:
32                         //FIXME
33                         ASSERT(0);
34                         break;
35                 case UNARY:
36                         ASSERT(0);
37                         break;
38                 case BINARYINDEX:
39                         return getElementValueBinaryIndexConstraint(This, value);
40                         break;
41                 case ONEHOTBINARY:
42                         ASSERT(0);
43                         break;
44                 case BINARYVAL:
45                         ASSERT(0);
46                         break;
47                 default:
48                         ASSERT(0);
49                         break;
50         }
51         return NULL;
52 }
53 Constraint * getElementValueBinaryIndexConstraint(Element* This, uint64_t value) {
54         ASTNodeType type = GETELEMENTTYPE(This);
55         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
56         ElementEncoding* elemEnc = getElementEncoding(This);
57         for(uint i=0; i<elemEnc->encArraySize; i++){
58                 if( isinUseElement(elemEnc, i) && elemEnc->encodingArray[i]==value){
59                         return generateBinaryConstraint(elemEnc->numVars,
60                                 elemEnc->variables, i);
61                 }
62         }
63         return NULL;
64 }
65
66 void addConstraintToSATSolver(Constraint *c, IncrementalSolver* satSolver) {
67         VectorConstraint* simplified = simplifyConstraint(c);
68         uint size = getSizeVectorConstraint(simplified);
69         for(uint i=0; i<size; i++) {
70                 Constraint *simp=getVectorConstraint(simplified, i);
71                 if (simp->type==TRUE)
72                         continue;
73                 ASSERT(simp->type!=FALSE);
74                 dumpConstraint(simp, satSolver);
75                 freerecConstraint(simp);
76         }
77         deleteVectorConstraint(simplified);
78 }
79
80 void encodeAllSATEncoder(CSolver *csolver, SATEncoder * This) {
81         VectorBoolean *constraints=csolver->constraints;
82         uint size=getSizeVectorBoolean(constraints);
83         for(uint i=0;i<size;i++) {
84                 Boolean *constraint=getVectorBoolean(constraints, i);
85                 Constraint* c= encodeConstraintSATEncoder(This, constraint);
86                 printConstraint(c);
87                 model_print("\n\n");
88                 addConstraintToSATSolver(c, This->satSolver);
89                 //FIXME: When do we want to delete constraints? Should we keep an array of them
90                 // and delete them later, or it would be better to just delete them right away?
91         }
92 }
93
94 Constraint * encodeConstraintSATEncoder(SATEncoder *This, Boolean *constraint) {
95         switch(GETBOOLEANTYPE(constraint)) {
96         case ORDERCONST:
97                 return encodeOrderSATEncoder(This, (BooleanOrder *) constraint);
98         case BOOLEANVAR:
99                 return encodeVarSATEncoder(This, (BooleanVar *) constraint);
100         case LOGICOP:
101                 return encodeLogicSATEncoder(This, (BooleanLogic *) constraint);
102         case PREDICATEOP:
103                 return encodePredicateSATEncoder(This, (BooleanPredicate *) constraint);
104         default:
105                 model_print("Unhandled case in encodeConstraintSATEncoder %u", GETBOOLEANTYPE(constraint));
106                 exit(-1);
107         }
108 }
109
110 void getArrayNewVarsSATEncoder(SATEncoder* encoder, uint num, Constraint **carray) {
111         for(uint i=0;i<num;i++)
112                 carray[i]=getNewVarSATEncoder(encoder);
113 }
114
115 Constraint * getNewVarSATEncoder(SATEncoder *This) {
116         Constraint * var=allocVarConstraint(VAR, This->varcount);
117         Constraint * varneg=allocVarConstraint(NOTVAR, This->varcount++);
118         setNegConstraint(var, varneg);
119         setNegConstraint(varneg, var);
120         return var;
121 }
122
123 Constraint * encodeVarSATEncoder(SATEncoder *This, BooleanVar * constraint) {
124         if (constraint->var == NULL) {
125                 constraint->var=getNewVarSATEncoder(This);
126         }
127         return constraint->var;
128 }
129
130 Constraint * encodeLogicSATEncoder(SATEncoder *This, BooleanLogic * constraint) {
131         Constraint * array[getSizeArrayBoolean(&constraint->inputs)];
132         for(uint i=0;i<getSizeArrayBoolean(&constraint->inputs);i++)
133                 array[i]=encodeConstraintSATEncoder(This, getArrayBoolean(&constraint->inputs, i));
134
135         switch(constraint->op) {
136         case L_AND:
137                 return allocArrayConstraint(AND, getSizeArrayBoolean(&constraint->inputs), array);
138         case L_OR:
139                 return allocArrayConstraint(OR, getSizeArrayBoolean(&constraint->inputs), array);
140         case L_NOT:
141                 ASSERT( getSizeArrayBoolean(&constraint->inputs)==1);
142                 return negateConstraint(array[0]);
143         case L_XOR: {
144                 ASSERT( getSizeArrayBoolean(&constraint->inputs)==2);
145                 Constraint * nleft=negateConstraint(cloneConstraint(array[0]));
146                 Constraint * nright=negateConstraint(cloneConstraint(array[1]));
147                 return allocConstraint(OR,
148                                                                                                          allocConstraint(AND, array[0], nright),
149                                                                                                          allocConstraint(AND, nleft, array[1]));
150         }
151         case L_IMPLIES:
152                 ASSERT( getSizeArrayBoolean( &constraint->inputs)==2);
153                 return allocConstraint(IMPLIES, array[0], array[1]);
154         default:
155                 model_print("Unhandled case in encodeLogicSATEncoder %u", constraint->op);
156                 exit(-1);
157         }
158 }
159
160
161 Constraint * encodeOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint) {
162         switch( constraint->order->type){
163                 case PARTIAL:
164                         return encodePartialOrderSATEncoder(This, constraint);
165                 case TOTAL:
166                         return encodeTotalOrderSATEncoder(This, constraint);
167                 default:
168                         ASSERT(0);
169         }
170         return NULL;
171 }
172
173 Constraint * getPairConstraint(SATEncoder *This, HashTableBoolConst * table, OrderPair * pair) {
174         bool negate = false;
175         OrderPair flipped;
176         if (pair->first > pair->second) {
177                 negate=true;
178                 flipped.first=pair->second;
179                 flipped.second=pair->first;
180                 pair = &flipped;        //FIXME: accessing a local variable from outside of the function?
181         }
182         Constraint * constraint;
183         if (!containsBoolConst(table, pair)) {
184                 constraint = getNewVarSATEncoder(This);
185                 OrderPair * paircopy = allocOrderPair(pair->first, pair->second, constraint);
186                 putBoolConst(table, paircopy, paircopy);
187         } else
188                 constraint = getBoolConst(table, pair)->constraint;
189         if (negate)
190                 return negateConstraint(constraint);
191         else
192                 return constraint;
193         
194 }
195
196 Constraint * encodeTotalOrderSATEncoder(SATEncoder *This, BooleanOrder * boolOrder){
197         ASSERT(boolOrder->order->type == TOTAL);
198         if(boolOrder->order->boolsToConstraints == NULL){
199                 initializeOrderHashTable(boolOrder->order);
200                 return createAllTotalOrderConstraintsSATEncoder(This, boolOrder->order);
201         }
202         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
203         OrderPair pair={boolOrder->first, boolOrder->second, NULL};
204         Constraint *constraint = getPairConstraint(This, boolToConsts, & pair);
205         return constraint;
206 }
207
208 Constraint* createAllTotalOrderConstraintsSATEncoder(SATEncoder* This, Order* order){
209         ASSERT(order->type == TOTAL);
210         VectorInt* mems = order->set->members;
211         HashTableBoolConst* table = order->boolsToConstraints;
212         uint size = getSizeVectorInt(mems);
213         Constraint* constraints [size*size];
214         uint csize =0;
215         for(uint i=0; i<size; i++){
216                 uint64_t valueI = getVectorInt(mems, i);
217                 for(uint j=i+1; j<size;j++){
218                         uint64_t valueJ = getVectorInt(mems, j);
219                         OrderPair pairIJ = {valueI, valueJ};
220                         Constraint* constIJ=getPairConstraint(This, table, & pairIJ);
221                         for(uint k=j+1; k<size; k++){
222                                 uint64_t valueK = getVectorInt(mems, k);
223                                 OrderPair pairJK = {valueJ, valueK};
224                                 OrderPair pairIK = {valueI, valueK};
225                                 Constraint* constIK = getPairConstraint(This, table, & pairIK);
226                                 Constraint* constJK = getPairConstraint(This, table, & pairJK);
227                                 constraints[csize++] = generateTransOrderConstraintSATEncoder(This, constIJ, constJK, constIK); 
228                                 ASSERT(csize < size*size);
229                         }
230                 }
231         }
232         return allocArrayConstraint(AND, csize, constraints);
233 }
234
235 Constraint* getOrderConstraint(HashTableBoolConst *table, OrderPair *pair){
236         ASSERT(pair->first!= pair->second);
237         Constraint* constraint= getBoolConst(table, pair)->constraint;
238         if(pair->first > pair->second)
239                 return constraint;
240         else
241                 return negateConstraint(constraint);
242 }
243
244 Constraint * generateTransOrderConstraintSATEncoder(SATEncoder *This, Constraint *constIJ,Constraint *constJK,Constraint *constIK){
245         //FIXME: first we should add the the constraint to the satsolver!
246         ASSERT(constIJ!= NULL && constJK != NULL && constIK != NULL);
247         Constraint *carray[] = {constIJ, constJK, negateConstraint(constIK)};
248         Constraint * loop1= allocArrayConstraint(OR, 3, carray);
249         Constraint * carray2[] = {negateConstraint(constIJ), negateConstraint(constJK), constIK};
250         Constraint * loop2= allocArrayConstraint(OR, 3,carray2 );
251         return allocConstraint(AND, loop1, loop2);
252 }
253
254 Constraint * encodePartialOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint){
255         // FIXME: we can have this implementation for partial order. Basically,
256         // we compute the transitivity between two order constraints specified by the client! (also can be used
257         // when client specify sparse constraints for the total order!)
258         ASSERT(constraint->order->type == PARTIAL);
259 /*
260         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
261         if( containsBoolConst(boolToConsts, boolOrder) ){
262                 return getBoolConst(boolToConsts, boolOrder);
263         } else {
264                 Constraint* constraint = getNewVarSATEncoder(This); 
265                 putBoolConst(boolToConsts,boolOrder, constraint);
266                 VectorBoolean* orderConstrs = &boolOrder->order->constraints;
267                 uint size= getSizeVectorBoolean(orderConstrs);
268                 for(uint i=0; i<size; i++){
269                         ASSERT(GETBOOLEANTYPE( getVectorBoolean(orderConstrs, i)) == ORDERCONST );
270                         BooleanOrder* tmp = (BooleanOrder*)getVectorBoolean(orderConstrs, i);
271                         BooleanOrder* newBool;
272                         Constraint* first, *second;
273                         if(tmp->second==boolOrder->first){
274                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,tmp->first,boolOrder->second);
275                                 first = encodeTotalOrderSATEncoder(This, tmp);
276                                 second = constraint;
277                                 
278                         }else if (boolOrder->second == tmp->first){
279                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,boolOrder->first,tmp->second);
280                                 first = constraint;
281                                 second = encodeTotalOrderSATEncoder(This, tmp);
282                         }else
283                                 continue;
284                         Constraint* transConstr= encodeTotalOrderSATEncoder(This, newBool);
285                         generateTransOrderConstraintSATEncoder(This, first, second, transConstr );
286                 }
287                 return constraint;
288         }
289 */      
290         return NULL;
291 }
292
293 Constraint * encodePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint) {
294         switch(GETPREDICATETYPE(constraint->predicate) ){
295                 case TABLEPRED:
296                         return encodeTablePredicateSATEncoder(This, constraint);
297                 case OPERATORPRED:
298                         return encodeOperatorPredicateSATEncoder(This, constraint);
299                 default:
300                         ASSERT(0);
301         }
302         return NULL;
303 }
304
305 Constraint * encodeTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
306         switch(constraint->encoding.type){
307                 case ENUMERATEIMPLICATIONS:
308                 case ENUMERATEIMPLICATIONSNEGATE:
309                         return encodeEnumTablePredicateSATEncoder(This, constraint);
310                 case CIRCUIT:
311                         ASSERT(0);
312                         break;
313                 default:
314                         ASSERT(0);
315         }
316         return NULL;
317 }
318
319 Constraint * encodeEnumTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
320         VectorTableEntry* entries = &(((PredicateTable*)constraint->predicate)->table->entries);
321         FunctionEncodingType encType = constraint->encoding.type;
322         uint size = getSizeVectorTableEntry(entries);
323         Constraint* constraints[size];
324         for(uint i=0; i<size; i++){
325                 TableEntry* entry = getVectorTableEntry(entries, i);
326                 if(encType==ENUMERATEIMPLICATIONS && entry->output!= true)
327                         continue;
328                 else if(encType==ENUMERATEIMPLICATIONSNEGATE && entry->output !=false)
329                         continue;
330                 ArrayElement* inputs = &constraint->inputs;
331                 uint inputNum =getSizeArrayElement(inputs);
332                 Constraint* carray[inputNum];
333                 for(uint j=0; j<inputNum; j++){
334                         Element* el = getArrayElement(inputs, j);
335                         Constraint* tmpc = getElementValueConstraint(This,el, entry->inputs[j]);
336                         ASSERT(tmpc!= NULL);
337                         if( GETELEMENTTYPE(el) == ELEMFUNCRETURN){
338                                 Constraint* func =encodeFunctionElementSATEncoder(This, (ElementFunction*) el);
339                                 ASSERT(func!=NULL);
340                                 carray[j] = allocConstraint(AND, func, tmpc);
341                         } else {
342                                 carray[j] = tmpc;
343                         }
344                         ASSERT(carray[j]!= NULL);
345                 }
346                 constraints[i]=allocArrayConstraint(AND, inputNum, carray);
347         }
348         Constraint* result= allocArrayConstraint(OR, size, constraints);
349         //FIXME: if it didn't match with any entry
350         return encType==ENUMERATEIMPLICATIONS? result: negateConstraint(result);
351 }
352
353 Constraint * encodeOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
354         switch(constraint->encoding.type){
355                 case ENUMERATEIMPLICATIONS:
356                         return encodeEnumOperatorPredicateSATEncoder(This, constraint);
357                 case CIRCUIT:
358                         ASSERT(0);
359                         break;
360                 default:
361                         ASSERT(0);
362         }
363         return NULL;
364 }
365
366 Constraint * encodeEnumOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
367         ASSERT(GETPREDICATETYPE(constraint->predicate)==OPERATORPRED);
368         PredicateOperator* predicate = (PredicateOperator*)constraint->predicate;
369         ASSERT(predicate->op == EQUALS); //For now, we just only support equals
370         //getting maximum size of in common elements between two sets!
371         uint size=getSizeVectorInt( getArraySet( &predicate->domains, 0)->members);
372         uint64_t commonElements [size];
373         getEqualitySetIntersection(predicate, &size, commonElements);
374         Constraint*  carray[size];
375         Element* elem1 = getArrayElement( &constraint->inputs, 0);
376         Constraint *elemc1 = NULL, *elemc2 = NULL;
377         if( GETELEMENTTYPE(elem1) == ELEMFUNCRETURN)
378                 elemc1 = encodeFunctionElementSATEncoder(This, (ElementFunction*) elem1);
379         Element* elem2 = getArrayElement( &constraint->inputs, 1);
380         if( GETELEMENTTYPE(elem2) == ELEMFUNCRETURN)
381                 elemc2 = encodeFunctionElementSATEncoder(This, (ElementFunction*) elem2);
382         for(uint i=0; i<size; i++){
383                 Constraint* arg1 = getElementValueConstraint(This, elem1, commonElements[i]);
384                 ASSERT(arg1!=NULL);
385                 Constraint* arg2 = getElementValueConstraint(This, elem2, commonElements[i]);
386                 ASSERT(arg2 != NULL);
387                 carray[i] =  allocConstraint(AND, arg1, arg2);
388         }
389         //FIXME: the case when there is no intersection ....
390         Constraint* result = allocArrayConstraint(OR, size, carray);
391         ASSERT(result!= NULL);
392         if(elemc1!= NULL)
393                 result = allocConstraint(AND, result, elemc1);
394         if(elemc2 != NULL)
395                 result = allocConstraint (AND, result, elemc2);
396         return result;
397 }
398
399 Constraint* encodeFunctionElementSATEncoder(SATEncoder* encoder, ElementFunction *This){
400         switch(GETFUNCTIONTYPE(This->function)){
401                 case TABLEFUNC:
402                         return encodeTableElementFunctionSATEncoder(encoder, This);
403                 case OPERATORFUNC:
404                         return encodeOperatorElementFunctionSATEncoder(encoder, This);
405                 default:
406                         ASSERT(0);
407         }
408         return NULL;
409 }
410
411 Constraint* encodeTableElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
412         switch(getElementFunctionEncoding(This)->type){
413                 case ENUMERATEIMPLICATIONS:
414                         return encodeEnumTableElemFunctionSATEncoder(encoder, This);
415                         break;
416                 case CIRCUIT:
417                         ASSERT(0);
418                         break;
419                 default:
420                         ASSERT(0);
421         }
422         return NULL;
423 }
424
425 Constraint* encodeOperatorElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
426         ASSERT(GETFUNCTIONTYPE(This->function) == OPERATORFUNC);
427         ASSERT(getSizeArrayElement(&This->inputs)==2 );
428         ElementEncoding* elem1 = getElementEncoding( getArrayElement(&This->inputs,0) );
429         ElementEncoding* elem2 = getElementEncoding( getArrayElement(&This->inputs,1) );
430         Constraint* carray[elem1->encArraySize*elem2->encArraySize];
431         uint size=0;
432         Constraint* overFlowConstraint = ((BooleanVar*) This->overflowstatus)->var;
433         for(uint i=0; i<elem1->encArraySize; i++){
434                 if(isinUseElement(elem1, i)){
435                         for( uint j=0; j<elem2->encArraySize; j++){
436                                 if(isinUseElement(elem2, j)){
437                                         bool isInRange = false;
438                                         uint64_t result= applyFunctionOperator((FunctionOperator*)This->function,elem1->encodingArray[i],
439                                                 elem2->encodingArray[j], &isInRange);
440                                         //FIXME: instead of getElementValueConstraint, it might be useful to have another function
441                                         // that doesn't iterate over encodingArray and treats more efficient ...
442                                         Constraint* valConstrIn1 = getElementValueConstraint(encoder, elem1->element, elem1->encodingArray[i]);
443                                         ASSERT(valConstrIn1 != NULL);
444                                         Constraint* valConstrIn2 = getElementValueConstraint(encoder, elem2->element, elem2->encodingArray[j]);
445                                         ASSERT(valConstrIn2 != NULL);
446                                         Constraint* valConstrOut = getElementValueConstraint(encoder, (Element*) This, result);
447                                         if(valConstrOut == NULL)
448                                                 continue; //FIXME:Should talk to brian about it!
449                                         Constraint* OpConstraint = allocConstraint(IMPLIES, 
450                                                 allocConstraint(AND, valConstrIn1, valConstrIn2) , valConstrOut);
451                                         switch( ((FunctionOperator*)This->function)->overflowbehavior ){
452                                                 case IGNORE:
453                                                         if(isInRange){
454                                                                 carray[size++] = OpConstraint;
455                                                         }
456                                                         break;
457                                                 case WRAPAROUND:
458                                                         carray[size++] = OpConstraint;
459                                                         break;
460                                                 case FLAGFORCESOVERFLOW:
461                                                         if(isInRange){
462                                                                 Constraint* const1 = allocConstraint(IMPLIES,
463                                                                         allocConstraint(AND, valConstrIn1, valConstrIn2), 
464                                                                         negateConstraint(overFlowConstraint));
465                                                                 carray[size++] = allocConstraint(AND, const1, OpConstraint);
466                                                         }
467                                                         break;
468                                                 case OVERFLOWSETSFLAG:
469                                                         if(isInRange){
470                                                                 carray[size++] = OpConstraint;
471                                                         } else{
472                                                                 carray[size++] = allocConstraint(IMPLIES,
473                                                                         allocConstraint(AND, valConstrIn1, valConstrIn2),
474                                                                         overFlowConstraint);
475                                                         }
476                                                         break;
477                                                 case FLAGIFFOVERFLOW:
478                                                         if(isInRange){
479                                                                 Constraint* const1 = allocConstraint(IMPLIES,
480                                                                         allocConstraint(AND, valConstrIn1, valConstrIn2), 
481                                                                         negateConstraint(overFlowConstraint));
482                                                                 carray[size++] = allocConstraint(AND, const1, OpConstraint);
483                                                         }else{
484                                                                 carray[size++] = allocConstraint(IMPLIES,
485                                                                         allocConstraint(AND, valConstrIn1, valConstrIn2),
486                                                                         overFlowConstraint);
487                                                         }
488                                                         break;
489                                                 case NOOVERFLOW:
490                                                         if(!isInRange){
491                                                                 ASSERT(0);
492                                                         }
493                                                         carray[size++] = OpConstraint;
494                                                         break;
495                                                 default:
496                                                         ASSERT(0);
497                                         }
498                                         
499                                 }
500                         }
501                 }
502         }
503         return allocArrayConstraint(AND, size, carray);
504 }
505
506 Constraint* encodeEnumTableElemFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
507         ASSERT(GETFUNCTIONTYPE(This->function)==TABLEFUNC);
508         ArrayElement* elements= &This->inputs;
509         Table* table = ((FunctionTable*) (This->function))->table;
510         uint size = getSizeVectorTableEntry(&table->entries);
511         Constraint* constraints[size]; //FIXME: should add a space for the case that didn't match any entries
512         for(uint i=0; i<size; i++){
513                 TableEntry* entry = getVectorTableEntry(&table->entries, i);
514                 uint inputNum =getSizeArrayElement(elements);
515                 Constraint* carray[inputNum];
516                 for(uint j=0; j<inputNum; j++){
517                         Element* el= getArrayElement(elements, j);
518                         carray[j] = getElementValueConstraint(encoder, el, entry->inputs[j]);
519                         ASSERT(carray[j]!= NULL);
520                 }
521                 Constraint* output = getElementValueConstraint(encoder, (Element*)This, entry->output);
522                 ASSERT(output!= NULL);
523                 Constraint* row= allocConstraint(IMPLIES, allocArrayConstraint(AND, inputNum, carray), output);
524                 constraints[i]=row;
525         }
526         Constraint* result = allocArrayConstraint(OR, size, constraints);
527         return result;
528 }