Merge branch 'master' of ssh://plrg.eecs.uci.edu/home/git/constraint_compiler
[satune.git] / src / Backend / satencoder.c
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "csolver.h"
4 #include "boolean.h"
5 #include "constraint.h"
6 #include "common.h"
7 #include "element.h"
8 #include "function.h"
9 #include "tableentry.h"
10 #include "table.h"
11 #include "order.h"
12 #include "predicate.h"
13 #include "orderpair.h"
14 #include "set.h"
15
16 SATEncoder * allocSATEncoder() {
17         SATEncoder *This=ourmalloc(sizeof (SATEncoder));
18         This->varcount=1;
19         return This;
20 }
21
22 void deleteSATEncoder(SATEncoder *This) {
23         ourfree(This);
24 }
25
26 void initializeConstraintVars(CSolver* csolver, SATEncoder* This){
27         /** We really should not walk the free list to generate constraint
28                         variables...walk the constraint tree instead.  Or even better
29                         yet, just do this as you need to during the encodeAllSATEncoder
30                         walk.  */
31
32 //      FIXME!!!!(); // Make sure Hamed sees comment above
33
34         uint size = getSizeVectorElement(csolver->allElements);
35         for(uint i=0; i<size; i++){
36                 Element* element = getVectorElement(csolver->allElements, i);
37                 generateElementEncodingVariables(This,getElementEncoding(element));
38         }
39 }
40
41
42 Constraint * getElementValueConstraint(Element* This, uint64_t value) {
43         switch(getElementEncoding(This)->type){
44                 case ONEHOT:
45                         ASSERT(0);
46                         break;
47                 case UNARY:
48                         ASSERT(0);
49                         break;
50                 case BINARYINDEX:
51                         ASSERT(0);
52                         break;
53                 case ONEHOTBINARY:
54                         return getElementValueBinaryIndexConstraint(This, value);
55                         break;
56                 case BINARYVAL:
57                         ASSERT(0);
58                         break;
59                 default:
60                         ASSERT(0);
61                         break;
62         }
63         return NULL;
64 }
65 Constraint * getElementValueBinaryIndexConstraint(Element* This, uint64_t value) {
66         ASTNodeType type = GETELEMENTTYPE(This);
67         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
68         ElementEncoding* elemEnc = getElementEncoding(This);
69         for(uint i=0; i<elemEnc->encArraySize; i++){
70                 if( isinUseElement(elemEnc, i) && elemEnc->encodingArray[i]==value){
71                         return generateBinaryConstraint(elemEnc->numVars,
72                                 elemEnc->variables, i);
73                 }
74         }
75         return NULL;
76 }
77
78 void encodeAllSATEncoder(CSolver *csolver, SATEncoder * This) {
79         VectorBoolean *constraints=csolver->constraints;
80         uint size=getSizeVectorBoolean(constraints);
81         for(uint i=0;i<size;i++) {
82                 Boolean *constraint=getVectorBoolean(constraints, i);
83                 encodeConstraintSATEncoder(This, constraint);
84         }
85         
86 //      FIXME: Following line for element!
87 //      size = getSizeVectorElement(csolver->allElements);
88 //      for(uint i=0; i<size; i++){
89 //              Element* element = getVectorElement(csolver->allElements, i);
90 //              switch(GETELEMENTTYPE(element)){
91 //                      case ELEMFUNCRETURN: 
92 //                              encodeFunctionElementSATEncoder(This, (ElementFunction*) element);
93 //                              break;
94 //                      default:        
95 //                              continue;
96 //                              //ElementSets that aren't used in any constraints/functions
97 //                              //will be eliminated.
98 //              }
99 //      }
100 }
101
102 Constraint * encodeConstraintSATEncoder(SATEncoder *This, Boolean *constraint) {
103         switch(GETBOOLEANTYPE(constraint)) {
104         case ORDERCONST:
105                 return encodeOrderSATEncoder(This, (BooleanOrder *) constraint);
106         case BOOLEANVAR:
107                 return encodeVarSATEncoder(This, (BooleanVar *) constraint);
108         case LOGICOP:
109                 return encodeLogicSATEncoder(This, (BooleanLogic *) constraint);
110         case PREDICATEOP:
111                 return encodePredicateSATEncoder(This, (BooleanPredicate *) constraint);
112         default:
113                 model_print("Unhandled case in encodeConstraintSATEncoder %u", GETBOOLEANTYPE(constraint));
114                 exit(-1);
115         }
116 }
117
118 void getArrayNewVarsSATEncoder(SATEncoder* encoder, uint num, Constraint **carray) {
119         for(uint i=0;i<num;i++)
120                 carray[i]=getNewVarSATEncoder(encoder);
121 }
122
123 Constraint * getNewVarSATEncoder(SATEncoder *This) {
124         Constraint * var=allocVarConstraint(VAR, This->varcount);
125         Constraint * varneg=allocVarConstraint(NOTVAR, This->varcount++);
126         setNegConstraint(var, varneg);
127         setNegConstraint(varneg, var);
128         return var;
129 }
130
131 Constraint * encodeVarSATEncoder(SATEncoder *This, BooleanVar * constraint) {
132         if (constraint->var == NULL) {
133                 constraint->var=getNewVarSATEncoder(This);
134         }
135         return constraint->var;
136 }
137
138 Constraint * encodeLogicSATEncoder(SATEncoder *This, BooleanLogic * constraint) {
139         Constraint * array[getSizeArrayBoolean(&constraint->inputs)];
140         for(uint i=0;i<getSizeArrayBoolean(&constraint->inputs);i++)
141                 array[i]=encodeConstraintSATEncoder(This, getArrayBoolean(&constraint->inputs, i));
142
143         switch(constraint->op) {
144         case L_AND:
145                 return allocArrayConstraint(AND, getSizeArrayBoolean(&constraint->inputs), array);
146         case L_OR:
147                 return allocArrayConstraint(OR, getSizeArrayBoolean(&constraint->inputs), array);
148         case L_NOT:
149                 ASSERT(constraint->numArray==1);
150                 return negateConstraint(array[0]);
151         case L_XOR: {
152                 ASSERT(constraint->numArray==2);
153                 Constraint * nleft=negateConstraint(cloneConstraint(array[0]));
154                 Constraint * nright=negateConstraint(cloneConstraint(array[1]));
155                 return allocConstraint(OR,
156                                                                                                          allocConstraint(AND, array[0], nright),
157                                                                                                          allocConstraint(AND, nleft, array[1]));
158         }
159         case L_IMPLIES:
160                 ASSERT(constraint->numArray==2);
161                 return allocConstraint(IMPLIES, array[0], array[1]);
162         default:
163                 model_print("Unhandled case in encodeLogicSATEncoder %u", constraint->op);
164                 exit(-1);
165         }
166 }
167
168
169 Constraint * encodeOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint) {
170         switch( constraint->order->type){
171                 case PARTIAL:
172                         return encodePartialOrderSATEncoder(This, constraint);
173                 case TOTAL:
174                         return encodeTotalOrderSATEncoder(This, constraint);
175                 default:
176                         ASSERT(0);
177         }
178         return NULL;
179 }
180
181
182 Constraint * encodeTotalOrderSATEncoder(SATEncoder *This, BooleanOrder * boolOrder){
183         ASSERT(boolOrder->order->type == TOTAL);
184         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
185         OrderPair* pair = allocOrderPair(boolOrder->first, boolOrder->second);
186         if( !containsBoolConst(boolToConsts, pair) ){
187                 createAllTotalOrderConstraintsSATEncoder(This, boolOrder->order);
188         }
189         Constraint* constraint = getOrderConstraint(boolToConsts, pair);
190         deleteOrderPair(pair);
191         return constraint;
192 }
193
194 void createAllTotalOrderConstraintsSATEncoder(SATEncoder* This, Order* order){
195         ASSERT(order->type == TOTAL);
196         VectorInt* mems = order->set->members;
197         HashTableBoolConst* table = order->boolsToConstraints;
198         uint size = getSizeVectorInt(mems);
199         for(uint i=0; i<size; i++){
200                 uint64_t valueI = getVectorInt(mems, i);
201                 for(uint j=i+1; j<size;j++){
202                         uint64_t valueJ = getVectorInt(mems, j);
203                         OrderPair* pairIJ = allocOrderPair(valueI, valueJ);
204                         ASSERT(!containsBoolConst(table, pairIJ));
205                         Constraint* constIJ = getNewVarSATEncoder(This);
206                         putBoolConst(table, pairIJ, constIJ);
207                         for(uint k=j+1; k<size; k++){
208                                 uint64_t valueK = getVectorInt(mems, k);
209                                 OrderPair* pairJK = allocOrderPair(valueJ, valueK);
210                                 OrderPair* pairIK = allocOrderPair(valueI, valueK);
211                                 Constraint* constJK, *constIK;
212                                 if(!containsBoolConst(table, pairJK)){
213                                         constJK = getNewVarSATEncoder(This);
214                                         putBoolConst(table, pairJK, constJK);
215                                 }
216                                 if(!containsBoolConst(table, pairIK)){
217                                         constIK = getNewVarSATEncoder(This);
218                                         putBoolConst(table, pairIK, constIK);
219                                 }
220                                 generateTransOrderConstraintSATEncoder(This, constIJ, constJK, constIK); 
221                         }
222                 }
223         }
224 }
225
226 Constraint* getOrderConstraint(HashTableBoolConst *table, OrderPair *pair){
227         ASSERT(pair->first!= pair->second);
228         Constraint* constraint= getBoolConst(table, pair);
229         ASSERT(constraint!= NULL);
230         if(pair->first > pair->second)
231                 return constraint;
232         else
233                 return negateConstraint(constraint);
234 }
235
236 Constraint * generateTransOrderConstraintSATEncoder(SATEncoder *This, Constraint *constIJ,Constraint *constJK,Constraint *constIK){
237         //FIXME: first we should add the the constraint to the satsolver!
238         Constraint *carray[] = {constIJ, constJK, negateConstraint(constIK)};
239         Constraint * loop1= allocArrayConstraint(OR, 3, carray);
240         Constraint * carray2[] = {negateConstraint(constIJ), negateConstraint(constJK), constIK};
241         Constraint * loop2= allocArrayConstraint(OR, 3,carray2 );
242         return allocConstraint(AND, loop1, loop2);
243 }
244
245 Constraint * encodePartialOrderSATEncoder(SATEncoder *This, BooleanOrder * constraint){
246         // FIXME: we can have this implementation for partial order. Basically,
247         // we compute the transitivity between two order constraints specified by the client! (also can be used
248         // when client specify sparse constraints for the total order!)
249         ASSERT(boolOrder->order->type == PARTIAL);
250 /*
251         HashTableBoolConst* boolToConsts = boolOrder->order->boolsToConstraints;
252         if( containsBoolConst(boolToConsts, boolOrder) ){
253                 return getBoolConst(boolToConsts, boolOrder);
254         } else {
255                 Constraint* constraint = getNewVarSATEncoder(This); 
256                 putBoolConst(boolToConsts,boolOrder, constraint);
257                 VectorBoolean* orderConstrs = &boolOrder->order->constraints;
258                 uint size= getSizeVectorBoolean(orderConstrs);
259                 for(uint i=0; i<size; i++){
260                         ASSERT(GETBOOLEANTYPE( getVectorBoolean(orderConstrs, i)) == ORDERCONST );
261                         BooleanOrder* tmp = (BooleanOrder*)getVectorBoolean(orderConstrs, i);
262                         BooleanOrder* newBool;
263                         Constraint* first, *second;
264                         if(tmp->second==boolOrder->first){
265                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,tmp->first,boolOrder->second);
266                                 first = encodeTotalOrderSATEncoder(This, tmp);
267                                 second = constraint;
268                                 
269                         }else if (boolOrder->second == tmp->first){
270                                 newBool = (BooleanOrder*)allocBooleanOrder(tmp->order,boolOrder->first,tmp->second);
271                                 first = constraint;
272                                 second = encodeTotalOrderSATEncoder(This, tmp);
273                         }else
274                                 continue;
275                         Constraint* transConstr= encodeTotalOrderSATEncoder(This, newBool);
276                         generateTransOrderConstraintSATEncoder(This, first, second, transConstr );
277                 }
278                 return constraint;
279         }
280 */      
281         return NULL;
282 }
283
284 Constraint * encodePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint) {
285         switch(GETPREDICATETYPE(constraint->predicate) ){
286                 case TABLEPRED:
287                         return encodeTablePredicateSATEncoder(This, constraint);
288                 case OPERATORPRED:
289                         return encodeOperatorPredicateSATEncoder(This, constraint);
290                 default:
291                         ASSERT(0);
292         }
293         return NULL;
294 }
295
296 Constraint * encodeTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
297         switch(constraint->encoding.type){
298                 case ENUMERATEIMPLICATIONS:
299                 case ENUMERATEIMPLICATIONSNEGATE:
300                         return encodeEnumTablePredicateSATEncoder(This, constraint);
301                 case CIRCUIT:
302                         ASSERT(0);
303                         break;
304                 default:
305                         ASSERT(0);
306         }
307         return NULL;
308 }
309
310 Constraint * encodeEnumTablePredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
311         VectorTableEntry* entries = &(((PredicateTable*)constraint->predicate)->table->entries);
312         FunctionEncodingType encType = constraint->encoding.type;
313         uint size = getSizeVectorTableEntry(entries);
314         Constraint* constraints[size];
315         for(uint i=0; i<size; i++){
316                 TableEntry* entry = getVectorTableEntry(entries, i);
317                 if(encType==ENUMERATEIMPLICATIONS && entry->output!= true)
318                         continue;
319                 else if(encType==ENUMERATEIMPLICATIONSNEGATE && entry->output !=false)
320                         continue;
321                 ArrayElement* inputs = &constraint->inputs;
322                 uint inputNum =getSizeArrayElement(inputs);
323                 Constraint* carray[inputNum];
324                 Element* el = getArrayElement(inputs, i);
325                 for(uint j=0; j<inputNum; j++){
326                         carray[j] = getElementValueConstraint(el, entry->inputs[j]);
327                 }
328                 constraints[i]=allocArrayConstraint(AND, inputNum, carray);
329         }
330         Constraint* result= allocArrayConstraint(OR, size, constraints);
331         //FIXME: if it didn't match with any entry
332         return encType==ENUMERATEIMPLICATIONS? result: negateConstraint(result);
333 }
334
335 Constraint * encodeOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
336         switch(constraint->encoding.type){
337                 case ENUMERATEIMPLICATIONS:
338                         return encodeEnumOperatorPredicateSATEncoder(This, constraint);
339                 case CIRCUIT:
340                         ASSERT(0);
341                         break;
342                 default:
343                         ASSERT(0);
344         }
345         return NULL;
346 }
347
348 Constraint * encodeEnumOperatorPredicateSATEncoder(SATEncoder * This, BooleanPredicate * constraint){
349         ASSERT(GETPREDICATETYPE(constraint)==OPERATORPRED);
350         PredicateOperator* predicate = (PredicateOperator*)constraint->predicate;
351         ASSERT(predicate->op == EQUALS); //For now, we just only support equals
352         //getting maximum size of in common elements between two sets!
353         uint size=getSizeVectorInt( getArraySet( &predicate->domains, 0)->members);
354         uint64_t commonElements [size];
355         getEqualitySetIntersection(predicate, &size, commonElements);
356         Constraint*  carray[size];
357         Element* elem1 = getArrayElement( &constraint->inputs, 0);
358         Element* elem2 = getArrayElement( &constraint->inputs, 1);
359         for(uint i=0; i<size; i++){
360                 
361                 carray[i] =  allocConstraint(AND, getElementValueConstraint(elem1, commonElements[i]),
362                         getElementValueConstraint(elem2, commonElements[i]) );
363         }
364         //FIXME: the case when there is no intersection ....
365         return allocArrayConstraint(OR, size, carray);
366 }
367
368 Constraint* encodeFunctionElementSATEncoder(SATEncoder* encoder, ElementFunction *This){
369         switch(GETFUNCTIONTYPE(This->function)){
370                 case TABLEFUNC:
371                         return encodeTableElementFunctionSATEncoder(encoder, This);
372                 case OPERATORFUNC:
373                         return encodeOperatorElementFunctionSATEncoder(encoder, This);
374                 default:
375                         ASSERT(0);
376         }
377         return NULL;
378 }
379
380 Constraint* encodeTableElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
381         switch(getElementFunctionEncoding(This)->type){
382                 case ENUMERATEIMPLICATIONS:
383                         return encodeEnumTableElemFunctionSATEncoder(encoder, This);
384                         break;
385                 case CIRCUIT:
386                         ASSERT(0);
387                         break;
388                 default:
389                         ASSERT(0);
390         }
391         return NULL;
392 }
393
394 Constraint* encodeOperatorElementFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
395         //FIXME: for now it just adds/substracts inputs exhustively
396         return NULL;
397 }
398
399 Constraint* encodeEnumTableElemFunctionSATEncoder(SATEncoder* encoder, ElementFunction* This){
400         ASSERT(GETFUNCTIONTYPE(This->function)==TABLEFUNC);
401         ArrayElement* elements= &This->inputs;
402         Table* table = ((FunctionTable*) (This->function))->table;
403         uint size = getSizeVectorTableEntry(&table->entries);
404         Constraint* constraints[size]; //FIXME: should add a space for the case that didn't match any entries
405         for(uint i=0; i<size; i++){
406                 TableEntry* entry = getVectorTableEntry(&table->entries, i);
407                 uint inputNum =getSizeArrayElement(elements);
408                 Element* el= getArrayElement(elements, i);
409                 Constraint* carray[inputNum];
410                 for(uint j=0; j<inputNum; j++){
411                         carray[inputNum] = getElementValueConstraint(el, entry->inputs[j]);
412                 }
413                 Constraint* row= allocConstraint(IMPLIES, allocArrayConstraint(AND, inputNum, carray),
414                         getElementValueBinaryIndexConstraint((Element*)This, entry->output));
415                 constraints[i]=row;
416         }
417         Constraint* result = allocArrayConstraint(OR, size, constraints);
418         return result;
419 }