Unary encoding of predicates
[satune.git] / src / Backend / satfuncopencoder.cc
index ff91dc31b20ef41e75c7030601acc55ef8439381..7372c951041898a94969a7ae6c60fe6ecbacf80d 100644 (file)
@@ -22,6 +22,141 @@ Edge SATEncoder::encodeOperatorPredicateSATEncoder(BooleanPredicate *constraint)
        exit(-1);
 }
 
+Edge SATEncoder::encodeEnumEqualsPredicateSATEncoder(BooleanPredicate *constraint) {
+       Polarity polarity = constraint->polarity;
+
+       /* Call base encoders for children */
+       for (uint i = 0; i < 2; i++) {
+               Element *elem = constraint->inputs.get(i);
+               encodeElementSATEncoder(elem);
+       }
+       VectorEdge *clauses = vector;
+
+       Set *set0 = constraint->inputs.get(0)->getRange();
+       uint size0 = set0->getSize();
+
+       Set *set1 = constraint->inputs.get(1)->getRange();
+       uint size1 = set1->getSize();
+
+       uint64_t val0 = set0->getElement(0);
+       uint64_t val1 = set1->getElement(0);
+       if (size0 != 0 && size1 != 0)
+               for (uint i = 0, j = 0; true; ) {
+                       if (val0 == val1) {
+                               Edge carray[2];
+                               carray[0] = getElementValueConstraint(constraint->inputs.get(0), polarity, val0);
+                               carray[1] = getElementValueConstraint(constraint->inputs.get(1), polarity, val0);
+                               Edge term = constraintAND(cnf, 2, carray);
+                               pushVectorEdge(clauses, term);
+                               i++; j++;
+                               if (i < size0)
+                                       val0 = set0->getElement(i);
+                               else
+                                       break;
+                               if (j < size1)
+                                       val1 = set1->getElement(j);
+                               else
+                                       break;
+                       } else if (val0 < val1) {
+                               i++;
+                               if (i < size0)
+                                       val0 = set0->getElement(i);
+                               else
+                                       break;
+                       } else {
+                               j++;
+                               if (j < size1)
+                                       val1 = set1->getElement(j);
+                               else
+                                       break;
+                       }
+               }
+       if (getSizeVectorEdge(clauses) == 0) {
+               return E_False;
+       }
+       Edge cor = constraintOR(cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
+       clearVectorEdge(clauses);
+       return cor;
+}
+
+Edge SATEncoder::encodeUnaryPredicateSATEncoder(BooleanPredicate *constraint) {
+       Polarity polarity = constraint->polarity;
+       PredicateOperator *predicate = (PredicateOperator *)constraint->predicate;
+       CompOp op = predicate->getOp();
+
+       /* Call base encoders for children */
+       for (uint i = 0; i < 2; i++) {
+               Element *elem = constraint->inputs.get(i);
+               encodeElementSATEncoder(elem);
+       }
+       VectorEdge *clauses = vector;
+
+       Element *elem0 = constraint->inputs.get(0);
+       Element *elem1 = constraint->inputs.get(1);
+
+       //Eliminate symmetric cases
+       if (op == SATC_GT) {
+               op = SATC_LT;
+               Element *tmp = elem0;
+               elem0 = elem1;
+               elem1 = elem0;
+       } else if (op == SATC_GTE) {
+               op = SATC_LTE;
+               Element *tmp = elem0;
+               elem0 = elem1;
+               elem1 = elem0;
+       }
+
+       Set *set0 = elem0->getRange();
+       uint size0 = set0->getSize();
+       Edge *vars0 = elem0->getElementEncoding()->variables;
+
+       Set *set1 = elem1->getRange();
+       uint size1 = set1->getSize();
+       Edge *vars1 = elem1->getElementEncoding()->variables;
+
+
+       uint64_t val0 = set0->getElement(0);
+       uint64_t val1 = set1->getElement(0);
+       if (size0 != 0 && size1 != 0) {
+               for (uint i = 0, j = 0; true; ) {
+                       if (val0 > val1 || (op == SATC_LT && val0 == val1)) {
+                               j++;
+                               if (j == size1) {
+                                       //need to assert val0 isn't this big
+                                       if (i == 0)
+                                               return E_False;//Can't satisfy this constraint
+                                       pushVectorEdge(clauses, constraintNegate(vars0[i - 1]));
+                                       break;
+                               }
+                               val1 = set1->getElement(j);
+                       } else {
+                               if (i == 0) {
+                                       if (j != 0) {
+                                               pushVectorEdge(clauses, vars1[j - 1]);
+                                       }
+                               } else {
+                                       if (j != 0) {
+                                               Edge term = constraintIMPLIES(cnf, vars0[i - 1], vars1[j - 1]);
+                                               pushVectorEdge(clauses, term);
+                                       }
+                               }
+                               i++;
+                               if (i == size0)
+                                       break;
+                               val0 = set0->getElement(i);
+                       }
+               }
+       }
+       //Trivially true constraint
+       if (getSizeVectorEdge(clauses) == 0)
+               return E_True;
+
+       Edge cand = constraintAND(cnf, getSizeVectorEdge(clauses), exposeArrayEdge(clauses));
+       clearVectorEdge(clauses);
+       return cand;
+}
+
 Edge SATEncoder::encodeEnumOperatorPredicateSATEncoder(BooleanPredicate *constraint) {
        PredicateOperator *predicate = (PredicateOperator *)constraint->predicate;
        uint numDomains = constraint->inputs.getSize();
@@ -31,6 +166,17 @@ Edge SATEncoder::encodeEnumOperatorPredicateSATEncoder(BooleanPredicate *constra
        if (generateNegation)
                polarity = negatePolarity(polarity);
 
+       CompOp op = predicate->getOp();
+       if (!generateNegation && op == SATC_EQUALS)
+               return encodeEnumEqualsPredicateSATEncoder(constraint);
+
+       if (!generateNegation && numDomains == 2 &&
+                       (op == SATC_LT || op == SATC_GT || op == SATC_LTE || op == SATC_GTE) &&
+                       constraint->inputs.get(0)->encoding.type == UNARY &&
+                       constraint->inputs.get(1)->encoding.type == UNARY) {
+               return encodeUnaryPredicateSATEncoder(constraint);
+       }
+
        /* Call base encoders for children */
        for (uint i = 0; i < numDomains; i++) {
                Element *elem = constraint->inputs.get(i);