Merging + fixing memory bugs
[satune.git] / src / Backend / satelemencoder.cc
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "common.h"
4 #include "ops.h"
5 #include "element.h"
6 #include "set.h"
7 #include "predicate.h"
8 #include "csolver.h"
9 #include "tunable.h"
10 #include <cmath>
11
12 void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool &memo) {
13         uint numParents = elem->parents.getSize();
14         uint posPolarity = 0;
15         uint negPolarity = 0;
16         memo = false;
17         if (elem->type == ELEMFUNCRETURN) {
18                 memo = true;
19         }
20         for (uint i = 0; i < numParents; i++) {
21                 ASTNode *node = elem->parents.get(i);
22                 if (node->type == PREDICATEOP) {
23                         BooleanPredicate *pred = (BooleanPredicate *) node;
24                         Polarity polarity = pred->polarity;
25                         FunctionEncodingType encType = pred->encoding.type;
26                         bool generateNegation = encType == ENUMERATEIMPLICATIONSNEGATE;
27                         if (pred->predicate->type == TABLEPRED) {
28                                 //Could be smarter, but just do default thing for now
29
30                                 UndefinedBehavior undefStatus = ((PredicateTable *)pred->predicate)->undefinedbehavior;
31
32                                 Polarity tpolarity = polarity;
33                                 if (generateNegation)
34                                         tpolarity = negatePolarity(tpolarity);
35                                 if (undefStatus == SATC_FLAGFORCEUNDEFINED)
36                                         tpolarity = P_BOTHTRUEFALSE;
37                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
38                                         memo = true;
39                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
40                                         memo = true;
41                         } else if (pred->predicate->type == OPERATORPRED) {
42                                 if (pred->encoding.type == ENUMERATEIMPLICATIONS || pred->encoding.type == ENUMERATEIMPLICATIONSNEGATE) {
43                                         Polarity tpolarity = polarity;
44                                         if (generateNegation)
45                                                 tpolarity = negatePolarity(tpolarity);
46                                         PredicateOperator *predicate = (PredicateOperator *)pred->predicate;
47                                         uint numDomains = pred->inputs.getSize();
48                                         bool isConstant = true;
49                                         for (uint i = 0; i < numDomains; i++) {
50                                                 Element *e = pred->inputs.get(i);
51                                                 if (elem != e && e->type != ELEMCONST) {
52                                                         isConstant = false;
53                                                 }
54                                         }
55                                         if (predicate->getOp() == SATC_EQUALS) {
56                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
57                                                         posPolarity++;
58                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
59                                                         negPolarity++;
60                                         } else {
61                                                 if (isConstant) {
62                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
63                                                                 posPolarity++;
64                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
65                                                                 negPolarity++;
66                                                 } else {
67                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
68                                                                 memo = true;
69                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
70                                                                 memo = true;
71                                                 }
72                                         }
73                                 }
74                         } else {
75                                 ASSERT(0);
76                         }
77                 } else if (node->type == ELEMFUNCRETURN) {
78                         //we are input to function, so memoize negative case
79                         memo = true;
80                 } else {
81                         ASSERT(0);
82                 }
83         }
84         if (posPolarity > 1)
85                 memo = true;
86         if (negPolarity > 1)
87                 memo = true;
88 }
89
90
91 Edge SATEncoder::getElementValueConstraint(Element *elem, Polarity p, uint64_t value) {
92         switch (elem->getElementEncoding()->type) {
93         case ONEHOT:
94                 return getElementValueOneHotConstraint(elem, p, value);
95         case UNARY:
96                 return getElementValueUnaryConstraint(elem, p, value);
97         case BINARYINDEX:
98                 return getElementValueBinaryIndexConstraint(elem, p, value);
99         case BINARYVAL:
100                 return getElementValueBinaryValueConstraint(elem, p, value);
101                 break;
102         default:
103                 ASSERT(0);
104                 break;
105         }
106         return E_BOGUS;
107 }
108
109 bool impliesPolarity(Polarity curr, Polarity goal) {
110         return (((int) curr) & ((int)goal)) == ((int) goal);
111 }
112
113 Edge SATEncoder::getElementValueBinaryIndexConstraint(Element *elem, Polarity p, uint64_t value) {
114         ASTNodeType type = elem->type;
115         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
116         ElementEncoding *elemEnc = elem->getElementEncoding();
117
118         //Check if we need to generate proxy variables
119         if (elemEnc->encoding == EENC_UNKNOWN && elemEnc->numVars > 1) {
120                 bool memo = false;
121                 shouldMemoize(elem, value, memo);
122                 if (memo) {
123                         elemEnc->encoding = EENC_BOTH;
124                         elemEnc->polarityArray = (Polarity *) ourcalloc(1, sizeof(Polarity) * elemEnc->encArraySize);
125                         elemEnc->edgeArray = (Edge *) ourcalloc(1, sizeof(Edge) * elemEnc->encArraySize);
126                 } else {
127                         elemEnc->encoding = EENC_NONE;
128                 }
129         }
130
131         for (uint i = 0; i < elemEnc->encArraySize; i++) {
132                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
133                         if (elemEnc->numVars == 0)
134                                 return E_True;
135
136                         if (elemEnc->encoding != EENC_NONE && elemEnc->numVars > 1) {
137                                 if (impliesPolarity(elemEnc->polarityArray[i], p)) {
138                                         return elemEnc->edgeArray[i];
139                                 } else {
140                                         if (edgeIsNull(elemEnc->edgeArray[i])) {
141                                                 elemEnc->edgeArray[i] = constraintNewVar(cnf);
142                                         }
143                                         if (elemEnc->polarityArray[i] == P_UNDEFINED && p == P_BOTHTRUEFALSE) {
144                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_BOTHTRUEFALSE);
145                                                 elemEnc->polarityArray[i] = p;
146                                         } else if (!impliesPolarity(elemEnc->polarityArray[i], P_TRUE)  && impliesPolarity(p, P_TRUE)) {
147                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_TRUE);
148                                                 elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i]) | ((int)P_TRUE));
149                                         } else if (!impliesPolarity(elemEnc->polarityArray[i], P_FALSE)  && impliesPolarity(p, P_FALSE)) {
150                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_FALSE);
151                                                 elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i]) | ((int)P_FALSE));
152                                         }
153                                         return elemEnc->edgeArray[i];
154                                 }
155                         }
156                         return generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i);
157                 }
158         }
159         return E_False;
160 }
161
162 Edge SATEncoder::getElementValueOneHotConstraint(Element *elem, Polarity p, uint64_t value) {
163         ASTNodeType type = elem->type;
164         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
165         ElementEncoding *elemEnc = elem->getElementEncoding();
166         for (uint i = 0; i < elemEnc->encArraySize; i++) {
167                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
168                         return (elemEnc->numVars == 0) ? E_True : elemEnc->variables[i];
169                 }
170         }
171         return E_False;
172 }
173
174 Edge SATEncoder::getElementValueUnaryConstraint(Element *elem, Polarity p, uint64_t value) {
175         ASTNodeType type = elem->type;
176         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
177         ElementEncoding *elemEnc = elem->getElementEncoding();
178         for (uint i = 0; i < elemEnc->encArraySize; i++) {
179                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
180                         if (elemEnc->numVars == 0)
181                                 return E_True;
182                         if (i == 0)
183                                 return constraintNegate(elemEnc->variables[0]);
184                         else if ((i + 1) == elemEnc->encArraySize)
185                                 return elemEnc->variables[i - 1];
186                         else
187                                 return constraintAND2(cnf, elemEnc->variables[i - 1], constraintNegate(elemEnc->variables[i]));
188                 }
189         }
190         return E_False;
191 }
192
193 Edge SATEncoder::getElementValueBinaryValueConstraint(Element *element, Polarity p, uint64_t value) {
194         ASTNodeType type = element->type;
195         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
196         ElementEncoding *elemEnc = element->getElementEncoding();
197         if (elemEnc->low <= elemEnc->high) {
198                 if (value < elemEnc->low || value > elemEnc->high)
199                         return E_False;
200         } else {
201                 //Range wraps around 0
202                 if (value < elemEnc->low && value > elemEnc->high)
203                         return E_False;
204         }
205
206         uint64_t valueminusoffset = value - elemEnc->offset;
207         return generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, valueminusoffset);
208 }
209
210 void allocElementConstraintVariables(ElementEncoding *This, uint numVars) {
211         This->numVars = numVars;
212         This->variables = (Edge *)ourmalloc(sizeof(Edge) * numVars);
213 }
214
215 void SATEncoder::generateBinaryValueEncodingVars(ElementEncoding *encoding) {
216         ASSERT(encoding->type == BINARYVAL);
217         allocElementConstraintVariables(encoding, encoding->numBits);
218         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
219         if (encoding->element->anyValue)
220                 generateAnyValueBinaryValueEncoding(encoding);
221 }
222
223 void SATEncoder::freezeElementVariables(ElementEncoding *encoding) {
224         ASSERT(encoding->element->frozen);
225         for (uint i = 0; i < encoding->numVars; i++) {
226                 Edge e = encoding->variables[i];
227                 ASSERT(edgeIsVarConst(e));
228                 freezeVariable(cnf, e);
229         }
230         for(uint i=0; i< encoding->encArraySize; i++){
231                 if(encoding->isinUseElement(i) && encoding->encoding != EENC_NONE && encoding->numVars > 1){
232                         Edge e = encoding->edgeArray[i];
233                         if(!edgeIsNull(e)){
234                                 ASSERT(edgeIsVarConst(e));
235                                 freezeVariable(cnf, e);
236                         }
237                 }
238         }
239 }
240
241 void SATEncoder::generateBinaryIndexEncodingVars(ElementEncoding *encoding) {
242         ASSERT(encoding->type == BINARYINDEX);
243         allocElementConstraintVariables(encoding, NUMBITS(encoding->encArraySize - 1));
244         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
245         if (encoding->element->anyValue) {
246                 uint setSize = encoding->element->getRange()->getSize();
247                 int maxIndex = getMaximumUsedSize(encoding);
248                 if (setSize == encoding->encArraySize && maxIndex == (int)setSize) {
249                         return;
250                 }
251                 double ratio = (setSize * (1 + 2 * encoding->numVars)) / (encoding->numVars * (encoding->numVars + maxIndex * 1.0 - setSize));
252 //              model_print("encArraySize=%u\tmaxIndex=%d\tsetSize=%u\tmetric=%f\tnumBits=%u\n", encoding->encArraySize, maxIndex, setSize, ratio, encoding->numVars);
253                 if ( ratio <  pow(2, (uint)solver->getTuner()->getTunable(MUSTVALUE, &mustValueBinaryIndex) - 3)) {
254                         generateAnyValueBinaryIndexEncodingPositive(encoding);
255                 } else {
256                         generateAnyValueBinaryIndexEncoding(encoding);
257                 }
258         }
259 }
260
261 void SATEncoder::generateOneHotEncodingVars(ElementEncoding *encoding) {
262         allocElementConstraintVariables(encoding, encoding->encArraySize);
263         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
264         for (uint i = 0; i < encoding->numVars; i++) {
265                 for (uint j = i + 1; j < encoding->numVars; j++) {
266                         addConstraintCNF(cnf, constraintNegate(constraintAND2(cnf, encoding->variables[i], encoding->variables[j])));
267                 }
268         }
269         if (encoding->element->anyValue)
270                 addConstraintCNF(cnf, constraintOR(cnf, encoding->numVars, encoding->variables));
271 }
272
273 void SATEncoder::generateUnaryEncodingVars(ElementEncoding *encoding) {
274         allocElementConstraintVariables(encoding, encoding->encArraySize - 1);
275         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
276         //Add unary constraint
277         for (uint i = 1; i < encoding->numVars; i++) {
278                 addConstraintCNF(cnf, constraintOR2(cnf, encoding->variables[i - 1], constraintNegate(encoding->variables[i])));
279         }
280 }
281
282 void SATEncoder::generateElementEncoding(Element *element) {
283         ElementEncoding *encoding = element->getElementEncoding();
284         ASSERT(encoding->type != ELEM_UNASSIGNED);
285         if (encoding->variables != NULL)
286                 return;
287         switch (encoding->type) {
288         case ONEHOT:
289                 generateOneHotEncodingVars(encoding);
290                 return;
291         case BINARYINDEX:
292                 generateBinaryIndexEncodingVars(encoding);
293                 return;
294         case UNARY:
295                 generateUnaryEncodingVars(encoding);
296                 return;
297         case BINARYVAL:
298                 generateBinaryValueEncodingVars(encoding);
299                 return;
300         default:
301                 ASSERT(0);
302         }
303 }
304
305 int SATEncoder::getMaximumUsedSize(ElementEncoding *encoding) {
306         if(encoding->encArraySize == 1){
307                 return 1;
308         }
309         for (int i = encoding->encArraySize - 1; i >= 0; i--) {
310                 if (encoding->isinUseElement(i))
311                         return i + 1;
312         }
313         ASSERT(false);
314         return -1;
315 }
316
317 void SATEncoder::generateAnyValueBinaryIndexEncoding(ElementEncoding *encoding) {
318         if (encoding->numVars == 0)
319                 return;
320         int index = getMaximumUsedSize(encoding);
321         if ( index != (int)encoding->encArraySize ) {
322                 addConstraintCNF(cnf, generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, index));
323         }
324         for (int i = index - 1; i >= 0; i--) {
325                 if (!encoding->isinUseElement(i)) {
326                         addConstraintCNF(cnf, constraintNegate( generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i)));
327                 }
328         }
329 }
330
331 void SATEncoder::generateAnyValueBinaryIndexEncodingPositive(ElementEncoding *encoding) {
332         if (encoding->numVars == 0)
333                 return;
334         Edge carray[encoding->encArraySize];
335         uint size = 0;
336         for (uint i = 0; i < encoding->encArraySize; i++) {
337                 if (encoding->isinUseElement(i)) {
338                         carray[size] = generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i);
339                         size++;
340                 }
341         }
342         addConstraintCNF(cnf, constraintOR(cnf, size, carray));
343 }
344
345 void SATEncoder::generateAnyValueBinaryValueEncoding(ElementEncoding *encoding) {
346         uint64_t minvalueminusoffset = encoding->low - encoding->offset;
347         uint64_t maxvalueminusoffset = encoding->high - encoding->offset;
348         model_print("This is minvalueminus offset: %lu", minvalueminusoffset);
349         Edge lowerbound = generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, maxvalueminusoffset);
350         Edge upperbound = constraintNegate(generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, minvalueminusoffset));
351         addConstraintCNF(cnf, lowerbound);
352         addConstraintCNF(cnf, upperbound);
353 }
354