Merge branch 'master' of ssh://demsky.eecs.uci.edu/home/git/constraint_compiler into...
[satune.git] / src / AST / set.cc
index 003379e003edb9ed64bb9aea0a4840f87a8e70ea..d3801f8d498eb98e3a55440a4ac7faf243a195f3 100644 (file)
 #include "set.h"
 #include <stddef.h>
+#include "csolver.h"
+#include "serializer.h"
+#include "qsort.h"
 
-Set *allocSet(VarType t, uint64_t *elements, uint num) {
-       Set *This = (Set *)ourmalloc(sizeof(Set));
-       This->type = t;
-       This->isRange = false;
-       This->low = 0;
-       This->high = 0;
-       This->members = allocVectorArrayInt(num, elements);
-       return This;
-}
-
-Set *allocSetRange(VarType t, uint64_t lowrange, uint64_t highrange) {
-       Set *This = (Set *)ourmalloc(sizeof(Set));
-       This->type = t;
-       This->isRange = true;
-       This->low = lowrange;
-       This->high = highrange;
-       This->members = NULL;
-       return This;
-}
-
-bool existsInSet(Set *This, uint64_t element) {
-       if (This->isRange) {
-               return element >= This->low && element <= This->high;
+int intcompare(const void *p1, const void *p2) {
+       uint64_t a=*(uint64_t const *) p1;
+       uint64_t b=*(uint64_t const *) p2;
+       if (a < b)
+               return -1;
+       else if (a==b)
+               return 0;
+       else
+               return 1;
+}
+
+Set::Set(VarType t) : type(t), isRange(false), low(0), high(0) {
+       members = new Vector<uint64_t>();
+}
+
+
+Set::Set(VarType t, uint64_t *elements, uint num) : type(t), isRange(false), low(0), high(0) {
+       members = new Vector<uint64_t>(num, elements);
+       bsdqsort(members->expose(), members->getSize(), sizeof(uint64_t), intcompare);
+}
+
+
+Set::Set(VarType t, uint64_t lowrange, uint64_t highrange) : type(t), isRange(true), low(lowrange), high(highrange), members(NULL) {
+}
+
+bool Set::exists(uint64_t element) {
+       if (isRange) {
+               return element >= low && element <= high;
        } else {
-               uint size = getSizeVectorInt(This->members);
-               for (uint i = 0; i < size; i++) {
-                       if (element == getVectorInt(This->members, i))
+               //Use Binary Search
+               uint low=0;
+               uint high=members->getSize()-1;
+               while(true) {
+                       uint middle=(low+high)/2;
+                       uint64_t val=members->get(middle);
+                       if (element < val) {
+                               high=middle-1;
+                               if (middle<=low)
+                                       return false;
+                       } else if (element > val) {
+                               low=middle+1;
+                               if (middle>=high)
+                                       return false;
+                       } else {
                                return true;
+                       }
                }
-               return false;
        }
 }
 
-uint64_t getSetElement(Set *This, uint index) {
-       if (This->isRange)
-               return This->low + index;
+uint64_t Set::getElement(uint index) {
+       if (isRange)
+               return low + index;
        else
-               return getVectorInt(This->members, index);
+               return members->get(index);
+}
+
+uint Set::getSize() {
+       if (isRange) {
+               return high - low + 1;
+       } else {
+               return members->getSize();
+       }
 }
 
-uint getSetSize(Set *This) {
-       if (This->isRange) {
-               return This->high - This->low + 1;
+uint64_t Set::getMemberAt(uint index) {
+       if (isRange) {
+               return low + index;
        } else {
-               return getSizeVectorInt(This->members);
+               return members->get(index);
        }
 }
 
-void deleteSet(Set *This) {
-       if (!This->isRange)
-               deleteVectorInt(This->members);
-       ourfree(This);
+Set::~Set() {
+       if (!isRange)
+               delete members;
+}
+
+Set *Set::clone(CSolver *solver, CloneMap *map) {
+       Set *s = (Set *) map->get(this);
+       if (s != NULL)
+               return s;
+       if (isRange) {
+               s = solver->createRangeSet(type, low, high);
+       } else {
+               s = solver->createSet(type, members->expose(), members->getSize());
+       }
+       map->put(this, s);
+       return s;
+}
+
+uint Set::getUnionSize(Set *s) {
+       uint sSize = s->getSize();
+       uint thisSize = getSize();
+       uint sIndex = 0;
+       uint thisIndex = 0;
+       uint sum = 0;
+       uint64_t sValue = s->getElement(sIndex);
+       uint64_t thisValue = getElement(thisIndex);
+       while(thisIndex < thisSize && sIndex < sSize) {
+               if (sValue < thisValue) {
+                       sValue = s->getElement(++sIndex);
+                       sum++;
+               } else if (thisValue < sValue) {
+                       thisValue = getElement(++thisIndex);
+                       sum++;
+               } else {
+                       thisValue = getElement(++thisIndex);
+                       sValue = s->getElement(++sIndex);
+                       sum++;
+               }
+       }
+       sum += (thisSize - thisIndex) + (sSize - sIndex);
+       
+       return sum;
+}
+
+void Set::serialize(Serializer* serializer){
+       if(serializer->isSerialized(this))
+               return;
+       serializer->addObject(this);
+       ASTNodeType asttype = SETTYPE;
+       serializer->mywrite(&asttype, sizeof(ASTNodeType));
+       Set* This = this;
+       serializer->mywrite(&This, sizeof(Set*));
+       serializer->mywrite(&type, sizeof(VarType));
+       serializer->mywrite(&isRange, sizeof(bool));
+       serializer->mywrite(&low, sizeof(uint64_t));
+       serializer->mywrite(&high, sizeof(uint64_t));
+       bool isMutable = isMutableSet();
+       serializer->mywrite(&isMutable, sizeof(bool));
+       uint size = members->getSize();
+       serializer->mywrite(&size, sizeof(uint));
+       for(uint i=0; i<size; i++){
+               uint64_t mem = members->get(i);
+               serializer->mywrite(&mem, sizeof(uint64_t));
+       }
 }