25ccd05e67d6d9e480d45346735ea40ef4d1ed6f
[satune.git] / src / AST / set.cc
1 #include "set.h"
2 #include <stddef.h>
3 #include "csolver.h"
4 #include "serializer.h"
5 #include "qsort.h"
6
7 int intcompare(const void *p1, const void *p2) {
8         uint64_t a=*(uint64_t const *) p1;
9         uint64_t b=*(uint64_t const *) p2;
10         if (a < b)
11                 return -1;
12         else if (a==b)
13                 return 0;
14         else
15                 return 1;
16 }
17
18 Set::Set(VarType t) : type(t), isRange(false), low(0), high(0) {
19         members = new Vector<uint64_t>();
20 }
21
22
23 Set::Set(VarType t, uint64_t *elements, uint num) : type(t), isRange(false), low(0), high(0) {
24         members = new Vector<uint64_t>(num, elements);
25         bsdqsort(members->expose(), members->getSize(), sizeof(uint64_t), intcompare);
26 }
27
28
29 Set::Set(VarType t, uint64_t lowrange, uint64_t highrange) : type(t), isRange(true), low(lowrange), high(highrange), members(NULL) {
30 }
31
32 bool Set::exists(uint64_t element) {
33         if (isRange) {
34                 return element >= low && element <= high;
35         } else {
36                 //Use Binary Search
37                 uint low=0;
38                 uint high=members->getSize()-1;
39                 while(true) {
40                         uint middle=(low+high)/2;
41                         uint64_t val=members->get(middle);
42                         if (element < val) {
43                                 high=middle-1;
44                                 if (middle<=low)
45                                         return false;
46                         } else if (element > val) {
47                                 low=middle+1;
48                                 if (middle>=high)
49                                         return false;
50                         } else {
51                                 return true;
52                         }
53                 }
54         }
55 }
56
57 uint64_t Set::getElement(uint index) {
58         if (isRange)
59                 return low + index;
60         else
61                 return members->get(index);
62 }
63
64 uint Set::getSize() {
65         if (isRange) {
66                 return high - low + 1;
67         } else {
68                 return members->getSize();
69         }
70 }
71
72 uint64_t Set::getMemberAt(uint index) {
73         if (isRange) {
74                 return low + index;
75         } else {
76                 return members->get(index);
77         }
78 }
79
80 Set::~Set() {
81         if (!isRange)
82                 delete members;
83 }
84
85 Set *Set::clone(CSolver *solver, CloneMap *map) {
86         Set *s = (Set *) map->get(this);
87         if (s != NULL)
88                 return s;
89         if (isRange) {
90                 s = solver->createRangeSet(type, low, high);
91         } else {
92                 s = solver->createSet(type, members->expose(), members->getSize());
93         }
94         map->put(this, s);
95         return s;
96 }
97
98 uint Set::getUnionSize(Set *s) {
99         uint sSize = s->getSize();
100         uint thisSize = getSize();
101         uint sIndex = 0;
102         uint thisIndex = 0;
103         uint sum = 0;
104         uint64_t sValue = s->getElement(sIndex);
105         uint64_t thisValue = getElement(thisIndex);
106         while(thisIndex < thisSize && sIndex < sSize) {
107                 if (sValue < thisValue) {
108                         sValue = s->getElement(++sIndex);
109                         sum++;
110                 } else if (thisValue < sValue) {
111                         thisValue = getElement(++thisIndex);
112                         sum++;
113                 } else {
114                         thisValue = getElement(++thisIndex);
115                         sValue = s->getElement(++sIndex);
116                         sum++;
117                 }
118         }
119         sum += (thisSize - thisIndex) + (sSize - sIndex);
120         
121         return sum;
122 }
123
124 void Set::serialize(Serializer* serializer){
125         if(serializer->isSerialized(this))
126                 return;
127         serializer->addObject(this);
128         ASTNodeType asttype = SETTYPE;
129         serializer->mywrite(&asttype, sizeof(ASTNodeType));
130         Set* This = this;
131         serializer->mywrite(&This, sizeof(Set*));
132         serializer->mywrite(&type, sizeof(VarType));
133         serializer->mywrite(&isRange, sizeof(bool));
134         serializer->mywrite(&low, sizeof(uint64_t));
135         serializer->mywrite(&high, sizeof(uint64_t));
136         bool isMutable = isMutableSet();
137         serializer->mywrite(&isMutable, sizeof(bool));
138         uint size = members->getSize();
139         serializer->mywrite(&size, sizeof(uint));
140         for(uint i=0; i<size; i++){
141                 uint64_t mem = members->get(i);
142                 serializer->mywrite(&mem, sizeof(uint64_t));
143         }
144 }
145
146 void Set::print(){
147         model_print("{Set:");
148         if(isRange){
149                 model_print("Range: low=%lu, high=%lu}\n\n", low, high);
150         } else {
151                 uint size = members->getSize();
152                 model_print("Members: ");
153                 for(uint i=0; i<size; i++){
154                         uint64_t mem = members->get(i);
155                         model_print("%lu, ", mem);
156                 }
157                 model_println("}\n");
158         }
159 }