More fixes
[repair.git] / Repair / RepairCompiler / MCC / IR / RepairGenerator.java
index 6074bbd504f07b78536f80471bdbacc67393d079..0f21555cc39f8cf04682f9f7c6ee3b2d2c75a19e 100755 (executable)
@@ -753,15 +753,18 @@ public class RepairGenerator {
     }
 
     private void generate_checks() {
-
         /* do constraint checks */
-       //        Vector constraints = state.vConstraints;
-
-
-       //        for (int i = 0; i < constraints.size(); i++) {
-       //            Constraint constraint = (Constraint) constraints.elementAt(i); 
-       for (Iterator i = termination.constraintdependence.computeOrdering().iterator(); i.hasNext();) {
-           Constraint constraint = (Constraint) ((GraphNode)i.next()).getOwner();
+       Iterator i;
+       if (Compiler.REPAIR)
+           i=termination.constraintdependence.computeOrdering().iterator();
+       else
+           i=state.vConstraints.iterator();
+       for (; i.hasNext();) {
+           Constraint constraint;
+           if (Compiler.REPAIR)
+               constraint= (Constraint) ((GraphNode)i.next()).getOwner();
+           else
+               constraint=(Constraint)i.next();
            
             {
                final SymbolTable st = constraint.getSymbolTable();
@@ -788,7 +791,7 @@ public class RepairGenerator {
                 cr.outputline("if (maybe)");
                 cr.startblock();
                 cr.outputline("printf(\"maybe fail " +  escape(constraint.toString()) + ". \\n\");");
-                cr.outputline("exit(1);");
+               //cr.outputline("exit(1);");
                 cr.endblock();
 
                 cr.outputline("else if (!" + constraintboolean.getSafeSymbol() + ")");
@@ -836,14 +839,14 @@ public class RepairGenerator {
                            p.generate(cr,predvalue);
                            if (k==0)
                                cr.outputline("int "+costvar.getSafeSymbol()+"=0;");
-
+                           
                            if (negate)
                                cr.outputline("if (maybe||"+predvalue.getSafeSymbol()+")");
                            else
                                cr.outputline("if (maybe||!"+predvalue.getSafeSymbol()+")");
                            cr.outputline(costvar.getSafeSymbol()+"+="+cost.getCost(dpred)+";");
                        }
-
+                       
                        if(!first) {
                            cr.outputline("if ("+costvar.getSafeSymbol()+"<"+mincost.getSafeSymbol()+")");
                            cr.startblock();
@@ -859,7 +862,6 @@ public class RepairGenerator {
                for(int j=0;j<dnfconst.size();j++) {
                    GraphNode gn=(GraphNode)dnfconst.get(j);
                    Conjunction conj=((TermNode)gn.getOwner()).getConjunction();
-
                    if (removed.contains(gn))
                        continue;
                    cr.outputline("case "+j+":");
@@ -893,7 +895,7 @@ public class RepairGenerator {
                    cr.outputline("break;");
                }
                cr.outputline("}");
-
+               
                cr.outputline("if ("+oldmodel.getSafeSymbol()+")");
                cr.outputline("delete "+oldmodel.getSafeSymbol()+";");
                cr.outputline(oldmodel.getSafeSymbol()+"="+newmodel.getSafeSymbol()+";");
@@ -923,13 +925,21 @@ public class RepairGenerator {
        cr.endblock();
        cr.outputline("rebuild:");
        cr.outputline(";");     
-       
     }
     
     private MultUpdateNode getmultupdatenode(Conjunction conj, DNFPredicate dpred, int repairtype) {
-       MultUpdateNode mun=null;
+       Set nodes=getmultupdatenodeset(conj,dpred,repairtype);
+       Iterator it=nodes.iterator();
+       if (it.hasNext())
+           return (MultUpdateNode)it.next();
+       else
+           return null;
+    }
+
+    private Set getmultupdatenodeset(Conjunction conj, DNFPredicate dpred, int repairtype) {
+       HashSet hs=new HashSet();
        GraphNode gn=(GraphNode) termination.conjtonodemap.get(conj);
-       for(Iterator edgeit=gn.edges();(mun==null)&&edgeit.hasNext();) {
+       for(Iterator edgeit=gn.edges();edgeit.hasNext();) {
            GraphNode gn2=((GraphNode.Edge) edgeit.next()).getTarget();
            TermNode tn2=(TermNode)gn2.getOwner();
            if (tn2.getType()==TermNode.ABSTRACT) {
@@ -941,23 +951,54 @@ public class RepairGenerator {
                        if (!removed.contains(gn3)) {
                            TermNode tn3=(TermNode)gn3.getOwner();
                            if (tn3.getType()==TermNode.UPDATE) {
-                               mun=tn3.getUpdate();
-                               break;
+                               hs.add(tn3.getUpdate());
                            }
                        }
                    }
                }
            }
        }
-       return mun;
+       return hs;
+    }
+
+    private AbstractRepair getabstractrepair(Conjunction conj, DNFPredicate dpred, int repairtype) {
+       HashSet hs=new HashSet();
+       MultUpdateNode mun=null;
+       GraphNode gn=(GraphNode) termination.conjtonodemap.get(conj);
+       for(Iterator edgeit=gn.edges();(mun==null)&&edgeit.hasNext();) {
+           GraphNode gn2=((GraphNode.Edge) edgeit.next()).getTarget();
+           TermNode tn2=(TermNode)gn2.getOwner();
+           if (tn2.getType()==TermNode.ABSTRACT) {
+               AbstractRepair ar=tn2.getAbstract();
+               if (((repairtype==-1)||(ar.getType()==repairtype))&&
+                   ar.getPredicate()==dpred) {
+                   return ar;
+               }
+           }
+       }
+       return null;
     }
 
+
     /** Generates abstract (and concrete) repair for a comparison */
 
     private void generatecomparisonrepair(Conjunction conj, DNFPredicate dpred, CodeWriter cr){
-       MultUpdateNode munmodify=getmultupdatenode(conj,dpred,AbstractRepair.MODIFYRELATION);
-       MultUpdateNode munremove=getmultupdatenode(conj,dpred,AbstractRepair.REMOVEFROMRELATION);
-       MultUpdateNode munadd=getmultupdatenode(conj,dpred,AbstractRepair.ADDTORELATION);
+       Set updates=getmultupdatenodeset(conj,dpred,AbstractRepair.MODIFYRELATION);
+       AbstractRepair ar=getabstractrepair(conj,dpred,AbstractRepair.MODIFYRELATION);
+       MultUpdateNode munmodify=null;
+       MultUpdateNode munadd=null;
+       MultUpdateNode munremove=null;
+       for(Iterator it=updates.iterator();it.hasNext();) {
+           MultUpdateNode mun=(MultUpdateNode)it.next();
+           if (mun.getType()==MultUpdateNode.ADD) {
+               munadd=mun;
+           } else if (mun.getType()==MultUpdateNode.REMOVE) { 
+               munremove=mun;
+           } else if (mun.getType()==MultUpdateNode.MODIFY) {
+               munmodify=mun;
+           }
+       }
+       
        ExprPredicate ep=(ExprPredicate)dpred.getPredicate();
        RelationDescriptor rd=(RelationDescriptor)ep.getDescriptor();
        boolean usageimage=rd.testUsage(RelationDescriptor.IMAGE);
@@ -969,7 +1010,16 @@ public class RepairGenerator {
        VarDescriptor leftside=VarDescriptor.makeNew("leftside");
        VarDescriptor rightside=VarDescriptor.makeNew("rightside");
        VarDescriptor newvalue=VarDescriptor.makeNew("newvalue");
-       if (!inverted) {
+       boolean needremoveloop=ar.mayNeedFunctionEnforcement(state)&&ar.needsRemoves(state);
+
+       if (needremoveloop&&((munadd==null)||(munremove==null))) {
+           System.out.println("Warning:  need to have individual remove operations for"+dpred.name());
+           needremoveloop=false;
+       }
+       if (needremoveloop)
+           cr.outputline("while (1) {");
+
+       if (!inverted) {
            ((RelationExpr)expr.getLeftExpr()).getExpr().generate(cr,leftside);
            expr.getRightExpr().generate(cr,newvalue);
            cr.outputline(rd.getRange().getType().getGenerateType().getSafeSymbol()+" "+rightside.getSafeSymbol()+";");
@@ -980,22 +1030,8 @@ public class RepairGenerator {
            cr.outputline(rd.getDomain().getType().getGenerateType().getSafeSymbol()+" "+leftside.getSafeSymbol()+";");
            cr.outputline(rd.getSafeSymbol()+"_hashinv->get("+rightside.getSafeSymbol()+","+leftside.getSafeSymbol()+");");
        }
-       if (negated)
-           if (opcode==Opcode.GT) {
-               opcode=Opcode.LE;
-           } else if (opcode==Opcode.GE) {
-               opcode=Opcode.LT;
-           } else if (opcode==Opcode.LT) {
-               opcode=Opcode.GE;
-           } else if (opcode==Opcode.LE) {
-               opcode=Opcode.GT;
-           } else if (opcode==Opcode.EQ) {
-               opcode=Opcode.NE;
-           } else if (opcode==Opcode.NE) {
-               opcode=Opcode.EQ;
-           } else {
-               throw new Error("Unrecognized Opcode");
-           }
+
+       opcode=Opcode.translateOpcode(negated,opcode);
 
        if (opcode==Opcode.GT) {
            cr.outputline(newvalue.getSafeSymbol()+"++;");
@@ -1015,6 +1051,35 @@ public class RepairGenerator {
        /* Do abstract repairs */
        if (usageimage) {
            cr.outputline(rd.getSafeSymbol()+"_hash->remove("+leftside.getSafeSymbol()+","+rightside.getSafeSymbol()+");");
+       }
+       if (usageinvimage) {
+           cr.outputline(rd.getSafeSymbol()+"_hashinv->remove("+rightside.getSafeSymbol()+","+leftside.getSafeSymbol()+");");
+       }
+
+       if (needremoveloop) {
+           if (!inverted) {
+               cr.outputline("if ("+rd.getSafeSymbol()+"_hash->contains("+leftside.getSafeSymbol()+")) {");
+           } else {
+               cr.outputline("if ("+rd.getSafeSymbol()+"_hashinv->contains("+rightside.getSafeSymbol()+")) {");
+           }
+           for(int i=0;i<state.vRules.size();i++) {
+               Rule r=(Rule)state.vRules.get(i);
+               if (r.getInclusion().getTargetDescriptors().contains(rd)) {
+                   for(int j=0;j<munremove.numUpdates();j++) {
+                       UpdateNode un=munremove.getUpdate(i);
+                       if (un.getRule()==r) {
+                               /* Update for rule r */
+                           String name=(String)updatenames.get(un);
+                           cr.outputline(repairtable.getSafeSymbol()+"->addrelation("+rd.getNum()+","+r.getNum()+","+leftside.getSafeSymbol()+","+rightside.getSafeSymbol()+",(int) &"+name+");");
+                       }
+                   }
+               }
+           }
+           cr.outputline("continue;");
+           cr.outputline("}");
+       }
+
+       if (usageimage) {
            if (!inverted) {
                cr.outputline(rd.getSafeSymbol()+"_hash->add("+leftside.getSafeSymbol()+","+newvalue.getSafeSymbol()+");");
            } else {
@@ -1022,7 +1087,6 @@ public class RepairGenerator {
            }
        }
        if (usageinvimage) {
-           cr.outputline(rd.getSafeSymbol()+"_hashinv->remove("+rightside.getSafeSymbol()+","+leftside.getSafeSymbol()+");");
            if (!inverted) {
                cr.outputline(rd.getSafeSymbol()+"_hashinv->add("+newvalue.getSafeSymbol()+","+leftside.getSafeSymbol()+");");
            } else {
@@ -1030,7 +1094,7 @@ public class RepairGenerator {
            }
        }
        /* Do concrete repairs */
-       if (munmodify!=null) {
+       if (munmodify!=null&&(!ar.mayNeedFunctionEnforcement(state))||(munadd==null)||(ar.needsRemoves(state)&&(munremove==null))) {
            for(int i=0;i<state.vRules.size();i++) {
                Rule r=(Rule)state.vRules.get(i);
                if (r.getInclusion().getTargetDescriptors().contains(rd)) {
@@ -1044,22 +1108,22 @@ public class RepairGenerator {
                    }
                }
            }
-
        } else {
            /* Start with scheduling removal */
-           for(int i=0;i<state.vRules.size();i++) {
-               Rule r=(Rule)state.vRules.get(i);
-               if (r.getInclusion().getTargetDescriptors().contains(rd)) {
-                   for(int j=0;j<munremove.numUpdates();j++) {
-                       UpdateNode un=munremove.getUpdate(i);
-                       if (un.getRule()==r) {
-                           /* Update for rule r */
-                           String name=(String)updatenames.get(un);
-                           cr.outputline(repairtable.getSafeSymbol()+"->addrelation("+rd.getNum()+","+r.getNum()+","+leftside.getSafeSymbol()+","+rightside.getSafeSymbol()+",(int) &"+name+");");
+           if (ar.needsRemoves(state))
+               for(int i=0;i<state.vRules.size();i++) {
+                   Rule r=(Rule)state.vRules.get(i);
+                   if (r.getInclusion().getTargetDescriptors().contains(rd)) {
+                       for(int j=0;j<munremove.numUpdates();j++) {
+                           UpdateNode un=munremove.getUpdate(i);
+                           if (un.getRule()==r) {
+                               /* Update for rule r */
+                               String name=(String)updatenames.get(un);
+                               cr.outputline(repairtable.getSafeSymbol()+"->addrelation("+rd.getNum()+","+r.getNum()+","+leftside.getSafeSymbol()+","+rightside.getSafeSymbol()+",(int) &"+name+");");
+                           }
                        }
                    }
                }
-           }
            /* Now do addition */
            UpdateNode un=munadd.getUpdate(0);
            String name=(String)updatenames.get(un);
@@ -1069,6 +1133,10 @@ public class RepairGenerator {
                cr.outputline(name+"(this,"+newmodel.getSafeSymbol()+","+repairtable.getSafeSymbol()+","+newvalue.getSafeSymbol()+","+rightside.getSafeSymbol()+");");
            }
        }
+       if (needremoveloop) {
+           cr.outputline("break;");
+           cr.outputline("}");
+       }
     }
 
     public void generatesizerepair(Conjunction conj, DNFPredicate dpred, CodeWriter cr) {