Bug fix
[satune.git] / src / AST / set.cc
index fc8bf150f0bd3433caa3798eb8d0c0196fede809..d3801f8d498eb98e3a55440a4ac7faf243a195f3 100644 (file)
@@ -4,10 +4,6 @@
 #include "serializer.h"
 #include "qsort.h"
 
-Set::Set(VarType t) : type(t), isRange(false), low(0), high(0) {
-       members = new Vector<uint64_t>();
-}
-
 int intcompare(const void *p1, const void *p2) {
        uint64_t a=*(uint64_t const *) p1;
        uint64_t b=*(uint64_t const *) p2;
@@ -19,6 +15,11 @@ int intcompare(const void *p1, const void *p2) {
                return 1;
 }
 
+Set::Set(VarType t) : type(t), isRange(false), low(0), high(0) {
+       members = new Vector<uint64_t>();
+}
+
+
 Set::Set(VarType t, uint64_t *elements, uint num) : type(t), isRange(false), low(0), high(0) {
        members = new Vector<uint64_t>(num, elements);
        bsdqsort(members->expose(), members->getSize(), sizeof(uint64_t), intcompare);
@@ -32,12 +33,24 @@ bool Set::exists(uint64_t element) {
        if (isRange) {
                return element >= low && element <= high;
        } else {
-               uint size = members->getSize();
-               for (uint i = 0; i < size; i++) {
-                       if (element == members->get(i))
+               //Use Binary Search
+               uint low=0;
+               uint high=members->getSize()-1;
+               while(true) {
+                       uint middle=(low+high)/2;
+                       uint64_t val=members->get(middle);
+                       if (element < val) {
+                               high=middle-1;
+                               if (middle<=low)
+                                       return false;
+                       } else if (element > val) {
+                               low=middle+1;
+                               if (middle>=high)
+                                       return false;
+                       } else {
                                return true;
+                       }
                }
-               return false;
        }
 }
 
@@ -82,6 +95,31 @@ Set *Set::clone(CSolver *solver, CloneMap *map) {
        return s;
 }
 
+uint Set::getUnionSize(Set *s) {
+       uint sSize = s->getSize();
+       uint thisSize = getSize();
+       uint sIndex = 0;
+       uint thisIndex = 0;
+       uint sum = 0;
+       uint64_t sValue = s->getElement(sIndex);
+       uint64_t thisValue = getElement(thisIndex);
+       while(thisIndex < thisSize && sIndex < sSize) {
+               if (sValue < thisValue) {
+                       sValue = s->getElement(++sIndex);
+                       sum++;
+               } else if (thisValue < sValue) {
+                       thisValue = getElement(++thisIndex);
+                       sum++;
+               } else {
+                       thisValue = getElement(++thisIndex);
+                       sValue = s->getElement(++sIndex);
+                       sum++;
+               }
+       }
+       sum += (thisSize - thisIndex) + (sSize - sIndex);
+       
+       return sum;
+}
 
 void Set::serialize(Serializer* serializer){
        if(serializer->isSerialized(this))