edits
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
index 4fdc1c657951f587aa31206c3242a0fcdea6d1fc..9b446faba24a59be41187379bd96bf4a03762b1a 100644 (file)
@@ -7,6 +7,8 @@
 #include "csolver.h"
 #include "tunable.h"
 #include "qsort.h"
+#include "subgraph.h"
+#include "elementencoding.h"
 
 EncodingGraph::EncodingGraph(CSolver * _solver) :
        solver(_solver) {
@@ -41,6 +43,79 @@ void EncodingGraph::buildGraph() {
                }
        }
        bsdqsort(edgeVector.expose(), edgeVector.getSize(), sizeof(EncodingEdge *), sortEncodingEdge);
+       decideEdges();
+}
+
+void EncodingGraph::encode() {
+       SetIteratorEncodingSubGraph * itesg=subgraphs.iterator();
+       while(itesg->hasNext()) {
+               EncodingSubGraph *sg=itesg->next();
+               sg->encode();
+       }
+       delete itesg;
+
+       ElementIterator it(solver);
+       while(it.hasNext()) {
+               Element * e = it.next();
+               switch(e->type) {
+               case ELEMSET:
+               case ELEMFUNCRETURN: {
+                       ElementEncoding *encoding=getElementEncoding(e);
+                       if (encoding->getElementEncodingType() == ELEM_UNASSIGNED) {
+                               EncodingNode *n = getNode(e);
+                               ASSERT(n != NULL);
+                               ElementEncodingType encodetype=n->getEncoding();
+                               encoding->setElementEncodingType(encodetype);
+                               if (encodetype == UNARY || encodetype == ONEHOT) {
+                                       encoding->encodingArrayInitialization();
+                               } else if (encodetype == BINARYINDEX) {
+                                       EncodingSubGraph * subgraph = graphMap.get(n);
+                                       uint encodingSize = subgraph->getEncodingSize(n);
+                                       uint paddedSize = encoding->getSizeEncodingArray(encodingSize);
+                                       encoding->allocInUseArrayElement(paddedSize);
+                                       encoding->allocEncodingArrayElement(paddedSize);
+                                       Set * s=e->getRange();
+                                       for(uint i=0;i<s->getSize();i++) {
+                                               uint64_t value=s->getElement(i);
+                                               uint encodingIndex=subgraph->getEncoding(n, value);
+                                               encoding->setInUseElement(encodingIndex);
+                                               encoding->encodingArray[encodingIndex] = value;
+                                       }
+                               }
+                       }
+                       break;
+               }
+               default:
+                       break;
+               }
+               encodeParent(e);
+       }
+}
+
+void EncodingGraph::encodeParent(Element *e) {
+       uint size=e->parents.getSize();
+       for(uint i=0;i<size;i++) {
+               ASTNode * n = e->parents.get(i);
+               if (n->type==PREDICATEOP) {
+                       BooleanPredicate *b=(BooleanPredicate *)n;
+                       FunctionEncoding *fenc=b->getFunctionEncoding();
+                       if (fenc->getFunctionEncodingType() != FUNC_UNASSIGNED)
+                               continue;
+                       Predicate *p=b->getPredicate();
+                       if (p->type==OPERATORPRED) {
+                               PredicateOperator *po=(PredicateOperator *)p;
+                               ASSERT(b->inputs.getSize()==2);
+                               EncodingNode *left=createNode(b->inputs.get(0));
+                               EncodingNode *right=createNode(b->inputs.get(1));
+                               if (left == NULL || right == NULL)
+                                       return;
+                               EncodingEdge *edge=getEdge(left, right, NULL);
+                               if (edge != NULL && edge->getEncoding() == EDGE_MATCH) {
+                                       fenc->setFunctionEncodingType(CIRCUIT);
+                               }
+                       }
+               }
+       }
 }
 
 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
@@ -48,6 +123,7 @@ void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
        EncodingSubGraph *graph2=graphMap.get(second);
        if (graph1 == NULL && graph2 == NULL) {
                graph1 = new EncodingSubGraph();
+               subgraphs.add(graph1);
                graphMap.put(first, graph1);
                graph1->addNode(first);
        }
@@ -65,6 +141,7 @@ void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
                        graph1->addNode(node);
                        graphMap.put(node, graph1);
                }
+               subgraphs.remove(graph2);
                delete nodeit;
                delete graph2;
        } else {
@@ -101,7 +178,7 @@ void EncodingGraph::processFunction(ElementFunction *ef) {
                if (left == NULL && right == NULL)
                        return;
                EncodingNode *dst=createNode(ef);
-               EncodingEdge *edge=getEdge(left, right, dst);
+               EncodingEdge *edge=createEdge(left, right, dst);
                edge->numArithOps++;
        }
 }
