some bug fixes
[satune.git] / src / Backend / satencoder.c
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "csolver.h"
4 #include "boolean.h"
5 #include "constraint.h"
6 #include "common.h"
7 #include "element.h"
8 #include "function.h"
9 #include "tableentry.h"
10 #include "table.h"
11 #include "order.h"
12 #include "predicate.h"
13 #include "orderpair.h"
14 #include "set.h"
15
16 SATEncoder * allocSATEncoder() {
17         SATEncoder *This=ourmalloc(sizeof (SATEncoder));
18         This->varcount=1;
19         return This;
20 }
21
22 void deleteSATEncoder(SATEncoder *This) {
23         ourfree(This);
24 }
25
26 void initializeConstraintVars(CSolver* csolver, SATEncoder* This){
27         /** We really should not walk the free list to generate constraint
28                         variables...walk the constraint tree instead.  Or even better
29                         yet, just do this as you need to during the encodeAllSATEncoder
30                         walk.  */
31
32 //      FIXME!!!!(); // Make sure Hamed sees comment above
33
34         uint size = getSizeVectorElement(csolver->allElements);
35         for(uint i=0; i<size; i++){
36                 Element* element = getVectorElement(csolver->allElements, i);
37                 generateElementEncodingVariables(This,getElementEncoding(element));
38         }
39 }
40
41
42 Constraint * getElementValueConstraint(Element* This, uint64_t value) {
43         switch(getElementEncoding(This)->type){
44                 case ONEHOT:
45                         ASSERT(0);
46                         break;
47                 case UNARY:
48                         ASSERT(0);
49                         break;
50                 case BINARYINDEX:
51                         ASSERT(0);
52                         break;
53                 case ONEHOTBINARY:
54                         return getElementValueBinaryIndexConstraint(This, value);
55                         break;
56                 case BINARYVAL:
57                         ASSERT(0);
58                         break;
59                 default:
60                         ASSERT(0);
61                         break;
62         }
63         return NULL;
64 }
65 Constraint * getElementValueBinaryIndexConstraint(Element* This, uint64_t value) {
66         ASTNodeType type = GETELEMENTTYPE(This);
67         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
68         ElementEncoding* elemEnc = getElementEncoding(This);
69         for(uint i=0; i<elemEnc->encArraySize; i++){
70                 if( isinUseElement(elemEnc, i) && elemEnc->encodingArray[i]==value){
71                         return generateBinaryConstraint(elemEnc->numVars,
72                                 elemEnc->variables, i);
73                 }
74         }
75         return NULL;
76 }
77
78 void encodeAllSATEncoder(CSolver *csolver, SATEncoder * This) {
79         VectorBoolean *constraints=csolver->constraints;
80         uint size=getSizeVectorBoolean(constraints);
81         for(uint i=0;i<size;i++) {
82                 Boolean *constraint=getVectorBoolean(constraints, i);
83                 encodeConstraintSATEncoder(This, constraint);
84         }
85         
86 //      FIXME: Following line for element!
87 //      size = getSizeVectorElement(csolver->allElements);
88 //      for(uint i=0; i<size; i++){
89 //              Element* element = getVectorElement(csolver->allElements, i);
90 //              switch(GETELEMENTTYPE(element)){
91 //                      case ELEMFUNCRETURN: 
92 //                              encodeFunctionElementSATEncoder(This, (ElementFunction*) element);
93 //                              break;
94 //                      default:        
95 //                              continue;
96 //                              //ElementSets that aren't used in any constraints/functions
97 //                              //will be eliminated.
98 //              }
99 //      }
100 }
101
102 Constraint * encodeConstraintSATEncoder(SATEncoder *This, Boolean *constraint) {
103         switch(GETBOOLEANTYPE(constraint)) {
104         case ORDERCONST:
105                 return encodeOrderSATEncoder(This, (BooleanOrder *) constraint);
106         case BOOLEANVAR:
107                 return encodeVarSATEncoder(This, (BooleanVar *) constraint);
108         case LOGICOP:
109                 return encodeLogicSATEncoder(This, (BooleanLogic *) constraint);
110         case PREDICATEOP:
111                 return encodePredicateSATEncoder(This, (BooleanPredicate *) constraint);
112         default:
113                 model_print("Unhandled case in encodeConstraintSATEncoder %u", GETBOOLEANTYPE(constraint));
114                 exit(-1);
115         }
116 }
117
118 void getArrayNewVarsSATEncoder(SATEncoder* encoder, uint num, Constraint **carray) {
119         for(uint i=0;i<num;i++)
120                 carray[i]=getNewVarSATEncoder(encoder);
121 }
122
123 Constraint * getNewVarSATEncoder(SATEncoder *This) {
124         Constraint * var=allocVarConstraint(VAR, This->varcount);
125         Constraint * varneg=allocVarConstraint(NOTVAR, This->varcount++);
126         setNegConstraint(var, varneg);
127         setNegConstraint(varneg, var);
128         return var;
129 }
130
131 Constraint * encodeVarSATEncoder(SATEncoder *This, BooleanVar * constraint) {
132         if (constraint->var == NULL) {
133                 constraint->var=getNewVarSATEncoder(This);
134         }
135         return constraint->var;
136 }
137
138 Constraint * encodeLogicSATEncoder(SATEncoder *This, BooleanLogic * constraint) {
139         Constraint * array[getSizeArrayBoolean(&constraint->inputs)];
140         for(uint i=0;i<getSizeArrayBoolean(&constraint->inputs);i++)
141                 array[i]=encodeConstraintSATEncoder(This, getArrayBoolean(&constraint->inputs, i));
142
143         switch(constraint->op) {
144         case L_AND:
145                 return allocArrayConstraint(AND, getSizeArrayBoolean(&constraint->inputs), array);
146         case L_OR:
147                 return allocArrayConstraint(OR, getSizeArrayBoolean(&constraint->inputs), array);
148         case L_NOT:
149                 ASSERT(constraint->numArray==1);
150                 return negateConstraint(array[0]);
151         case L_XOR: {
152                 ASSERT(constraint->numArray==2);
153                 Constraint * nleft=negateConstraint(cloneConstraint(array[0]));
154                 Constraint * nright=negateConstraint(cloneConstraint(array[1]));
155                 return allocConstraint(OR,
156                                                                                                          allocConstraint(AND, array[0], nright),
157                                                                                                          allocConstraint(AND, nleft, array[1]));
158         }
159         case L_IMPLIES:
160                 ASSERT(constraint->numArray==2);
161                 return allocConstraint(IMPLIES, array[0], array[1]);
162         default:
163                 model_print("Unhandled case in encodeLogicSATEncoder %u", constraint->op);
164                 exit(-1);
165         }
166 }
167
168
169 Constraint * encodeOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint) {
170         switch( constraint->order->type){
171                 case PARTIAL:
172                         return encodePartialOrderSATEncoder(This, constraint);
173                 case TOTAL:
174                         return encodeTotalOrderSATEncoder(This, constraint);
175                 default:
176                         ASSERT(0);
177         }
178         return NULL;
179 }
180
181 Constraint * getPairConstraint(SATEncoder *This, HashTableBoolConst * table, OrderPair * pair) {
182         ASSERT(pair->first < pair->second);
183         if (!containsBoolConst(table, pair)) {
184                 Constraint *constraint = getNewVarSATEncoder(This);
185                 OrderPair * paircopy = allocOrderPair(pair->first, pair->second);
186                 putBoolConst(table, paircopy, constraint);
187                 return constraint;
188         } else
189                 return getBoolConst(table, pair);
190 }
191
192 Constraint * encodeTotalOrderSATEncoder(SATEncoder *This, BooleanOrder * boolOrder){
193         ASSERT(boolOrder->order->type == TOTAL);
194         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
195         OrderPair pair;
196         if (boolOrder->first < boolOrder->second) {
197                 pair.first=boolOrder->first;
198                 pair.second=boolOrder->second;
199         } else {
200                 pair.first=boolOrder->second;
201                 pair.second=boolOrder->first;
202         }
203         Constraint* constraint = getPairConstraint(This, boolToConsts, & pair);
204         return constraint;
205 }
206
207 void createAllTotalOrderConstraintsSATEncoder(SATEncoder* This, Order* order){
208         ASSERT(order->type == TOTAL);
209         VectorInt* mems = order->set->members;
210         HashTableBoolConst* table = order->boolsToConstraints;
211         uint size = getSizeVectorInt(mems);
212         for(uint i=0; i<size; i++){
213                 uint64_t valueI = getVectorInt(mems, i);
214                 for(uint j=i+1; j<size;j++){
215                         uint64_t valueJ = getVectorInt(mems, j);
216                         OrderPair pairIJ = {valueI, valueJ};
217                         Constraint* constIJ=getPairConstraint(This, table, & pairIJ);
218                         for(uint k=j+1; k<size; k++){
219                                 uint64_t valueK = getVectorInt(mems, k);
220                                 OrderPair pairJK = {valueJ, valueK};
221                                 OrderPair pairIK = {valueI, valueK};
222                                 Constraint* constIK = getPairConstraint(This, table, & pairIK);
223                                 Constraint* constJK = getPairConstraint(This, table, & pairJK);
224                                 generateTransOrderConstraintSATEncoder(This, constIJ, constJK, constIK); 
225                         }
226                 }
227         }
228 }
229
230 Constraint* getOrderConstraint(HashTableBoolConst *table, OrderPair *pair){
231         ASSERT(pair->first!= pair->second);
232         Constraint* constraint= getBoolConst(table, pair);
233         ASSERT(constraint!= NULL);
234         if(pair->first > pair->second)
235                 return constraint;
236         else
237                 return negateConstraint(constraint);
238 }
239
240 Constraint * generateTransOrderConstraintSATEncoder(SATEncoder *This, Constraint *constIJ,Constraint *constJK,Constraint *constIK){
241         //FIXME: first we should add the the constraint to the satsolver!
242         Constraint *carray[] = {constIJ, constJK, negateConstraint(constIK)};
243         Constraint * loop1= allocArrayConstraint(OR, 3, carray);
244         Constraint * carray2[] = {negateConstraint(constIJ), negateConstraint(constJK), constIK};
245         Constraint * loop2= allocArrayConstraint(OR, 3,carray2 );
246         return allocConstraint(AND, loop1, loop2);
247 }
248
249 Constraint * encodePartialOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint){
250         // FIXME: we can have this implementation for partial order. Basically,
251         // we compute the transitivity between two order constraints specified by the client! (also can be used
252         // when client specify sparse constraints for the total order!)
253         ASSERT(boolOrder->order->type == PARTIAL);
254 /*
255         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
256         if( containsBoolConst(boolToConsts, boolOrder) ){
257                 return getBoolConst(boolToConsts, boolOrder);
258         } else {
259                 Constraint* constraint = getNewVarSATEncoder(This); 
260                 putBoolConst(boolToConsts,boolOrder, constraint);
261                 VectorBoolean* orderConstrs = &boolOrder->order->constraints;
262                 uint size= getSizeVectorBoolean(orderConstrs);
263                 for(uint i=0; i<size; i++){
264                         ASSERT(GETBOOLEANTYPE( getVectorBoolean(orderConstrs, i)) == ORDERCONST );
265                         BooleanOrder* tmp = (BooleanOrder*)getVectorBoolean(orderConstrs, i);
266                         BooleanOrder* newBool;
267                         Constraint* first, *second;
268                         if(tmp->second==boolOrder->first){
269                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,tmp->first,boolOrder->second);
270                                 first = encodeTotalOrderSATEncoder(This, tmp);
271                                 second = constraint;
272                                 
273                         }else if (boolOrder->second == tmp->first){
274                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,boolOrder->first,tmp->second);
275                                 first = constraint;
276                                 second = encodeTotalOrderSATEncoder(This, tmp);
277                         }else
278                                 continue;
279                         Constraint* transConstr= encodeTotalOrderSATEncoder(This, newBool);
280                         generateTransOrderConstraintSATEncoder(This, first, second, transConstr );
281                 }
282                 return constraint;
283         }
284 */      
285         return NULL;
286 }
287
288 Constraint * encodePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint) {
289         switch(GETPREDICATETYPE(constraint) ){
290                 case TABLEPRED:
291                         return encodeTablePredicateSATEncoder(This, constraint);
292                 case OPERATORPRED:
293                         return encodeOperatorPredicateSATEncoder(This, constraint);
294                 default:
295                         ASSERT(0);
296         }
297         return NULL;
298 }
299
300 Constraint * encodeTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
301         switch(constraint->encoding.type){
302                 case ENUMERATEIMPLICATIONS:
303                 case ENUMERATEIMPLICATIONSNEGATE:
304                         return encodeEnumTablePredicateSATEncoder(This, constraint);
305                 case CIRCUIT:
306                         ASSERT(0);
307                         break;
308                 default:
309                         ASSERT(0);
310         }
311         return NULL;
312 }
313
314 Constraint * encodeEnumTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
315         VectorTableEntry* entries = &(((PredicateTable*)constraint->predicate)->table->entries);
316         FunctionEncodingType encType = constraint->encoding.type;
317         uint size = getSizeVectorTableEntry(entries);
318         Constraint* constraints[size]; //FIXME: should add a space for the case that didn't match any entries
319         for(uint i=0; i<size; i++){
320                 TableEntry* entry = getVectorTableEntry(entries, i);
321                 if(encType==ENUMERATEIMPLICATIONS && entry->output!= true)
322                         continue;
323                 else if(encType==ENUMERATEIMPLICATIONSNEGATE && entry->output !=false)
324                         continue;
325                 ArrayElement* inputs = &constraint->inputs;
326                 uint inputNum =getSizeArrayElement(inputs);
327                 Constraint* carray[inputNum];
328                 Element* el = getArrayElement(inputs, i);
329                 for(uint j=0; j<inputNum; j++){
330                         carray[inputNum] = getElementValueConstraint(el, entry->inputs[j]);
331                 }
332                 constraints[i]=allocArrayConstraint(AND, inputNum, carray);
333         }
334         Constraint* result= allocArrayConstraint(OR, size, constraints);
335         return encType==ENUMERATEIMPLICATIONS? result: negateConstraint(result);
336 }
337
338 Constraint * encodeOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
339         switch(constraint->encoding.type){
340                 case ENUMERATEIMPLICATIONS:
341                         break;
342                 case CIRCUIT:
343                         break;
344                 default:
345                         ASSERT(0);
346         }
347         return NULL;
348 }
349
350 Constraint* encodeFunctionElementSATEncoder(SATEncoder* encoder, ElementFunction *This){
351         switch(GETFUNCTIONTYPE(This->function)){
352                 case TABLEFUNC:
353                         return encodeTableElementFunctionSATEncoder(encoder, This);
354                 case OPERATORFUNC:
355                         return encodeOperatorElementFunctionSATEncoder(encoder, This);
356                 default:
357                         ASSERT(0);
358         }
359         return NULL;
360 }
361
362 Constraint* encodeTableElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
363         switch(getElementFunctionEncoding(This)->type){
364                 case ENUMERATEIMPLICATIONS:
365                         return encodeEnumTableElemFunctionSATEncoder(encoder, This);
366                         break;
367                 case CIRCUIT:
368                         ASSERT(0);
369                         break;
370                 default:
371                         ASSERT(0);
372         }
373         return NULL;
374 }
375
376 Constraint* encodeOperatorElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
377         //FIXME: for now it just adds/substracts inputs exhustively
378         return NULL;
379 }
380
381 Constraint* encodeEnumTableElemFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
382         ASSERT(GETFUNCTIONTYPE(This->function)==TABLEFUNC);
383         ArrayElement* elements= &This->inputs;
384         Table* table = ((FunctionTable*) (This->function))->table;
385         uint size = getSizeVectorTableEntry(&table->entries);
386         Constraint* constraints[size]; //FIXME: should add a space for the case that didn't match any entries
387         for(uint i=0; i<size; i++){
388                 TableEntry* entry = getVectorTableEntry(&table->entries, i);
389                 uint inputNum =getSizeArrayElement(elements);
390                 Element* el= getArrayElement(elements, i);
391                 Constraint* carray[inputNum];
392                 for(uint j=0; j<inputNum; j++){
393                         carray[inputNum] = getElementValueConstraint(el, entry->inputs[j]);
394                 }
395                 Constraint* row= allocConstraint(IMPLIES, allocArrayConstraint(AND, inputNum, carray),
396                         getElementValueBinaryIndexConstraint((Element*)This, entry->output));
397                 constraints[i]=row;
398         }
399         Constraint* result = allocArrayConstraint(OR, size, constraints);
400         return result;
401 }