Edits
[satune.git] / src / AST / element.cc
index 3b0f6e40464a271415836cc3d5f5e9bac298b9d8..f575a9034c2565695ac2f1696d0a5af2541c5e21 100644 (file)
 #include "constraint.h"
 #include "function.h"
 #include "table.h"
+#include "csolver.h"
 
-Element::Element(ASTNodeType _type) : ASTNode(_type) {
-       initDefVectorASTNode(GETELEMENTPARENTS(this));
-       initElementEncoding(&encoding, (Element *) this);
+Element::Element(ASTNodeType _type) :
+       ASTNode(_type),
+       encoding(this) {
 }
 
-ElementSet::ElementSet(Set *s) : Element(ELEMSET), set(s) {
+ElementSet::ElementSet(Set *s) :
+       Element(ELEMSET),
+       set(s) {
 }
 
-ElementFunction::ElementFunction(Function *_function, Element **array, uint numArrays, Boolean *_overflowstatus) : Element(ELEMFUNCRETURN), function(_function), overflowstatus(_overflowstatus) {
-       initArrayInitElement(&inputs, array, numArrays);
-       for (uint i = 0; i < numArrays; i++)
-               pushVectorASTNode(GETELEMENTPARENTS(array[i]), this);
-       initFunctionEncoding(&functionencoding, this);
+ElementSet::ElementSet(ASTNodeType _type, Set *s) :
+       Element(_type),
+       set(s) {
 }
 
-ElementConst::ElementConst(uint64_t _value, VarType _type) : Element(ELEMCONST), value(_value) {
-       uint64_t array[]={value};
-       set = allocSet(_type, array, 1);
+ElementFunction::ElementFunction(Function *_function, Element **array, uint numArrays, BooleanEdge _overflowstatus) :
+       Element(ELEMFUNCRETURN),
+       inputs(array, numArrays),
+       overflowstatus(_overflowstatus),
+       functionencoding(this),
+       function(_function) {
 }
 
-Set *getElementSet(Element *This) {
-       switch (GETELEMENTTYPE(This)) {
-       case ELEMSET:
-               return ((ElementSet *)This)->set;
-       case ELEMCONST:
-               return ((ElementConst *)This)->set;
-       case ELEMFUNCRETURN: {
-               Function *func = ((ElementFunction *)This)->function;
-               switch (GETFUNCTIONTYPE(func)) {
-               case TABLEFUNC:
-                       return ((FunctionTable *)func)->table->range;
-               case OPERATORFUNC:
-                       return ((FunctionOperator *)func)->range;
-               default:
-                       ASSERT(0);
-               }
-       }
-       default:
-               ASSERT(0);
+ElementConst::ElementConst(uint64_t _value, Set *_set) :
+       ElementSet(ELEMCONST, _set),
+       value(_value) {
+}
+
+Element *ElementConst::clone(CSolver *solver, CloneMap *map) {
+       return solver->getElementConst(type, value);
+}
+
+Element *ElementSet::clone(CSolver *solver, CloneMap *map) {
+       Element *e = (Element *) map->get(this);
+       if (e != NULL)
+               return e;
+       e = solver->getElementVar(set->clone(solver, map));
+       map->put(e, e);
+       return e;
+}
+
+Element *ElementFunction::clone(CSolver *solver, CloneMap *map) {
+       Element *array[inputs.getSize()];
+       for (uint i = 0; i < inputs.getSize(); i++) {
+               array[i] = inputs.get(i)->clone(solver, map);
        }
-       ASSERT(0);
-       return NULL;
+       Element *e = solver->applyFunction(function->clone(solver, map), array, inputs.getSize(), overflowstatus->clone(solver, map));
+       return e;
+}
+
+void ElementFunction::updateParents() {
+       for (uint i = 0; i < inputs.getSize(); i++) inputs.get(i)->parents.push(this);
+}
+
+Set *ElementFunction::getRange() {
+       return function->getRange();
+}
+
+void ElementSet::serialize(Serializer *serializer) {
+       if (serializer->isSerialized(this))
+               return;
+       serializer->addObject(this);
+
+       set->serialize(serializer);
+
+       serializer->mywrite(&type, sizeof(ASTNodeType));
+       ElementSet *This = this;
+       serializer->mywrite(&This, sizeof(ElementSet *));
+       serializer->mywrite(&set, sizeof(Set *));
 }
 
-ElementFunction::~ElementFunction() {
-       deleteInlineArrayElement(&inputs);
-       deleteFunctionEncoding(&functionencoding);
+void ElementSet::print() {
+       model_print("{ElementSet:");
+       set->print();
+       model_print(" %p ", this);
+       getElementEncoding()->print();
+       model_print("}");
 }
 
-ElementConst::~ElementConst() {
-       deleteSet(set);
+void ElementConst::serialize(Serializer *serializer) {
+       if (serializer->isSerialized(this))
+               return;
+       serializer->addObject(this);
+
+       set->serialize(serializer);
+
+       serializer->mywrite(&type, sizeof(ASTNodeType));
+       ElementSet *This = this;
+       serializer->mywrite(&This, sizeof(ElementSet *));
+       VarType type = set->getType();
+       serializer->mywrite(&type, sizeof(VarType));
+       serializer->mywrite(&value, sizeof(uint64_t));
 }
 
-Element::~Element() {
-       deleteElementEncoding(&encoding);
-       deleteVectorArrayASTNode(GETELEMENTPARENTS(this));
+void ElementConst::print() {
+       model_print("{ElementConst: %" PRIu64 "}\n", value);
+}
+
+void ElementFunction::serialize(Serializer *serializer) {
+       if (serializer->isSerialized(this))
+               return;
+       serializer->addObject(this);
+
+       function->serialize(serializer);
+       uint size = inputs.getSize();
+       for (uint i = 0; i < size; i++) {
+               Element *input = inputs.get(i);
+               input->serialize(serializer);
+       }
+       serializeBooleanEdge(serializer, overflowstatus);
+
+       serializer->mywrite(&type, sizeof(ASTNodeType));
+       ElementFunction *This = this;
+       serializer->mywrite(&This, sizeof(ElementFunction *));
+       serializer->mywrite(&function, sizeof(Function *));
+       serializer->mywrite(&size, sizeof(uint));
+       for (uint i = 0; i < size; i++) {
+               Element *input = inputs.get(i);
+               serializer->mywrite(&input, sizeof(Element *));
+       }
+       Boolean *overflowstat = overflowstatus.getRaw();
+       serializer->mywrite(&overflowstat, sizeof(Boolean *));
+}
+
+void ElementFunction::print() {
+       model_print("{ElementFunction:\n");
+       function->print();
+       model_print("Elements:\n");
+       uint size = inputs.getSize();
+       for (uint i = 0; i < size; i++) {
+               Element *input = inputs.get(i);
+               input->print();
+       }
+       model_print("}\n");
 }