@@ -115,7 +192,7 @@ void EncodingGraph::processPredicate(BooleanPredicate *b) {
                EncodingNode *right=createNode(b->inputs.get(1));
                if (left == NULL || right == NULL)
                        return;
-               EncodingEdge *edge=getEdge(left, right, NULL);
+               EncodingEdge *edge=createEdge(left, right, NULL);
                CompOp op=po->getOp();
                switch(op) {
                case SATC_EQUALS:
@@ -142,14 +219,18 @@ void EncodingGraph::decideEdges() {
        uint size=edgeVector.getSize();
        for(uint i=0; i<size; i++) {
                EncodingEdge *ee = edgeVector.get(i);
-               if (ee->encoding != EDGE_UNASSIGNED)
+               EncodingNode *left = ee->left;
+               EncodingNode *right = ee->right;
+               
+               if (ee->encoding != EDGE_UNASSIGNED ||
+                               left->encoding != BINARYINDEX ||
+                               right->encoding != BINARYINDEX)
                        continue;
                
                uint64_t eeValue = ee->getValue();
                if (eeValue == 0)
                        return;
-               EncodingNode *left = ee->left;
-               EncodingNode *right = ee->right;
+
                EncodingSubGraph *leftGraph = graphMap.get(left);
                EncodingSubGraph *rightGraph = graphMap.get(right);
                if (leftGraph == NULL && rightGraph !=NULL) {
@@ -196,6 +277,12 @@ void EncodingGraph::decideEdges() {
 static TunableDesc EdgeEncodingDesc(EDGE_UNASSIGNED, EDGE_MATCH, EDGE_UNASSIGNED);
 
 EncodingEdge * EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
+       EncodingEdge e(left, right, dst);
+       EncodingEdge *result = edgeMap.get(&e);
+       return result;
+}
+
+EncodingEdge * EncodingGraph::createEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
        EncodingEdge e(left, right, dst);
        EncodingEdge *result = edgeMap.get(&e);
        if (result == NULL) {
@@ -256,6 +343,14 @@ EncodingNode * EncodingGraph::createNode(Element *e) {
        return n;
 }
 
+EncodingNode * EncodingGraph::getNode(Element *e) {
+       if (e->type == ELEMCONST)
+               return NULL;
+       Set *s = e->getRange();
+       EncodingNode *n = encodingMap.get(s);
+       return n;
+}
+
 void EncodingNode::addElement(Element *e) {
        elements.add(e);
 }
@@ -298,57 +393,4 @@ uint64_t EncodingEdge::getValue() const {
        return numEquals * min + numComparisons * lSize * rSize;
 }
 
-EncodingSubGraph::EncodingSubGraph() :
-       encodingSize(0),
-       numElements(0) {
-}
-
-uint EncodingSubGraph::estimateNewSize(EncodingSubGraph *sg) {
-       uint newSize=0;
-       SetIteratorEncodingNode * nit = sg->nodes.iterator();
-       while(nit->hasNext()) {
-               EncodingNode *en = nit->next();
-               uint size=estimateNewSize(en);
-               if (size > newSize)
-                       newSize = size;
-       }
-       delete nit;
-       return newSize;
-}
 
-uint EncodingSubGraph::estimateNewSize(EncodingNode *n) {
-       SetIteratorEncodingEdge * eeit = n->edges.iterator();
-       uint newsize=n->getSize();
-       while(eeit->hasNext()) {
-               EncodingEdge * ee = eeit->next();
-               if (ee->left != NULL && ee->left != n && nodes.contains(ee->left)) {
-                       uint intersectSize = n->s->getUnionSize(ee->left->s);
-                       if (intersectSize > newsize)
-                               newsize = intersectSize;
-               }
-               if (ee->right != NULL && ee->right != n && nodes.contains(ee->right)) {
-                       uint intersectSize = n->s->getUnionSize(ee->right->s);
-                       if (intersectSize > newsize)
-                               newsize = intersectSize;
-               }
-               if (ee->dst != NULL && ee->dst != n && nodes.contains(ee->dst)) {
-                       uint intersectSize = n->s->getUnionSize(ee->dst->s);
-                       if (intersectSize > newsize)
-                               newsize = intersectSize;
-               }
-       }
-       delete eeit;
-       return newsize;
-}
-
-void EncodingSubGraph::addNode(EncodingNode *n) {
-       nodes.add(n);
-       uint newSize=estimateNewSize(n);
-       numElements += n->elements.getSize();
-       if (newSize > encodingSize)
-               encodingSize=newSize;
-}
-
-SetIteratorEncodingNode * EncodingSubGraph::nodeIterator() {
-       return nodes.iterator();
-}