Removing extra constraints for the unary encoding
[satune.git] / src / Backend / satelemencoder.cc
index 27d5720abc32d1beca742ef20f65b0937acb601b..b80978ec0c5a620809544b08376171d939729662 100644 (file)
@@ -4,17 +4,98 @@
 #include "ops.h"
 #include "element.h"
 #include "set.h"
+#include "predicate.h"
 
-Edge SATEncoder::getElementValueConstraint(Element *elem, uint64_t value) {
+
+void SATEncoder::shouldMemoize(Element *elem, uint64_t val, bool & memo) {
+       uint numParents = elem->parents.getSize();
+       uint posPolarity = 0;
+       uint negPolarity = 0;
+       memo = false;
+       if (elem->type == ELEMFUNCRETURN) {
+               memo = true;
+       }
+       for(uint i = 0; i < numParents; i++) {
+               ASTNode * node = elem->parents.get(i);
+               if (node->type == PREDICATEOP) {
+                       BooleanPredicate * pred = (BooleanPredicate *) node;
+                       Polarity polarity = pred->polarity;
+                       FunctionEncodingType encType = pred->encoding.type;
+                       bool generateNegation = encType == ENUMERATEIMPLICATIONSNEGATE;
+                       if (pred->predicate->type == TABLEPRED) {
+                               //Could be smarter, but just do default thing for now
+
+                               UndefinedBehavior undefStatus = ((PredicateTable *)pred->predicate)->undefinedbehavior;
+
+                               Polarity tpolarity=polarity;
+                               if (generateNegation)
+                                       tpolarity = negatePolarity(tpolarity);
+                               if (undefStatus ==SATC_FLAGFORCEUNDEFINED)
+                                       tpolarity = P_BOTHTRUEFALSE;
+                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
+                                       memo = true;
+                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
+                                       memo = true;
+                       } else if (pred->predicate->type == OPERATORPRED) {
+                                       if (pred->encoding.type == ENUMERATEIMPLICATIONS || pred->encoding.type == ENUMERATEIMPLICATIONSNEGATE) {
+                                               Polarity tpolarity = polarity;
+                                               if (generateNegation)
+                                                       tpolarity = negatePolarity(tpolarity);
+                                               PredicateOperator *predicate = (PredicateOperator *)pred->predicate;
+                                               uint numDomains = predicate->domains.getSize();
+                                               bool isConstant = true;
+                                               for (uint i = 0; i < numDomains; i++) {
+                                                       Element *e = pred->inputs.get(i);
+                                                       if (elem != e && e->type != ELEMCONST) {
+                                                               isConstant = false;
+                                                       }
+                                               }
+                                               if (predicate->getOp() == SATC_EQUALS) {
+                                                       if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
+                                                               posPolarity++;
+                                                       if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
+                                                               negPolarity++;
+                                               } else {
+                                                       if (isConstant) {
+                                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
+                                                                       posPolarity++;
+                                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
+                                                                       negPolarity++;
+                                                       } else {
+                                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_TRUE)
+                                                                       memo = true;
+                                                               if (tpolarity == P_BOTHTRUEFALSE || tpolarity == P_FALSE)
+                                                                       memo = true;                                                    
+                                                       }
+                                               }
+                                       }
+                               } else {
+                               ASSERT(0);
+                       }
+               } else if (node->type == ELEMFUNCRETURN) {
+                       //we are input to function, so memoize negative case
+                       memo = true;
+               } else {
+                       ASSERT(0);
+               }
+       }
+       if (posPolarity > 1)
+               memo = true;
+       if (negPolarity > 1)
+               memo = true;
+}
+
+
+Edge SATEncoder::getElementValueConstraint(Element *elem, Polarity p, uint64_t value) {
        switch (elem->getElementEncoding()->type) {
        case ONEHOT:
-               return getElementValueOneHotConstraint(elem, value);
+               return getElementValueOneHotConstraint(elem, p, value);
        case UNARY:
-               return getElementValueUnaryConstraint(elem, value);
+               return getElementValueUnaryConstraint(elem, p, value);
        case BINARYINDEX:
-               return getElementValueBinaryIndexConstraint(elem, value);
+               return getElementValueBinaryIndexConstraint(elem, p, value);
        case BINARYVAL:
-               return getElementValueBinaryValueConstraint(elem, value);
+               return getElementValueBinaryValueConstraint(elem, p, value);
                break;
        default:
                ASSERT(0);
@@ -23,19 +104,60 @@ Edge SATEncoder::getElementValueConstraint(Element *elem, uint64_t value) {
        return E_BOGUS;
 }
 
-Edge SATEncoder::getElementValueBinaryIndexConstraint(Element *elem, uint64_t value) {
+bool impliesPolarity(Polarity curr, Polarity goal) {
+       return (((int) curr) & ((int)goal)) == ((int) goal);
+}
+
+Edge SATEncoder::getElementValueBinaryIndexConstraint(Element *elem, Polarity p, uint64_t value) {
        ASTNodeType type = elem->type;
        ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
        ElementEncoding *elemEnc = elem->getElementEncoding();
+
+       //Check if we need to generate proxy variables
+       if (elemEnc->encoding == EENC_UNKNOWN && elemEnc->numVars > 1) {
+               bool memo = false;
+               shouldMemoize(elem, value, memo);
+               if (memo) {
+                       elemEnc->encoding = EENC_BOTH;
+                       elemEnc->polarityArray = (Polarity *) ourcalloc(1, sizeof(Polarity) * elemEnc->encArraySize);
+                       elemEnc->edgeArray = (Edge *) ourcalloc(1, sizeof(Edge) * elemEnc->encArraySize);
+               } else {
+                       elemEnc->encoding = EENC_NONE;
+               }
+       }
+
        for (uint i = 0; i < elemEnc->encArraySize; i++) {
                if (elemEnc->isinUseElement(i) && elemEnc->encodingArray[i] == value) {
-                       return (elemEnc->numVars == 0) ? E_True : generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i);
+                       if (elemEnc->numVars == 0)
+                               return E_True;
+                       
+                       if (elemEnc->encoding != EENC_NONE && elemEnc->numVars > 1) {
+                               if (impliesPolarity(elemEnc->polarityArray[i], p)) {
+                                       return elemEnc->edgeArray[i];
+                               } else {
+                                       if (edgeIsNull(elemEnc->edgeArray[i])) {
+                                               elemEnc->edgeArray[i] = constraintNewVar(cnf);
+                                       }
+                                       if (elemEnc->polarityArray[i] == P_UNDEFINED && p == P_BOTHTRUEFALSE) {
+                                               generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_BOTHTRUEFALSE);
+                                               elemEnc->polarityArray[i] = p;
+                                       } else if (!impliesPolarity(elemEnc->polarityArray[i], P_TRUE)  && impliesPolarity(p, P_TRUE)) {
+                                               generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_TRUE);                      
+                                               elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i])| ((int)P_TRUE));
+                                       }       else if (!impliesPolarity(elemEnc->polarityArray[i], P_FALSE)  && impliesPolarity(p, P_FALSE)) {
+                                               generateProxy(cnf, generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i), elemEnc->edgeArray[i], P_FALSE);                     
+                                               elemEnc->polarityArray[i] = (Polarity) (((int) elemEnc->polarityArray[i])| ((int)P_FALSE));
+                                       }
+                                       return elemEnc->edgeArray[i];
+                               }
+                       }
+                       return generateBinaryConstraint(cnf, elemEnc->numVars, elemEnc->variables, i);
                }
        }
        return E_False;
 }
 
