1)Making naiveencoder and encoding graph use tuner 2)Adding timeout to the sat solver...
[satune.git] / src / Backend / satelemencoder.cc
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "common.h"
4 #include "ops.h"
5 #include "element.h"
6 #include "set.h"
7 #include "predicate.h"
8 #include "csolver.h"
9 #include "tunable.h"
10
11 void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool &memo) {
12         uint numParents = elem->parents.getSize();
13         uint posPolarity = 0;
14         uint negPolarity = 0;
15         memo = false;
16         if (elem->type == ELEMFUNCRETURN) {
17                 memo = true;
18         }
19         for (uint i = 0; i < numParents; i++) {
20                 ASTNode *node = elem->parents.get(i);
21                 if (node->type == PREDICATEOP) {
22                         BooleanPredicate *pred = (BooleanPredicate *) node;
23                         Polarity polarity = pred->polarity;
24                         FunctionEncodingType encType = pred->encoding.type;
25                         bool generateNegation = encType == ENUMERATEIMPLICATIONSNEGATE;
26                         if (pred->predicate->type == TABLEPRED) {
27                                 //Could be smarter, but just do default thing for now
28
29                                 UndefinedBehavior undefStatus = ((PredicateTable *)pred->predicate)->undefinedbehavior;
30
31                                 Polarity tpolarity = polarity;
32                                 if (generateNegation)
33                                         tpolarity = negatePolarity(tpolarity);
34                                 if (undefStatus == SATC_FLAGFORCEUNDEFINED)
35                                         tpolarity = P_BOTHTRUEFALSE;
36                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
37                                         memo = true;
38                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
39                                         memo = true;
40                         } else if (pred->predicate->type == OPERATORPRED) {
41                                 if (pred->encoding.type == ENUMERATEIMPLICATIONS || pred->encoding.type == ENUMERATEIMPLICATIONSNEGATE) {
42                                         Polarity tpolarity = polarity;
43                                         if (generateNegation)
44                                                 tpolarity = negatePolarity(tpolarity);
45                                         PredicateOperator *predicate = (PredicateOperator *)pred->predicate;
46                                         uint numDomains = pred->inputs.getSize();
47                                         bool isConstant = true;
48                                         for (uint i = 0; i < numDomains; i++) {
49                                                 Element *e = pred->inputs.get(i);
50                                                 if (elem != e && e->type != ELEMCONST) {
51                                                         isConstant = false;
52                                                 }
53                                         }
54                                         if (predicate->getOp() == SATC_EQUALS) {
55                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
56                                                         posPolarity++;
57                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
58                                                         negPolarity++;
59                                         } else {
60                                                 if (isConstant) {
61                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
62                                                                 posPolarity++;
63                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
64                                                                 negPolarity++;
65                                                 } else {
66                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
67                                                                 memo = true;
68                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
69                                                                 memo = true;
70                                                 }
71                                         }
72                                 }
73                         } else {
74                                 ASSERT(0);
75                         }
76                 } else if (node->type == ELEMFUNCRETURN) {
77                         //we are input to function, so memoize negative case
78                         memo = true;
79                 } else {
80                         ASSERT(0);
81                 }
82         }
83         if (posPolarity > 1)
84                 memo = true;
85         if (negPolarity > 1)
86                 memo = true;
87 }
88
89
90 Edge SATEncoder::getElementValueConstraint(Element *elem, Polarity p, uint64_t value) {
91         switch (elem->getElementEncoding()->type) {
92         case ONEHOT:
93                 return getElementValueOneHotConstraint(elem, p, value);
94         case UNARY:
95                 return getElementValueUnaryConstraint(elem, p, value);
96         case BINARYINDEX:
97                 return getElementValueBinaryIndexConstraint(elem, p, value);
98         case BINARYVAL:
99                 return getElementValueBinaryValueConstraint(elem, p, value);
100                 break;
101         default:
102                 ASSERT(0);
103                 break;
104         }
105         return E_BOGUS;
106 }
107
108 bool impliesPolarity(Polarity curr, Polarity goal) {
109         return (((int) curr) & ((int)goal)) == ((int) goal);
110 }
111
112 Edge SATEncoder::getElementValueBinaryIndexConstraint(Element *elem, Polarity p, uint64_t value) {
113         ASTNodeType type = elem->type;
114         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
115         ElementEncoding *elemEnc = elem->getElementEncoding();
116
117         //Check if we need to generate proxy variables
118         if (elemEnc->encoding == EENC_UNKNOWN && elemEnc->numVars > 1) {
119                 bool memo = false;
120                 shouldMemoize(elem, value, memo);
121                 if (memo) {
122                         elemEnc->encoding = EENC_BOTH;
123                         elemEnc->polarityArray = (Polarity *) ourcalloc(1, sizeof(Polarity) * elemEnc->encArraySize);
124                         elemEnc->edgeArray = (Edge *) ourcalloc(1, sizeof(Edge) * elemEnc->encArraySize);
125                 } else {
126                         elemEnc->encoding = EENC_NONE;
127                 }
128         }
129
130         for (uint i = 0; i < elemEnc->encArraySize; i++) {
131                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
132                         if (elemEnc->numVars == 0)
133                                 return E_True;
134
135                         if (elemEnc->encoding != EENC_NONE && elemEnc->numVars > 1) {
136                                 if (impliesPolarity(elemEnc->polarityArray[i], p)) {
137                                         return elemEnc->edgeArray[i];
138                                 } else {
139                                         if (edgeIsNull(elemEnc->edgeArray[i])) {
140                                                 elemEnc->edgeArray[i] = constraintNewVar(cnf);
141                                         }
142                                         if (elemEnc->polarityArray[i] == P_UNDEFINED && p == P_BOTHTRUEFALSE) {
143                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_BOTHTRUEFALSE);
144                                                 elemEnc->polarityArray[i] = p;
145                                         } else if (!impliesPolarity(elemEnc->polarityArray[i], P_TRUE)  && impliesPolarity(p, P_TRUE)) {
146                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_TRUE);
147                                                 elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i]) | ((int)P_TRUE));
148                                         } else if (!impliesPolarity(elemEnc->polarityArray[i], P_FALSE)  && impliesPolarity(p, P_FALSE)) {
149                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_FALSE);
150                                                 elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i]) | ((int)P_FALSE));
151                                         }
152                                         return elemEnc->edgeArray[i];
153                                 }
154                         }
155                         return generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i);
156                 }
157         }
158         return E_False;
159 }
160
161 Edge SATEncoder::getElementValueOneHotConstraint(Element *elem, Polarity p, uint64_t value) {
162         ASTNodeType type = elem->type;
163         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
164         ElementEncoding *elemEnc = elem->getElementEncoding();
165         for (uint i = 0; i < elemEnc->encArraySize; i++) {
166                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
167                         return (elemEnc->numVars == 0) ? E_True : elemEnc->variables[i];
168                 }
169         }
170         return E_False;
171 }
172
173 Edge SATEncoder::getElementValueUnaryConstraint(Element *elem, Polarity p, uint64_t value) {
174         ASTNodeType type = elem->type;
175         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
176         ElementEncoding *elemEnc = elem->getElementEncoding();
177         for (uint i = 0; i < elemEnc->encArraySize; i++) {
178                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
179                         if (elemEnc->numVars == 0)
180                                 return E_True;
181                         if (i == 0)
182                                 return constraintNegate(elemEnc->variables[0]);
183                         else if ((i + 1) == elemEnc->encArraySize)
184                                 return elemEnc->variables[i - 1];
185                         else
186                                 return constraintAND2(cnf, elemEnc->variables[i - 1], constraintNegate(elemEnc->variables[i]));
187                 }
188         }
189         return E_False;
190 }
191
192 Edge SATEncoder::getElementValueBinaryValueConstraint(Element *element, Polarity p, uint64_t value) {
193         ASTNodeType type = element->type;
194         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
195         ElementEncoding *elemEnc = element->getElementEncoding();
196         if (elemEnc->low <= elemEnc->high) {
197                 if (value < elemEnc->low || value > elemEnc->high)
198                         return E_False;
199         } else {
200                 //Range wraps around 0
201                 if (value < elemEnc->low && value > elemEnc->high)
202                         return E_False;
203         }
204
205         uint64_t valueminusoffset = value - elemEnc->offset;
206         return generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, valueminusoffset);
207 }
208
209 void allocElementConstraintVariables(ElementEncoding *This, uint numVars) {
210         This->numVars = numVars;
211         This->variables = (Edge *)ourmalloc(sizeof(Edge) * numVars);
212 }
213
214 void SATEncoder::generateBinaryValueEncodingVars(ElementEncoding *encoding) {
215         ASSERT(encoding->type == BINARYVAL);
216         allocElementConstraintVariables(encoding, encoding->numBits);
217         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
218         if (encoding->element->anyValue)
219                 generateAnyValueBinaryValueEncoding(encoding);
220 }
221
222 void SATEncoder::generateBinaryIndexEncodingVars(ElementEncoding *encoding) {
223         ASSERT(encoding->type == BINARYINDEX);
224         allocElementConstraintVariables(encoding, NUMBITS(encoding->encArraySize - 1));
225         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
226         if (encoding->element->anyValue){
227                 uint setSize = encoding->element->getRange()->getSize();
228                 uint encArraySize = encoding->encArraySize;
229                 model_print("setSize=%u\tencArraySize=%u\n", setSize, encArraySize);
230                 if(setSize < encArraySize * (uint)solver->getTuner()->getTunable(MUSTVALUE, &mustValueBinaryIndex)/10){
231                         generateAnyValueBinaryIndexEncodingPositive(encoding);
232                 } else {
233                 generateAnyValueBinaryIndexEncoding(encoding);
234                 }
235         }
236 }
237
238 void SATEncoder::generateOneHotEncodingVars(ElementEncoding *encoding) {
239         allocElementConstraintVariables(encoding, encoding->encArraySize);
240         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
241         for (uint i = 0; i < encoding->numVars; i++) {
242                 for (uint j = i + 1; j < encoding->numVars; j++) {
243                         addConstraintCNF(cnf, constraintNegate(constraintAND2(cnf, encoding->variables[i], encoding->variables[j])));
244                 }
245         }
246         if (encoding->element->anyValue)
247                 addConstraintCNF(cnf, constraintOR(cnf, encoding->numVars, encoding->variables));
248 }
249
250 void SATEncoder::generateUnaryEncodingVars(ElementEncoding *encoding) {
251         allocElementConstraintVariables(encoding, encoding->encArraySize - 1);
252         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
253         //Add unary constraint
254         for (uint i = 1; i < encoding->numVars; i++) {
255                 addConstraintCNF(cnf, constraintOR2(cnf, encoding->variables[i - 1], constraintNegate(encoding->variables[i])));
256         }
257 }
258
259 void SATEncoder::generateElementEncoding(Element *element) {
260         ElementEncoding *encoding = element->getElementEncoding();
261         ASSERT(encoding->type != ELEM_UNASSIGNED);
262         if (encoding->variables != NULL)
263                 return;
264         switch (encoding->type) {
265         case ONEHOT:
266                 generateOneHotEncodingVars(encoding);
267                 return;
268         case BINARYINDEX:
269                 generateBinaryIndexEncodingVars(encoding);
270                 return;
271         case UNARY:
272                 generateUnaryEncodingVars(encoding);
273                 return;
274         case BINARYVAL:
275                 generateBinaryValueEncodingVars(encoding);
276                 return;
277         default:
278                 ASSERT(0);
279         }
280 }
281
282 void SATEncoder::generateAnyValueBinaryIndexEncoding(ElementEncoding *encoding) {
283         if (encoding->numVars == 0)
284                 return;
285         int index = -1;
286         for (uint i = encoding->encArraySize - 1; i >= 0; i--) {
287                 if (encoding->isinUseElement(i)) {
288                         if (i + 1 < encoding->encArraySize) {
289                                 index = i + 1;
290                         }
291                         break;
292                 }
293         }
294         if ( index != -1 ) {
295                 addConstraintCNF(cnf, generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, index));
296         }
297         index = index == -1 ? encoding->encArraySize - 1 : index - 1;
298         for (int i = index; i >= 0; i--) {
299                 if (!encoding->isinUseElement(i)) {
300                         addConstraintCNF(cnf, constraintNegate( generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i)));
301                 }
302         }
303 }
304
305 void SATEncoder::generateAnyValueBinaryIndexEncodingPositive(ElementEncoding *encoding) {
306         if (encoding->numVars == 0)
307                 return;
308         Edge carray[encoding->encArraySize];
309         uint size = 0;
310         for (uint i = 0; i < encoding->encArraySize; i++) {
311                 if (encoding->isinUseElement(i)) {
312                         carray[size] = generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i);
313                         size++;
314                 }
315         }
316         addConstraintCNF(cnf, constraintOR(cnf, size, carray));
317 }
318
319 void SATEncoder::generateAnyValueBinaryValueEncoding(ElementEncoding *encoding) {
320         uint64_t minvalueminusoffset = encoding->low - encoding->offset;
321         uint64_t maxvalueminusoffset = encoding->high - encoding->offset;
322         model_print("This is minvalueminus offset: %lu", minvalueminusoffset);
323         Edge lowerbound = generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, maxvalueminusoffset);
324         Edge upperbound = constraintNegate(generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, minvalueminusoffset));
325         addConstraintCNF(cnf, lowerbound);
326         addConstraintCNF(cnf, upperbound);
327 }
328