Performance improvement
[satune.git] / src / Backend / satfuncopencoder.cc
index 9fc3c4124f4f1187836fc93b9ebaf59c458ab055..c2b9025f06008d56ec6014c181653a580fc6af8b 100644 (file)
@@ -22,47 +22,109 @@ Edge SATEncoder::encodeOperatorPredicateSATEncoder(BooleanPredicate *constraint)
        exit(-1);
 }
 
+Edge SATEncoder::encodeEnumEqualsPredicateSATEncoder(BooleanPredicate *constraint) {
+       Polarity polarity = constraint->polarity;
+
+       /* Call base encoders for children */
+       for (uint i = 0; i < 2; i++) {
+               Element *elem = constraint->inputs.get(i);
+               encodeElementSATEncoder(elem);
+       }
+       VectorEdge *clauses = vector;
+
+       Set *set0 = constraint->inputs.get(0)->getRange();
+       uint size0 = set0->getSize();
+
+       Set *set1 = constraint->inputs.get(1)->getRange();
+       uint size1 = set1->getSize();
+
+       uint64_t val0 = set0->getElement(0);
+       uint64_t val1 = set1->getElement(0);
+       if (size0 != 0 && size1 != 0)
+               for (uint i = 0, j = 0; true; ) {
+                       if (val0 == val1) {
+                               Edge carray[2];
+                               carray[0] = getElementValueConstraint(constraint->inputs.get(0), polarity, val0);
+                               carray[1] = getElementValueConstraint(constraint->inputs.get(1), polarity, val0);
+                               Edge term = constraintAND(cnf, 2, carray);
+                               pushVectorEdge(clauses, term);
+                               i++; j++;
+                               if (i < size0)
+                                       val0 = set0->getElement(i);
+                               else
+                                       break;
+                               if (j < size1)
+                                       val1 = set1->getElement(j);
+                               else
+                                       break;
+                       } else if (val0 < val1) {
+                               i++;
+                               if (i < size0)
+                                       val0 = set0->getElement(i);
+                               else
+                                       break;
+                       } else {
+                               j++;
+                               if (j < size1)
+                                       val1 = set1->getElement(j);
+                               else
+                                       break;
+                       }
+               }
+       if (getSizeVectorEdge(clauses) == 0) {
+               return E_False;
+       }
+       Edge cor = constraintOR(cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
+       clearVectorEdge(clauses);
+       return cor;
+}
+
 Edge SATEncoder::encodeEnumOperatorPredicateSATEncoder(BooleanPredicate *constraint) {
        PredicateOperator *predicate = (PredicateOperator *)constraint->predicate;
-       uint numDomains = predicate->domains.getSize();
-
+       uint numDomains = constraint->inputs.getSize();
+       Polarity polarity = constraint->polarity;
        FunctionEncodingType encType = constraint->encoding.type;
        bool generateNegation = encType == ENUMERATEIMPLICATIONSNEGATE;
+       if (generateNegation)
+               polarity = negatePolarity(polarity);
+
+       if (!generateNegation && predicate->getOp() == SATC_EQUALS)
+               return encodeEnumEqualsPredicateSATEncoder(constraint);
 
        /* Call base encoders for children */
        for (uint i = 0; i < numDomains; i++) {
                Element *elem = constraint->inputs.get(i);
                encodeElementSATEncoder(elem);
        }
-       VectorEdge *clauses = allocDefVectorEdge();     // Setup array of clauses
+       VectorEdge *clauses = vector;
 
        uint indices[numDomains];       //setup indices
        bzero(indices, sizeof(uint) * numDomains);
 
        uint64_t vals[numDomains];//setup value array
        for (uint i = 0; i < numDomains; i++) {
-               Set *set = predicate->domains.get(i);
+               Set *set = constraint->inputs.get(i)->getRange();
                vals[i] = set->getElement(indices[i]);
        }
 
        bool notfinished = true;
+       Edge carray[numDomains];
        while (notfinished) {
-               Edge carray[numDomains];
-
                if (predicate->evalPredicateOperator(vals) != generateNegation) {
                        //Include this in the set of terms
                        for (uint i = 0; i < numDomains; i++) {
                                Element *elem = constraint->inputs.get(i);
-                               carray[i] = getElementValueConstraint(elem, vals[i]);
+                               carray[i] = getElementValueConstraint(elem, polarity, vals[i]);
                        }
                        Edge term = constraintAND(cnf, numDomains, carray);
                        pushVectorEdge(clauses, term);
+                       ASSERT(getSizeVectorEdge(clauses) > 0);
                }
 
                notfinished = false;
                for (uint i = 0; i < numDomains; i++) {
                        uint index = ++indices[i];
-                       Set *set = predicate->domains.get(i);
+                       Set *set = constraint->inputs.get(i)->getRange();
 
                        if (index < set->getSize()) {
                                vals[i] = set->getElement(index);
@@ -75,11 +137,10 @@ Edge SATEncoder::encodeEnumOperatorPredicateSATEncoder(BooleanPredicate *constra
                }
        }
        if (getSizeVectorEdge(clauses) == 0) {
-               deleteVectorEdge(clauses);
                return E_False;
        }
        Edge cor = constraintOR(cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
-       deleteVectorEdge(clauses);
+       clearVectorEdge(clauses);
        return generateNegation ? constraintNegate(cor) : cor;
 }
 
@@ -108,12 +169,9 @@ void SATEncoder::encodeOperatorElementFunctionSATEncoder(ElementFunction *func)
                vals[i] = set->getElement(indices[i]);
        }
 
-       Edge overFlowConstraint = encodeConstraintSATEncoder(func->overflowstatus);
-
        bool notfinished = true;
+       Edge carray[numDomains + 1];
        while (notfinished) {
-               Edge carray[numDomains + 1];
-
                uint64_t result = function->applyFunctionOperator(numDomains, vals);
                bool isInRange = ((FunctionOperator *)func->getFunction())->isInRangeFunction(result);
                bool needClause = isInRange;
@@ -125,10 +183,10 @@ void SATEncoder::encodeOperatorElementFunctionSATEncoder(ElementFunction *func)
                        //Include this in the set of terms
                        for (uint i = 0; i < numDomains; i++) {
                                Element *elem = func->inputs.get(i);
-                               carray[i] = getElementValueConstraint(elem, vals[i]);
+                               carray[i] = getElementValueConstraint(elem, P_FALSE, vals[i]);
                        }
                        if (isInRange) {
-                               carray[numDomains] = getElementValueConstraint(func, result);
+                               carray[numDomains] = getElementValueConstraint(func, P_TRUE, result);
                        }
 
                        Edge clause;
@@ -140,6 +198,7 @@ void SATEncoder::encodeOperatorElementFunctionSATEncoder(ElementFunction *func)
                                break;
                        }
                        case SATC_FLAGFORCESOVERFLOW: {
+                               Edge overFlowConstraint = encodeConstraintSATEncoder(func->overflowstatus);
                                clause = constraintIMPLIES(cnf,constraintAND(cnf, numDomains, carray), constraintAND2(cnf, carray[numDomains], constraintNegate(overFlowConstraint)));
                                break;
                        }
@@ -147,11 +206,13 @@ void SATEncoder::encodeOperatorElementFunctionSATEncoder(ElementFunction *func)
                                if (isInRange) {
                                        clause = constraintIMPLIES(cnf, constraintAND(cnf, numDomains, carray), carray[numDomains]);
                                } else {
+                                       Edge overFlowConstraint = encodeConstraintSATEncoder(func->overflowstatus);
                                        clause = constraintIMPLIES(cnf,constraintAND(cnf, numDomains, carray), overFlowConstraint);
                                }
                                break;
                        }
                        case SATC_FLAGIFFOVERFLOW: {
+                               Edge overFlowConstraint = encodeConstraintSATEncoder(func->overflowstatus);
                                if (isInRange) {
                                        clause = constraintIMPLIES(cnf, constraintAND(cnf, numDomains, carray), constraintAND2(cnf, carray[numDomains], constraintNegate(overFlowConstraint)));
                                } else {
@@ -189,8 +250,8 @@ void SATEncoder::encodeOperatorElementFunctionSATEncoder(ElementFunction *func)
                deleteVectorEdge(clauses);
                return;
        }
-       Edge cor = constraintAND(cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
-       addConstraintCNF(cnf, cor);
+       Edge cand = constraintAND(cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
+       addConstraintCNF(cnf, cand);
        deleteVectorEdge(clauses);
 }
 
@@ -200,8 +261,8 @@ Edge SATEncoder::encodeCircuitOperatorPredicateEncoder(BooleanPredicate *constra
        encodeElementSATEncoder(elem0);
        Element *elem1 = constraint->inputs.get(1);
        encodeElementSATEncoder(elem1);
-       ElementEncoding *ee0 = getElementEncoding(elem0);
-       ElementEncoding *ee1 = getElementEncoding(elem1);
+       ElementEncoding *ee0 = elem0->getElementEncoding();
+       ElementEncoding *ee1 = elem1->getElementEncoding();
        ASSERT(ee0->numVars == ee1->numVars);
        uint numVars = ee0->numVars;
        switch (predicate->getOp()) {