Edits
[satune.git] / src / Backend / satelemencoder.cc
1 #include "satencoder.h"
2 #include "structs.h"
3 #include "common.h"
4 #include "ops.h"
5 #include "element.h"
6 #include "set.h"
7 #include "predicate.h"
8
9
10 void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool & memo) {
11         uint numParents = elem->parents.getSize();
12         uint posPolarity = 0;
13         uint negPolarity = 0;
14         memo = false;
15         if (elem->type == ELEMFUNCRETURN) {
16                 memo = true;
17         }
18         for(uint i = 0; i < numParents; i++) {
19                 ASTNode * node = elem->parents.get(i);
20                 if (node->type == PREDICATEOP) {
21                         BooleanPredicate * pred = (BooleanPredicate *) node;
22                         Polarity polarity = pred->polarity;
23                         FunctionEncodingType encType = pred->encoding.type;
24                         bool generateNegation = encType == ENUMERATEIMPLICATIONSNEGATE;
25                         if (pred->predicate->type == TABLEPRED) {
26                                 //Could be smarter, but just do default thing for now
27
28                                 UndefinedBehavior undefStatus = ((PredicateTable *)pred->predicate)->undefinedbehavior;
29
30                                 Polarity tpolarity=polarity;
31                                 if (generateNegation)
32                                         tpolarity = negatePolarity(tpolarity);
33                                 if (undefStatus ==SATC_FLAGFORCEUNDEFINED)
34                                         tpolarity = P_BOTHTRUEFALSE;
35                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
36                                         memo = true;
37                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
38                                         memo = true;
39                         } else if (pred->predicate->type == OPERATORPRED) {
40                                         if (pred->encoding.type == ENUMERATEIMPLICATIONS || pred->encoding.type == ENUMERATEIMPLICATIONSNEGATE) {
41                                                 Polarity tpolarity = polarity;
42                                                 if (generateNegation)
43                                                         tpolarity = negatePolarity(tpolarity);
44                                                 PredicateOperator *predicate = (PredicateOperator *)pred->predicate;
45                                                 uint numDomains = predicate->domains.getSize();
46                                                 bool isConstant = true;
47                                                 for (uint i = 0; i < numDomains; i++) {
48                                                         Element *e = pred->inputs.get(i);
49                                                         if (elem != e && e->type != ELEMCONST) {
50                                                                 isConstant = false;
51                                                         }
52                                                 }
53                                                 if (predicate->getOp() == SATC_EQUALS) {
54                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
55                                                                 posPolarity++;
56                                                         if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
57                                                                 negPolarity++;
58                                                 } else {
59                                                         if (isConstant) {
60                                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
61                                                                         posPolarity++;
62                                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
63                                                                         negPolarity++;
64                                                         } else {
65                                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
66                                                                         memo = true;
67                                                                 if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
68                                                                         memo = true;                                                    
69                                                         }
70                                                 }
71                                         }
72                                 } else {
73                                 ASSERT(0);
74                         }
75                 } else if (node->type == ELEMFUNCRETURN) {
76                         //we are input to function, so memoize negative case
77                         memo = true;
78                 } else {
79                         ASSERT(0);
80                 }
81         }
82         if (posPolarity > 1)
83                 memo = true;
84         if (negPolarity > 1)
85                 memo = true;
86 }
87
88
89 Edge SATEncoder::getElementValueConstraint(Element *elem, Polarity p, uint64_t value) {
90         switch (elem->getElementEncoding()->type) {
91         case ONEHOT:
92                 return getElementValueOneHotConstraint(elem, p, value);
93         case UNARY:
94                 return getElementValueUnaryConstraint(elem, p, value);
95         case BINARYINDEX:
96                 return getElementValueBinaryIndexConstraint(elem, p, value);
97         case BINARYVAL:
98                 return getElementValueBinaryValueConstraint(elem, p, value);
99                 break;
100         default:
101                 ASSERT(0);
102                 break;
103         }
104         return E_BOGUS;
105 }
106
107 bool impliesPolarity(Polarity curr, Polarity goal) {
108         return (((int) curr) & ((int)goal)) == ((int) goal);
109 }
110
111 Edge SATEncoder::getElementValueBinaryIndexConstraint(Element *elem, Polarity p, uint64_t value) {
112         ASTNodeType type = elem->type;
113         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
114         ElementEncoding *elemEnc = elem->getElementEncoding();
115
116         //Check if we need to generate proxy variables
117         if (elemEnc->encoding == EENC_UNKNOWN && elemEnc->numVars > 1) {
118                 bool memo = false;
119                 shouldMemoize(elem, value, memo);
120                 if (memo) {
121                         elemEnc->encoding = EENC_BOTH;
122                         elemEnc->polarityArray = (Polarity *) ourcalloc(1, sizeof(Polarity) * elemEnc->encArraySize);
123                         elemEnc->edgeArray = (Edge *) ourcalloc(1, sizeof(Edge) * elemEnc->encArraySize);
124                 } else {
125                         elemEnc->encoding = EENC_NONE;
126                 }
127         }
128
129         for (uint i = 0; i < elemEnc->encArraySize; i++) {
130                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
131                         if (elemEnc->numVars == 0)
132                                 return E_True;
133                         
134                         if (elemEnc->encoding != EENC_NONE && elemEnc->numVars > 1) {
135                                 if (impliesPolarity(elemEnc->polarityArray[i], p)) {
136                                         return elemEnc->edgeArray[i];
137                                 } else {
138                                         if (edgeIsNull(elemEnc->edgeArray[i])) {
139                                                 elemEnc->edgeArray[i] = constraintNewVar(cnf);
140                                         }
141                                         if (elemEnc->polarityArray[i] == P_UNDEFINED && p == P_BOTHTRUEFALSE) {
142                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_BOTHTRUEFALSE);
143                                                 elemEnc->polarityArray[i] = p;
144                                         } else if (!impliesPolarity(elemEnc->polarityArray[i], P_TRUE)  && impliesPolarity(p, P_TRUE)) {
145                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_TRUE);                      
146                                                 elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i])| ((int)P_TRUE));
147                                         }       else if (!impliesPolarity(elemEnc->polarityArray[i], P_FALSE)  && impliesPolarity(p, P_FALSE)) {
148                                                 generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_FALSE);                     
149                                                 elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i])| ((int)P_FALSE));
150                                         }
151                                         return elemEnc->edgeArray[i];
152                                 }
153                         }
154                         return generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i);
155                 }
156         }
157         return E_False;
158 }
159
160 Edge SATEncoder::getElementValueOneHotConstraint(Element *elem, Polarity p, uint64_t value) {
161         ASTNodeType type = elem->type;
162         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
163         ElementEncoding *elemEnc = elem->getElementEncoding();
164         for (uint i = 0; i < elemEnc->encArraySize; i++) {
165                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
166                         return (elemEnc->numVars == 0) ? E_True : elemEnc->variables[i];
167                 }
168         }
169         return E_False;
170 }
171
172 Edge SATEncoder::getElementValueUnaryConstraint(Element *elem, Polarity p, uint64_t value) {
173         ASTNodeType type = elem->type;
174         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
175         ElementEncoding *elemEnc = elem->getElementEncoding();
176         for (uint i = 0; i < elemEnc->encArraySize; i++) {
177                 if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
178                         if (elemEnc->numVars == 0)
179                                 return E_True;
180                         if (i == 0)
181                                 return constraintNegate(elemEnc->variables[0]);
182                         else if ((i + 1) == elemEnc->encArraySize)
183                                 return elemEnc->variables[i - 1];
184                         else
185                                 return constraintAND2(cnf, elemEnc->variables[i - 1], constraintNegate(elemEnc->variables[i]));
186                 }
187         }
188         return E_False;
189 }
190
191 Edge SATEncoder::getElementValueBinaryValueConstraint(Element *element, Polarity p, uint64_t value) {
192         ASTNodeType type = element->type;
193         ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
194         ElementEncoding *elemEnc = element->getElementEncoding();
195         if (elemEnc->low <= elemEnc->high) {
196                 if (value < elemEnc->low || value > elemEnc->high)
197                         return E_False;
198         } else {
199                 //Range wraps around 0
200                 if (value < elemEnc->low && value > elemEnc->high)
201                         return E_False;
202         }
203
204         uint64_t valueminusoffset = value - elemEnc->offset;
205         return generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, valueminusoffset);
206 }
207
208 void allocElementConstraintVariables(ElementEncoding *This, uint numVars) {
209         This->numVars = numVars;
210         This->variables = (Edge *)ourmalloc(sizeof(Edge) * numVars);
211 }
212
213 void SATEncoder::generateBinaryValueEncodingVars(ElementEncoding *encoding) {
214         ASSERT(encoding->type == BINARYVAL);
215         allocElementConstraintVariables(encoding, encoding->numBits);
216         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
217 }
218
219 void SATEncoder::generateBinaryIndexEncodingVars(ElementEncoding *encoding) {
220         ASSERT(encoding->type == BINARYINDEX);
221         allocElementConstraintVariables(encoding, NUMBITS(encoding->encArraySize - 1));
222         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
223 }
224
225 void SATEncoder::generateOneHotEncodingVars(ElementEncoding *encoding) {
226         allocElementConstraintVariables(encoding, encoding->encArraySize);
227         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
228         for (uint i = 0; i < encoding->numVars; i++) {
229                 for (uint j = i + 1; j < encoding->numVars; j++) {
230                         addConstraintCNF(cnf, constraintNegate(constraintAND2(cnf, encoding->variables[i], encoding->variables[j])));
231                 }
232         }
233         addConstraintCNF(cnf, constraintOR(cnf, encoding->numVars, encoding->variables));
234 }
235
236 void SATEncoder::generateUnaryEncodingVars(ElementEncoding *encoding) {
237         allocElementConstraintVariables(encoding, encoding->encArraySize - 1);
238         getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
239         //Add unary constraint
240         for (uint i = 1; i < encoding->numVars; i++) {
241                 addConstraintCNF(cnf, constraintOR2(cnf, encoding->variables[i - 1], constraintNegate(encoding->variables[i])));
242         }
243 }
244
245 void SATEncoder::generateElementEncoding(Element *element) {
246         ElementEncoding *encoding = element->getElementEncoding();
247         ASSERT(encoding->type != ELEM_UNASSIGNED);
248         if (encoding->variables != NULL)
249                 return;
250         switch (encoding->type) {
251         case ONEHOT:
252                 generateOneHotEncodingVars(encoding);
253                 return;
254         case BINARYINDEX:
255                 generateBinaryIndexEncodingVars(encoding);
256                 return;
257         case UNARY:
258                 generateUnaryEncodingVars(encoding);
259                 return;
260         case BINARYVAL:
261                 generateBinaryValueEncodingVars(encoding);
262                 return;
263         default:
264                 ASSERT(0);
265         }
266 }
267