Adding a directed search based config for the tuner
[satune.git] / src / Backend / satelemencoder.cc
index 3266f1ca9e5f3be1d9473c3a3be1ffa969d88400..8f06eecad991e3d68d9b83e5d5b47f3bba3a4f6b 100644 (file)
@@ -5,9 +5,10 @@
 #include "element.h"
 #include "set.h"
 #include "predicate.h"
+#include "csolver.h"
+#include "tunable.h"
 
-
-void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool & memo) {
+void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool &memo) {
        uint numParents = elem->parents.getSize();
        uint posPolarity = 0;
        uint negPolarity = 0;
@@ -15,10 +16,10 @@ void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool & memo) {
        if (elem->type == ELEMFUNCRETURN) {
                memo = true;
        }
-       for(uint i = 0; i < numParents; i++) {
-               ASTNode * node = elem->parents.get(i);
+       for (uint i = 0; i < numParents; i++) {
+               ASTNode *node = elem->parents.get(i);
                if (node->type == PREDICATEOP) {
-                       BooleanPredicate * pred = (BooleanPredicate *) node;
+                       BooleanPredicate *pred = (BooleanPredicate *) node;
                        Polarity polarity = pred->polarity;
                        FunctionEncodingType encType = pred->encoding.type;
                        bool generateNegation = encType == ENUMERATEIMPLICATIONSNEGATE;
@@ -27,49 +28,49 @@ void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool & memo) {
 
                                UndefinedBehavior undefStatus = ((PredicateTable *)pred->predicate)->undefinedbehavior;
 
-                               Polarity tpolarity=polarity;
+                               Polarity tpolarity = polarity;
                                if (generateNegation)
                                        tpolarity = negatePolarity(tpolarity);
-                               if (undefStatus ==SATC_FLAGFORCEUNDEFINED)
+                               if (undefStatus == SATC_FLAGFORCEUNDEFINED)
                                        tpolarity = P_BOTHTRUEFALSE;
                                if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
                                        memo = true;
                                if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
                                        memo = true;
                        } else if (pred->predicate->type == OPERATORPRED) {
-                                       if (pred->encoding.type == ENUMERATEIMPLICATIONS || pred->encoding.type == ENUMERATEIMPLICATIONSNEGATE) {
-                                               Polarity tpolarity = polarity;
-                                               if (generateNegation)
-                                                       tpolarity = negatePolarity(tpolarity);
-                                               PredicateOperator *predicate = (PredicateOperator *)pred->predicate;
-                                               uint numDomains = pred->inputs.getSize();
-                                               bool isConstant = true;
-                                               for (uint i = 0; i < numDomains; i++) {
-                                                       Element *e = pred->inputs.get(i);
-                                                       if (elem != e && e->type != ELEMCONST) {
-                                                               isConstant = false;
-                                                       }
+                               if (pred->encoding.type == ENUMERATEIMPLICATIONS || pred->encoding.type == ENUMERATEIMPLICATIONSNEGATE) {
+                                       Polarity tpolarity = polarity;
+                                       if (generateNegation)
+                                               tpolarity = negatePolarity(tpolarity);
+                                       PredicateOperator *predicate = (PredicateOperator *)pred->predicate;
+                                       uint numDomains = pred->inputs.getSize();
+                                       bool isConstant = true;
+                                       for (uint i = 0; i < numDomains; i++) {
+                                               Element *e = pred->inputs.get(i);
+                                               if (elem != e && e->type != ELEMCONST) {
+                                                       isConstant = false;
                                                }
-                                               if (predicate->getOp() == SATC_EQUALS) {
+                                       }
+                                       if (predicate->getOp() == SATC_EQUALS) {
+                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
+                                                       posPolarity++;
+                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
+                                                       negPolarity++;
+                                       } else {
+                                               if (isConstant) {
                                                        if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
                                                                posPolarity++;
                                                        if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
                                                                negPolarity++;
                                                } else {
-                                                       if (isConstant) {
-                                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
-                                                                       posPolarity++;
-                                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
-                                                                       negPolarity++;
-                                                       } else {
-                                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
-                                                                       memo = true;
-                                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
-                                                                       memo = true;                                                    
-                                                       }
+                                                       if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
+                                                               memo = true;
+                                                       if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
+                                                               memo = true;
                                                }
                                        }
-                               } else {
+                               }
+                       } else {
                                ASSERT(0);
                        }
                } else if (node->type == ELEMFUNCRETURN) {
@@ -130,7 +131,7 @@ Edge SATEncoder::getElementValueBinaryIndexConstraint(Element *elem, Polarity p,
                if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
                        if (elemEnc->numVars == 0)
                                return E_True;
-                       
+
                        if (elemEnc->encoding != EENC_NONE && elemEnc->numVars > 1) {
                                if (impliesPolarity(elemEnc->polarityArray[i], p)) {
                                        return elemEnc->edgeArray[i];
@@ -142,11 +143,11 @@ Edge SATEncoder::getElementValueBinaryIndexConstraint(Element *elem, Polarity p,
                                                generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_BOTHTRUEFALSE);
                                                elemEnc->polarityArray[i] = p;
                                        } else if (!impliesPolarity(elemEnc->polarityArray[i], P_TRUE)  && impliesPolarity(p, P_TRUE)) {
-                                               generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_TRUE);                      
-                                               elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i])| ((int)P_TRUE));
-                                       }       else if (!impliesPolarity(elemEnc->polarityArray[i], P_FALSE)  && impliesPolarity(p, P_FALSE)) {
-                                               generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_FALSE);                     
-                                               elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i])| ((int)P_FALSE));
+                                               generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_TRUE);
+                                               elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i]) | ((int)P_TRUE));
+                                       } else if (!impliesPolarity(elemEnc->polarityArray[i], P_FALSE)  && impliesPolarity(p, P_FALSE)) {
+                                               generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_FALSE);
+                                               elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i]) | ((int)P_FALSE));
                                        }
                                        return elemEnc->edgeArray[i];
                                }
