fa124a68f741e7044cd10c2b3cbe0ec0bdbdc494
[satune.git] / src / Backend / satencoder.c
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "csolver.h"
4 #include "boolean.h"
5 #include "constraint.h"
6 #include "common.h"
7 #include "element.h"
8 #include "function.h"
9 #include "tableentry.h"
10 #include "table.h"
11 #include "order.h"
12 #include "predicate.h"
13 #include "orderpair.h"
14 #include "set.h"
15
16 SATEncoder * allocSATEncoder() {
17         SATEncoder *This=ourmalloc(sizeof (SATEncoder));
18         This->varcount=1;
19         return This;
20 }
21
22 void deleteSATEncoder(SATEncoder *This) {
23         ourfree(This);
24 }
25
26 void initializeConstraintVars(CSolver* csolver, SATEncoder* This){
27         /** We really should not walk the free list to generate constraint
28                         variables...walk the constraint tree instead.  Or even better
29                         yet, just do this as you need to during the encodeAllSATEncoder
30                         walk.  */
31
32 //      FIXME!!!!(); // Make sure Hamed sees comment above
33
34         uint size = getSizeVectorElement(csolver->allElements);
35         for(uint i=0; i<size; i++){
36                 Element* element = getVectorElement(csolver->allElements, i);
37                 generateElementEncodingVariables(This,getElementEncoding(element));
38         }
39 }
40
41
42 Constraint * getElementValueConstraint(Element* This, uint64_t value) {
43         switch(getElementEncoding(This)->type){
44                 case ONEHOT:
45                         ASSERT(0);
46                         break;
47                 case UNARY:
48                         ASSERT(0);
49                         break;
50                 case BINARYINDEX:
51                         ASSERT(0);
52                         break;
53                 case ONEHOTBINARY:
54                         return getElementValueBinaryIndexConstraint(This, value);
55                         break;
56                 case BINARYVAL:
57                         ASSERT(0);
58                         break;
59                 default:
60                         ASSERT(0);
61                         break;
62         }
63         return NULL;
64 }
65 Constraint * getElementValueBinaryIndexConstraint(Element* This, uint64_t value) {
66         ASTNodeType type = GETELEMENTTYPE(This);
67         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
68         ElementEncoding* elemEnc = getElementEncoding(This);
69         for(uint i=0; i<elemEnc->encArraySize; i++){
70                 if( isinUseElement(elemEnc, i) && elemEnc->encodingArray[i]==value){
71                         return generateBinaryConstraint(elemEnc->numVars,
72                                 elemEnc->variables, i);
73                 }
74         }
75         return NULL;
76 }
77
78 void encodeAllSATEncoder(CSolver *csolver, SATEncoder * This) {
79         VectorBoolean *constraints=csolver->constraints;
80         uint size=getSizeVectorBoolean(constraints);
81         for(uint i=0;i<size;i++) {
82                 Boolean *constraint=getVectorBoolean(constraints, i);
83                 encodeConstraintSATEncoder(This, constraint);
84         }
85         
86 //      FIXME: Following line for element!
87 //      size = getSizeVectorElement(csolver->allElements);
88 //      for(uint i=0; i<size; i++){
89 //              Element* element = getVectorElement(csolver->allElements, i);
90 //              switch(GETELEMENTTYPE(element)){
91 //                      case ELEMFUNCRETURN: 
92 //                              encodeFunctionElementSATEncoder(This, (ElementFunction*) element);
93 //                              break;
94 //                      default:        
95 //                              continue;
96 //                              //ElementSets that aren't used in any constraints/functions
97 //                              //will be eliminated.
98 //              }
99 //      }
100 }
101
102 Constraint * encodeConstraintSATEncoder(SATEncoder *This, Boolean *constraint) {
103         switch(GETBOOLEANTYPE(constraint)) {
104         case ORDERCONST:
105                 return encodeOrderSATEncoder(This, (BooleanOrder *) constraint);
106         case BOOLEANVAR:
107                 return encodeVarSATEncoder(This, (BooleanVar *) constraint);
108         case LOGICOP:
109                 return encodeLogicSATEncoder(This, (BooleanLogic *) constraint);
110         case PREDICATEOP:
111                 return encodePredicateSATEncoder(This, (BooleanPredicate *) constraint);
112         default:
113                 model_print("Unhandled case in encodeConstraintSATEncoder %u", GETBOOLEANTYPE(constraint));
114                 exit(-1);
115         }
116 }
117
118 void getArrayNewVarsSATEncoder(SATEncoder* encoder, uint num, Constraint **carray) {
119         for(uint i=0;i<num;i++)
120                 carray[i]=getNewVarSATEncoder(encoder);
121 }
122
123 Constraint * getNewVarSATEncoder(SATEncoder *This) {
124         Constraint * var=allocVarConstraint(VAR, This->varcount);
125         Constraint * varneg=allocVarConstraint(NOTVAR, This->varcount++);
126         setNegConstraint(var, varneg);
127         setNegConstraint(varneg, var);
128         return var;
129 }
130
131 Constraint * encodeVarSATEncoder(SATEncoder *This, BooleanVar * constraint) {
132         if (constraint->var == NULL) {
133                 constraint->var=getNewVarSATEncoder(This);
134         }
135         return constraint->var;
136 }
137
138 Constraint * encodeLogicSATEncoder(SATEncoder *This, BooleanLogic * constraint) {
139         Constraint * array[getSizeArrayBoolean(&constraint->inputs)];
140         for(uint i=0;i<getSizeArrayBoolean(&constraint->inputs);i++)
141                 array[i]=encodeConstraintSATEncoder(This, getArrayBoolean(&constraint->inputs, i));
142
143         switch(constraint->op) {
144         case L_AND:
145                 return allocArrayConstraint(AND, getSizeArrayBoolean(&constraint->inputs), array);
146         case L_OR:
147                 return allocArrayConstraint(OR, getSizeArrayBoolean(&constraint->inputs), array);
148         case L_NOT:
149                 ASSERT(constraint->numArray==1);
150                 return negateConstraint(array[0]);
151         case L_XOR: {
152                 ASSERT(constraint->numArray==2);
153                 Constraint * nleft=negateConstraint(cloneConstraint(array[0]));
154                 Constraint * nright=negateConstraint(cloneConstraint(array[1]));
155                 return allocConstraint(OR,
156                                                                                                          allocConstraint(AND, array[0], nright),
157                                                                                                          allocConstraint(AND, nleft, array[1]));
158         }
159         case L_IMPLIES:
160                 ASSERT(constraint->numArray==2);
161                 return allocConstraint(IMPLIES, array[0], array[1]);
162         default:
163                 model_print("Unhandled case in encodeLogicSATEncoder %u", constraint->op);
164                 exit(-1);
165         }
166 }
167
168
169 Constraint * encodeOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint) {
170         switch( constraint->order->type){
171                 case PARTIAL:
172                         return encodePartialOrderSATEncoder(This, constraint);
173                 case TOTAL:
174                         return encodeTotalOrderSATEncoder(This, constraint);
175                 default:
176                         ASSERT(0);
177         }
178         return NULL;
179 }
180
181
182 Constraint * encodeTotalOrderSATEncoder(SATEncoder *This, BooleanOrder * boolOrder){
183         ASSERT(boolOrder->order->type == TOTAL);
184         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
185         OrderPair* pair = allocOrderPair(boolOrder->first, boolOrder->second);
186         if( !containsBoolConst(boolToConsts, pair) ){
187                 createAllTotalOrderConstraintsSATEncoder(This, boolOrder->order);
188         }
189         Constraint* constraint = getOrderConstraint(boolToConsts, pair);
190         deleteOrderPair(pair);
191         return constraint;
192 }
193
194 void createAllTotalOrderConstraintsSATEncoder(SATEncoder* This, Order* order){
195         ASSERT(order->type == TOTAL);
196         VectorInt* mems = order->set->members;
197         HashTableBoolConst* table = order->boolsToConstraints;
198         uint size = getSizeVectorInt(mems);
199         for(uint i=0; i<size; i++){
200                 uint64_t valueI = getVectorInt(mems, i);
201                 for(uint j=i+1; j<size;j++){
202                         uint64_t valueJ = getVectorInt(mems, j);
203                         OrderPair* pairIJ = allocOrderPair(valueI, valueJ);
204                         ASSERT(!containsBoolConst(table, pairIJ));
205                         Constraint* constIJ = getNewVarSATEncoder(This);
206                         putBoolConst(table, pairIJ, constIJ);
207                         for(uint k=j+1; k<size; k++){
208                                 uint64_t valueK = getVectorInt(mems, k);
209                                 OrderPair* pairJK = allocOrderPair(valueJ, valueK);
210                                 OrderPair* pairIK = allocOrderPair(valueI, valueK);
211                                 Constraint* constJK, *constIK;
212                                 if(!containsBoolConst(table, pairJK)){
213                                         constJK = getNewVarSATEncoder(This);
214                                         putBoolConst(table, pairJK, constJK);
215                                 }
216                                 if(!containsBoolConst(table, pairIK)){
217                                         constIK = getNewVarSATEncoder(This);
218                                         putBoolConst(table, pairIK, constIK);
219                                 }
220                                 generateTransOrderConstraintSATEncoder(This, constIJ, constJK, constIK); 
221                         }
222                 }
223         }
224 }
225
226 Constraint* getOrderConstraint(HashTableBoolConst *table, OrderPair *pair){
227         ASSERT(pair->first!= pair->second);
228         Constraint* constraint= getBoolConst(table, pair);
229         ASSERT(constraint!= NULL);
230         if(pair->first > pair->second)
231                 return constraint;
232         else
233                 return negateConstraint(constraint);
234 }
235
236 Constraint * generateTransOrderConstraintSATEncoder(SATEncoder *This, Constraint *constIJ,Constraint *constJK,Constraint *constIK){
237         //FIXME: first we should add the the constraint to the satsolver!
238         Constraint *carray[] = {constIJ, constJK, negateConstraint(constIK)};
239         Constraint * loop1= allocArrayConstraint(OR, 3, carray);
240         Constraint * carray2[] = {negateConstraint(constIJ), negateConstraint(constJK), constIK};
241         Constraint * loop2= allocArrayConstraint(OR, 3,carray2 );
242         return allocConstraint(AND, loop1, loop2);
243 }
244
245 Constraint * encodePartialOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint){
246         // FIXME: we can have this implementation for partial order. Basically,
247         // we compute the transitivity between two order constraints specified by the client! (also can be used
248         // when client specify sparse constraints for the total order!)
249         ASSERT(boolOrder->order->type == PARTIAL);
250 /*
251         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
252         if( containsBoolConst(boolToConsts, boolOrder) ){
253                 return getBoolConst(boolToConsts, boolOrder);
254         } else {
255                 Constraint* constraint = getNewVarSATEncoder(This); 
256                 putBoolConst(boolToConsts,boolOrder, constraint);
257                 VectorBoolean* orderConstrs = &boolOrder->order->constraints;
258                 uint size= getSizeVectorBoolean(orderConstrs);
259                 for(uint i=0; i<size; i++){
260                         ASSERT(GETBOOLEANTYPE( getVectorBoolean(orderConstrs, i)) == ORDERCONST );
261                         BooleanOrder* tmp = (BooleanOrder*)getVectorBoolean(orderConstrs, i);
262                         BooleanOrder* newBool;
263                         Constraint* first, *second;
264                         if(tmp->second==boolOrder->first){
265                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,tmp->first,boolOrder->second);
266                                 first = encodeTotalOrderSATEncoder(This, tmp);
267                                 second = constraint;
268                                 
269                         }else if (boolOrder->second == tmp->first){
270                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,boolOrder->first,tmp->second);
271                                 first = constraint;
272                                 second = encodeTotalOrderSATEncoder(This, tmp);
273                         }else
274                                 continue;
275                         Constraint* transConstr= encodeTotalOrderSATEncoder(This, newBool);
276                         generateTransOrderConstraintSATEncoder(This, first, second, transConstr );
277                 }
278                 return constraint;
279         }
280 */      
281         return NULL;
282 }
283
284 Constraint * encodePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint) {
285         switch(GETPREDICATETYPE(constraint) ){
286                 case TABLEPRED:
287                         return encodeTablePredicateSATEncoder(This, constraint);
288                 case OPERATORPRED:
289                         return encodeOperatorPredicateSATEncoder(This, constraint);
290                 default:
291                         ASSERT(0);
292         }
293         return NULL;
294 }
295
296 Constraint * encodeTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
297         switch(constraint->encoding.type){
298                 case ENUMERATEIMPLICATIONS:
299                 case ENUMERATEIMPLICATIONSNEGATE:
300                         return encodeEnumTablePredicateSATEncoder(This, constraint);
301                 case CIRCUIT:
302                         ASSERT(0);
303                         break;
304                 default:
305                         ASSERT(0);
306         }
307         return NULL;
308 }
309
310 Constraint * encodeEnumTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
311         VectorTableEntry* entries = &(((PredicateTable*)constraint->predicate)->table->entries);
312         FunctionEncodingType encType = constraint->encoding.type;
313         uint size = getSizeVectorTableEntry(entries);
314         Constraint* constraints[size]; //FIXME: should add a space for the case that didn't match any entries
315         for(uint i=0; i<size; i++){
316                 TableEntry* entry = getVectorTableEntry(entries, i);
317                 if(encType==ENUMERATEIMPLICATIONS && entry->output!= true)
318                         continue;
319                 else if(encType==ENUMERATEIMPLICATIONSNEGATE && entry->output !=false)
320                         continue;
321                 ArrayElement* inputs = &constraint->inputs;
322                 uint inputNum =getSizeArrayElement(inputs);
323                 Constraint* carray[inputNum];
324                 Element* el = getArrayElement(inputs, i);
325                 for(uint j=0; j<inputNum; j++){
326                         carray[inputNum] = getElementValueConstraint(el, entry->inputs[j]);
327                 }
328                 constraints[i]=allocArrayConstraint(AND, inputNum, carray);
329         }
330         Constraint* result= allocArrayConstraint(OR, size, constraints);
331         return encType==ENUMERATEIMPLICATIONS? result: negateConstraint(result);
332 }
333
334 Constraint * encodeOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
335         switch(constraint->encoding.type){
336                 case ENUMERATEIMPLICATIONS:
337                         break;
338                 case CIRCUIT:
339                         break;
340                 default:
341                         ASSERT(0);
342         }
343         return NULL;
344 }
345
346 Constraint* encodeFunctionElementSATEncoder(SATEncoder* encoder, ElementFunction *This){
347         switch(GETFUNCTIONTYPE(This->function)){
348                 case TABLEFUNC:
349                         return encodeTableElementFunctionSATEncoder(encoder, This);
350                 case OPERATORFUNC:
351                         return encodeOperatorElementFunctionSATEncoder(encoder, This);
352                 default:
353                         ASSERT(0);
354         }
355         return NULL;
356 }
357
358 Constraint* encodeTableElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
359         switch(getElementFunctionEncoding(This)->type){
360                 case ENUMERATEIMPLICATIONS:
361                         return encodeEnumTableElemFunctionSATEncoder(encoder, This);
362                         break;
363                 case CIRCUIT:
364                         ASSERT(0);
365                         break;
366                 default:
367                         ASSERT(0);
368         }
369         return NULL;
370 }
371
372 Constraint* encodeOperatorElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
373         //FIXME: for now it just adds/substracts inputs exhustively
374         return NULL;
375 }
376
377 Constraint* encodeEnumTableElemFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
378         ASSERT(GETFUNCTIONTYPE(This->function)==TABLEFUNC);
379         ArrayElement* elements= &This->inputs;
380         Table* table = ((FunctionTable*) (This->function))->table;
381         uint size = getSizeVectorTableEntry(&table->entries);
382         Constraint* constraints[size]; //FIXME: should add a space for the case that didn't match any entries
383         for(uint i=0; i<size; i++){
384                 TableEntry* entry = getVectorTableEntry(&table->entries, i);
385                 uint inputNum =getSizeArrayElement(elements);
386                 Element* el= getArrayElement(elements, i);
387                 Constraint* carray[inputNum];
388                 for(uint j=0; j<inputNum; j++){
389                         carray[inputNum] = getElementValueConstraint(el, entry->inputs[j]);
390                 }
391                 Constraint* row= allocConstraint(IMPLIES, allocArrayConstraint(AND, inputNum, carray),
392                         getElementValueBinaryIndexConstraint((Element*)This, entry->output));
393                 constraints[i]=row;
394         }
395         Constraint* result = allocArrayConstraint(OR, size, constraints);
396         return result;
397 }