8f06eecad991e3d68d9b83e5d5b47f3bba3a4f6b
[satune.git] / src / Backend / satelemencoder.cc
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "common.h"
4 #include "ops.h"
5 #include "element.h"
6 #include "set.h"
7 #include "predicate.h"
8 #include "csolver.h"
9 #include "tunable.h"
10
11 void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool &memo) {
12         uint numParents = elem->parents.getSize();
13         uint posPolarity = 0;
14         uint negPolarity = 0;
15         memo = false;
16         if (elem->type == ELEMFUNCRETURN) {
17                 memo = true;
18         }
19         for (uint i = 0; i < numParents; i++) {
20                 ASTNode *node = elem->parents.get(i);
21                 if (node->type == PREDICATEOP) {
22                         BooleanPredicate *pred = (BooleanPredicate *) node;
23                         Polarity polarity = pred->polarity;
24                         FunctionEncodingType encType = pred->encoding.type;
25                         bool generateNegation = encType == ENUMERATEIMPLICATIONSNEGATE;
26                         if (pred->predicate->type == TABLEPRED) {
27                                 //Could be smarter, but just do default thing for now
28
29                                 UndefinedBehavior undefStatus = ((PredicateTable *)pred->predicate)->undefinedbehavior;
30
31                                 Polarity tpolarity = polarity;
32                                 if (generateNegation)
33                                         tpolarity = negatePolarity(tpolarity);
34                                 if (undefStatus == SATC_FLAGFORCEUNDEFINED)
35                                         tpolarity = P_BOTHTRUEFALSE;
36                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
37                                         memo = true;
38                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
39                                         memo = true;
40                         } else if (pred->predicate->type == OPERATORPRED) {
41                                 if (pred->encoding.type == ENUMERATEIMPLICATIONS || pred->encoding.type == ENUMERATEIMPLICATIONSNEGATE) {
42                                         Polarity tpolarity = polarity;
43                                         if (generateNegation)
44                                                 tpolarity = negatePolarity(tpolarity);
45                                         PredicateOperator *predicate = (PredicateOperator *)pred->predicate;
46                                         uint numDomains = pred->inputs.getSize();
47                                         bool isConstant = true;
48                                         for (uint i = 0; i < numDomains; i++) {
49                                                 Element *e = pred->inputs.get(i);
50                                                 if (elem != e && e->type != ELEMCONST) {
51                                                         isConstant = false;
52                                                 }
53                                         }
54                                         if (predicate->getOp() == SATC_EQUALS) {
55                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
56                                                         posPolarity++;
57                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
58                                                         negPolarity++;
59                                         } else {
60                                                 if (isConstant) {
61                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
62                                                                 posPolarity++;
63                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
64                                                                 negPolarity++;
65                                                 } else {
66                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
67                                                                 memo = true;
68                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
69                                                                 memo = true;
70                                                 }
71                                         }
72                                 }
73                         } else {
74                                 ASSERT(0);
75                         }
76                 } else if (node->type == ELEMFUNCRETURN) {
77                         //we are input to function, so memoize negative case
78                         memo = true;
79                 } else {
80                         ASSERT(0);
81                 }
82         }
83         if (posPolarity > 1)
84                 memo = true;
85         if (negPolarity > 1)
86                 memo = true;
87 }
88
89
90 Edge SATEncoder::getElementValueConstraint(Element *elem, Polarity p, uint64_t value) {
91         switch (elem->getElementEncoding()->type) {
92         case ONEHOT:
93                 return getElementValueOneHotConstraint(elem, p, value);
94         case UNARY:
95                 return getElementValueUnaryConstraint(elem, p, value);
96         case BINARYINDEX:
97                 return getElementValueBinaryIndexConstraint(elem, p, value);
98         case BINARYVAL:
99                 return getElementValueBinaryValueConstraint(elem, p, value);
100                 break;
101         default:
102                 ASSERT(0);
103                 break;
104         }
105         return E_BOGUS;
106 }
107
108 bool impliesPolarity(Polarity curr, Polarity goal) {
109         return (((int) curr) & ((int)goal)) == ((int) goal);
110 }
111
112 Edge SATEncoder::getElementValueBinaryIndexConstraint(Element *elem, Polarity p, uint64_t value) {
113         ASTNodeType type = elem->type;
114         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
115         ElementEncoding *elemEnc = elem->getElementEncoding();
116
117         //Check if we need to generate proxy variables
118         if (elemEnc->encoding == EENC_UNKNOWN && elemEnc->numVars > 1) {
119                 bool memo = false;
120                 shouldMemoize(elem, value, memo);
121                 if (memo) {
122                         elemEnc->encoding = EENC_BOTH;
123                         elemEnc->polarityArray = (Polarity *) ourcalloc(1, sizeof(Polarity) * elemEnc->encArraySize);
124                         elemEnc->edgeArray = (Edge *) ourcalloc(1, sizeof(Edge) * elemEnc->encArraySize);
125                 } else {
126                         elemEnc->encoding = EENC_NONE;
127                 }
128         }
129
130         for (uint i = 0; i < elemEnc->encArraySize; i++) {
131                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
132                         if (elemEnc->numVars == 0)
133                                 return E_True;
134
135                         if (elemEnc->encoding != EENC_NONE && elemEnc->numVars > 1) {
136                                 if (impliesPolarity(elemEnc->polarityArray[i], p)) {
137                                         return elemEnc->edgeArray[i];
138                                 } else {
139                                         if (edgeIsNull(elemEnc->edgeArray[i])) {
140                                                 elemEnc->edgeArray[i] = constraintNewVar(cnf);
141                                         }
142                                         if (elemEnc->polarityArray[i] == P_UNDEFINED && p == P_BOTHTRUEFALSE) {
143                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_BOTHTRUEFALSE);
144                                                 elemEnc->polarityArray[i] = p;
145                                         } else if (!impliesPolarity(elemEnc->polarityArray[i], P_TRUE)  && impliesPolarity(p, P_TRUE)) {
146                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_TRUE);
147                                                 elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i]) | ((int)P_TRUE));
148                                         } else if (!impliesPolarity(elemEnc->polarityArray[i], P_FALSE)  && impliesPolarity(p, P_FALSE)) {
149                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_FALSE);
150                                                 elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i]) | ((int)P_FALSE));
151                                         }
152                                         return elemEnc->edgeArray[i];
153                                 }
154                         }
155                         return generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i);
156                 }
157         }
158         return E_False;
159 }
160
161 Edge SATEncoder::getElementValueOneHotConstraint(Element *elem, Polarity p, uint64_t value) {
162         ASTNodeType type = elem->type;
163         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
164         ElementEncoding *elemEnc = elem->getElementEncoding();
165         for (uint i = 0; i < elemEnc->encArraySize; i++) {
166                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
167                         return (elemEnc->numVars == 0) ? E_True : elemEnc->variables[i];
168                 }
169         }
170         return E_False;
171 }
172
173 Edge SATEncoder::getElementValueUnaryConstraint(Element *elem, Polarity p, uint64_t value) {
174         ASTNodeType type = elem->type;
175         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
176         ElementEncoding *elemEnc = elem->getElementEncoding();
177         for (uint i = 0; i < elemEnc->encArraySize; i++) {
178                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
179                         if (elemEnc->numVars == 0)
180                                 return E_True;
181                         if (i == 0)
182                                 return constraintNegate(elemEnc->variables[0]);
183                         else if ((i + 1) == elemEnc->encArraySize)
184                                 return elemEnc->variables[i - 1];
185                         else
186                                 return constraintAND2(cnf, elemEnc->variables[i - 1], constraintNegate(elemEnc->variables[i]));
187                 }
188         }
189         return E_False;
190 }
191
192 Edge SATEncoder::getElementValueBinaryValueConstraint(Element *element, Polarity p, uint64_t value) {
193         ASTNodeType type = element->type;
194         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
195         ElementEncoding *elemEnc = element->getElementEncoding();
196         if (elemEnc->low <= elemEnc->high) {
197                 if (value < elemEnc->low || value > elemEnc->high)
198                         return E_False;
199         } else {
200                 //Range wraps around 0
201                 if (value < elemEnc->low && value > elemEnc->high)
202                         return E_False;
203         }
204
205         uint64_t valueminusoffset = value - elemEnc->offset;
206         return generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, valueminusoffset);
207 }
208
209 void allocElementConstraintVariables(ElementEncoding *This, uint numVars) {
210         This->numVars = numVars;
211         This->variables = (Edge *)ourmalloc(sizeof(Edge) * numVars);
212 }
213
214 void SATEncoder::generateBinaryValueEncodingVars(ElementEncoding *encoding) {
215         ASSERT(encoding->type == BINARYVAL);
216         allocElementConstraintVariables(encoding, encoding->numBits);
217         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
218         if (encoding->element->anyValue)
219                 generateAnyValueBinaryValueEncoding(encoding);
220 }
221
222 void SATEncoder::generateBinaryIndexEncodingVars(ElementEncoding *encoding) {
223         ASSERT(encoding->type == BINARYINDEX);
224         allocElementConstraintVariables(encoding, NUMBITS(encoding->encArraySize - 1));
225         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
226         if (encoding->element->anyValue){
227                 uint setSize = encoding->element->getRange()->getSize();
228                 uint encArraySize = encoding->encArraySize;
229                 if(setSize < encArraySize * (uint)solver->getTuner()->getTunable(MUSTVALUE, &mustValueBinaryIndex)/10){
230                         generateAnyValueBinaryIndexEncodingPositive(encoding);
231                 } else {
232                         generateAnyValueBinaryIndexEncoding(encoding);
233                 }
234         }
235 }
236
237 void SATEncoder::generateOneHotEncodingVars(ElementEncoding *encoding) {
238         allocElementConstraintVariables(encoding, encoding->encArraySize);
239         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
240         for (uint i = 0; i < encoding->numVars; i++) {
241                 for (uint j = i + 1; j < encoding->numVars; j++) {
242                         addConstraintCNF(cnf, constraintNegate(constraintAND2(cnf, encoding->variables[i], encoding->variables[j])));
243                 }
244         }
245         if (encoding->element->anyValue)
246                 addConstraintCNF(cnf, constraintOR(cnf, encoding->numVars, encoding->variables));
247 }
248
249 void SATEncoder::generateUnaryEncodingVars(ElementEncoding *encoding) {
250         allocElementConstraintVariables(encoding, encoding->encArraySize - 1);
251         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
252         //Add unary constraint
253         for (uint i = 1; i < encoding->numVars; i++) {
254                 addConstraintCNF(cnf, constraintOR2(cnf, encoding->variables[i - 1], constraintNegate(encoding->variables[i])));
255         }
256 }
257
258 void SATEncoder::generateElementEncoding(Element *element) {
259         ElementEncoding *encoding = element->getElementEncoding();
260         ASSERT(encoding->type != ELEM_UNASSIGNED);
261         if (encoding->variables != NULL)
262                 return;
263         switch (encoding->type) {
264         case ONEHOT:
265                 generateOneHotEncodingVars(encoding);
266                 return;
267         case BINARYINDEX:
268                 generateBinaryIndexEncodingVars(encoding);
269                 return;
270         case UNARY:
271                 generateUnaryEncodingVars(encoding);
272                 return;
273         case BINARYVAL:
274                 generateBinaryValueEncodingVars(encoding);
275                 return;
276         default:
277                 ASSERT(0);
278         }
279 }
280
281 void SATEncoder::generateAnyValueBinaryIndexEncoding(ElementEncoding *encoding) {
282         if (encoding->numVars == 0)
283                 return;
284         int index = -1;
285         for (uint i = encoding->encArraySize - 1; i >= 0; i--) {
286                 if (encoding->isinUseElement(i)) {
287                         if (i + 1 < encoding->encArraySize) {
288                                 index = i + 1;
289                         }
290                         break;
291                 }
292         }
293         if ( index != -1 ) {
294                 addConstraintCNF(cnf, generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, index));
295         }
296         index = index == -1 ? encoding->encArraySize - 1 : index - 1;
297         for (int i = index; i >= 0; i--) {
298                 if (!encoding->isinUseElement(i)) {
299                         addConstraintCNF(cnf, constraintNegate( generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i)));
300                 }
301         }
302 }
303
304 void SATEncoder::generateAnyValueBinaryIndexEncodingPositive(ElementEncoding *encoding) {
305         if (encoding->numVars == 0)
306                 return;
307         Edge carray[encoding->encArraySize];
308         uint size = 0;
309         for (uint i = 0; i < encoding->encArraySize; i++) {
310                 if (encoding->isinUseElement(i)) {
311                         carray[size] = generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i);
312                         size++;
313                 }
314         }
315         addConstraintCNF(cnf, constraintOR(cnf, size, carray));
316 }
317
318 void SATEncoder::generateAnyValueBinaryValueEncoding(ElementEncoding *encoding) {
319         uint64_t minvalueminusoffset = encoding->low - encoding->offset;
320         uint64_t maxvalueminusoffset = encoding->high - encoding->offset;
321         model_print("This is minvalueminus offset: %lu", minvalueminusoffset);
322         Edge lowerbound = generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, maxvalueminusoffset);
323         Edge upperbound = constraintNegate(generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, minvalueminusoffset));
324         addConstraintCNF(cnf, lowerbound);
325         addConstraintCNF(cnf, upperbound);
326 }
327