@@ -214,7 +215,7 @@ void SATEncoder::generateBinaryValueEncodingVars(ElementEncoding *encoding) {
        ASSERT(encoding->type == BINARYVAL);
        allocElementConstraintVariables(encoding, encoding->numBits);
        getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
-       if(encoding->anyValue)
+       if (encoding->element->anyValue)
                generateAnyValueBinaryValueEncoding(encoding);
 }
 
@@ -222,8 +223,15 @@ void SATEncoder::generateBinaryIndexEncodingVars(ElementEncoding *encoding) {
        ASSERT(encoding->type == BINARYINDEX);
        allocElementConstraintVariables(encoding, NUMBITS(encoding->encArraySize - 1));
        getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
-       if(encoding->anyValue)
-               generateAnyValueBinaryIndexEncoding(encoding);
+       if (encoding->element->anyValue){
+               uint setSize = encoding->element->getRange()->getSize();
+               uint encArraySize = encoding->encArraySize;
+               if(setSize < encArraySize * (uint)solver->getTuner()->getTunable(MUSTVALUE, &mustValueBinaryIndex)/10){
+                       generateAnyValueBinaryIndexEncodingPositive(encoding);
+               } else {
+                       generateAnyValueBinaryIndexEncoding(encoding);
+               }
+       }
 }
 
 void SATEncoder::generateOneHotEncodingVars(ElementEncoding *encoding) {
@@ -234,9 +242,8 @@ void SATEncoder::generateOneHotEncodingVars(ElementEncoding *encoding) {
                        addConstraintCNF(cnf, constraintNegate(constraintAND2(cnf, encoding->variables[i], encoding->variables[j])));
                }
        }
