Adding checks to avoid further processing on UNSAT Problems
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
index ebb8a96a5009dc28fc78ca386d21aaf4ea18b3ef..84166a88ddf6dd51773c04765f0e0107053e0a9b 100644 (file)
@@ -52,8 +52,72 @@ void EncodingGraph::buildGraph() {
        decideEdges();
 }
 
+
+void EncodingGraph::validate() {
+       SetIteratorBooleanEdge *it = solver->getConstraints();
+       while (it->hasNext()) {
+               BooleanEdge be = it->next();
+               if (be->type == PREDICATEOP) {
+                       BooleanPredicate *b = (BooleanPredicate *)be.getBoolean();
+                       if (b->predicate->type == OPERATORPRED) {
+                               PredicateOperator *predicate = (PredicateOperator *) b->predicate;
+                               if (predicate->getOp() == SATC_EQUALS) {
+                                       ASSERT(b->inputs.getSize() == 2);
+                                       Element *e1 = b->inputs.get(0);
+                                       Element *e2 = b->inputs.get(1);
+                                       if (e1->type == ELEMCONST || e1->type == ELEMCONST)
+                                               continue;
+                                       ElementEncoding *enc1 = e1->getElementEncoding();
+                                       ElementEncoding *enc2 = e2->getElementEncoding();
+                                       ASSERT(enc1->getElementEncodingType() != ELEM_UNASSIGNED);
+                                       ASSERT(enc2->getElementEncodingType() != ELEM_UNASSIGNED);
+                                       if (enc1->getElementEncodingType() == enc2->getElementEncodingType() && enc1->getElementEncodingType() == BINARYINDEX && b->getFunctionEncoding()->type == CIRCUIT) {
+                                               for (uint i = 0; i < enc1->encArraySize; i++) {
+                                                       if (enc1->isinUseElement(i)) {
+                                                               uint64_t val1 = enc1->encodingArray[i];
+                                                               if (enc2->isinUseElement(i)) {
+                                                                       ASSERT(val1 == enc2->encodingArray[i]);
+                                                               } else {
+                                                                       for (uint j = 0; j < enc2->encArraySize; j++) {
+                                                                               if (enc2->isinUseElement(j)) {
+                                                                                       ASSERT(val1 != enc2->encodingArray[j]);
+                                                                               }
+                                                                       }
+                                                               }
+                                                       }
+                                               }
+                                       }
+                                       //Now make sure that all the elements in the set are appeared in the encoding array!
+                                       for (uint k = 0; k < b->inputs.getSize(); k++) {
+                                               Element *e = b->inputs.get(k);
+                                               ElementEncoding *enc = e->getElementEncoding();
+                                               Set *s = e->getRange();
+                                               for (uint i = 0; i < s->getSize(); i++) {
+                                                       uint64_t value = s->getElement(i);
+                                                       bool exist = false;
+                                                       for (uint j = 0; j < enc->encArraySize; j++) {
+                                                               if (enc->isinUseElement(j) && enc->encodingArray[j] == value) {
+                                                                       exist = true;
+                                                                       break;
+                                                               }
+                                                       }
+                                                       ASSERT(exist);
+                                               }
+                                       }
+                               }
+                       }
+               }
+       }
+       delete it;
+}
+
+
 void EncodingGraph::encode() {
+       if (solver->isUnSAT() || solver->getTuner()->getTunable(ENCODINGGRAPHOPT, &offon) == 0)
+               return;
+       buildGraph();
        SetIteratorEncodingSubGraph *itesg = subgraphs.iterator();
+       model_print("#SubGraph = %u\n", subgraphs.getSize());
        while (itesg->hasNext()) {
                EncodingSubGraph *sg = itesg->next();
                sg->encode();
@@ -77,10 +141,11 @@ void EncodingGraph::encode() {
                                        encoding->encodingArrayInitialization();
                                } else if (encodetype == BINARYINDEX) {
                                        EncodingSubGraph *subgraph = graphMap.get(n);
-                                        DEBUG("graphMap.get(subgraph=%p, n=%p)\n", subgraph, n);
+                                       DEBUG("graphMap.get(subgraph=%p, n=%p)\n", subgraph, n);
                                        if (subgraph == NULL) {
+                                               encoding->encodingArrayInitialization();
                                                continue;
-                                        }
+                                       }
                                        uint encodingSize = subgraph->getEncodingMaxVal(n) + 1;
                                        uint paddedSize = encoding->getSizeEncodingArray(encodingSize);
                                        encoding->allocInUseArrayElement(paddedSize);
@@ -90,6 +155,7 @@ void EncodingGraph::encode() {
                                                uint64_t value = s->getElement(i);
                                                uint encodingIndex = subgraph->getEncoding(n, value);
                                                encoding->setInUseElement(encodingIndex);
+                                               ASSERT(encoding->isinUseElement(encodingIndex));
                                                encoding->encodingArray[encodingIndex] = value;
                                        }
                                }
@@ -121,8 +187,11 @@ void EncodingGraph::encodeParent(Element *e) {
                                if (left == NULL || right == NULL)
                                        return;
                                EncodingEdge *edge = getEdge(left, right, NULL);
-                               if (edge != NULL && edge->getEncoding() == EDGE_MATCH) {
-                                       fenc->setFunctionEncodingType(CIRCUIT);
+                               if (edge != NULL) {
+                                       EncodingSubGraph *leftGraph = graphMap.get(left);
+                                       if (leftGraph != NULL && leftGraph == graphMap.get(right)) {
+                                               fenc->setFunctionEncodingType(CIRCUIT);
+                                       }
                                }
                        }
                }
@@ -131,9 +200,9 @@ void EncodingGraph::encodeParent(Element *e) {
 
 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
        EncodingSubGraph *graph1 = graphMap.get(first);
-        DEBUG("graphMap.get(first=%p, graph1=%p)\n", first, graph1);
+       DEBUG("graphMap.get(first=%p, graph1=%p)\n", first, graph1);
        EncodingSubGraph *graph2 = graphMap.get(second);
-        DEBUG("graphMap.get(second=%p, graph2=%p)\n", second, graph2);
+       DEBUG("graphMap.get(second=%p, graph2=%p)\n", second, graph2);
        if (graph1 == NULL)
                first->setEncoding(BINARYINDEX);
        if (graph2 == NULL)
@@ -142,7 +211,7 @@ void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
        if (graph1 == NULL && graph2 == NULL) {
                graph1 = new EncodingSubGraph();
                subgraphs.add(graph1);
-                DEBUG("graphMap.put(first=%p, graph1=%p)\n", first, graph1);
+               DEBUG("graphMap.put(first=%p, graph1=%p)\n", first, graph1);
                graphMap.put(first, graph1);
                graph1->addNode(first);
        }
@@ -154,21 +223,24 @@ void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
                first = tmp;
        }
        if (graph1 != NULL && graph2 != NULL) {
+               if (graph1 == graph2)
+                       return;
+
                SetIteratorEncodingNode *nodeit = graph2->nodeIterator();
                while (nodeit->hasNext()) {
                        EncodingNode *node = nodeit->next();
                        graph1->addNode(node);
-                        DEBUG("graphMap.put(node=%p, graph1=%p)\n", node, graph1);
+                       DEBUG("graphMap.put(node=%p, graph1=%p)\n", node, graph1);
                        graphMap.put(node, graph1);
                }
                subgraphs.remove(graph2);
                delete nodeit;
-                DEBUG("Deleting graph2 =%p \n", graph2);
+               DEBUG("Deleting graph2 =%p \n", graph2);
                delete graph2;
        } else {
                ASSERT(graph1 != NULL && graph2 == NULL);
                graph1->addNode(second);
-                DEBUG("graphMap.put(first=%p, graph1=%p)\n", first, graph1);
+               DEBUG("graphMap.put(first=%p, graph1=%p)\n", first, graph1);
                graphMap.put(second, graph1);
        }
 }
