Fixed some errors in the Repair Generator code.
[repair.git] / Repair / RepairCompiler / MCC / IR / RepairGenerator.java
index c096372..79b89a6 100755 (executable)
@@ -753,13 +753,7 @@ public class RepairGenerator {
     }
 
     private void generate_checks() {
-
         /* do constraint checks */
-       //        Vector constraints = state.vConstraints;
-
-
-       //        for (int i = 0; i < constraints.size(); i++) {
-       //            Constraint constraint = (Constraint) constraints.elementAt(i); 
        Iterator i;
        if (Compiler.REPAIR)
            i=termination.constraintdependence.computeOrdering().iterator();
@@ -845,14 +839,14 @@ public class RepairGenerator {
                            p.generate(cr,predvalue);
                            if (k==0)
                                cr.outputline("int "+costvar.getSafeSymbol()+"=0;");
-
+                           
                            if (negate)
                                cr.outputline("if (maybe||"+predvalue.getSafeSymbol()+")");
                            else
                                cr.outputline("if (maybe||!"+predvalue.getSafeSymbol()+")");
                            cr.outputline(costvar.getSafeSymbol()+"+="+cost.getCost(dpred)+";");
                        }
-
+                       
                        if(!first) {
                            cr.outputline("if ("+costvar.getSafeSymbol()+"<"+mincost.getSafeSymbol()+")");
                            cr.startblock();
@@ -868,7 +862,6 @@ public class RepairGenerator {
                for(int j=0;j<dnfconst.size();j++) {
                    GraphNode gn=(GraphNode)dnfconst.get(j);
                    Conjunction conj=((TermNode)gn.getOwner()).getConjunction();
-
                    if (removed.contains(gn))
                        continue;
                    cr.outputline("case "+j+":");
@@ -902,7 +895,7 @@ public class RepairGenerator {
                    cr.outputline("break;");
                }
                cr.outputline("}");
-
+               
                cr.outputline("if ("+oldmodel.getSafeSymbol()+")");
                cr.outputline("delete "+oldmodel.getSafeSymbol()+";");
                cr.outputline(oldmodel.getSafeSymbol()+"="+newmodel.getSafeSymbol()+";");
@@ -932,10 +925,19 @@ public class RepairGenerator {
        cr.endblock();
        cr.outputline("rebuild:");
        cr.outputline(";");     
-       
     }
     
     private MultUpdateNode getmultupdatenode(Conjunction conj, DNFPredicate dpred, int repairtype) {
+       Set nodes=getmultupdatenodeset(conj,dpred,repairtype);
+       Iterator it=nodes.iterator();
+       if (it.hasNext())
+           return (MultUpdateNode)it.next();
+       else
+           return null;
+    }
+
+    private Set getmultupdatenodeset(Conjunction conj, DNFPredicate dpred, int repairtype) {
+       HashSet hs=new HashSet();
        MultUpdateNode mun=null;
        GraphNode gn=(GraphNode) termination.conjtonodemap.get(conj);
        for(Iterator edgeit=gn.edges();(mun==null)&&edgeit.hasNext();) {
@@ -950,23 +952,54 @@ public class RepairGenerator {
                        if (!removed.contains(gn3)) {
                            TermNode tn3=(TermNode)gn3.getOwner();
                            if (tn3.getType()==TermNode.UPDATE) {
-                               mun=tn3.getUpdate();
-                               break;
+                               hs.add(mun);
                            }
                        }
                    }
                }
            }
        }
-       return mun;
+       return hs;
+    }
+
+    private AbstractRepair getabstractrepair(Conjunction conj, DNFPredicate dpred, int repairtype) {
+       HashSet hs=new HashSet();
+       MultUpdateNode mun=null;
+       GraphNode gn=(GraphNode) termination.conjtonodemap.get(conj);
+       for(Iterator edgeit=gn.edges();(mun==null)&&edgeit.hasNext();) {
+           GraphNode gn2=((GraphNode.Edge) edgeit.next()).getTarget();
+           TermNode tn2=(TermNode)gn2.getOwner();
+           if (tn2.getType()==TermNode.ABSTRACT) {
+               AbstractRepair ar=tn2.getAbstract();
+               if (((repairtype==-1)||(ar.getType()==repairtype))&&
+                   ar.getPredicate()==dpred) {
+                   return ar;
+               }
+           }
+       }
+       return null;
     }
 
+
     /** Generates abstract (and concrete) repair for a comparison */
 
     private void generatecomparisonrepair(Conjunction conj, DNFPredicate dpred, CodeWriter cr){
-       MultUpdateNode munmodify=getmultupdatenode(conj,dpred,AbstractRepair.MODIFYRELATION);
-       MultUpdateNode munremove=getmultupdatenode(conj,dpred,AbstractRepair.REMOVEFROMRELATION);
-       MultUpdateNode munadd=getmultupdatenode(conj,dpred,AbstractRepair.ADDTORELATION);
+       Set updates=getmultupdatenodeset(conj,dpred,AbstractRepair.MODIFYRELATION);
+       AbstractRepair ar=getabstractrepair(conj,dpred,AbstractRepair.MODIFYRELATION);
+       MultUpdateNode munmodify=null;
+       MultUpdateNode munadd=null;
+       MultUpdateNode munremove=null;
+       for(Iterator it=updates.iterator();it.hasNext();) {
+           MultUpdateNode mun=(MultUpdateNode)it.next();
+           if (mun.getType()==MultUpdateNode.ADD) {
+               munadd=mun;
+           } else if (mun.getType()==MultUpdateNode.REMOVE) { 
+               munremove=mun;
+           } else if (mun.getType()==MultUpdateNode.MODIFY) {
+               munmodify=mun;
+           }
+       }
+       
        ExprPredicate ep=(ExprPredicate)dpred.getPredicate();
        RelationDescriptor rd=(RelationDescriptor)ep.getDescriptor();
        boolean usageimage=rd.testUsage(RelationDescriptor.IMAGE);
@@ -1053,22 +1086,23 @@ public class RepairGenerator {
                    }
                }
            }
-
        } else {
            /* Start with scheduling removal */
-           for(int i=0;i<state.vRules.size();i++) {
-               Rule r=(Rule)state.vRules.get(i);
-               if (r.getInclusion().getTargetDescriptors().contains(rd)) {
-                   for(int j=0;j<munremove.numUpdates();j++) {
-                       UpdateNode un=munremove.getUpdate(i);
-                       if (un.getRule()==r) {
-                           /* Update for rule r */
-                           String name=(String)updatenames.get(un);
-                           cr.outputline(repairtable.getSafeSymbol()+"->addrelation("+rd.getNum()+","+r.getNum()+","+leftside.getSafeSymbol()+","+rightside.getSafeSymbol()+",(int) &"+name+");");
+           
+           if (ar.needsRemoves(state))
+               for(int i=0;i<state.vRules.size();i++) {
+                   Rule r=(Rule)state.vRules.get(i);
+                   if (r.getInclusion().getTargetDescriptors().contains(rd)) {
+                       for(int j=0;j<munremove.numUpdates();j++) {
+                           UpdateNode un=munremove.getUpdate(i);
+                           if (un.getRule()==r) {
+                               /* Update for rule r */
+                               String name=(String)updatenames.get(un);
+                               cr.outputline(repairtable.getSafeSymbol()+"->addrelation("+rd.getNum()+","+r.getNum()+","+leftside.getSafeSymbol()+","+rightside.getSafeSymbol()+",(int) &"+name+");");
+                           }
                        }
                    }
                }
-           }
            /* Now do addition */
            UpdateNode un=munadd.getUpdate(0);
            String name=(String)updatenames.get(un);