-       addConstraintCNF(cnf, constraintOR(cnf, encoding->numVars, encoding->variables));
-       if(encoding->anyValue)
-               generateAnyValueOneHotEncoding(encoding);
+       if (encoding->element->anyValue)
+               addConstraintCNF(cnf, constraintOR(cnf, encoding->numVars, encoding->variables));
 }
 
 void SATEncoder::generateUnaryEncodingVars(ElementEncoding *encoding) {
@@ -271,55 +278,50 @@ void SATEncoder::generateElementEncoding(Element *element) {
        }
 }
 
-void SATEncoder::generateAnyValueOneHotEncoding(ElementEncoding *encoding){
-       if(encoding->numVars == 0)
-               return;
-       Edge carray[encoding->numVars];
-       int size = 0;
-       for (uint i = 0; i < encoding->encArraySize; i++) {
-               if (encoding->isinUseElement(i)){
-                       carray[size++] = encoding->variables[i];
-               }
-       }
-       if(size > 0){
-               addConstraintCNF(cnf, constraintOR(cnf, size, carray));
-       }
-}
-
-void SATEncoder::generateAnyValueBinaryIndexEncoding(ElementEncoding *encoding){
-       if(encoding->numVars == 0)
+void SATEncoder::generateAnyValueBinaryIndexEncoding(ElementEncoding *encoding) {
+       if (encoding->numVars == 0)
                return;
-       Edge carray[encoding->numVars];
-       int size = 0;
        int index = -1;
-       for(uint i= encoding->encArraySize-1; i>=0; i--){
-               if(encoding->isinUseElement(i)){
-                       if(i+1 < encoding->encArraySize){
-                               index = i+1;
+       for (uint i = encoding->encArraySize - 1; i >= 0; i--) {
+               if (encoding->isinUseElement(i)) {
+                       if (i + 1 < encoding->encArraySize) {
+                               index = i + 1;
                        }
                        break;
                }
        }
-       if( index != -1 ){
-               carray[size++] = generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, index);
+       if ( index != -1 ) {
+               addConstraintCNF(cnf, generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, index));
        }
-       index = index == -1? encoding->encArraySize-1 : index-1;
-       for(int i= index; i>=0; i--){
-               if (!encoding->isinUseElement(i)){
-                       carray[size++] = constraintNegate( generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i));
+       index = index == -1 ? encoding->encArraySize - 1 : index - 1;
+       for (int i = index; i >= 0; i--) {
+               if (!encoding->isinUseElement(i)) {
+                       addConstraintCNF(cnf, constraintNegate( generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i)));
                }
        }
-       if(size > 0){
-               addConstraintCNF(cnf, constraintAND(cnf, size, carray));
+}
+
+void SATEncoder::generateAnyValueBinaryIndexEncodingPositive(ElementEncoding *encoding) {
+       if (encoding->numVars == 0)
+               return;
+       Edge carray[encoding->encArraySize];
+       uint size = 0;
+       for (uint i = 0; i < encoding->encArraySize; i++) {
+               if (encoding->isinUseElement(i)) {
+                       carray[size] = generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i);
+                       size++;
+               }
        }
+       addConstraintCNF(cnf, constraintOR(cnf, size, carray));
 }
 
-void SATEncoder::generateAnyValueBinaryValueEncoding(ElementEncoding *encoding){
+void SATEncoder::generateAnyValueBinaryValueEncoding(ElementEncoding *encoding) {
        uint64_t minvalueminusoffset = encoding->low - encoding->offset;
        uint64_t maxvalueminusoffset = encoding->high - encoding->offset;
        model_print("This is minvalueminus offset: %lu", minvalueminusoffset);
        Edge lowerbound = generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, maxvalueminusoffset);
        Edge upperbound = constraintNegate(generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, minvalueminusoffset));
-       addConstraintCNF(cnf, constraintAND2(cnf, lowerbound, upperbound));
+       addConstraintCNF(cnf, lowerbound);
+       addConstraintCNF(cnf, upperbound);
 }