-Edge SATEncoder::getElementValueOneHotConstraint(Element *elem, uint64_t value) {
+Edge SATEncoder::getElementValueOneHotConstraint(Element *elem, Polarity p, uint64_t value) {
        ASTNodeType type = elem->type;
        ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
        ElementEncoding *elemEnc = elem->getElementEncoding();
@@ -47,7 +169,7 @@ Edge SATEncoder::getElementValueOneHotConstraint(Element *elem, uint64_t value)
        return E_False;
 }
 
-Edge SATEncoder::getElementValueUnaryConstraint(Element *elem, uint64_t value) {
+Edge SATEncoder::getElementValueUnaryConstraint(Element *elem, Polarity p, uint64_t value) {
        ASTNodeType type = elem->type;
        ASSERT(type == ELEMSET || type == ELEMFUNCRETURN || type == ELEMCONST);
        ElementEncoding *elemEnc = elem->getElementEncoding();
@@ -66,7 +188,7 @@ Edge SATEncoder::getElementValueUnaryConstraint(Element *elem, uint64_t value) {
        return E_False;
 }
 
-Edge SATEncoder::getElementValueBinaryValueConstraint(Element *element, uint64_t value) {
+Edge SATEncoder::getElementValueBinaryValueConstraint(Element *element, Polarity p, uint64_t value) {
        ASTNodeType type = element->type;
        ASSERT(type == ELEMSET || type == ELEMFUNCRETURN);
        ElementEncoding *elemEnc = element->getElementEncoding();
@@ -92,12 +214,16 @@ void SATEncoder::generateBinaryValueEncodingVars(ElementEncoding *encoding) {
        ASSERT(encoding->type == BINARYVAL);
        allocElementConstraintVariables(encoding, encoding->numBits);
        getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
+       if(encoding->anyValue)
+               generateAnyValueBinaryValueEncoding(encoding);
 }
 
 void SATEncoder::generateBinaryIndexEncodingVars(ElementEncoding *encoding) {
        ASSERT(encoding->type == BINARYINDEX);
        allocElementConstraintVariables(encoding, NUMBITS(encoding->encArraySize - 1));
        getArrayNewVarsSATEncoder(encoding->numVars, encoding->variables);
+       if(encoding->anyValue)
+               generateAnyValueBinaryIndexEncoding(encoding);
 }
 
 void SATEncoder::generateOneHotEncodingVars(ElementEncoding *encoding) {
@@ -109,6 +235,8 @@ void SATEncoder::generateOneHotEncodingVars(ElementEncoding *encoding) {
                }
        }
        addConstraintCNF(cnf, constraintOR(cnf, encoding->numVars, encoding->variables));
+       if(encoding->anyValue)
+               generateAnyValueOneHotEncoding(encoding);
 }
 
 void SATEncoder::generateUnaryEncodingVars(ElementEncoding *encoding) {
@@ -143,3 +271,55 @@ void SATEncoder::generateElementEncoding(Element *element) {
        }
 }
 
+void SATEncoder::generateAnyValueOneHotEncoding(ElementEncoding *encoding){
+       if(encoding->numVars == 0)
+               return;
+       Edge carray[encoding->numVars];
+       int size = 0;
+       for (uint i = 0; i < encoding->encArraySize; i++) {
+               if (encoding->isinUseElement(i)){
+                       carray[size++] = encoding->variables[i];
+               }
+       }
+       if(size > 0){
+               addConstraintCNF(cnf, constraintOR(cnf, size, carray));
+       }
+}
+
+void SATEncoder::generateAnyValueBinaryIndexEncoding(ElementEncoding *encoding){
+       if(encoding->numVars == 0)
+               return;
+       Edge carray[encoding->numVars];
+       int size = 0;
+       int index = -1;
+       for(uint i= encoding->encArraySize-1; i>=0; i--){
+               if(encoding->isinUseElement(i)){
+                       if(i+1 < encoding->encArraySize){
+                               index = i+1;
+                       }
+                       break;
+               }
+       }
+       if( index != -1 ){
+               carray[size++] = generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, index);
+       }
+       index = index == -1? encoding->encArraySize-1 : index-1;
+       for(int i= index; i>=0; i--){
+               if (!encoding->isinUseElement(i)){
+                       carray[size++] = constraintNegate( generateBinaryConstraint(cnf, encoding->numVars, encoding->variables, i));
+               }
+       }
+       if(size > 0){
+               addConstraintCNF(cnf, constraintAND(cnf, size, carray));
+       }
+}
+
+void SATEncoder::generateAnyValueBinaryValueEncoding(ElementEncoding *encoding){
+       uint64_t minvalueminusoffset = encoding->low - encoding->offset;
+       uint64_t maxvalueminusoffset = encoding->high - encoding->offset;
+       model_print("This is minvalueminus offset: %lu", minvalueminusoffset);
+       Edge lowerbound = generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, maxvalueminusoffset);
+       Edge upperbound = constraintNegate(generateLTValueConstraint(cnf, encoding->numVars, encoding->variables, minvalueminusoffset));
+       addConstraintCNF(cnf, constraintAND2(cnf, lowerbound, upperbound));
+}
+