Fixing some bugs
[satune.git] / src / Backend / satencoder.c
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "csolver.h"
4 #include "boolean.h"
5 #include "constraint.h"
6 #include "common.h"
7 #include "element.h"
8 #include "function.h"
9 #include "tableentry.h"
10 #include "table.h"
11 #include "order.h"
12 #include "predicate.h"
13 #include "orderpair.h"
14 #include "set.h"
15
16 SATEncoder * allocSATEncoder() {
17         SATEncoder *This=ourmalloc(sizeof (SATEncoder));
18         This->varcount=1;
19         return This;
20 }
21
22 void deleteSATEncoder(SATEncoder *This) {
23         ourfree(This);
24 }
25
26 void initializeConstraintVars(CSolver* csolver, SATEncoder* This){
27         /** We really should not walk the free list to generate constraint
28                         variables...walk the constraint tree instead.  Or even better
29                         yet, just do this as you need to during the encodeAllSATEncoder
30                         walk.  */
31
32 //      FIXME!!!!(); // Make sure Hamed sees comment above
33
34         uint size = getSizeVectorElement(csolver->allElements);
35         for(uint i=0; i<size; i++){
36                 Element* element = getVectorElement(csolver->allElements, i);
37                 generateElementEncodingVariables(This,getElementEncoding(element));
38         }
39 }
40
41
42 Constraint * getElementValueConstraint(Element* This, uint64_t value) {
43         switch(getElementEncoding(This)->type){
44                 case ONEHOT:
45                         ASSERT(0);
46                         break;
47                 case UNARY:
48                         ASSERT(0);
49                         break;
50                 case BINARYINDEX:
51                         ASSERT(0);
52                         break;
53                 case ONEHOTBINARY:
54                         return getElementValueBinaryIndexConstraint(This, value);
55                         break;
56                 case BINARYVAL:
57                         ASSERT(0);
58                         break;
59                 default:
60                         ASSERT(0);
61                         break;
62         }
63         return NULL;
64 }
65 Constraint * getElementValueBinaryIndexConstraint(Element* This, uint64_t value) {
66         ASTNodeType type = GETELEMENTTYPE(This);
67         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
68         ElementEncoding* elemEnc = getElementEncoding(This);
69         for(uint i=0; i<elemEnc->encArraySize; i++){
70                 if( isinUseElement(elemEnc, i) && elemEnc->encodingArray[i]==value){
71                         return generateBinaryConstraint(elemEnc->numVars,
72                                 elemEnc->variables, i);
73                 }
74         }
75         return NULL;
76 }
77
78 void encodeAllSATEncoder(CSolver *csolver, SATEncoder * This) {
79         VectorBoolean *constraints=csolver->constraints;
80         uint size=getSizeVectorBoolean(constraints);
81         for(uint i=0;i<size;i++) {
82                 Boolean *constraint=getVectorBoolean(constraints, i);
83                 encodeConstraintSATEncoder(This, constraint);
84         }
85 }
86
87 Constraint * encodeConstraintSATEncoder(SATEncoder *This, Boolean *constraint) {
88         switch(GETBOOLEANTYPE(constraint)) {
89         case ORDERCONST:
90                 return encodeOrderSATEncoder(This, (BooleanOrder *) constraint);
91         case BOOLEANVAR:
92                 return encodeVarSATEncoder(This, (BooleanVar *) constraint);
93         case LOGICOP:
94                 return encodeLogicSATEncoder(This, (BooleanLogic *) constraint);
95         case PREDICATEOP:
96                 return encodePredicateSATEncoder(This, (BooleanPredicate *) constraint);
97         default:
98                 model_print("Unhandled case in encodeConstraintSATEncoder %u", GETBOOLEANTYPE(constraint));
99                 exit(-1);
100         }
101 }
102
103 void getArrayNewVarsSATEncoder(SATEncoder* encoder, uint num, Constraint **carray) {
104         for(uint i=0;i<num;i++)
105                 carray[i]=getNewVarSATEncoder(encoder);
106 }
107
108 Constraint * getNewVarSATEncoder(SATEncoder *This) {
109         Constraint * var=allocVarConstraint(VAR, This->varcount);
110         Constraint * varneg=allocVarConstraint(NOTVAR, This->varcount++);
111         setNegConstraint(var, varneg);
112         setNegConstraint(varneg, var);
113         return var;
114 }
115
116 Constraint * encodeVarSATEncoder(SATEncoder *This, BooleanVar * constraint) {
117         if (constraint->var == NULL) {
118                 constraint->var=getNewVarSATEncoder(This);
119         }
120         return constraint->var;
121 }
122
123 Constraint * encodeLogicSATEncoder(SATEncoder *This, BooleanLogic * constraint) {
124         Constraint * array[getSizeArrayBoolean(&constraint->inputs)];
125         for(uint i=0;i<getSizeArrayBoolean(&constraint->inputs);i++)
126                 array[i]=encodeConstraintSATEncoder(This, getArrayBoolean(&constraint->inputs, i));
127
128         switch(constraint->op) {
129         case L_AND:
130                 return allocArrayConstraint(AND, getSizeArrayBoolean(&constraint->inputs), array);
131         case L_OR:
132                 return allocArrayConstraint(OR, getSizeArrayBoolean(&constraint->inputs), array);
133         case L_NOT:
134                 ASSERT(constraint->numArray==1);
135                 return negateConstraint(array[0]);
136         case L_XOR: {
137                 ASSERT(constraint->numArray==2);
138                 Constraint * nleft=negateConstraint(cloneConstraint(array[0]));
139                 Constraint * nright=negateConstraint(cloneConstraint(array[1]));
140                 return allocConstraint(OR,
141                                                                                                          allocConstraint(AND, array[0], nright),
142                                                                                                          allocConstraint(AND, nleft, array[1]));
143         }
144         case L_IMPLIES:
145                 ASSERT(constraint->numArray==2);
146                 return allocConstraint(IMPLIES, array[0], array[1]);
147         default:
148                 model_print("Unhandled case in encodeLogicSATEncoder %u", constraint->op);
149                 exit(-1);
150         }
151 }
152
153
154 Constraint * encodeOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint) {
155         switch( constraint->order->type){
156                 case PARTIAL:
157                         return encodePartialOrderSATEncoder(This, constraint);
158                 case TOTAL:
159                         return encodeTotalOrderSATEncoder(This, constraint);
160                 default:
161                         ASSERT(0);
162         }
163         return NULL;
164 }
165
166 Constraint * getPairConstraint(SATEncoder *This, HashTableBoolConst * table, OrderPair * pair) {
167         bool negate = false;
168         OrderPair flipped;
169         if (pair->first > pair->second) {
170                 negate=true;
171                 flipped.first=pair->second;
172                 flipped.second=pair->first;
173                 pair = &flipped;
174         }
175         Constraint * constraint;
176         if (!containsBoolConst(table, pair)) {
177                 constraint = getNewVarSATEncoder(This);
178                 OrderPair * paircopy = allocOrderPair(pair->first, pair->second);
179                 putBoolConst(table, paircopy, constraint);
180         } else
181                 constraint = getBoolConst(table, pair);
182         if (negate)
183                 return negateConstraint(constraint);
184         else
185                 return constraint;
186         
187 }
188
189 Constraint * encodeTotalOrderSATEncoder(SATEncoder *This, BooleanOrder * boolOrder){
190         ASSERT(boolOrder->order->type == TOTAL);
191         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
192         OrderPair pair={boolOrder->first, boolOrder->second};
193         Constraint* constraint = getPairConstraint(This, boolToConsts, & pair);
194         return constraint;
195 }
196
197 void createAllTotalOrderConstraintsSATEncoder(SATEncoder* This, Order* order){
198         ASSERT(order->type == TOTAL);
199         VectorInt* mems = order->set->members;
200         HashTableBoolConst* table = order->boolsToConstraints;
201         uint size = getSizeVectorInt(mems);
202         for(uint i=0; i<size; i++){
203                 uint64_t valueI = getVectorInt(mems, i);
204                 for(uint j=i+1; j<size;j++){
205                         uint64_t valueJ = getVectorInt(mems, j);
206                         OrderPair pairIJ = {valueI, valueJ};
207                         Constraint* constIJ=getPairConstraint(This, table, & pairIJ);
208                         for(uint k=j+1; k<size; k++){
209                                 uint64_t valueK = getVectorInt(mems, k);
210                                 OrderPair pairJK = {valueJ, valueK};
211                                 OrderPair pairIK = {valueI, valueK};
212                                 Constraint* constIK = getPairConstraint(This, table, & pairIK);
213                                 Constraint* constJK = getPairConstraint(This, table, & pairJK);
214                                 generateTransOrderConstraintSATEncoder(This, constIJ, constJK, constIK); 
215                         }
216                 }
217         }
218 }
219
220 Constraint* getOrderConstraint(HashTableBoolConst *table, OrderPair *pair){
221         ASSERT(pair->first!= pair->second);
222         Constraint* constraint= getBoolConst(table, pair);
223         ASSERT(constraint!= NULL);
224         if(pair->first > pair->second)
225                 return constraint;
226         else
227                 return negateConstraint(constraint);
228 }
229
230 Constraint * generateTransOrderConstraintSATEncoder(SATEncoder *This, Constraint *constIJ,Constraint *constJK,Constraint *constIK){
231         //FIXME: first we should add the the constraint to the satsolver!
232         Constraint *carray[] = {constIJ, constJK, negateConstraint(constIK)};
233         Constraint * loop1= allocArrayConstraint(OR, 3, carray);
234         Constraint * carray2[] = {negateConstraint(constIJ), negateConstraint(constJK), constIK};
235         Constraint * loop2= allocArrayConstraint(OR, 3,carray2 );
236         return allocConstraint(AND, loop1, loop2);
237 }
238
239 Constraint * encodePartialOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint){
240         // FIXME: we can have this implementation for partial order. Basically,
241         // we compute the transitivity between two order constraints specified by the client! (also can be used
242         // when client specify sparse constraints for the total order!)
243         ASSERT(boolOrder->order->type == PARTIAL);
244 /*
245         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
246         if( containsBoolConst(boolToConsts, boolOrder) ){
247                 return getBoolConst(boolToConsts, boolOrder);
248         } else {
249                 Constraint* constraint = getNewVarSATEncoder(This); 
250                 putBoolConst(boolToConsts,boolOrder, constraint);
251                 VectorBoolean* orderConstrs = &boolOrder->order->constraints;
252                 uint size= getSizeVectorBoolean(orderConstrs);
253                 for(uint i=0; i<size; i++){
254                         ASSERT(GETBOOLEANTYPE( getVectorBoolean(orderConstrs, i)) == ORDERCONST );
255                         BooleanOrder* tmp = (BooleanOrder*)getVectorBoolean(orderConstrs, i);
256                         BooleanOrder* newBool;
257                         Constraint* first, *second;
258                         if(tmp->second==boolOrder->first){
259                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,tmp->first,boolOrder->second);
260                                 first = encodeTotalOrderSATEncoder(This, tmp);
261                                 second = constraint;
262                                 
263                         }else if (boolOrder->second == tmp->first){
264                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,boolOrder->first,tmp->second);
265                                 first = constraint;
266                                 second = encodeTotalOrderSATEncoder(This, tmp);
267                         }else
268                                 continue;
269                         Constraint* transConstr= encodeTotalOrderSATEncoder(This, newBool);
270                         generateTransOrderConstraintSATEncoder(This, first, second, transConstr );
271                 }
272                 return constraint;
273         }
274 */      
275         return NULL;
276 }
277
278 Constraint * encodePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint) {
279         switch(GETPREDICATETYPE(constraint->predicate) ){
280                 case TABLEPRED:
281                         return encodeTablePredicateSATEncoder(This, constraint);
282                 case OPERATORPRED:
283                         return encodeOperatorPredicateSATEncoder(This, constraint);
284                 default:
285                         ASSERT(0);
286         }
287         return NULL;
288 }
289
290 Constraint * encodeTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
291         switch(constraint->encoding.type){
292                 case ENUMERATEIMPLICATIONS:
293                 case ENUMERATEIMPLICATIONSNEGATE:
294                         return encodeEnumTablePredicateSATEncoder(This, constraint);
295                 case CIRCUIT:
296                         ASSERT(0);
297                         break;
298                 default:
299                         ASSERT(0);
300         }
301         return NULL;
302 }
303
304 Constraint * encodeEnumTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
305         VectorTableEntry* entries = &(((PredicateTable*)constraint->predicate)->table->entries);
306         FunctionEncodingType encType = constraint->encoding.type;
307         uint size = getSizeVectorTableEntry(entries);
308         Constraint* constraints[size];
309         for(uint i=0; i<size; i++){
310                 TableEntry* entry = getVectorTableEntry(entries, i);
311                 if(encType==ENUMERATEIMPLICATIONS && entry->output!= true)
312                         continue;
313                 else if(encType==ENUMERATEIMPLICATIONSNEGATE && entry->output !=false)
314                         continue;
315                 ArrayElement* inputs = &constraint->inputs;
316                 uint inputNum =getSizeArrayElement(inputs);
317                 Constraint* carray[inputNum];
318                 for(uint j=0; j<inputNum; j++){
319                         Element* el = getArrayElement(inputs, j);
320                         if( GETELEMENTTYPE(el) == ELEMFUNCRETURN)
321                                 encodeFunctionElementSATEncoder(This, (ElementFunction*) el);
322                         carray[j] = getElementValueConstraint(el, entry->inputs[j]);
323                 }
324                 constraints[i]=allocArrayConstraint(AND, inputNum, carray);
325         }
326         Constraint* result= allocArrayConstraint(OR, size, constraints);
327         //FIXME: if it didn't match with any entry
328         return encType==ENUMERATEIMPLICATIONS? result: negateConstraint(result);
329 }
330
331 Constraint * encodeOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
332         switch(constraint->encoding.type){
333                 case ENUMERATEIMPLICATIONS:
334                         return encodeEnumOperatorPredicateSATEncoder(This, constraint);
335                 case CIRCUIT:
336                         ASSERT(0);
337                         break;
338                 default:
339                         ASSERT(0);
340         }
341         return NULL;
342 }
343
344 Constraint * encodeEnumOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
345         ASSERT(GETPREDICATETYPE(constraint)==OPERATORPRED);
346         PredicateOperator* predicate = (PredicateOperator*)constraint->predicate;
347         ASSERT(predicate->op == EQUALS); //For now, we just only support equals
348         //getting maximum size of in common elements between two sets!
349         uint size=getSizeVectorInt( getArraySet( &predicate->domains, 0)->members);
350         uint64_t commonElements [size];
351         getEqualitySetIntersection(predicate, &size, commonElements);
352         Constraint*  carray[size];
353         Element* elem1 = getArrayElement( &constraint->inputs, 0);
354         if( GETELEMENTTYPE(elem1) == ELEMFUNCRETURN)
355                 encodeFunctionElementSATEncoder(This, (ElementFunction*) elem1);
356         Element* elem2 = getArrayElement( &constraint->inputs, 1);
357         if( GETELEMENTTYPE(elem2) == ELEMFUNCRETURN)
358                 encodeFunctionElementSATEncoder(This, (ElementFunction*) elem2);
359         for(uint i=0; i<size; i++){
360                 carray[i] =  allocConstraint(AND, getElementValueConstraint(elem1, commonElements[i]),
361                         getElementValueConstraint(elem2, commonElements[i]) );
362         }
363         //FIXME: the case when there is no intersection ....
364         return allocArrayConstraint(OR, size, carray);
365 }
366
367 Constraint* encodeFunctionElementSATEncoder(SATEncoder* encoder, ElementFunction *This){
368         switch(GETFUNCTIONTYPE(This->function)){
369                 case TABLEFUNC:
370                         return encodeTableElementFunctionSATEncoder(encoder, This);
371                 case OPERATORFUNC:
372                         return encodeOperatorElementFunctionSATEncoder(encoder, This);
373                 default:
374                         ASSERT(0);
375         }
376         return NULL;
377 }
378
379 Constraint* encodeTableElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
380         switch(getElementFunctionEncoding(This)->type){
381                 case ENUMERATEIMPLICATIONS:
382                         return encodeEnumTableElemFunctionSATEncoder(encoder, This);
383                         break;
384                 case CIRCUIT:
385                         ASSERT(0);
386                         break;
387                 default:
388                         ASSERT(0);
389         }
390         return NULL;
391 }
392
393 Constraint* encodeOperatorElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
394         //FIXME: for now it just adds/substracts inputs exhustively
395         return NULL;
396 }
397
398 Constraint* encodeEnumTableElemFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
399         ASSERT(GETFUNCTIONTYPE(This->function)==TABLEFUNC);
400         ArrayElement* elements= &This->inputs;
401         Table* table = ((FunctionTable*) (This->function))->table;
402         uint size = getSizeVectorTableEntry(&table->entries);
403         Constraint* constraints[size]; //FIXME: should add a space for the case that didn't match any entries
404         for(uint i=0; i<size; i++){
405                 TableEntry* entry = getVectorTableEntry(&table->entries, i);
406                 uint inputNum =getSizeArrayElement(elements);
407                 Constraint* carray[inputNum];
408                 for(uint j=0; j<inputNum; j++){
409                         Element* el= getArrayElement(elements, j);
410                         carray[j] = getElementValueConstraint(el, entry->inputs[j]);
411                 }
412                 Constraint* row= allocConstraint(IMPLIES, allocArrayConstraint(AND, inputNum, carray),
413                         getElementValueBinaryIndexConstraint((Element*)This, entry->output));
414                 constraints[i]=row;
415         }
416         Constraint* result = allocArrayConstraint(OR, size, constraints);
417         return result;
418 }