Merge branch 'hamed' into brian
[satune.git] / src / Backend / satencoder.c
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "csolver.h"
4 #include "boolean.h"
5 #include "constraint.h"
6 #include "common.h"
7 #include "element.h"
8 #include "function.h"
9 #include "tableentry.h"
10 #include "table.h"
11 #include "order.h"
12 #include "predicate.h"
13 #include "orderpair.h"
14 #include "set.h"
15
16 SATEncoder * allocSATEncoder() {
17         SATEncoder *This=ourmalloc(sizeof (SATEncoder));
18         This->varcount=1;
19         return This;
20 }
21
22 void deleteSATEncoder(SATEncoder *This) {
23         ourfree(This);
24 }
25
26 void initializeConstraintVars(CSolver* csolver, SATEncoder* This){
27         /** We really should not walk the free list to generate constraint
28                         variables...walk the constraint tree instead.  Or even better
29                         yet, just do this as you need to during the encodeAllSATEncoder
30                         walk.  */
31
32 //      FIXME!!!!(); // Make sure Hamed sees comment above
33
34         uint size = getSizeVectorElement(csolver->allElements);
35         for(uint i=0; i<size; i++){
36                 Element* element = getVectorElement(csolver->allElements, i);
37                 generateElementEncodingVariables(This,getElementEncoding(element));
38         }
39 }
40
41
42 Constraint * getElementValueConstraint(Element* This, uint64_t value) {
43         switch(getElementEncoding(This)->type){
44                 case ONEHOT:
45                         ASSERT(0);
46                         break;
47                 case UNARY:
48                         ASSERT(0);
49                         break;
50                 case BINARYINDEX:
51                         ASSERT(0);
52                         break;
53                 case ONEHOTBINARY:
54                         return getElementValueBinaryIndexConstraint(This, value);
55                         break;
56                 case BINARYVAL:
57                         ASSERT(0);
58                         break;
59                 default:
60                         ASSERT(0);
61                         break;
62         }
63         return NULL;
64 }
65 Constraint * getElementValueBinaryIndexConstraint(Element* This, uint64_t value) {
66         ASTNodeType type = GETELEMENTTYPE(This);
67         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
68         ElementEncoding* elemEnc = getElementEncoding(This);
69         for(uint i=0; i<elemEnc->encArraySize; i++){
70                 if( isinUseElement(elemEnc, i) && elemEnc->encodingArray[i]==value){
71                         return generateBinaryConstraint(elemEnc->numVars,
72                                 elemEnc->variables, i);
73                 }
74         }
75         return NULL;
76 }
77
78 void encodeAllSATEncoder(CSolver *csolver, SATEncoder * This) {
79         VectorBoolean *constraints=csolver->constraints;
80         uint size=getSizeVectorBoolean(constraints);
81         for(uint i=0;i<size;i++) {
82                 Boolean *constraint=getVectorBoolean(constraints, i);
83                 encodeConstraintSATEncoder(This, constraint);
84         }
85         
86 //      FIXME: Following line for element!
87 //      size = getSizeVectorElement(csolver->allElements);
88 //      for(uint i=0; i<size; i++){
89 //              Element* element = getVectorElement(csolver->allElements, i);
90 //              switch(GETELEMENTTYPE(element)){
91 //                      case ELEMFUNCRETURN: 
92 //                              encodeFunctionElementSATEncoder(This, (ElementFunction*) element);
93 //                              break;
94 //                      default:        
95 //                              continue;
96 //                              //ElementSets that aren't used in any constraints/functions
97 //                              //will be eliminated.
98 //              }
99 //      }
100 }
101
102 Constraint * encodeConstraintSATEncoder(SATEncoder *This, Boolean *constraint) {
103         switch(GETBOOLEANTYPE(constraint)) {
104         case ORDERCONST:
105                 return encodeOrderSATEncoder(This, (BooleanOrder *) constraint);
106         case BOOLEANVAR:
107                 return encodeVarSATEncoder(This, (BooleanVar *) constraint);
108         case LOGICOP:
109                 return encodeLogicSATEncoder(This, (BooleanLogic *) constraint);
110         case PREDICATEOP:
111                 return encodePredicateSATEncoder(This, (BooleanPredicate *) constraint);
112         default:
113                 model_print("Unhandled case in encodeConstraintSATEncoder %u", GETBOOLEANTYPE(constraint));
114                 exit(-1);
115         }
116 }
117
118 void getArrayNewVarsSATEncoder(SATEncoder* encoder, uint num, Constraint **carray) {
119         for(uint i=0;i<num;i++)
120                 carray[i]=getNewVarSATEncoder(encoder);
121 }
122
123 Constraint * getNewVarSATEncoder(SATEncoder *This) {
124         Constraint * var=allocVarConstraint(VAR, This->varcount);
125         Constraint * varneg=allocVarConstraint(NOTVAR, This->varcount++);
126         setNegConstraint(var, varneg);
127         setNegConstraint(varneg, var);
128         return var;
129 }
130
131 Constraint * encodeVarSATEncoder(SATEncoder *This, BooleanVar * constraint) {
132         if (constraint->var == NULL) {
133                 constraint->var=getNewVarSATEncoder(This);
134         }
135         return constraint->var;
136 }
137
138 Constraint * encodeLogicSATEncoder(SATEncoder *This, BooleanLogic * constraint) {
139         Constraint * array[getSizeArrayBoolean(&constraint->inputs)];
140         for(uint i=0;i<getSizeArrayBoolean(&constraint->inputs);i++)
141                 array[i]=encodeConstraintSATEncoder(This, getArrayBoolean(&constraint->inputs, i));
142
143         switch(constraint->op) {
144         case L_AND:
145                 return allocArrayConstraint(AND, getSizeArrayBoolean(&constraint->inputs), array);
146         case L_OR:
147                 return allocArrayConstraint(OR, getSizeArrayBoolean(&constraint->inputs), array);
148         case L_NOT:
149                 ASSERT(constraint->numArray==1);
150                 return negateConstraint(array[0]);
151         case L_XOR: {
152                 ASSERT(constraint->numArray==2);
153                 Constraint * nleft=negateConstraint(cloneConstraint(array[0]));
154                 Constraint * nright=negateConstraint(cloneConstraint(array[1]));
155                 return allocConstraint(OR,
156                                                                                                          allocConstraint(AND, array[0], nright),
157                                                                                                          allocConstraint(AND, nleft, array[1]));
158         }
159         case L_IMPLIES:
160                 ASSERT(constraint->numArray==2);
161                 return allocConstraint(IMPLIES, array[0], array[1]);
162         default:
163                 model_print("Unhandled case in encodeLogicSATEncoder %u", constraint->op);
164                 exit(-1);
165         }
166 }
167
168
169 Constraint * encodeOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint) {
170         switch( constraint->order->type){
171                 case PARTIAL:
172                         return encodePartialOrderSATEncoder(This, constraint);
173                 case TOTAL:
174                         return encodeTotalOrderSATEncoder(This, constraint);
175                 default:
176                         ASSERT(0);
177         }
178         return NULL;
179 }
180
181 Constraint * getPairConstraint(SATEncoder *This, HashTableBoolConst * table, OrderPair * pair) {
182         bool negate = false;
183         OrderPair flipped;
184         if (pair->first > pair->second) {
185                 negate=true;
186                 flipped.first=pair->second;
187                 flipped.second=pair->first;
188                 pair = &flipped;
189         }
190         Constraint * constraint;
191         if (!containsBoolConst(table, pair)) {
192                 constraint = getNewVarSATEncoder(This);
193                 OrderPair * paircopy = allocOrderPair(pair->first, pair->second);
194                 putBoolConst(table, paircopy, constraint);
195         } else
196                 constraint = getBoolConst(table, pair);
197         if (negate)
198                 return negateConstraint(constraint);
199         else
200                 return constraint;
201         
202 }
203
204 Constraint * encodeTotalOrderSATEncoder(SATEncoder *This, BooleanOrder * boolOrder){
205         ASSERT(boolOrder->order->type == TOTAL);
206         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
207         OrderPair pair={boolOrder->first, boolOrder->second};
208         Constraint* constraint = getPairConstraint(This, boolToConsts, & pair);
209         return constraint;
210 }
211
212 void createAllTotalOrderConstraintsSATEncoder(SATEncoder* This, Order* order){
213         ASSERT(order->type == TOTAL);
214         VectorInt* mems = order->set->members;
215         HashTableBoolConst* table = order->boolsToConstraints;
216         uint size = getSizeVectorInt(mems);
217         for(uint i=0; i<size; i++){
218                 uint64_t valueI = getVectorInt(mems, i);
219                 for(uint j=i+1; j<size;j++){
220                         uint64_t valueJ = getVectorInt(mems, j);
221                         OrderPair pairIJ = {valueI, valueJ};
222                         Constraint* constIJ=getPairConstraint(This, table, & pairIJ);
223                         for(uint k=j+1; k<size; k++){
224                                 uint64_t valueK = getVectorInt(mems, k);
225                                 OrderPair pairJK = {valueJ, valueK};
226                                 OrderPair pairIK = {valueI, valueK};
227                                 Constraint* constIK = getPairConstraint(This, table, & pairIK);
228                                 Constraint* constJK = getPairConstraint(This, table, & pairJK);
229                                 generateTransOrderConstraintSATEncoder(This, constIJ, constJK, constIK); 
230                         }
231                 }
232         }
233 }
234
235 Constraint* getOrderConstraint(HashTableBoolConst *table, OrderPair *pair){
236         ASSERT(pair->first!= pair->second);
237         Constraint* constraint= getBoolConst(table, pair);
238         ASSERT(constraint!= NULL);
239         if(pair->first > pair->second)
240                 return constraint;
241         else
242                 return negateConstraint(constraint);
243 }
244
245 Constraint * generateTransOrderConstraintSATEncoder(SATEncoder *This, Constraint *constIJ,Constraint *constJK,Constraint *constIK){
246         //FIXME: first we should add the the constraint to the satsolver!
247         Constraint *carray[] = {constIJ, constJK, negateConstraint(constIK)};
248         Constraint * loop1= allocArrayConstraint(OR, 3, carray);
249         Constraint * carray2[] = {negateConstraint(constIJ), negateConstraint(constJK), constIK};
250         Constraint * loop2= allocArrayConstraint(OR, 3,carray2 );
251         return allocConstraint(AND, loop1, loop2);
252 }
253
254 Constraint * encodePartialOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint){
255         // FIXME: we can have this implementation for partial order. Basically,
256         // we compute the transitivity between two order constraints specified by the client! (also can be used
257         // when client specify sparse constraints for the total order!)
258         ASSERT(boolOrder->order->type == PARTIAL);
259 /*
260         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
261         if( containsBoolConst(boolToConsts, boolOrder) ){
262                 return getBoolConst(boolToConsts, boolOrder);
263         } else {
264                 Constraint* constraint = getNewVarSATEncoder(This); 
265                 putBoolConst(boolToConsts,boolOrder, constraint);
266                 VectorBoolean* orderConstrs = &boolOrder->order->constraints;
267                 uint size= getSizeVectorBoolean(orderConstrs);
268                 for(uint i=0; i<size; i++){
269                         ASSERT(GETBOOLEANTYPE( getVectorBoolean(orderConstrs, i)) == ORDERCONST );
270                         BooleanOrder* tmp = (BooleanOrder*)getVectorBoolean(orderConstrs, i);
271                         BooleanOrder* newBool;
272                         Constraint* first, *second;
273                         if(tmp->second==boolOrder->first){
274                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,tmp->first,boolOrder->second);
275                                 first = encodeTotalOrderSATEncoder(This, tmp);
276                                 second = constraint;
277                                 
278                         }else if (boolOrder->second == tmp->first){
279                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,boolOrder->first,tmp->second);
280                                 first = constraint;
281                                 second = encodeTotalOrderSATEncoder(This, tmp);
282                         }else
283                                 continue;
284                         Constraint* transConstr= encodeTotalOrderSATEncoder(This, newBool);
285                         generateTransOrderConstraintSATEncoder(This, first, second, transConstr );
286                 }
287                 return constraint;
288         }
289 */      
290         return NULL;
291 }
292
293 Constraint * encodePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint) {
294         switch(GETPREDICATETYPE(constraint->predicate) ){
295                 case TABLEPRED:
296                         return encodeTablePredicateSATEncoder(This, constraint);
297                 case OPERATORPRED:
298                         return encodeOperatorPredicateSATEncoder(This, constraint);
299                 default:
300                         ASSERT(0);
301         }
302         return NULL;
303 }
304
305 Constraint * encodeTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
306         switch(constraint->encoding.type){
307                 case ENUMERATEIMPLICATIONS:
308                 case ENUMERATEIMPLICATIONSNEGATE:
309                         return encodeEnumTablePredicateSATEncoder(This, constraint);
310                 case CIRCUIT:
311                         ASSERT(0);
312                         break;
313                 default:
314                         ASSERT(0);
315         }
316         return NULL;
317 }
318
319 Constraint * encodeEnumTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
320         VectorTableEntry* entries = &(((PredicateTable*)constraint->predicate)->table->entries);
321         FunctionEncodingType encType = constraint->encoding.type;
322         uint size = getSizeVectorTableEntry(entries);
323         Constraint* constraints[size];
324         for(uint i=0; i<size; i++){
325                 TableEntry* entry = getVectorTableEntry(entries, i);
326                 if(encType==ENUMERATEIMPLICATIONS && entry->output!= true)
327                         continue;
328                 else if(encType==ENUMERATEIMPLICATIONSNEGATE && entry->output !=false)
329                         continue;
330                 ArrayElement* inputs = &constraint->inputs;
331                 uint inputNum =getSizeArrayElement(inputs);
332                 Constraint* carray[inputNum];
333                 Element* el = getArrayElement(inputs, i);
334                 for(uint j=0; j<inputNum; j++){
335                         carray[j] = getElementValueConstraint(el, entry->inputs[j]);
336                 }
337                 constraints[i]=allocArrayConstraint(AND, inputNum, carray);
338         }
339         Constraint* result= allocArrayConstraint(OR, size, constraints);
340         //FIXME: if it didn't match with any entry
341         return encType==ENUMERATEIMPLICATIONS? result: negateConstraint(result);
342 }
343
344 Constraint * encodeOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
345         switch(constraint->encoding.type){
346                 case ENUMERATEIMPLICATIONS:
347                         return encodeEnumOperatorPredicateSATEncoder(This, constraint);
348                 case CIRCUIT:
349                         ASSERT(0);
350                         break;
351                 default:
352                         ASSERT(0);
353         }
354         return NULL;
355 }
356
357 Constraint * encodeEnumOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
358         ASSERT(GETPREDICATETYPE(constraint)==OPERATORPRED);
359         PredicateOperator* predicate = (PredicateOperator*)constraint->predicate;
360         ASSERT(predicate->op == EQUALS); //For now, we just only support equals
361         //getting maximum size of in common elements between two sets!
362         uint size=getSizeVectorInt( getArraySet( &predicate->domains, 0)->members);
363         uint64_t commonElements [size];
364         getEqualitySetIntersection(predicate, &size, commonElements);
365         Constraint*  carray[size];
366         Element* elem1 = getArrayElement( &constraint->inputs, 0);
367         Element* elem2 = getArrayElement( &constraint->inputs, 1);
368         for(uint i=0; i<size; i++){
369                 
370                 carray[i] =  allocConstraint(AND, getElementValueConstraint(elem1, commonElements[i]),
371                         getElementValueConstraint(elem2, commonElements[i]) );
372         }
373         //FIXME: the case when there is no intersection ....
374         return allocArrayConstraint(OR, size, carray);
375 }
376
377 Constraint* encodeFunctionElementSATEncoder(SATEncoder* encoder, ElementFunction *This){
378         switch(GETFUNCTIONTYPE(This->function)){
379                 case TABLEFUNC:
380                         return encodeTableElementFunctionSATEncoder(encoder, This);
381                 case OPERATORFUNC:
382                         return encodeOperatorElementFunctionSATEncoder(encoder, This);
383                 default:
384                         ASSERT(0);
385         }
386         return NULL;
387 }
388
389 Constraint* encodeTableElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
390         switch(getElementFunctionEncoding(This)->type){
391                 case ENUMERATEIMPLICATIONS:
392                         return encodeEnumTableElemFunctionSATEncoder(encoder, This);
393                         break;
394                 case CIRCUIT:
395                         ASSERT(0);
396                         break;
397                 default:
398                         ASSERT(0);
399         }
400         return NULL;
401 }
402
403 Constraint* encodeOperatorElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
404         //FIXME: for now it just adds/substracts inputs exhustively
405         return NULL;
406 }
407
408 Constraint* encodeEnumTableElemFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
409         ASSERT(GETFUNCTIONTYPE(This->function)==TABLEFUNC);
410         ArrayElement* elements= &This->inputs;
411         Table* table = ((FunctionTable*) (This->function))->table;
412         uint size = getSizeVectorTableEntry(&table->entries);
413         Constraint* constraints[size]; //FIXME: should add a space for the case that didn't match any entries
414         for(uint i=0; i<size; i++){
415                 TableEntry* entry = getVectorTableEntry(&table->entries, i);
416                 uint inputNum =getSizeArrayElement(elements);
417                 Element* el= getArrayElement(elements, i);
418                 Constraint* carray[inputNum];
419                 for(uint j=0; j<inputNum; j++){
420                         carray[inputNum] = getElementValueConstraint(el, entry->inputs[j]);
421                 }
422                 Constraint* row= allocConstraint(IMPLIES, allocArrayConstraint(AND, inputNum, carray),
423                         getElementValueBinaryIndexConstraint((Element*)This, entry->output));
424                 constraints[i]=row;
425         }
426         Constraint* result = allocArrayConstraint(OR, size, constraints);
427         return result;
428 }