Performance improvement
[satune.git] / src / Backend / satfuncopencoder.cc
index 0cfce9bec8814031d6f5c240c3b6ebc18c2a61e0..c2b9025f06008d56ec6014c181653a580fc6af8b 100644 (file)
@@ -22,6 +22,63 @@ Edge SATEncoder::encodeOperatorPredicateSATEncoder(BooleanPredicate *constraint)
        exit(-1);
 }
 
+Edge SATEncoder::encodeEnumEqualsPredicateSATEncoder(BooleanPredicate *constraint) {
+       Polarity polarity = constraint->polarity;
+
+       /* Call base encoders for children */
+       for (uint i = 0; i < 2; i++) {
+               Element *elem = constraint->inputs.get(i);
+               encodeElementSATEncoder(elem);
+       }
+       VectorEdge *clauses = vector;
+
+       Set *set0 = constraint->inputs.get(0)->getRange();
+       uint size0 = set0->getSize();
+
+       Set *set1 = constraint->inputs.get(1)->getRange();
+       uint size1 = set1->getSize();
+
+       uint64_t val0 = set0->getElement(0);
+       uint64_t val1 = set1->getElement(0);
+       if (size0 != 0 && size1 != 0)
+               for (uint i = 0, j = 0; true; ) {
+                       if (val0 == val1) {
+                               Edge carray[2];
+                               carray[0] = getElementValueConstraint(constraint->inputs.get(0), polarity, val0);
+                               carray[1] = getElementValueConstraint(constraint->inputs.get(1), polarity, val0);
+                               Edge term = constraintAND(cnf, 2, carray);
+                               pushVectorEdge(clauses, term);
+                               i++; j++;
+                               if (i < size0)
+                                       val0 = set0->getElement(i);
+                               else
+                                       break;
+                               if (j < size1)
+                                       val1 = set1->getElement(j);
+                               else
+                                       break;
+                       } else if (val0 < val1) {
+                               i++;
+                               if (i < size0)
+                                       val0 = set0->getElement(i);
+                               else
+                                       break;
+                       } else {
+                               j++;
+                               if (j < size1)
+                                       val1 = set1->getElement(j);
+                               else
+                                       break;
+                       }
+               }
+       if (getSizeVectorEdge(clauses) == 0) {
+               return E_False;
+       }
+       Edge cor = constraintOR(cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
+       clearVectorEdge(clauses);
+       return cor;
+}
+
 Edge SATEncoder::encodeEnumOperatorPredicateSATEncoder(BooleanPredicate *constraint) {
        PredicateOperator *predicate = (PredicateOperator *)constraint->predicate;
        uint numDomains = constraint->inputs.getSize();
@@ -31,6 +88,9 @@ Edge SATEncoder::encodeEnumOperatorPredicateSATEncoder(BooleanPredicate *constra
        if (generateNegation)
                polarity = negatePolarity(polarity);
 
+       if (!generateNegation && predicate->getOp() == SATC_EQUALS)
+               return encodeEnumEqualsPredicateSATEncoder(constraint);
+
        /* Call base encoders for children */
        for (uint i = 0; i < numDomains; i++) {
                Element *elem = constraint->inputs.get(i);
@@ -48,7 +108,7 @@ Edge SATEncoder::encodeEnumOperatorPredicateSATEncoder(BooleanPredicate *constra
        }
 
        bool notfinished = true;
-        Edge carray[numDomains];
+       Edge carray[numDomains];
        while (notfinished) {
                if (predicate->evalPredicateOperator(vals) != generateNegation) {
                        //Include this in the set of terms
@@ -110,7 +170,7 @@ void SATEncoder::encodeOperatorElementFunctionSATEncoder(ElementFunction *func)
        }
 
        bool notfinished = true;
-        Edge carray[numDomains + 1];
+       Edge carray[numDomains + 1];
        while (notfinished) {
                uint64_t result = function->applyFunctionOperator(numDomains, vals);
                bool isInRange = ((FunctionOperator *)func->getFunction())->isInRangeFunction(result);