Adding function operator handler+ omitting some redundant codes
[satune.git] / src / Backend / satencoder.c
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "csolver.h"
4 #include "boolean.h"
5 #include "constraint.h"
6 #include "common.h"
7 #include "element.h"
8 #include "function.h"
9 #include "tableentry.h"
10 #include "table.h"
11 #include "order.h"
12 #include "predicate.h"
13 #include "orderpair.h"
14 #include "set.h"
15
16 SATEncoder * allocSATEncoder() {
17         SATEncoder *This=ourmalloc(sizeof (SATEncoder));
18         This->varcount=1;
19         return This;
20 }
21
22 void deleteSATEncoder(SATEncoder *This) {
23         ourfree(This);
24 }
25
26 Constraint * getElementValueConstraint(Element* This, uint64_t value) {
27         switch(getElementEncoding(This)->type){
28                 case ONEHOT:
29                         ASSERT(0);
30                         break;
31                 case UNARY:
32                         ASSERT(0);
33                         break;
34                 case BINARYINDEX:
35                         ASSERT(0);
36                         break;
37                 case ONEHOTBINARY:
38                         return getElementValueBinaryIndexConstraint(This, value);
39                         break;
40                 case BINARYVAL:
41                         ASSERT(0);
42                         break;
43                 default:
44                         ASSERT(0);
45                         break;
46         }
47         return NULL;
48 }
49 Constraint * getElementValueBinaryIndexConstraint(Element* This, uint64_t value) {
50         ASTNodeType type = GETELEMENTTYPE(This);
51         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
52         ElementEncoding* elemEnc = getElementEncoding(This);
53         for(uint i=0; i<elemEnc->encArraySize; i++){
54                 if( isinUseElement(elemEnc, i) && elemEnc->encodingArray[i]==value){
55                         return generateBinaryConstraint(elemEnc->numVars,
56                                 elemEnc->variables, i);
57                 }
58         }
59         return NULL;
60 }
61
62 void encodeAllSATEncoder(CSolver *csolver, SATEncoder * This) {
63         VectorBoolean *constraints=csolver->constraints;
64         uint size=getSizeVectorBoolean(constraints);
65         for(uint i=0;i<size;i++) {
66                 Boolean *constraint=getVectorBoolean(constraints, i);
67                 encodeConstraintSATEncoder(This, constraint);
68         }
69 }
70
71 Constraint * encodeConstraintSATEncoder(SATEncoder *This, Boolean *constraint) {
72         switch(GETBOOLEANTYPE(constraint)) {
73         case ORDERCONST:
74                 return encodeOrderSATEncoder(This, (BooleanOrder *) constraint);
75         case BOOLEANVAR:
76                 return encodeVarSATEncoder(This, (BooleanVar *) constraint);
77         case LOGICOP:
78                 return encodeLogicSATEncoder(This, (BooleanLogic *) constraint);
79         case PREDICATEOP:
80                 return encodePredicateSATEncoder(This, (BooleanPredicate *) constraint);
81         default:
82                 model_print("Unhandled case in encodeConstraintSATEncoder %u", GETBOOLEANTYPE(constraint));
83                 exit(-1);
84         }
85 }
86
87 void getArrayNewVarsSATEncoder(SATEncoder* encoder, uint num, Constraint **carray) {
88         for(uint i=0;i<num;i++)
89                 carray[i]=getNewVarSATEncoder(encoder);
90 }
91
92 Constraint * getNewVarSATEncoder(SATEncoder *This) {
93         Constraint * var=allocVarConstraint(VAR, This->varcount);
94         Constraint * varneg=allocVarConstraint(NOTVAR, This->varcount++);
95         setNegConstraint(var, varneg);
96         setNegConstraint(varneg, var);
97         return var;
98 }
99
100 Constraint * encodeVarSATEncoder(SATEncoder *This, BooleanVar * constraint) {
101         if (constraint->var == NULL) {
102                 constraint->var=getNewVarSATEncoder(This);
103         }
104         return constraint->var;
105 }
106
107 Constraint * encodeLogicSATEncoder(SATEncoder *This, BooleanLogic * constraint) {
108         Constraint * array[getSizeArrayBoolean(&constraint->inputs)];
109         for(uint i=0;i<getSizeArrayBoolean(&constraint->inputs);i++)
110                 array[i]=encodeConstraintSATEncoder(This, getArrayBoolean(&constraint->inputs, i));
111
112         switch(constraint->op) {
113         case L_AND:
114                 return allocArrayConstraint(AND, getSizeArrayBoolean(&constraint->inputs), array);
115         case L_OR:
116                 return allocArrayConstraint(OR, getSizeArrayBoolean(&constraint->inputs), array);
117         case L_NOT:
118                 ASSERT(constraint->numArray==1);
119                 return negateConstraint(array[0]);
120         case L_XOR: {
121                 ASSERT(constraint->numArray==2);
122                 Constraint * nleft=negateConstraint(cloneConstraint(array[0]));
123                 Constraint * nright=negateConstraint(cloneConstraint(array[1]));
124                 return allocConstraint(OR,
125                                                                                                          allocConstraint(AND, array[0], nright),
126                                                                                                          allocConstraint(AND, nleft, array[1]));
127         }
128         case L_IMPLIES:
129                 ASSERT(constraint->numArray==2);
130                 return allocConstraint(IMPLIES, array[0], array[1]);
131         default:
132                 model_print("Unhandled case in encodeLogicSATEncoder %u", constraint->op);
133                 exit(-1);
134         }
135 }
136
137
138 Constraint * encodeOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint) {
139         switch( constraint->order->type){
140                 case PARTIAL:
141                         return encodePartialOrderSATEncoder(This, constraint);
142                 case TOTAL:
143                         return encodeTotalOrderSATEncoder(This, constraint);
144                 default:
145                         ASSERT(0);
146         }
147         return NULL;
148 }
149
150 Constraint * getPairConstraint(SATEncoder *This, HashTableBoolConst * table, OrderPair * pair) {
151         bool negate = false;
152         OrderPair flipped;
153         if (pair->first > pair->second) {
154                 negate=true;
155                 flipped.first=pair->second;
156                 flipped.second=pair->first;
157                 pair = &flipped;
158         }
159         Constraint * constraint;
160         if (!containsBoolConst(table, pair)) {
161                 constraint = getNewVarSATEncoder(This);
162                 OrderPair * paircopy = allocOrderPair(pair->first, pair->second);
163                 putBoolConst(table, paircopy, constraint);
164         } else
165                 constraint = getBoolConst(table, pair);
166         if (negate)
167                 return negateConstraint(constraint);
168         else
169                 return constraint;
170         
171 }
172
173 Constraint * encodeTotalOrderSATEncoder(SATEncoder *This, BooleanOrder * boolOrder){
174         ASSERT(boolOrder->order->type == TOTAL);
175         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
176         OrderPair pair={boolOrder->first, boolOrder->second};
177         Constraint* constraint = getPairConstraint(This, boolToConsts, & pair);
178         return constraint;
179 }
180
181 void createAllTotalOrderConstraintsSATEncoder(SATEncoder* This, Order* order){
182         ASSERT(order->type == TOTAL);
183         VectorInt* mems = order->set->members;
184         HashTableBoolConst* table = order->boolsToConstraints;
185         uint size = getSizeVectorInt(mems);
186         for(uint i=0; i<size; i++){
187                 uint64_t valueI = getVectorInt(mems, i);
188                 for(uint j=i+1; j<size;j++){
189                         uint64_t valueJ = getVectorInt(mems, j);
190                         OrderPair pairIJ = {valueI, valueJ};
191                         Constraint* constIJ=getPairConstraint(This, table, & pairIJ);
192                         for(uint k=j+1; k<size; k++){
193                                 uint64_t valueK = getVectorInt(mems, k);
194                                 OrderPair pairJK = {valueJ, valueK};
195                                 OrderPair pairIK = {valueI, valueK};
196                                 Constraint* constIK = getPairConstraint(This, table, & pairIK);
197                                 Constraint* constJK = getPairConstraint(This, table, & pairJK);
198                                 generateTransOrderConstraintSATEncoder(This, constIJ, constJK, constIK); 
199                         }
200                 }
201         }
202 }
203
204 Constraint* getOrderConstraint(HashTableBoolConst *table, OrderPair *pair){
205         ASSERT(pair->first!= pair->second);
206         Constraint* constraint= getBoolConst(table, pair);
207         ASSERT(constraint!= NULL);
208         if(pair->first > pair->second)
209                 return constraint;
210         else
211                 return negateConstraint(constraint);
212 }
213
214 Constraint * generateTransOrderConstraintSATEncoder(SATEncoder *This, Constraint *constIJ,Constraint *constJK,Constraint *constIK){
215         //FIXME: first we should add the the constraint to the satsolver!
216         Constraint *carray[] = {constIJ, constJK, negateConstraint(constIK)};
217         Constraint * loop1= allocArrayConstraint(OR, 3, carray);
218         Constraint * carray2[] = {negateConstraint(constIJ), negateConstraint(constJK), constIK};
219         Constraint * loop2= allocArrayConstraint(OR, 3,carray2 );
220         return allocConstraint(AND, loop1, loop2);
221 }
222
223 Constraint * encodePartialOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint){
224         // FIXME: we can have this implementation for partial order. Basically,
225         // we compute the transitivity between two order constraints specified by the client! (also can be used
226         // when client specify sparse constraints for the total order!)
227         ASSERT(boolOrder->order->type == PARTIAL);
228 /*
229         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
230         if( containsBoolConst(boolToConsts, boolOrder) ){
231                 return getBoolConst(boolToConsts, boolOrder);
232         } else {
233                 Constraint* constraint = getNewVarSATEncoder(This); 
234                 putBoolConst(boolToConsts,boolOrder, constraint);
235                 VectorBoolean* orderConstrs = &boolOrder->order->constraints;
236                 uint size= getSizeVectorBoolean(orderConstrs);
237                 for(uint i=0; i<size; i++){
238                         ASSERT(GETBOOLEANTYPE( getVectorBoolean(orderConstrs, i)) == ORDERCONST );
239                         BooleanOrder* tmp = (BooleanOrder*)getVectorBoolean(orderConstrs, i);
240                         BooleanOrder* newBool;
241                         Constraint* first, *second;
242                         if(tmp->second==boolOrder->first){
243                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,tmp->first,boolOrder->second);
244                                 first = encodeTotalOrderSATEncoder(This, tmp);
245                                 second = constraint;
246                                 
247                         }else if (boolOrder->second == tmp->first){
248                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,boolOrder->first,tmp->second);
249                                 first = constraint;
250                                 second = encodeTotalOrderSATEncoder(This, tmp);
251                         }else
252                                 continue;
253                         Constraint* transConstr= encodeTotalOrderSATEncoder(This, newBool);
254                         generateTransOrderConstraintSATEncoder(This, first, second, transConstr );
255                 }
256                 return constraint;
257         }
258 */      
259         return NULL;
260 }
261
262 Constraint * encodePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint) {
263         switch(GETPREDICATETYPE(constraint->predicate) ){
264                 case TABLEPRED:
265                         return encodeTablePredicateSATEncoder(This, constraint);
266                 case OPERATORPRED:
267                         return encodeOperatorPredicateSATEncoder(This, constraint);
268                 default:
269                         ASSERT(0);
270         }
271         return NULL;
272 }
273
274 Constraint * encodeTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
275         switch(constraint->encoding.type){
276                 case ENUMERATEIMPLICATIONS:
277                 case ENUMERATEIMPLICATIONSNEGATE:
278                         return encodeEnumTablePredicateSATEncoder(This, constraint);
279                 case CIRCUIT:
280                         ASSERT(0);
281                         break;
282                 default:
283                         ASSERT(0);
284         }
285         return NULL;
286 }
287
288 Constraint * encodeEnumTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
289         VectorTableEntry* entries = &(((PredicateTable*)constraint->predicate)->table->entries);
290         FunctionEncodingType encType = constraint->encoding.type;
291         uint size = getSizeVectorTableEntry(entries);
292         Constraint* constraints[size];
293         for(uint i=0; i<size; i++){
294                 TableEntry* entry = getVectorTableEntry(entries, i);
295                 if(encType==ENUMERATEIMPLICATIONS && entry->output!= true)
296                         continue;
297                 else if(encType==ENUMERATEIMPLICATIONSNEGATE && entry->output !=false)
298                         continue;
299                 ArrayElement* inputs = &constraint->inputs;
300                 uint inputNum =getSizeArrayElement(inputs);
301                 Constraint* carray[inputNum];
302                 for(uint j=0; j<inputNum; j++){
303                         Element* el = getArrayElement(inputs, j);
304                         if( GETELEMENTTYPE(el) == ELEMFUNCRETURN)
305                                 encodeFunctionElementSATEncoder(This, (ElementFunction*) el);
306                         carray[j] = getElementValueConstraint(el, entry->inputs[j]);
307                 }
308                 constraints[i]=allocArrayConstraint(AND, inputNum, carray);
309         }
310         Constraint* result= allocArrayConstraint(OR, size, constraints);
311         //FIXME: if it didn't match with any entry
312         return encType==ENUMERATEIMPLICATIONS? result: negateConstraint(result);
313 }
314
315 Constraint * encodeOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
316         switch(constraint->encoding.type){
317                 case ENUMERATEIMPLICATIONS:
318                         return encodeEnumOperatorPredicateSATEncoder(This, constraint);
319                 case CIRCUIT:
320                         ASSERT(0);
321                         break;
322                 default:
323                         ASSERT(0);
324         }
325         return NULL;
326 }
327
328 Constraint * encodeEnumOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
329         ASSERT(GETPREDICATETYPE(constraint)==OPERATORPRED);
330         PredicateOperator* predicate = (PredicateOperator*)constraint->predicate;
331         ASSERT(predicate->op == EQUALS); //For now, we just only support equals
332         //getting maximum size of in common elements between two sets!
333         uint size=getSizeVectorInt( getArraySet( &predicate->domains, 0)->members);
334         uint64_t commonElements [size];
335         getEqualitySetIntersection(predicate, &size, commonElements);
336         Constraint*  carray[size];
337         Element* elem1 = getArrayElement( &constraint->inputs, 0);
338         if( GETELEMENTTYPE(elem1) == ELEMFUNCRETURN)
339                 encodeFunctionElementSATEncoder(This, (ElementFunction*) elem1);
340         Element* elem2 = getArrayElement( &constraint->inputs, 1);
341         if( GETELEMENTTYPE(elem2) == ELEMFUNCRETURN)
342                 encodeFunctionElementSATEncoder(This, (ElementFunction*) elem2);
343         for(uint i=0; i<size; i++){
344                 carray[i] =  allocConstraint(AND, getElementValueConstraint(elem1, commonElements[i]),
345                         getElementValueConstraint(elem2, commonElements[i]) );
346         }
347         //FIXME: the case when there is no intersection ....
348         return allocArrayConstraint(OR, size, carray);
349 }
350
351 Constraint* encodeFunctionElementSATEncoder(SATEncoder* encoder, ElementFunction *This){
352         switch(GETFUNCTIONTYPE(This->function)){
353                 case TABLEFUNC:
354                         return encodeTableElementFunctionSATEncoder(encoder, This);
355                 case OPERATORFUNC:
356                         return encodeOperatorElementFunctionSATEncoder(encoder, This);
357                 default:
358                         ASSERT(0);
359         }
360         return NULL;
361 }
362
363 Constraint* encodeTableElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
364         switch(getElementFunctionEncoding(This)->type){
365                 case ENUMERATEIMPLICATIONS:
366                         return encodeEnumTableElemFunctionSATEncoder(encoder, This);
367                         break;
368                 case CIRCUIT:
369                         ASSERT(0);
370                         break;
371                 default:
372                         ASSERT(0);
373         }
374         return NULL;
375 }
376
377 Constraint* encodeOperatorElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
378         ASSERT(GETFUNCTIONTYPE(This->function) == OPERATORFUNC);
379         ASSERT(getSizeArrayElement(&This->inputs)=2 );
380         ElementEncoding* elem1 = getElementEncoding( getArrayElement(&This->inputs,0) );
381         ElementEncoding* elem2 = getElementEncoding( getArrayElement(&This->inputs,1) );
382         Constraint* carray[elem1->encArraySize*elem2->encArraySize];
383         uint size=0;
384         Constraint* overFlowConstraint = ((BooleanVar*) This->overflowstatus)->var;
385         for(uint i=0; i<elem1->encArraySize; i++){
386                 if(isinUseElement(elem1, i)){
387                         for(uint j=0; j<elem2->encArraySize; j++){
388                                 if(isinUseElement(elem2, j)){
389                                         bool isInRange = false, hasOverFlow=false;
390                                         uint64_t result= applyFunctionOperator((FunctionOperator*)This->function,elem1->encodingArray[i],
391                                                 elem2->encodingArray[j], &isInRange, &hasOverFlow);
392                                         if(!isInRange)
393                                                 break; // Ignoring the cases that result of operation doesn't exist in the code.
394                                         //FIXME: instead of getElementValueConstraint, it might be useful to have another function
395                                         // that doesn't iterate over encodingArray and treats more efficient ...
396                                         Constraint* constraint = allocConstraint(IMPLIES, 
397                                                 allocConstraint(AND, getElementValueConstraint(elem1->element, elem1->encodingArray[i]),
398                                                         getElementValueConstraint(elem2->element, elem2->encodingArray[j])) ,
399                                                 getElementValueConstraint((Element*) This, result) );
400                                         switch( ((FunctionOperator*)This->function)->overflowbehavior ){
401                                                 case IGNORE:
402                                                         if(!hasOverFlow){
403                                                                 carray[size++] = constraint;
404                                                         }
405                                                 case WRAPAROUND:
406                                                         carray[size++] = constraint;
407                                                 case FLAGFORCESOVERFLOW:
408                                                         if(hasOverFlow){
409                                                                 Constraint* const1 = allocConstraint(IMPLIES, overFlowConstraint, 
410                                                                                 allocConstraint(AND, getElementValueConstraint(elem1->element, elem1->encodingArray[i]),
411                                                                                 getElementValueConstraint(elem2->element, elem2->encodingArray[j])));
412                                                                 carray[size++] = allocConstraint(AND, const1, constraint);
413                                                         }
414                                                 case OVERFLOWSETSFLAG:
415                                                         if(hasOverFlow){
416                                                                 Constraint* const1 = allocConstraint(IMPLIES, allocConstraint(AND, getElementValueConstraint(elem1->element, elem1->encodingArray[i]),
417                                                                                 getElementValueConstraint(elem2->element, elem2->encodingArray[j])), overFlowConstraint);
418                                                                 carray[size++] = allocConstraint(AND, const1, constraint);
419                                                         } else
420                                                                 carray[size++] = constraint;
421                                                 case FLAGIFFOVERFLOW:
422                                                         if(!hasOverFlow){
423                                                                 carray[size++] = constraint;
424                                                         }else{
425                                                                 Constraint* const1 = allocConstraint(IMPLIES, overFlowConstraint, 
426                                                                                 allocConstraint(AND, getElementValueConstraint(elem1->element, elem1->encodingArray[i]),
427                                                                                 getElementValueConstraint(elem2->element, elem2->encodingArray[j])));
428                                                                 Constraint* const2 = allocConstraint(IMPLIES, allocConstraint(AND, getElementValueConstraint(elem1->element, elem1->encodingArray[i]),
429                                                                                 getElementValueConstraint(elem2->element, elem2->encodingArray[j])), overFlowConstraint);
430                                                                 Constraint* res [] = {const1, const2, constraint};
431                                                                 carray[size++] = allocArrayConstraint(AND, 3, res);
432                                                         }
433                                                 case NOOVERFLOW:
434                                                         if(hasOverFlow)
435                                                                 ASSERT(0);
436                                                                 carray[size++] = constraint;
437                                                 default:
438                                                         ASSERT(0);
439                                         }
440                                         
441                                 }
442                         }
443                 }
444         }
445         return allocArrayConstraint(AND, size, carray);
446 }
447
448 Constraint* encodeEnumTableElemFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
449         ASSERT(GETFUNCTIONTYPE(This->function)==TABLEFUNC);
450         ArrayElement* elements= &This->inputs;
451         Table* table = ((FunctionTable*) (This->function))->table;
452         uint size = getSizeVectorTableEntry(&table->entries);
453         Constraint* constraints[size]; //FIXME: should add a space for the case that didn't match any entries
454         for(uint i=0; i<size; i++){
455                 TableEntry* entry = getVectorTableEntry(&table->entries, i);
456                 uint inputNum =getSizeArrayElement(elements);
457                 Constraint* carray[inputNum];
458                 for(uint j=0; j<inputNum; j++){
459                         Element* el= getArrayElement(elements, j);
460                         carray[j] = getElementValueConstraint(el, entry->inputs[j]);
461                 }
462                 Constraint* row= allocConstraint(IMPLIES, allocArrayConstraint(AND, inputNum, carray),
463                         getElementValueBinaryIndexConstraint((Element*)This, entry->output));
464                 constraints[i]=row;
465         }
466         Constraint* result = allocArrayConstraint(OR, size, constraints);
467         return result;
468 }