X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=src%2FAST%2Fset.cc;h=d3801f8d498eb98e3a55440a4ac7faf243a195f3;hb=8289a5fe3c5298b2477ffa611ca976376554afc7;hp=54c66798f9889d3dda6fb433f5c9b56d3547947e;hpb=65275e5ec0b495610fc450d01ced6bd095c9602a;p=satune.git diff --git a/src/AST/set.cc b/src/AST/set.cc index 54c6679..d3801f8 100644 --- a/src/AST/set.cc +++ b/src/AST/set.cc @@ -4,10 +4,6 @@ #include "serializer.h" #include "qsort.h" -Set::Set(VarType t) : type(t), isRange(false), low(0), high(0) { - members = new Vector(); -} - int intcompare(const void *p1, const void *p2) { uint64_t a=*(uint64_t const *) p1; uint64_t b=*(uint64_t const *) p2; @@ -19,6 +15,11 @@ int intcompare(const void *p1, const void *p2) { return 1; } +Set::Set(VarType t) : type(t), isRange(false), low(0), high(0) { + members = new Vector(); +} + + Set::Set(VarType t, uint64_t *elements, uint num) : type(t), isRange(false), low(0), high(0) { members = new Vector(num, elements); bsdqsort(members->expose(), members->getSize(), sizeof(uint64_t), intcompare); @@ -32,12 +33,24 @@ bool Set::exists(uint64_t element) { if (isRange) { return element >= low && element <= high; } else { - uint size = members->getSize(); - for (uint i = 0; i < size; i++) { - if (element == members->get(i)) + //Use Binary Search + uint low=0; + uint high=members->getSize()-1; + while(true) { + uint middle=(low+high)/2; + uint64_t val=members->get(middle); + if (element < val) { + high=middle-1; + if (middle<=low) + return false; + } else if (element > val) { + low=middle+1; + if (middle>=high) + return false; + } else { return true; + } } - return false; } } @@ -82,6 +95,31 @@ Set *Set::clone(CSolver *solver, CloneMap *map) { return s; } +uint Set::getUnionSize(Set *s) { + uint sSize = s->getSize(); + uint thisSize = getSize(); + uint sIndex = 0; + uint thisIndex = 0; + uint sum = 0; + uint64_t sValue = s->getElement(sIndex); + uint64_t thisValue = getElement(thisIndex); + while(thisIndex < thisSize && sIndex < sSize) { + if (sValue < thisValue) { + sValue = s->getElement(++sIndex); + sum++; + } else if (thisValue < sValue) { + thisValue = getElement(++thisIndex); + sum++; + } else { + thisValue = getElement(++thisIndex); + sValue = s->getElement(++sIndex); + sum++; + } + } + sum += (thisSize - thisIndex) + (sSize - sIndex); + + return sum; +} void Set::serialize(Serializer* serializer){ if(serializer->isSerialized(this)) @@ -95,6 +133,8 @@ void Set::serialize(Serializer* serializer){ serializer->mywrite(&isRange, sizeof(bool)); serializer->mywrite(&low, sizeof(uint64_t)); serializer->mywrite(&high, sizeof(uint64_t)); + bool isMutable = isMutableSet(); + serializer->mywrite(&isMutable, sizeof(bool)); uint size = members->getSize(); serializer->mywrite(&size, sizeof(uint)); for(uint i=0; i