Merging + fixing memory bugs
[satune.git] / src / Backend / satelemencoder.cc
index 0efe7a138a1da5edfd310fc9784e6c751ee76286..7687f3532e6b5f84840588aa56ae4dd3df87e032 100644 (file)
@@ -5,9 +5,11 @@
 #include "element.h"
 #include "set.h"
 #include "predicate.h"
+#include "csolver.h"
+#include "tunable.h"
+#include <cmath>
 
-
-void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool & memo) {
+void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool &memo) {
        uint numParents = elem->parents.getSize();
        uint posPolarity = 0;
        uint negPolarity = 0;
@@ -15,10 +17,10 @@ void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool & memo) {
        if (elem->type == ELEMFUNCRETURN) {
                memo = true;
        }
-       for(uint i = 0; i < numParents; i++) {
-               ASTNode * node = elem->parents.get(i);
+       for (uint i = 0; i < numParents; i++) {
+               ASTNode *node = elem->parents.get(i);
                if (node->type == PREDICATEOP) {
-                       BooleanPredicate * pred = (BooleanPredicate *) node;
+                       BooleanPredicate *pred = (BooleanPredicate *) node;
                        Polarity polarity = pred->polarity;
                        FunctionEncodingType encType = pred->encoding.type;
                        bool generateNegation = encType == ENUMERATEIMPLICATIONSNEGATE;
@@ -27,49 +29,49 @@ void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool & memo) {
 
                                UndefinedBehavior undefStatus = ((PredicateTable *)pred->predicate)->undefinedbehavior;
 
-                               Polarity tpolarity=polarity;
+                               Polarity tpolarity = polarity;
                                if (generateNegation)
                                        tpolarity = negatePolarity(tpolarity);
-                               if (undefStatus ==SATC_FLAGFORCEUNDEFINED)
+                               if (undefStatus == SATC_FLAGFORCEUNDEFINED)
                                        tpolarity = P_BOTHTRUEFALSE;
                                if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
                                        memo = true;
                                if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
                                        memo = true;
                        } else if (pred->predicate->type == OPERATORPRED) {
-                                       if (pred->encoding.type == ENUMERATEIMPLICATIONS || pred->encoding.type == ENUMERATEIMPLICATIONSNEGATE) {
-                                               Polarity tpolarity = polarity;
-                                               if (generateNegation)
-                                                       tpolarity = negatePolarity(tpolarity);
-                                               PredicateOperator *predicate = (PredicateOperator *)pred->predicate;
-                                               uint numDomains = pred->inputs.getSize();
-                                               bool isConstant = true;
-                                               for (uint i = 0; i < numDomains; i++) {
-                                                       Element *e = pred->inputs.get(i);
-                                                       if (elem != e && e->type != ELEMCONST) {
-                                                               isConstant = false;
-                                                       }
+                               if (pred->encoding.type == ENUMERATEIMPLICATIONS || pred->encoding.type == ENUMERATEIMPLICATIONSNEGATE) {
+                                       Polarity tpolarity = polarity;
+                                       if (generateNegation)
+                                               tpolarity = negatePolarity(tpolarity);
+                                       PredicateOperator *predicate = (PredicateOperator *)pred->predicate;
+                                       uint numDomains = pred->inputs.getSize();
+                                       bool isConstant = true;
+                                       for (uint i = 0; i < numDomains; i++) {
+                                               Element *e = pred->inputs.get(i);
+                                               if (elem != e && e->type != ELEMCONST) {
+                                                       isConstant = false;
                                                }
-                                               if (predicate->getOp() == SATC_EQUALS) {
+                                       }
+                                       if (predicate->getOp() == SATC_EQUALS) {
+                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
+                                                       posPolarity++;
+                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
+                                                       negPolarity++;
+                                       } else {
+                                               if (isConstant) {
                                                        if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
                                                                posPolarity++;
                                                        if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
                                                                negPolarity++;
                                                } else {
-                                                       if (isConstant) {
-                                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
-                                                                       posPolarity++;
-                                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
-                                                                       negPolarity++;
-                                                       } else {
-                                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
-                                                                       memo = true;
-                                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
-                                                                       memo = true;                                                    
-                                                       }
+                                                       if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
+                                                               memo = true;
+                                                       if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
+                                                               memo = true;
                                                }
                                        }
-                               } else {
+                               }
+                       } else {
                                ASSERT(0);
                        }
                } else if (node->type == ELEMFUNCRETURN) {
@@ -130,7 +132,7 @@ Edge SATEncoder::getElementValueBinaryIndexConstraint(Element *elem, Polarity p,
                if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
                        if (elemEnc->numVars == 0)
                                return E_True;
-                       
+
                        if (elemEnc->encoding != EENC_NONE && elemEnc->numVars > 1) {
                                if (impliesPolarity(elemEnc->polarityArray[i], p)) {
                                        return elemEnc->edgeArray[i];
@@ -142,11 +144,11 @@ Edge SATEncoder::getElementValueBinaryIndexConstraint(Element *elem, Polarity p,
                                                generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_BOTHTRUEFALSE);
                                                elemEnc->polarityArray[i] = p;
                                        } else if (!impliesPolarity(elemEnc->polarityArray[i], P_TRUE)  && impliesPolarity(p, P_TRUE)) {
-                                               generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_TRUE);                      
-                                               elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i])| ((int)P_TRUE));
-                                       }       else if (!impliesPolarity(elemEnc->polarityArray[i], P_FALSE)  && impliesPolarity(p, P_FALSE)) {
-                                               generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_FALSE);                     
-                                               elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i])| ((int)P_FALSE));
+                                               generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_TRUE);
+                                               elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i]) | ((int)P_TRUE));
+                                       } else if (!impliesPolarity(elemEnc->polarityArray[i], P_FALSE)  && impliesPolarity(p, P_FALSE)) {
+                                               generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_FALSE);
+                                               elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i]) | ((int)P_FALSE));
                                        }
                                        return elemEnc->edgeArray[i];
                                }