@@ -233,7 +305,6 @@ void EncodingGraph::processPredicate(BooleanPredicate *b) {
 }
 
 uint convertSize(uint cost) {
-       cost = 1.2 * cost;// fudge factor
        return NEXTPOW2(cost);
 }
 
@@ -254,44 +325,46 @@ void EncodingGraph::decideEdges() {
                        return;
 
                EncodingSubGraph *leftGraph = graphMap.get(left);
-                DEBUG("graphMap.get(left=%p, leftgraph=%p)\n", left, leftGraph);
+               DEBUG("graphMap.get(left=%p, leftgraph=%p)\n", left, leftGraph);
                EncodingSubGraph *rightGraph = graphMap.get(right);
-                DEBUG("graphMap.get(right=%p, rightgraph=%p)\n", right, rightGraph);
+               DEBUG("graphMap.get(right=%p, rightgraph=%p)\n", right, rightGraph);
                if (leftGraph == NULL && rightGraph != NULL) {
                        EncodingNode *tmp = left; left = right; right = tmp;
                        EncodingSubGraph *tmpsg = leftGraph; leftGraph = rightGraph; rightGraph = tmpsg;
                }
 
-               uint leftSize = 0, rightSize = 0, newSize = 0;
-               uint64_t totalCost = 0;
+               uint leftSize = 0, rightSize = 0, newSize = 0, min = 0;
+               bool merge = false;
                if (leftGraph == NULL && rightGraph == NULL) {
                        leftSize = convertSize(left->getSize());
                        rightSize = convertSize(right->getSize());
                        newSize = convertSize(left->s->getUnionSize(right->s));
                        newSize = (leftSize > newSize) ? leftSize : newSize;
                        newSize = (rightSize > newSize) ? rightSize : newSize;
-                       totalCost = (newSize - leftSize) * left->elements.getSize() +
-                                                                       (newSize - rightSize) * right->elements.getSize();
+                       min = rightSize > leftSize ? leftSize : rightSize;
+                       merge = left->measureSimilarity(right) > 1.5 || min == newSize;
                } else if (leftGraph != NULL && rightGraph == NULL) {
-                       leftSize = convertSize(leftGraph->encodingSize);
+                       leftSize = convertSize(leftGraph->numValues());
                        rightSize = convertSize(right->getSize());
                        newSize = convertSize(leftGraph->estimateNewSize(right));
                        newSize = (leftSize > newSize) ? leftSize : newSize;
                        newSize = (rightSize > newSize) ? rightSize : newSize;
-                       totalCost = (newSize - leftSize) * leftGraph->numElements +
-                                                                       (newSize - rightSize) * right->elements.getSize();
+                       min = rightSize > leftSize ? leftSize : rightSize;
+                       merge = leftGraph->measureSimilarity(right) > 1.5 || min == newSize;
+//                     model_print("Merge=%s\tsimilarity=%f\n", merge?"TRUE":"FALSE", leftGraph->measureSimilarity(right));
                } else {
                        //Neither are null
-                       leftSize = convertSize(leftGraph->encodingSize);
-                       rightSize = convertSize(rightGraph->encodingSize);
+                       leftSize = convertSize(leftGraph->numValues());
+                       rightSize = convertSize(rightGraph->numValues());
                        newSize = convertSize(leftGraph->estimateNewSize(rightGraph));
+//                     model_print("MergingSubGraphs: left=%u\tright=%u\tnewSize=%u\n", leftSize, rightSize, newSize);
                        newSize = (leftSize > newSize) ? leftSize : newSize;
                        newSize = (rightSize > newSize) ? rightSize : newSize;
-                       totalCost = (newSize - leftSize) * leftGraph->numElements +
-                                                                       (newSize - rightSize) * rightGraph->numElements;
+                       min = rightSize > leftSize ? leftSize : rightSize;
+                       merge = leftGraph->measureSimilarity(rightGraph) > 1.5 || min == newSize;
+//                     model_print("Merge=%s\tsimilarity=%f\n", merge?"TRUE":"FALSE", leftGraph->measureSimilarity(rightGraph));
                }
-               double conversionfactor = 0.5;
-               if ((totalCost * conversionfactor) < eeValue) {
+               if (merge) {
                        //add the edge
                        mergeNodes(left, right);
                }
@@ -347,11 +420,32 @@ uint EncodingNode::getSize() const {
        return s->getSize();
 }
 
+uint64_t EncodingNode::getIndex(uint index) {
+       return s->getElement(index);
+}
+
 VarType EncodingNode::getType() const {
        return s->getType();
 }
 
-static TunableDesc NodeEncodingDesc(ELEM_UNASSIGNED, BINARYINDEX, ELEM_UNASSIGNED);
+double EncodingNode::measureSimilarity(EncodingNode *node) {
+       uint common = 0;
+       for (uint i = 0, j = 0; i < s->getSize() && j < node->s->getSize(); ) {
+               uint64_t item = s->getElement(i);
+               uint64_t item2 = node->s->getElement(j);
+               if (item < item2)
+                       i++;
+               else if (item2 > item)
+                       j++;
+               else {
+                       i++;
+                       j++;
+                       common++;
+               }
+       }
+
+       return common * 1.0 / s->getSize() + common * 1.0 / node->getSize();
+}
 
 EncodingNode *EncodingGraph::createNode(Element *e) {
        if (e->type == ELEMCONST)