#include "set.h"
#include <stddef.h>
+#include "csolver.h"
+#include "serializer.h"
+#include "qsort.h"
-Set *allocSet(VarType t, uint64_t *elements, uint num) {
- Set *This = (Set *)ourmalloc(sizeof(Set));
- This->type = t;
- This->isRange = false;
- This->low = 0;
- This->high = 0;
- This->members = allocVectorArrayInt(num, elements);
- return This;
-}
-
-Set *allocSetRange(VarType t, uint64_t lowrange, uint64_t highrange) {
- Set *This = (Set *)ourmalloc(sizeof(Set));
- This->type = t;
- This->isRange = true;
- This->low = lowrange;
- This->high = highrange;
- This->members = NULL;
- return This;
-}
-
-bool existsInSet(Set *This, uint64_t element) {
- if (This->isRange) {
- return element >= This->low && element <= This->high;
+int intcompare(const void *p1, const void *p2) {
+ uint64_t a=*(uint64_t const *) p1;
+ uint64_t b=*(uint64_t const *) p2;
+ if (a < b)
+ return -1;
+ else if (a==b)
+ return 0;
+ else
+ return 1;
+}
+
+Set::Set(VarType t) : type(t), isRange(false), low(0), high(0) {
+ members = new Vector<uint64_t>();
+}
+
+
+Set::Set(VarType t, uint64_t *elements, uint num) : type(t), isRange(false), low(0), high(0) {
+ members = new Vector<uint64_t>(num, elements);
+ bsdqsort(members->expose(), members->getSize(), sizeof(uint64_t), intcompare);
+}
+
+
+Set::Set(VarType t, uint64_t lowrange, uint64_t highrange) : type(t), isRange(true), low(lowrange), high(highrange), members(NULL) {
+}
+
+bool Set::exists(uint64_t element) {
+ if (isRange) {
+ return element >= low && element <= high;
} else {
- uint size = getSizeVectorInt(This->members);
- for (uint i = 0; i < size; i++) {
- if (element == getVectorInt(This->members, i))
+ //Use Binary Search
+ uint low=0;
+ uint high=members->getSize()-1;
+ while(true) {
+ uint middle=(low+high)/2;
+ uint64_t val=members->get(middle);
+ if (element < val) {
+ high=middle;
+ if (middle==low)
+ return false;
+ } else if (element > val) {
+ low=middle;
+ if (middle==high)
+ return false;
+ } else {
return true;
+ }
}
- return false;
}
}
-uint64_t getSetElement(Set *This, uint index) {
- if (This->isRange)
- return This->low + index;
+uint64_t Set::getElement(uint index) {
+ if (isRange)
+ return low + index;
else
- return getVectorInt(This->members, index);
+ return members->get(index);
+}
+
+uint Set::getSize() {
+ if (isRange) {
+ return high - low + 1;
+ } else {
+ return members->getSize();
+ }
}
-uint getSetSize(Set *This) {
- if (This->isRange) {
- return This->high - This->low + 1;
+uint64_t Set::getMemberAt(uint index) {
+ if (isRange) {
+ return low + index;
} else {
- return getSizeVectorInt(This->members);
+ return members->get(index);
}
}
-void deleteSet(Set *This) {
- if (!This->isRange)
- deleteVectorInt(This->members);
- ourfree(This);
+Set::~Set() {
+ if (!isRange)
+ delete members;
+}
+
+Set *Set::clone(CSolver *solver, CloneMap *map) {
+ Set *s = (Set *) map->get(this);
+ if (s != NULL)
+ return s;
+ if (isRange) {
+ s = solver->createRangeSet(type, low, high);
+ } else {
+ s = solver->createSet(type, members->expose(), members->getSize());
+ }
+ map->put(this, s);
+ return s;
+}
+
+uint Set::getUnionSize(Set *s) {
+ uint sSize = s->getSize();
+ uint thisSize = getSize();
+ uint sIndex = 0;
+ uint thisIndex = 0;
+ uint sum = 0;
+ uint64_t sValue = s->getElement(sIndex);
+ uint64_t thisValue = getElement(thisIndex);
+ while(thisIndex < thisSize && sIndex < sSize) {
+ if (sValue < thisValue) {
+ sValue = s->getElement(++sIndex);
+ sum++;
+ } else if (thisValue < sValue) {
+ thisValue = getElement(++thisIndex);
+ sum++;
+ } else {
+ thisValue = getElement(++thisIndex);
+ sValue = s->getElement(++sIndex);
+ sum++;
+ }
+ }
+ sum += (thisSize - thisIndex) + (sSize - sIndex);
+
+ return sum;
+}
+
+void Set::serialize(Serializer* serializer){
+ if(serializer->isSerialized(this))
+ return;
+ serializer->addObject(this);
+ ASTNodeType asttype = SETTYPE;
+ serializer->mywrite(&asttype, sizeof(ASTNodeType));
+ Set* This = this;
+ serializer->mywrite(&This, sizeof(Set*));
+ serializer->mywrite(&type, sizeof(VarType));
+ serializer->mywrite(&isRange, sizeof(bool));
+ serializer->mywrite(&low, sizeof(uint64_t));
+ serializer->mywrite(&high, sizeof(uint64_t));
+ bool isMutable = isMutableSet();
+ serializer->mywrite(&isMutable, sizeof(bool));
+ uint size = members->getSize();
+ serializer->mywrite(&size, sizeof(uint));
+ for(uint i=0; i<size; i++){
+ uint64_t mem = members->get(i);
+ serializer->mywrite(&mem, sizeof(uint64_t));
+ }
}