@@ -214,16 +216,46 @@ void SATEncoder::generateBinaryValueEncodingVars(ElementEncoding *encoding) {
        ASSERT(encoding->type == BINARYVAL);
        allocElementConstraintVariables(encoding, encoding->numBits);
        getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
-       if(encoding->anyValue)
+       if (encoding->element->anyValue)
                generateAnyValueBinaryValueEncoding(encoding);
 }
 
+void SATEncoder::freezeElementVariables(ElementEncoding *encoding) {
+       ASSERT(encoding->element->frozen);
+       for (uint i = 0; i < encoding->numVars; i++) {
+               Edge e = encoding->variables[i];
+               ASSERT(edgeIsVarConst(e));
+               freezeVariable(cnf, e);
+       }
+       for(uint i=0; i< encoding->encArraySize; i++){
+               if(encoding->isinUseElement(i) && encoding->encoding != EENC_NONE && encoding->numVars > 1){
+                       Edge e = encoding->edgeArray[i];
+                       if(!edgeIsNull(e)){
+                               ASSERT(edgeIsVarConst(e));
+                               freezeVariable(cnf, e);
+                       }
+               }
+       }
+}
+
 void SATEncoder::generateBinaryIndexEncodingVars(ElementEncoding *encoding) {
        ASSERT(encoding->type == BINARYINDEX);
        allocElementConstraintVariables(encoding, NUMBITS(encoding->encArraySize - 1));
        getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
-       if(encoding->anyValue)
-               generateAnyValueBinaryIndexEncoding(encoding);
+       if (encoding->element->anyValue) {
+               uint setSize = encoding->element->getRange()->getSize();
+               int maxIndex = getMaximumUsedSize(encoding);
+               if (setSize == encoding->encArraySize && maxIndex == (int)setSize) {
+                       return;
+               }
+               double ratio = (setSize * (1 + 2 * encoding->numVars)) / (encoding->numVars * (encoding->numVars + maxIndex * 1.0 - setSize));
+//             model_print("encArraySize=%u\tmaxIndex=%d\tsetSize=%u\tmetric=%f\tnumBits=%u\n", encoding->encArraySize, maxIndex, setSize, ratio, encoding->numVars);
+               if ( ratio <  pow(2, (uint)solver->getTuner()->getTunable(MUSTVALUE, &mustValueBinaryIndex) - 3)) {
+                       generateAnyValueBinaryIndexEncodingPositive(encoding);
+               } else {
+                       generateAnyValueBinaryIndexEncoding(encoding);
+               }
+       }
 }
 
 void SATEncoder::generateOneHotEncodingVars(ElementEncoding *encoding) {
@@ -234,7 +266,7 @@ void SATEncoder::generateOneHotEncodingVars(ElementEncoding *encoding) {
                        addConstraintCNF(cnf, constraintNegate(constraintAND2(cnf, encoding->variables[i], encoding->variables[j])));
                }
        }
