Fix tabbing
[satune.git] / src / Backend / satelemencoder.cc
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "common.h"
4 #include "ops.h"
5 #include "element.h"
6 #include "set.h"
7 #include "predicate.h"
8
9
10 void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool &memo) {
11         uint numParents = elem->parents.getSize();
12         uint posPolarity = 0;
13         uint negPolarity = 0;
14         memo = false;
15         if (elem->type == ELEMFUNCRETURN) {
16                 memo = true;
17         }
18         for (uint i = 0; i < numParents; i++) {
19                 ASTNode *node = elem->parents.get(i);
20                 if (node->type == PREDICATEOP) {
21                         BooleanPredicate *pred = (BooleanPredicate *) node;
22                         Polarity polarity = pred->polarity;
23                         FunctionEncodingType encType = pred->encoding.type;
24                         bool generateNegation = encType == ENUMERATEIMPLICATIONSNEGATE;
25                         if (pred->predicate->type == TABLEPRED) {
26                                 //Could be smarter, but just do default thing for now
27
28                                 UndefinedBehavior undefStatus = ((PredicateTable *)pred->predicate)->undefinedbehavior;
29
30                                 Polarity tpolarity = polarity;
31                                 if (generateNegation)
32                                         tpolarity = negatePolarity(tpolarity);
33                                 if (undefStatus == SATC_FLAGFORCEUNDEFINED)
34                                         tpolarity = P_BOTHTRUEFALSE;
35                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
36                                         memo = true;
37                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
38                                         memo = true;
39                         } else if (pred->predicate->type == OPERATORPRED) {
40                                 if (pred->encoding.type == ENUMERATEIMPLICATIONS || pred->encoding.type == ENUMERATEIMPLICATIONSNEGATE) {
41                                         Polarity tpolarity = polarity;
42                                         if (generateNegation)
43                                                 tpolarity = negatePolarity(tpolarity);
44                                         PredicateOperator *predicate = (PredicateOperator *)pred->predicate;
45                                         uint numDomains = pred->inputs.getSize();
46                                         bool isConstant = true;
47                                         for (uint i = 0; i < numDomains; i++) {
48                                                 Element *e = pred->inputs.get(i);
49                                                 if (elem != e && e->type != ELEMCONST) {
50                                                         isConstant = false;
51                                                 }
52                                         }
53                                         if (predicate->getOp() == SATC_EQUALS) {
54                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
55                                                         posPolarity++;
56                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
57                                                         negPolarity++;
58                                         } else {
59                                                 if (isConstant) {
60                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
61                                                                 posPolarity++;
62                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
63                                                                 negPolarity++;
64                                                 } else {
65                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
66                                                                 memo = true;
67                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
68                                                                 memo = true;
69                                                 }
70                                         }
71                                 }
72                         } else {
73                                 ASSERT(0);
74                         }
75                 } else if (node->type == ELEMFUNCRETURN) {
76                         //we are input to function, so memoize negative case
77                         memo = true;
78                 } else {
79                         ASSERT(0);
80                 }
81         }
82         if (posPolarity > 1)
83                 memo = true;
84         if (negPolarity > 1)
85                 memo = true;
86 }
87
88
89 Edge SATEncoder::getElementValueConstraint(Element *elem, Polarity p, uint64_t value) {
90         switch (elem->getElementEncoding()->type) {
91         case ONEHOT:
92                 return getElementValueOneHotConstraint(elem, p, value);
93         case UNARY:
94                 return getElementValueUnaryConstraint(elem, p, value);
95         case BINARYINDEX:
96                 return getElementValueBinaryIndexConstraint(elem, p, value);
97         case BINARYVAL:
98                 return getElementValueBinaryValueConstraint(elem, p, value);
99                 break;
100         default:
101                 ASSERT(0);
102                 break;
103         }
104         return E_BOGUS;
105 }
106
107 bool impliesPolarity(Polarity curr, Polarity goal) {
108         return (((int) curr) & ((int)goal)) == ((int) goal);
109 }
110
111 Edge SATEncoder::getElementValueBinaryIndexConstraint(Element *elem, Polarity p, uint64_t value) {
112         ASTNodeType type = elem->type;
113         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
114         ElementEncoding *elemEnc = elem->getElementEncoding();
115
116         //Check if we need to generate proxy variables
117         if (elemEnc->encoding == EENC_UNKNOWN && elemEnc->numVars > 1) {
118                 bool memo = false;
119                 shouldMemoize(elem, value, memo);
120                 if (memo) {
121                         elemEnc->encoding = EENC_BOTH;
122                         elemEnc->polarityArray = (Polarity *) ourcalloc(1, sizeof(Polarity) * elemEnc->encArraySize);
123                         elemEnc->edgeArray = (Edge *) ourcalloc(1, sizeof(Edge) * elemEnc->encArraySize);
124                 } else {
125                         elemEnc->encoding = EENC_NONE;
126                 }
127         }
128
129         for (uint i = 0; i < elemEnc->encArraySize; i++) {
130                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
131                         if (elemEnc->numVars == 0)
132                                 return E_True;
133
134                         if (elemEnc->encoding != EENC_NONE && elemEnc->numVars > 1) {
135                                 if (impliesPolarity(elemEnc->polarityArray[i], p)) {
136                                         return elemEnc->edgeArray[i];
137                                 } else {
138                                         if (edgeIsNull(elemEnc->edgeArray[i])) {
139                                                 elemEnc->edgeArray[i] = constraintNewVar(cnf);
140                                         }
141                                         if (elemEnc->polarityArray[i] == P_UNDEFINED && p == P_BOTHTRUEFALSE) {
142                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_BOTHTRUEFALSE);
143                                                 elemEnc->polarityArray[i] = p;
144                                         } else if (!impliesPolarity(elemEnc->polarityArray[i], P_TRUE)  && impliesPolarity(p, P_TRUE)) {
145                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_TRUE);
146                                                 elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i]) | ((int)P_TRUE));
147                                         } else if (!impliesPolarity(elemEnc->polarityArray[i], P_FALSE)  && impliesPolarity(p, P_FALSE)) {
148                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_FALSE);
149                                                 elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i]) | ((int)P_FALSE));
150                                         }
151                                         return elemEnc->edgeArray[i];
152                                 }
153                         }
154                         return generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i);
155                 }
156         }
157         return E_False;
158 }
159
160 Edge SATEncoder::getElementValueOneHotConstraint(Element *elem, Polarity p, uint64_t value) {
161         ASTNodeType type = elem->type;
162         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
163         ElementEncoding *elemEnc = elem->getElementEncoding();
164         for (uint i = 0; i < elemEnc->encArraySize; i++) {
165                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
166                         return (elemEnc->numVars == 0) ? E_True : elemEnc->variables[i];
167                 }
168         }
169         return E_False;
170 }
171
172 Edge SATEncoder::getElementValueUnaryConstraint(Element *elem, Polarity p, uint64_t value) {
173         ASTNodeType type = elem->type;
174         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
175         ElementEncoding *elemEnc = elem->getElementEncoding();
176         for (uint i = 0; i < elemEnc->encArraySize; i++) {
177                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
178                         if (elemEnc->numVars == 0)
179                                 return E_True;
180                         if (i == 0)
181                                 return constraintNegate(elemEnc->variables[0]);
182                         else if ((i + 1) == elemEnc->encArraySize)
183                                 return elemEnc->variables[i - 1];
184                         else
185                                 return constraintAND2(cnf, elemEnc->variables[i - 1], constraintNegate(elemEnc->variables[i]));
186                 }
187         }
188         return E_False;
189 }
190
191 Edge SATEncoder::getElementValueBinaryValueConstraint(Element *element, Polarity p, uint64_t value) {
192         ASTNodeType type = element->type;
193         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
194         ElementEncoding *elemEnc = element->getElementEncoding();
195         if (elemEnc->low <= elemEnc->high) {
196                 if (value < elemEnc->low || value > elemEnc->high)
197                         return E_False;
198         } else {
199                 //Range wraps around 0
200                 if (value < elemEnc->low && value > elemEnc->high)
201                         return E_False;
202         }
203
204         uint64_t valueminusoffset = value - elemEnc->offset;
205         return generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, valueminusoffset);
206 }
207
208 void allocElementConstraintVariables(ElementEncoding *This, uint numVars) {
209         This->numVars = numVars;
210         This->variables = (Edge *)ourmalloc(sizeof(Edge) * numVars);
211 }
212
213 void SATEncoder::generateBinaryValueEncodingVars(ElementEncoding *encoding) {
214         ASSERT(encoding->type == BINARYVAL);
215         allocElementConstraintVariables(encoding, encoding->numBits);
216         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
217         if (encoding->element->anyValue)
218                 generateAnyValueBinaryValueEncoding(encoding);
219 }
220
221 void SATEncoder::generateBinaryIndexEncodingVars(ElementEncoding *encoding) {
222         ASSERT(encoding->type == BINARYINDEX);
223         allocElementConstraintVariables(encoding, NUMBITS(encoding->encArraySize - 1));
224         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
225         if (encoding->element->anyValue)
226                 generateAnyValueBinaryIndexEncoding(encoding);
227 }
228
229 void SATEncoder::generateOneHotEncodingVars(ElementEncoding *encoding) {
230         allocElementConstraintVariables(encoding, encoding->encArraySize);
231         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
232         for (uint i = 0; i < encoding->numVars; i++) {
233                 for (uint j = i + 1; j < encoding->numVars; j++) {
234                         addConstraintCNF(cnf, constraintNegate(constraintAND2(cnf, encoding->variables[i], encoding->variables[j])));
235                 }
236         }
237         if (encoding->element->anyValue)
238                 addConstraintCNF(cnf, constraintOR(cnf, encoding->numVars, encoding->variables));
239 }
240
241 void SATEncoder::generateUnaryEncodingVars(ElementEncoding *encoding) {
242         allocElementConstraintVariables(encoding, encoding->encArraySize - 1);
243         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
244         //Add unary constraint
245         for (uint i = 1; i < encoding->numVars; i++) {
246                 addConstraintCNF(cnf, constraintOR2(cnf, encoding->variables[i - 1], constraintNegate(encoding->variables[i])));
247         }
248 }
249
250 void SATEncoder::generateElementEncoding(Element *element) {
251         ElementEncoding *encoding = element->getElementEncoding();
252         ASSERT(encoding->type != ELEM_UNASSIGNED);
253         if (encoding->variables != NULL)
254                 return;
255         switch (encoding->type) {
256         case ONEHOT:
257                 generateOneHotEncodingVars(encoding);
258                 return;
259         case BINARYINDEX:
260                 generateBinaryIndexEncodingVars(encoding);
261                 return;
262         case UNARY:
263                 generateUnaryEncodingVars(encoding);
264                 return;
265         case BINARYVAL:
266                 generateBinaryValueEncodingVars(encoding);
267                 return;
268         default:
269                 ASSERT(0);
270         }
271 }
272
273 void SATEncoder::generateAnyValueBinaryIndexEncoding(ElementEncoding *encoding) {
274         if (encoding->numVars == 0)
275                 return;
276         int index = -1;
277         for (uint i = encoding->encArraySize - 1; i >= 0; i--) {
278                 if (encoding->isinUseElement(i)) {
279                         if (i + 1 < encoding->encArraySize) {
280                                 index = i + 1;
281                         }
282                         break;
283                 }
284         }
285         if ( index != -1 ) {
286                 addConstraintCNF(cnf, generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, index));
287         }
288         index = index == -1 ? encoding->encArraySize - 1 : index - 1;
289         for (int i = index; i >= 0; i--) {
290                 if (!encoding->isinUseElement(i)) {
291                         addConstraintCNF(cnf, constraintNegate( generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i)));
292                 }
293         }
294 }
295
296 void SATEncoder::generateAnyValueBinaryValueEncoding(ElementEncoding *encoding) {
297         uint64_t minvalueminusoffset = encoding->low - encoding->offset;
298         uint64_t maxvalueminusoffset = encoding->high - encoding->offset;
299         model_print("This is minvalueminus offset: %lu", minvalueminusoffset);
300         Edge lowerbound = generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, maxvalueminusoffset);
301         Edge upperbound = constraintNegate(generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, minvalueminusoffset));
302         addConstraintCNF(cnf, lowerbound);
303         addConstraintCNF(cnf, upperbound);
304 }
305