e019b12a3c3f333adfbbaf7b79c0485cb09d72d3
[satune.git] / src / AST / set.cc
1 #include "set.h"
2 #include <stddef.h>
3 #include "csolver.h"
4 #include "serializer.h"
5 #include "qsort.h"
6
7 int intcompare(const void *p1, const void *p2) {
8         uint64_t a = *(uint64_t const *) p1;
9         uint64_t b = *(uint64_t const *) p2;
10         if (a < b)
11                 return -1;
12         else if (a == b)
13                 return 0;
14         else
15                 return 1;
16 }
17
18 Set::Set(VarType t) : type(t), isRange(false), low(0), high(0) {
19         members = new Vector<uint64_t>();
20 }
21
22
23 Set::Set(VarType t, uint64_t *elements, uint num) : type(t), isRange(false), low(0), high(0) {
24         members = new Vector<uint64_t>(num, elements);
25         bsdqsort(members->expose(), members->getSize(), sizeof(uint64_t), intcompare);
26 }
27
28
29 Set::Set(VarType t, uint64_t lowrange, uint64_t highrange) : type(t), isRange(true), low(lowrange), high(highrange), members(NULL) {
30 }
31
32 bool Set::exists(uint64_t element) {
33         if (isRange) {
34                 return element >= low && element <= high;
35         } else {
36                 //Use Binary Search
37                 uint low = 0;
38                 uint high = members->getSize() - 1;
39                 while (true) {
40                         uint middle = (low + high) / 2;
41                         uint64_t val = members->get(middle);
42                         if (element < val) {
43                                 high = middle - 1;
44                                 if (middle <= low)
45                                         return false;
46                         } else if (element > val) {
47                                 low = middle + 1;
48                                 if (middle >= high)
49                                         return false;
50                         } else {
51                                 return true;
52                         }
53                 }
54         }
55 }
56
57 uint64_t Set::getElement(uint index) {
58         if (isRange)
59                 return low + index;
60         else
61                 return members->get(index);
62 }
63
64 uint Set::getSize() {
65         if (isRange) {
66                 return high - low + 1;
67         } else {
68                 return members->getSize();
69         }
70 }
71
72 uint64_t Set::getMemberAt(uint index) {
73         if (isRange) {
74                 return low + index;
75         } else {
76                 return members->get(index);
77         }
78 }
79
80 Set::~Set() {
81         if (!isRange)
82                 delete members;
83 }
84
85 Set *Set::clone(CSolver *solver, CloneMap *map) {
86         Set *s = (Set *) map->get(this);
87         if (s != NULL)
88                 return s;
89         if (isRange) {
90                 s = solver->createRangeSet(type, low, high);
91         } else {
92                 s = solver->createSet(type, members->expose(), members->getSize());
93         }
94         map->put(this, s);
95         return s;
96 }
97
98 uint Set::getUnionSize(Set *s) {
99         uint sSize = s->getSize();
100         uint thisSize = getSize();
101         uint sIndex = 0;
102         uint thisIndex = 0;
103         uint sum = 0;
104         uint64_t sValue = s->getElement(sIndex);
105         uint64_t thisValue = getElement(thisIndex);
106         while (thisIndex < thisSize && sIndex < sSize) {
107                 if (sValue < thisValue) {
108                         sIndex++;
109                         if (sIndex < sSize)
110                                 sValue = s->getElement(sIndex);
111                         sum++;
112                 } else if (thisValue < sValue) {
113                         thisIndex++;
114                         if (thisIndex < thisSize)
115                                 thisValue = getElement(thisIndex);
116                         sum++;
117                 } else {
118                         thisIndex++;
119                         sIndex++;
120                         if (sIndex < sSize)
121                                 sValue = s->getElement(sIndex);
122                         if (thisIndex < thisSize)
123                                 thisValue = getElement(thisIndex);
124                         sum++;
125                 }
126         }
127         sum += (thisSize - thisIndex) + (sSize - sIndex);
128
129         return sum;
130 }
131
132 void Set::serialize(Serializer *serializer) {
133         if (serializer->isSerialized(this))
134                 return;
135         serializer->addObject(this);
136         ASTNodeType asttype = SETTYPE;
137         serializer->mywrite(&asttype, sizeof(ASTNodeType));
138         Set *This = this;
139         serializer->mywrite(&This, sizeof(Set *));
140         serializer->mywrite(&type, sizeof(VarType));
141         serializer->mywrite(&isRange, sizeof(bool));
142         bool isMutable = isMutableSet();
143         serializer->mywrite(&isMutable, sizeof(bool));
144         if (isRange) {
145                 serializer->mywrite(&low, sizeof(uint64_t));
146                 serializer->mywrite(&high, sizeof(uint64_t));
147         } else {
148                 uint size = members->getSize();
149                 serializer->mywrite(&size, sizeof(uint));
150                 for (uint i = 0; i < size; i++) {
151                         uint64_t mem = members->get(i);
152                         serializer->mywrite(&mem, sizeof(uint64_t));
153                 }
154         }
155 }
156
157 void Set::print() {
158         model_print("{Set<%p>:", this);
159         if (isRange) {
160                 model_print("Range: low=%lu, high=%lu}", low, high);
161         } else {
162                 uint size = members->getSize();
163                 model_print("Members: ");
164                 for (uint i = 0; i < size; i++) {
165                         uint64_t mem = members->get(i);
166                         model_print("%lu, ", mem);
167                 }
168                 model_print("}");
169         }
170 }