Off by 1 bug
[satune.git] / src / ASTAnalyses / Encoding / encodinggraph.cc
index 08db96d7b55cc47e02c7be0e394ae1135270ce5a..5416ed0e2d7778444bff22640de8f127dd6ef187 100644 (file)
 #include "subgraph.h"
 #include "elementencoding.h"
 
-EncodingGraph::EncodingGraph(CSolver * _solver) :
+EncodingGraph::EncodingGraph(CSolver *_solver) :
        solver(_solver) {
 }
 
-int sortEncodingEdge(const void * p1, const void *p2) {
-       const EncodingEdge * e1 = * (const EncodingEdge **) p1;
-       const EncodingEdge * e2 = * (const EncodingEdge **) p2;
+int sortEncodingEdge(const void *p1, const void *p2) {
+       const EncodingEdge *e1 = *(const EncodingEdge **) p1;
+       const EncodingEdge *e2 = *(const EncodingEdge **) p2;
        uint64_t v1 = e1->getValue();
        uint64_t v2 = e2->getValue();
        if (v1 < v2)
@@ -29,9 +29,9 @@ int sortEncodingEdge(const void * p1, const void *p2) {
 
 void EncodingGraph::buildGraph() {
        ElementIterator it(solver);
-       while(it.hasNext()) {
-               Element * e = it.next();
-               switch(e->type) {
+       while (it.hasNext()) {
+               Element *e = it.next();
+               switch (e->type) {
                case ELEMSET:
                case ELEMFUNCRETURN:
                        processElement(e);
@@ -47,35 +47,91 @@ void EncodingGraph::buildGraph() {
 }
 
 void EncodingGraph::encode() {
-       SetIteratorEncodingSubGraph * itesg=subgraphs.iterator();
-       while(itesg->hasNext()) {
-               EncodingSubGraph *sg=itesg->next();
+       SetIteratorEncodingSubGraph *itesg = subgraphs.iterator();
+       while (itesg->hasNext()) {
+               EncodingSubGraph *sg = itesg->next();
                sg->encode();
        }
        delete itesg;
 
        ElementIterator it(solver);
-       while(it.hasNext()) {
-               Element * e = it.next();
-               switch(e->type) {
+       while (it.hasNext()) {
+               Element *e = it.next();
+               switch (e->type) {
                case ELEMSET:
                case ELEMFUNCRETURN: {
-                       ElementEncoding *encoding=getElementEncoding(e);
+                       ElementEncoding *encoding = e->getElementEncoding();
                        if (encoding->getElementEncodingType() == ELEM_UNASSIGNED) {
-                               //Do assignment...
+                               EncodingNode *n = getNode(e);
+                               if (n == NULL)
+                                       continue;
+                               ElementEncodingType encodetype = n->getEncoding();
+                               encoding->setElementEncodingType(encodetype);
+                               if (encodetype == UNARY || encodetype == ONEHOT) {
+                                       encoding->encodingArrayInitialization();
+                               } else if (encodetype == BINARYINDEX) {
+                                       EncodingSubGraph *subgraph = graphMap.get(n);
+                                       if (subgraph == NULL)
+                                               continue;
+                                       uint encodingSize = subgraph->getEncodingMaxVal(n)+1;
+                                       uint paddedSize = encoding->getSizeEncodingArray(encodingSize);
+                                       model_print("encoding size=%u\n", encodingSize);
+                                       model_print("padded=%u\n", paddedSize);
+                                       encoding->allocInUseArrayElement(paddedSize);
+                                       encoding->allocEncodingArrayElement(paddedSize);
+                                       Set *s = e->getRange();
+                                       for (uint i = 0; i < s->getSize(); i++) {
+                                               model_print("index=%u\n", i);
+                                               uint64_t value = s->getElement(i);
+                                               uint encodingIndex = subgraph->getEncoding(n, value);
+                                               encoding->setInUseElement(encodingIndex);
+                                               encoding->encodingArray[encodingIndex] = value;
+                                       }
+                               }
                        }
                        break;
                }
                default:
                        break;
                }
+               encodeParent(e);
+       }
+}
+
+void EncodingGraph::encodeParent(Element *e) {
+       uint size = e->parents.getSize();
+       for (uint i = 0; i < size; i++) {
+               ASTNode *n = e->parents.get(i);
+               if (n->type == PREDICATEOP) {
+                       BooleanPredicate *b = (BooleanPredicate *)n;
+                       FunctionEncoding *fenc = b->getFunctionEncoding();
+                       if (fenc->getFunctionEncodingType() != FUNC_UNASSIGNED)
+                               continue;
+                       Predicate *p = b->getPredicate();
+                       if (p->type == OPERATORPRED) {
+                               PredicateOperator *po = (PredicateOperator *)p;
+                               ASSERT(b->inputs.getSize() == 2);
+                               EncodingNode *left = createNode(b->inputs.get(0));
+                               EncodingNode *right = createNode(b->inputs.get(1));
+                               if (left == NULL || right == NULL)
+                                       return;
+                               EncodingEdge *edge = getEdge(left, right, NULL);
+                               if (edge != NULL && edge->getEncoding() == EDGE_MATCH) {
+                                       fenc->setFunctionEncodingType(CIRCUIT);
+                               }
+                       }
+               }
        }
-       
 }
 
 void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
-       EncodingSubGraph *graph1=graphMap.get(first);
-       EncodingSubGraph *graph2=graphMap.get(second);
+       EncodingSubGraph *graph1 = graphMap.get(first);
+       EncodingSubGraph *graph2 = graphMap.get(second);
+       if (graph1 == NULL)
+               first->setEncoding(BINARYINDEX);
+       if (graph2 == NULL)
+               second->setEncoding(BINARYINDEX);
+
        if (graph1 == NULL && graph2 == NULL) {
                graph1 = new EncodingSubGraph();
                subgraphs.add(graph1);
@@ -90,9 +146,9 @@ void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
                first = tmp;
        }
        if (graph1 != NULL && graph2 != NULL) {
-               SetIteratorEncodingNode * nodeit=graph2->nodeIterator();
-               while(nodeit->hasNext()) {
-                       EncodingNode *node=nodeit->next();
+               SetIteratorEncodingNode *nodeit = graph2->nodeIterator();
+               while (nodeit->hasNext()) {
+                       EncodingNode *node = nodeit->next();
                        graph1->addNode(node);
                        graphMap.put(node, graph1);
                }
@@ -107,10 +163,10 @@ void EncodingGraph::mergeNodes(EncodingNode *first, EncodingNode *second) {
 }
 
 void EncodingGraph::processElement(Element *e) {
-       uint size=e->parents.getSize();
-       for(uint i=0;i<size;i++) {
-               ASTNode * n = e->parents.get(i);
-               switch(n->type) {
+       uint size = e->parents.getSize();
+       for (uint i = 0; i < size; i++) {
+               ASTNode *n = e->parents.get(i);
+               switch (n->type) {
                case PREDICATEOP:
                        processPredicate((BooleanPredicate *)n);
                        break;
@@ -124,32 +180,32 @@ void EncodingGraph::processElement(Element *e) {
 }
 
 void EncodingGraph::processFunction(ElementFunction *ef) {
-       Function *f=ef->getFunction();
-       if (f->type==OPERATORFUNC) {
-               FunctionOperator *fo=(FunctionOperator*)f;
+       Function *f = ef->getFunction();
+       if (f->type == OPERATORFUNC) {
+               FunctionOperator *fo = (FunctionOperator *)f;
                ASSERT(ef->inputs.getSize() == 2);
-               EncodingNode *left=createNode(ef->inputs.get(0));
-               EncodingNode *right=createNode(ef->inputs.get(1));
+               EncodingNode *left = createNode(ef->inputs.get(0));
+               EncodingNode *right = createNode(ef->inputs.get(1));
                if (left == NULL && right == NULL)
                        return;
-               EncodingNode *dst=createNode(ef);
-               EncodingEdge *edge=getEdge(left, right, dst);
+               EncodingNode *dst = createNode(ef);
+               EncodingEdge *edge = createEdge(left, right, dst);
                edge->numArithOps++;
        }
 }
 
 void EncodingGraph::processPredicate(BooleanPredicate *b) {
-       Predicate *p=b->getPredicate();
-       if (p->type==OPERATORPRED) {
-               PredicateOperator *po=(PredicateOperator *)p;
-               ASSERT(b->inputs.getSize()==2);
-               EncodingNode *left=createNode(b->inputs.get(0));
-               EncodingNode *right=createNode(b->inputs.get(1));
+       Predicate *p = b->getPredicate();
+       if (p->type == OPERATORPRED) {
+               PredicateOperator *po = (PredicateOperator *)p;
+               ASSERT(b->inputs.getSize() == 2);
+               EncodingNode *left = createNode(b->inputs.get(0));
+               EncodingNode *right = createNode(b->inputs.get(1));
                if (left == NULL || right == NULL)
                        return;
-               EncodingEdge *edge=getEdge(left, right, NULL);
-               CompOp op=po->getOp();
-               switch(op) {
+               EncodingEdge *edge = createEdge(left, right, NULL);
+               CompOp op = po->getOp();
+               switch (op) {
                case SATC_EQUALS:
                        edge->numEquals++;
                        break;
@@ -166,60 +222,60 @@ void EncodingGraph::processPredicate(BooleanPredicate *b) {
 }
 
 uint convertSize(uint cost) {
-       cost = 1.2 * cost; // fudge factor
+       cost = 1.2 * cost;// fudge factor
        return NEXTPOW2(cost);
 }
 
 void EncodingGraph::decideEdges() {
-       uint size=edgeVector.getSize();
-       for(uint i=0; i<size; i++) {
+       uint size = edgeVector.getSize();
+       for (uint i = 0; i < size; i++) {
                EncodingEdge *ee = edgeVector.get(i);
                EncodingNode *left = ee->left;
                EncodingNode *right = ee->right;
-               
+
                if (ee->encoding != EDGE_UNASSIGNED ||
-                               left->encoding != BINARYINDEX ||
-                               right->encoding != BINARYINDEX)
+                               !left->couldBeBinaryIndex() ||
+                               !right->couldBeBinaryIndex())
                        continue;
-               
+
                uint64_t eeValue = ee->getValue();
                if (eeValue == 0)
                        return;
 
                EncodingSubGraph *leftGraph = graphMap.get(left);
                EncodingSubGraph *rightGraph = graphMap.get(right);
-               if (leftGraph == NULL && rightGraph !=NULL) {
-                       EncodingNode *tmp = left; left=right; right=tmp;
+               if (leftGraph == NULL && rightGraph != NULL) {
+                       EncodingNode *tmp = left; left = right; right = tmp;
                        EncodingSubGraph *tmpsg = leftGraph; leftGraph = rightGraph; rightGraph = tmpsg;
                }
 
-               uint leftSize=0, rightSize=0, newSize=0;
-               uint64_t totalCost=0;
+               uint leftSize = 0, rightSize = 0, newSize = 0;
+               uint64_t totalCost = 0;
                if (leftGraph == NULL && rightGraph == NULL) {
-                       leftSize=convertSize(left->getSize());
-                       rightSize=convertSize(right->getSize());
-                       newSize=convertSize(left->s->getUnionSize(right->s));
-                       newSize=(leftSize > newSize) ? leftSize: newSize;
-                       newSize=(rightSize > newSize) ? rightSize: newSize;
+                       leftSize = convertSize(left->getSize());
+                       rightSize = convertSize(right->getSize());
+                       newSize = convertSize(left->s->getUnionSize(right->s));
+                       newSize = (leftSize > newSize) ? leftSize : newSize;
+                       newSize = (rightSize > newSize) ? rightSize : newSize;
                        totalCost = (newSize - leftSize) * left->elements.getSize() +
-                               (newSize - rightSize) * right->elements.getSize();
+                                                                       (newSize - rightSize) * right->elements.getSize();
                } else if (leftGraph != NULL && rightGraph == NULL) {
-                       leftSize=convertSize(leftGraph->encodingSize);
-                       rightSize=convertSize(right->getSize());
-                       newSize=convertSize(leftGraph->estimateNewSize(right));
-                       newSize=(leftSize > newSize) ? leftSize: newSize;
-                       newSize=(rightSize > newSize) ? rightSize: newSize;
+                       leftSize = convertSize(leftGraph->encodingSize);
+                       rightSize = convertSize(right->getSize());
+                       newSize = convertSize(leftGraph->estimateNewSize(right));
+                       newSize = (leftSize > newSize) ? leftSize : newSize;
+                       newSize = (rightSize > newSize) ? rightSize : newSize;
                        totalCost = (newSize - leftSize) * leftGraph->numElements +
-                               (newSize - rightSize) * right->elements.getSize();
+                                                                       (newSize - rightSize) * right->elements.getSize();
                } else {
                        //Neither are null
-                       leftSize=convertSize(leftGraph->encodingSize);
-                       rightSize=convertSize(rightGraph->encodingSize);
-                       newSize=convertSize(leftGraph->estimateNewSize(rightGraph));
-                       newSize=(leftSize > newSize) ? leftSize: newSize;
-                       newSize=(rightSize > newSize) ? rightSize: newSize;
+                       leftSize = convertSize(leftGraph->encodingSize);
+                       rightSize = convertSize(rightGraph->encodingSize);
+                       newSize = convertSize(leftGraph->estimateNewSize(rightGraph));
+                       newSize = (leftSize > newSize) ? leftSize : newSize;
+                       newSize = (rightSize > newSize) ? rightSize : newSize;
                        totalCost = (newSize - leftSize) * leftGraph->numElements +
-                               (newSize - rightSize) * rightGraph->numElements;
+                                                                       (newSize - rightSize) * rightGraph->numElements;
                }
                double conversionfactor = 0.5;
                if ((totalCost * conversionfactor) < eeValue) {
@@ -231,22 +287,28 @@ void EncodingGraph::decideEdges() {
 
 static TunableDesc EdgeEncodingDesc(EDGE_UNASSIGNED, EDGE_MATCH, EDGE_UNASSIGNED);
 
-EncodingEdge * EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
+EncodingEdge *EncodingGraph::getEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
+       EncodingEdge e(left, right, dst);
+       EncodingEdge *result = edgeMap.get(&e);
+       return result;
+}
+
+EncodingEdge *EncodingGraph::createEdge(EncodingNode *left, EncodingNode *right, EncodingNode *dst) {
        EncodingEdge e(left, right, dst);
        EncodingEdge *result = edgeMap.get(&e);
        if (result == NULL) {
-               result=new EncodingEdge(left, right, dst);
-               VarType v1=left->getType();
-               VarType v2=right->getType();
+               result = new EncodingEdge(left, right, dst);
+               VarType v1 = left->getType();
+               VarType v2 = right->getType();
                if (v1 > v2) {
-                       VarType tmp=v2;
-                       v2=v1;
-                       v1=tmp;
+                       VarType tmp = v2;
+                       v2 = v1;
+                       v1 = tmp;
                }
 
-               if ((left != NULL && left->encoding==BINARYINDEX) &&
-                               (right != NULL) && right->encoding==BINARYINDEX) {
-                       EdgeEncodingType type=(EdgeEncodingType)solver->getTuner()->getVarTunable(v1, v2, EDGEENCODING, &EdgeEncodingDesc);
+               if ((left != NULL && left->couldBeBinaryIndex()) &&
+                               (right != NULL) && right->couldBeBinaryIndex()) {
+                       EdgeEncodingType type = (EdgeEncodingType)solver->getTuner()->getVarTunable(v1, v2, EDGEENCODING, &EdgeEncodingDesc);
                        result->setEncoding(type);
                        if (type == EDGE_MATCH) {
                                mergeNodes(left, right);
@@ -278,7 +340,7 @@ VarType EncodingNode::getType() const {
 
 static TunableDesc NodeEncodingDesc(ELEM_UNASSIGNED, BINARYINDEX, ELEM_UNASSIGNED);
 
-EncodingNode * EncodingGraph::createNode(Element *e) {
+EncodingNode *EncodingGraph::createNode(Element *e) {
        if (e->type == ELEMCONST)
                return NULL;
        Set *s = e->getRange();
@@ -286,12 +348,21 @@ EncodingNode * EncodingGraph::createNode(Element *e) {
        if (n == NULL) {
                n = new EncodingNode(s);
                n->setEncoding((ElementEncodingType)solver->getTuner()->getVarTunable(n->getType(), NODEENCODING, &NodeEncodingDesc));
+
                encodingMap.put(s, n);
        }
        n->addElement(e);
        return n;
 }
 
+EncodingNode *EncodingGraph::getNode(Element *e) {
+       if (e->type == ELEMCONST)
+               return NULL;
+       Set *s = e->getRange();
+       EncodingNode *n = encodingMap.get(s);
+       return n;
+}
+
 void EncodingNode::addElement(Element *e) {
        elements.add(e);
 }
@@ -319,7 +390,7 @@ EncodingEdge::EncodingEdge(EncodingNode *_left, EncodingNode *_right, EncodingNo
 }
 
 uint hashEncodingEdge(EncodingEdge *edge) {
-       uintptr_t hash=(((uintptr_t) edge->left) >> 2) ^ (((uintptr_t)edge->right) >> 4) ^ (((uintptr_t)edge->dst) >> 6);
+       uintptr_t hash = (((uintptr_t) edge->left) >> 2) ^ (((uintptr_t)edge->right) >> 4) ^ (((uintptr_t)edge->dst) >> 6);
        return (uint) hash;
 }