9342484a21a81ce61e73cd14fae87a9148d9902c
[satune.git] / src / AST / set.cc
1 #include "set.h"
2 #include <stddef.h>
3 #include "csolver.h"
4 #include "serializer.h"
5 #include "qsort.h"
6
7 int intcompare(const void *p1, const void *p2) {
8         uint64_t a = *(uint64_t const *) p1;
9         uint64_t b = *(uint64_t const *) p2;
10         if (a < b)
11                 return -1;
12         else if (a == b)
13                 return 0;
14         else
15                 return 1;
16 }
17
18 Set::Set(VarType t) : type(t), isRange(false), low(0), high(0) {
19         members = new Vector<uint64_t>();
20 }
21
22
23 Set::Set(VarType t, uint64_t *elements, uint num) : type(t), isRange(false), low(0), high(0) {
24         members = new Vector<uint64_t>(num, elements);
25         bsdqsort(members->expose(), members->getSize(), sizeof(uint64_t), intcompare);
26 }
27
28
29 Set::Set(VarType t, uint64_t lowrange, uint64_t highrange) : type(t), isRange(true), low(lowrange), high(highrange), members(NULL) {
30 }
31
32 bool Set::exists(uint64_t element) {
33         if (isRange) {
34                 return element >= low && element <= high;
35         } else {
36                 //Use Binary Search
37                 uint low = 0;
38                 uint high = members->getSize() - 1;
39                 while (true) {
40                         uint middle = (low + high) / 2;
41                         uint64_t val = members->get(middle);
42                         if (element < val) {
43                                 high = middle - 1;
44                                 if (middle <= low)
45                                         return false;
46                         } else if (element > val) {
47                                 low = middle + 1;
48                                 if (middle >= high)
49                                         return false;
50                         } else {
51                                 return true;
52                         }
53                 }
54         }
55 }
56
57 uint64_t Set::getElement(uint index) {
58         if (isRange)
59                 return low + index;
60         else
61                 return members->get(index);
62         }
63
64 uint Set::getSize() {
65         if (isRange) {
66                 return high - low + 1;
67         } else {
68                 return members->getSize();
69         }
70 }
71
72 Set::~Set() {
73         if (!isRange)
74                 delete members;
75 }
76
77 Set *Set::clone(CSolver *solver, CloneMap *map) {
78         Set *s = (Set *) map->get(this);
79         if (s != NULL)
80                 return s;
81         if (isRange) {
82                 s = solver->createRangeSet(type, low, high);
83         } else {
84                 s = solver->createSet(type, members->expose(), members->getSize());
85         }
86         map->put(this, s);
87         return s;
88 }
89
90 uint Set::getUnionSize(Set *s) {
91         uint sSize = s->getSize();
92         uint thisSize = getSize();
93         uint sIndex = 0;
94         uint thisIndex = 0;
95         uint sum = 0;
96         uint64_t sValue = s->getElement(sIndex);
97         uint64_t thisValue = getElement(thisIndex);
98         while (thisIndex < thisSize && sIndex < sSize) {
99                 if (sValue < thisValue) {
100                         sIndex++;
101                         if (sIndex < sSize)
102                                 sValue = s->getElement(sIndex);
103                         sum++;
104                 } else if (thisValue < sValue) {
105                         thisIndex++;
106                         if (thisIndex < thisSize)
107                                 thisValue = getElement(thisIndex);
108                         sum++;
109                 } else {
110                         thisIndex++;
111                         sIndex++;
112                         if (sIndex < sSize)
113                                 sValue = s->getElement(sIndex);
114                         if (thisIndex < thisSize)
115                                 thisValue = getElement(thisIndex);
116                         sum++;
117                 }
118         }
119         sum += (thisSize - thisIndex) + (sSize - sIndex);
120
121         return sum;
122 }
123
124 void Set::serialize(Serializer *serializer) {
125         if (serializer->isSerialized(this))
126                 return;
127         serializer->addObject(this);
128         ASTNodeType asttype = SETTYPE;
129         serializer->mywrite(&asttype, sizeof(ASTNodeType));
130         Set *This = this;
131         serializer->mywrite(&This, sizeof(Set *));
132         serializer->mywrite(&type, sizeof(VarType));
133         serializer->mywrite(&isRange, sizeof(bool));
134         bool isMutable = isMutableSet();
135         serializer->mywrite(&isMutable, sizeof(bool));
136         if (isRange) {
137                 serializer->mywrite(&low, sizeof(uint64_t));
138                 serializer->mywrite(&high, sizeof(uint64_t));
139         } else {
140                 uint size = members->getSize();
141                 serializer->mywrite(&size, sizeof(uint));
142                 for (uint i = 0; i < size; i++) {
143                         uint64_t mem = members->get(i);
144                         serializer->mywrite(&mem, sizeof(uint64_t));
145                 }
146         }
147 }
148
149 void Set::print() {
150         model_print("{Set(%lu)<%p>:", type, this);
151         if (isRange) {
152                 model_print("Range: low=%lu, high=%lu}", low, high);
153         } else {
154                 uint size = members->getSize();
155                 model_print("Members: ");
156                 for (uint i = 0; i < size; i++) {
157                         uint64_t mem = members->get(i);
158                         model_print("%lu, ", mem);
159                 }
160                 model_print("}");
161         }
162 }