fbf09e620a097b0573d101ac30412f70f3ce8bf7
[satune.git] / src / Backend / satelemencoder.cc
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "common.h"
4 #include "ops.h"
5 #include "element.h"
6 #include "set.h"
7 #include "predicate.h"
8
9
10 void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool & memo) {
11         uint numParents = elem->parents.getSize();
12         uint posPolarity = 0;
13         uint negPolarity = 0;
14         memo = false;
15         if (elem->type == ELEMFUNCRETURN) {
16                 memo = true;
17         }
18         for(uint i = 0; i < numParents; i++) {
19                 ASTNode * node = elem->parents.get(i);
20                 if (node->type == PREDICATEOP) {
21                         BooleanPredicate * pred = (BooleanPredicate *) node;
22                         Polarity polarity = pred->polarity;
23                         FunctionEncodingType encType = pred->encoding.type;
24                         bool generateNegation = encType == ENUMERATEIMPLICATIONSNEGATE;
25                         if (pred->predicate->type == TABLEPRED) {
26                                 //Could be smarter, but just do default thing for now
27
28                                 UndefinedBehavior undefStatus = ((PredicateTable *)pred->predicate)->undefinedbehavior;
29
30                                 Polarity tpolarity=polarity;
31                                 if (generateNegation)
32                                         tpolarity = negatePolarity(tpolarity);
33                                 if (undefStatus ==SATC_FLAGFORCEUNDEFINED)
34                                         tpolarity = P_BOTHTRUEFALSE;
35                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
36                                         memo = true;
37                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
38                                         memo = true;
39                         } else if (pred->predicate->type == OPERATORPRED) {
40                                         if (pred->encoding.type == ENUMERATEIMPLICATIONS || pred->encoding.type == ENUMERATEIMPLICATIONSNEGATE) {
41                                                 Polarity tpolarity = polarity;
42                                                 if (generateNegation)
43                                                         tpolarity = negatePolarity(tpolarity);
44                                                 PredicateOperator *predicate = (PredicateOperator *)pred->predicate;
45                                                 uint numDomains = predicate->domains.getSize();
46                                                 bool isConstant = true;
47                                                 for (uint i = 0; i < numDomains; i++) {
48                                                         Element *e = pred->inputs.get(i);
49                                                         if (elem != e && e->type != ELEMCONST) {
50                                                                 isConstant = false;
51                                                         }
52                                                 }
53                                                 if (predicate->getOp() == SATC_EQUALS) {
54                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
55                                                                 posPolarity++;
56                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
57                                                                 negPolarity++;
58                                                 } else {
59                                                         if (isConstant) {
60                                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
61                                                                         posPolarity++;
62                                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
63                                                                         negPolarity++;
64                                                         } else {
65                                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
66                                                                         memo = true;
67                                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
68                                                                         memo = true;                                                    
69                                                         }
70                                                 }
71                                         }
72                                 } else {
73                                 ASSERT(0);
74                         }
75                 } else if (node->type == ELEMFUNCRETURN) {
76                         //we are input to function, so memoize negative case
77                         memo = true;
78                 } else {
79                         ASSERT(0);
80                 }
81         }
82         if (posPolarity > 1)
83                 memo = true;
84         if (negPolarity > 1)
85                 memo = true;
86 }
87
88
89 Edge SATEncoder::getElementValueConstraint(Element *elem, Polarity p, uint64_t value) {
90         switch (elem->getElementEncoding()->type) {
91         case ONEHOT:
92                 return getElementValueOneHotConstraint(elem, p, value);
93         case UNARY:
94                 return getElementValueUnaryConstraint(elem, p, value);
95         case BINARYINDEX:
96                 return getElementValueBinaryIndexConstraint(elem, p, value);
97         case BINARYVAL:
98                 return getElementValueBinaryValueConstraint(elem, p, value);
99                 break;
100         default:
101                 ASSERT(0);
102                 break;
103         }
104         return E_BOGUS;
105 }
106
107 bool impliesPolarity(Polarity curr, Polarity goal) {
108         return (((int) curr) & ((int)goal)) == ((int) goal);
109 }
110
111 Edge SATEncoder::getElementValueBinaryIndexConstraint(Element *elem, Polarity p, uint64_t value) {
112         ASTNodeType type = elem->type;
113         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
114         ElementEncoding *elemEnc = elem->getElementEncoding();
115
116         //Check if we need to generate proxy variables
117         if (elemEnc->encoding == EENC_UNKNOWN && elemEnc->numVars > 1) {
118                 bool memo = false;
119                 shouldMemoize(elem, value, memo);
120                 if (memo) {
121                         elemEnc->encoding = EENC_BOTH;
122                         elemEnc->polarityArray = (Polarity *) ourcalloc(1, sizeof(Polarity) * elemEnc->encArraySize);
123                         elemEnc->edgeArray = (Edge *) ourcalloc(1, sizeof(Edge) * elemEnc->encArraySize);
124                 } else {
125                         elemEnc->encoding = EENC_NONE;
126                 }
127         }
128
129         for (uint i = 0; i < elemEnc->encArraySize; i++) {
130                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
131                         if (elemEnc->numVars == 0)
132                                 return E_True;
133                         
134                         if (elemEnc->encoding != EENC_NONE && elemEnc->numVars > 1) {
135                                 if (impliesPolarity(elemEnc->polarityArray[i], p)) {
136                                         return elemEnc->edgeArray[i];
137                                 } else {
138                                         if (edgeIsNull(elemEnc->edgeArray[i])) {
139                                                 elemEnc->edgeArray[i] = constraintNewVar(cnf);
140                                         }
141                                         if (elemEnc->polarityArray[i] == P_UNDEFINED && p == P_BOTHTRUEFALSE) {
142                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_BOTHTRUEFALSE);
143                                                 elemEnc->polarityArray[i] = p;
144                                         } else if (!impliesPolarity(elemEnc->polarityArray[i], P_TRUE)  && impliesPolarity(p, P_TRUE)) {
145                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_TRUE);                      
146                                                 elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i])| ((int)P_TRUE));
147                                         }       else if (!impliesPolarity(elemEnc->polarityArray[i], P_FALSE)  && impliesPolarity(p, P_FALSE)) {
148                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_FALSE);                     
149                                                 elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i])| ((int)P_FALSE));
150                                         }
151                                         return elemEnc->edgeArray[i];
152                                 }
153                         }
154                         return generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i);
155                 }
156         }
157         return E_False;
158 }
159
160 Edge SATEncoder::getElementValueOneHotConstraint(Element *elem, Polarity p, uint64_t value) {
161         ASTNodeType type = elem->type;
162         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
163         ElementEncoding *elemEnc = elem->getElementEncoding();
164         for (uint i = 0; i < elemEnc->encArraySize; i++) {
165                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
166                         return (elemEnc->numVars == 0) ? E_True : elemEnc->variables[i];
167                 }
168         }
169         return E_False;
170 }
171
172 Edge SATEncoder::getElementValueUnaryConstraint(Element *elem, Polarity p, uint64_t value) {
173         ASTNodeType type = elem->type;
174         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
175         ElementEncoding *elemEnc = elem->getElementEncoding();
176         for (uint i = 0; i < elemEnc->encArraySize; i++) {
177                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
178                         if (elemEnc->numVars == 0)
179                                 return E_True;
180                         if (i == 0)
181                                 return constraintNegate(elemEnc->variables[0]);
182                         else if ((i + 1) == elemEnc->encArraySize)
183                                 return elemEnc->variables[i - 1];
184                         else
185                                 return constraintAND2(cnf, elemEnc->variables[i - 1], constraintNegate(elemEnc->variables[i]));
186                 }
187         }
188         return E_False;
189 }
190
191 Edge SATEncoder::getElementValueBinaryValueConstraint(Element *element, Polarity p, uint64_t value) {
192         ASTNodeType type = element->type;
193         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
194         ElementEncoding *elemEnc = element->getElementEncoding();
195         if (elemEnc->low <= elemEnc->high) {
196                 if (value < elemEnc->low || value > elemEnc->high)
197                         return E_False;
198         } else {
199                 //Range wraps around 0
200                 if (value < elemEnc->low && value > elemEnc->high)
201                         return E_False;
202         }
203
204         uint64_t valueminusoffset = value - elemEnc->offset;
205         return generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, valueminusoffset);
206 }
207
208 void allocElementConstraintVariables(ElementEncoding *This, uint numVars) {
209         This->numVars = numVars;
210         This->variables = (Edge *)ourmalloc(sizeof(Edge) * numVars);
211 }
212
213 void SATEncoder::generateBinaryValueEncodingVars(ElementEncoding *encoding) {
214         ASSERT(encoding->type == BINARYVAL);
215         allocElementConstraintVariables(encoding, encoding->numBits);
216         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
217         if(encoding->anyValue)
218                 generateAnyValueBinaryValueEncoding(encoding);
219 }
220
221 void SATEncoder::generateBinaryIndexEncodingVars(ElementEncoding *encoding) {
222         ASSERT(encoding->type == BINARYINDEX);
223         allocElementConstraintVariables(encoding, NUMBITS(encoding->encArraySize - 1));
224         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
225         if(encoding->anyValue)
226                 generateAnyValueBinaryIndexEncoding(encoding);
227 }
228
229 void SATEncoder::generateOneHotEncodingVars(ElementEncoding *encoding) {
230         allocElementConstraintVariables(encoding, encoding->encArraySize);
231         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
232         for (uint i = 0; i < encoding->numVars; i++) {
233                 for (uint j = i + 1; j < encoding->numVars; j++) {
234                         addConstraintCNF(cnf, constraintNegate(constraintAND2(cnf, encoding->variables[i], encoding->variables[j])));
235                 }
236         }
237         addConstraintCNF(cnf, constraintOR(cnf, encoding->numVars, encoding->variables));
238         if(encoding->anyValue)
239                 generateAnyValueOneHotEncoding(encoding);
240 }
241
242 void SATEncoder::generateUnaryEncodingVars(ElementEncoding *encoding) {
243         allocElementConstraintVariables(encoding, encoding->encArraySize - 1);
244         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
245         //Add unary constraint
246         for (uint i = 1; i < encoding->numVars; i++) {
247                 addConstraintCNF(cnf, constraintOR2(cnf, encoding->variables[i - 1], constraintNegate(encoding->variables[i])));
248         }
249         if(encoding->anyValue)
250                 generateAnyValueUnaryEncoding(encoding);
251 }
252
253 void SATEncoder::generateElementEncoding(Element *element) {
254         ElementEncoding *encoding = element->getElementEncoding();
255         ASSERT(encoding->type != ELEM_UNASSIGNED);
256         if (encoding->variables != NULL)
257                 return;
258         switch (encoding->type) {
259         case ONEHOT:
260                 generateOneHotEncodingVars(encoding);
261                 return;
262         case BINARYINDEX:
263                 generateBinaryIndexEncodingVars(encoding);
264                 return;
265         case UNARY:
266                 generateUnaryEncodingVars(encoding);
267                 return;
268         case BINARYVAL:
269                 generateBinaryValueEncodingVars(encoding);
270                 return;
271         default:
272                 ASSERT(0);
273         }
274 }
275
276 void SATEncoder::generateAnyValueOneHotEncoding(ElementEncoding *encoding){
277         if(encoding->numVars == 0)
278                 return;
279         Edge carray[encoding->numVars];
280         int size = 0;
281         for (uint i = 0; i < encoding->encArraySize; i++) {
282                 if (encoding->isinUseElement(i)){
283                         carray[size++] = encoding->variables[i];
284                 }
285         }
286         if(size > 0){
287                 addConstraintCNF(cnf, constraintOR(cnf, size, carray));
288         }
289 }
290
291 void SATEncoder::generateAnyValueUnaryEncoding(ElementEncoding *encoding){
292         if (encoding->numVars == 0)
293                 return;
294         Edge carray[encoding->numVars];
295         int size = 0;
296         for (uint i = 0; i < encoding->encArraySize; i++) {
297                 if (encoding->isinUseElement(i)) {
298                         if (i == 0)
299                                  carray[size++] = constraintNegate(encoding->variables[0]);
300                         else if ((i + 1) == encoding->encArraySize)
301                                 carray[size++] = encoding->variables[i - 1];
302                         else
303                                 carray[size++] = constraintAND2(cnf, encoding->variables[i - 1], constraintNegate(encoding->variables[i]));
304                 }
305         }
306         if(size > 0){
307                 addConstraintCNF(cnf, constraintOR(cnf, size, carray));
308         }
309 }
310
311 void SATEncoder::generateAnyValueBinaryIndexEncoding(ElementEncoding *encoding){
312         if(encoding->numVars == 0)
313                 return;
314         Edge carray[encoding->numVars];
315         int size = 0;
316         int index = -1;
317         for(uint i= encoding->encArraySize-1; i>=0; i--){
318                 if(encoding->isinUseElement(i)){
319                         if(i+1 < encoding->encArraySize){
320                                 index = i+1;
321                         }
322                         break;
323                 }
324         }
325         if( index != -1 ){
326                 carray[size++] = generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, index);
327         }
328         index = index == -1? encoding->encArraySize-1 : index-1;
329         for(int i= index; i>=0; i--){
330                 if (!encoding->isinUseElement(i)){
331                         carray[size++] = constraintNegate( generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i));
332                 }
333         }
334         if(size > 0){
335                 addConstraintCNF(cnf, constraintAND(cnf, size, carray));
336         }
337 }
338
339 void SATEncoder::generateAnyValueBinaryValueEncoding(ElementEncoding *encoding){
340         uint64_t minvalueminusoffset = encoding->low - encoding->offset;
341         uint64_t maxvalueminusoffset = encoding->high - encoding->offset;
342         model_print("This is minvalueminus offset: %lu", minvalueminusoffset);
343         Edge lowerbound = generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, maxvalueminusoffset);
344         Edge upperbound = constraintNegate(generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, minvalueminusoffset));
345         addConstraintCNF(cnf, constraintAND2(cnf, lowerbound, upperbound));
346 }
347