-       if(encoding->anyValue)
+       if (encoding->element->anyValue)
                addConstraintCNF(cnf, constraintOR(cnf, encoding->numVars, encoding->variables));
 }
 
@@ -270,40 +302,53 @@ void SATEncoder::generateElementEncoding(Element *element) {
        }
 }
 
-void SATEncoder::generateAnyValueBinaryIndexEncoding(ElementEncoding *encoding){
-       if(encoding->numVars == 0)
-               return;
-       Edge carray[encoding->numVars];
-       int size = 0;
-       int index = -1;
-       for(uint i= encoding->encArraySize-1; i>=0; i--){
-               if(encoding->isinUseElement(i)){
-                       if(i+1 < encoding->encArraySize){
-                               index = i+1;
-                       }
-                       break;
-               }
+int SATEncoder::getMaximumUsedSize(ElementEncoding *encoding) {
+       if(encoding->encArraySize == 1){
+               return 1;
+       }
+       for (int i = encoding->encArraySize - 1; i >= 0; i--) {
+               if (encoding->isinUseElement(i))
+                       return i + 1;
        }
-       if( index != -1 ){
-               carray[size++] = generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, index);
+       ASSERT(false);
+       return -1;
+}
+
+void SATEncoder::generateAnyValueBinaryIndexEncoding(ElementEncoding *encoding) {
+       if (encoding->numVars == 0)
+               return;
+       int index = getMaximumUsedSize(encoding);
+       if ( index != (int)encoding->encArraySize ) {
+               addConstraintCNF(cnf, generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, index));
        }
-       index = index == -1? encoding->encArraySize-1 : index-1;
-       for(int i= index; i>=0; i--){
-               if (!encoding->isinUseElement(i)){
-                       carray[size++] = constraintNegate( generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i));
+       for (int i = index - 1; i >= 0; i--) {
+               if (!encoding->isinUseElement(i)) {
+                       addConstraintCNF(cnf, constraintNegate( generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i)));
                }
        }
-       if(size > 0){
-               addConstraintCNF(cnf, constraintAND(cnf, size, carray));
+}
+
+void SATEncoder::generateAnyValueBinaryIndexEncodingPositive(ElementEncoding *encoding) {
+       if (encoding->numVars == 0)
+               return;
+       Edge carray[encoding->encArraySize];
+       uint size = 0;
+       for (uint i = 0; i < encoding->encArraySize; i++) {
+               if (encoding->isinUseElement(i)) {
+                       carray[size] = generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i);
+                       size++;
+               }
        }
+       addConstraintCNF(cnf, constraintOR(cnf, size, carray));
 }
 
-void SATEncoder::generateAnyValueBinaryValueEncoding(ElementEncoding *encoding){
+void SATEncoder::generateAnyValueBinaryValueEncoding(ElementEncoding *encoding) {
        uint64_t minvalueminusoffset = encoding->low - encoding->offset;
        uint64_t maxvalueminusoffset = encoding->high - encoding->offset;
        model_print("This is minvalueminus offset: %lu", minvalueminusoffset);
        Edge lowerbound = generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, maxvalueminusoffset);
        Edge upperbound = constraintNegate(generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, minvalueminusoffset));
-       addConstraintCNF(cnf, constraintAND2(cnf, lowerbound, upperbound));
+       addConstraintCNF(cnf, lowerbound);
+       addConstraintCNF(cnf, upperbound);
 }