060c89ef4c4a2e46d239207d9dd5c7bfdc7b9495
[repair.git] / Repair / RepairCompiler / MCC / IR / UpdateNode.java
1 package MCC.IR;
2 import java.util.*;
3 import MCC.State;
4
5 class UpdateNode {
6     Vector updates;
7     Vector bindings;
8     Vector invariants;
9     Hashtable binding;
10     Rule rule;
11
12     public UpdateNode(Rule r) {
13         updates=new Vector();
14         bindings=new Vector();
15         invariants=new Vector();
16         binding=new Hashtable();
17         rule=r;
18     }
19
20     public Rule getRule() {
21         return rule;
22     }
23
24     public String toString() {
25         String st="";
26         st+="Bindings:\n";
27         for(int i=0;i<bindings.size();i++)
28             st+=bindings.get(i).toString()+"\n";
29         st+="---------------------\n";
30         st+="Updates:\n";
31         for(int i=0;i<updates.size();i++)
32             st+=updates.get(i).toString()+"\n";
33         st+="---------------------\n";
34         st+="Invariants:\n";
35         for(int i=0;i<invariants.size();i++)
36             st+=((Expr)invariants.get(i)).name()+"\n";
37         st+="---------------------\n";
38         return st;
39     }
40   
41     public void addBindings(Vector v) {
42         for (int i=0;i<v.size();i++) {
43             addBinding((Binding)v.get(i));
44         }
45     }
46
47     public boolean checkupdates() {
48         if (!checkconflicts()) /* Do we have conflicting concrete updates */
49             return false;
50         if (computeordering()) /* Ordering exists */
51             return true;
52         return false;
53     }
54
55     private boolean computeordering() {
56         /* Build dependency graph between updates */
57         HashSet graph=new HashSet();
58         Hashtable mapping=new Hashtable();
59         for(int i=0;i<updates.size();i++) {
60             Updates u=(Updates)updates.get(i);
61             GraphNode gn=new GraphNode(String.valueOf(i),u);
62             mapping.put(u, gn);
63             graph.add(gn);
64         }
65         for(int i=0;i<updates.size();i++) {
66             Updates u1=(Updates)updates.get(i);
67             if (u1.isAbstract())
68                 continue;
69             for(int j=0;j<updates.size();j++) {
70                 Updates u2=(Updates)updates.get(j);
71                 if (!u2.isExpr())
72                     continue;
73                 Descriptor d=u1.getDescriptor();
74                 if (u2.getRightExpr().usesDescriptor(d)) {
75                     /* Add edge for dependency */
76                     GraphNode gn1=(GraphNode) mapping.get(u1);
77                     GraphNode gn2=(GraphNode) mapping.get(u2);
78                     GraphNode.Edge e=new GraphNode.Edge("dependency",gn2);
79                     gn1.addEdge(e);
80                 }
81             }
82         }
83
84         if (!GraphNode.DFS.depthFirstSearch(graph))  /* DFS & check for acyclicity */
85             return false;
86
87         TreeSet topologicalsort = new TreeSet(new Comparator() {
88                 public boolean equals(Object obj) { return false; }
89                 public int compare(Object o1, Object o2) {
90                     GraphNode g1 = (GraphNode) o1;
91                     GraphNode g2 = (GraphNode) o2;
92                     return g2.getFinishingTime() - g1.getFinishingTime();
93                 }
94             });
95         topologicalsort.addAll(graph);
96         Vector sortedvector=new Vector();
97         for(Iterator sort=topologicalsort.iterator();sort.hasNext();) {
98             GraphNode gn=(GraphNode)sort.next();
99             sortedvector.add(gn.getOwner());
100         }
101         updates=sortedvector; //replace updates with the sorted array
102         return true;
103     }
104
105     private boolean checkconflicts() {
106         Set toremove=new HashSet();
107         for(int i=0;i<updates.size();i++) {
108             Updates u1=(Updates)updates.get(i);
109             if (!u1.isAbstract()) {
110                 Descriptor d=u1.getDescriptor();
111                 for(int j=0;j<invariants.size();j++) {
112                     Expr invariant=(Expr)invariants.get(j);
113                     if (invariant.usesDescriptor(d))
114                         return false;
115                 }
116             }
117             for(int j=0;j<updates.size();j++) {
118                 Updates u2=(Updates)updates.get(j);
119                 if (i==j)
120                     continue;
121                 if (u1.isAbstract()||u2.isAbstract())
122                     continue;  /* Abstract updates are already accounted for by graph */
123                 if (u1.getDescriptor()!=u2.getDescriptor())
124                     continue; /* No interference - different descriptors */
125                 
126                 if ((u1.getOpcode()==Opcode.GT||u1.getOpcode()==Opcode.GE)&&
127                     (u2.getOpcode()==Opcode.GT||u2.getOpcode()==Opcode.GE))
128                     continue; /* Can be satisfied simultaneously */
129
130                 if ((u1.getOpcode()==Opcode.LT||u1.getOpcode()==Opcode.LE)&&
131                     (u2.getOpcode()==Opcode.LT||u2.getOpcode()==Opcode.LE))
132                     continue;
133                 if ((u1.getOpcode()==u2.getOpcode())&&
134                     u1.isExpr()&&u2.isExpr()&&
135                     u1.getRightExpr().equals(null, u2.getRightExpr())) {
136                     /*We'll remove the second occurence*/
137                     if (i>j)
138                         toremove.add(u1);
139                     else
140                         toremove.add(u2);
141                     continue;
142                 }
143
144                 /* Handle = or != NULL */
145                 if ((((u1.getOpcode()==Opcode.EQ)&&(u2.getOpcode()==Opcode.NE))||
146                      ((u1.getOpcode()==Opcode.NE)&&(u2.getOpcode()==Opcode.EQ)))&&
147                     (((u1.isExpr()&&u1.getRightExpr().isNull())&&(!u2.isExpr()||u2.getRightExpr().isNonNull()))
148                      ||((!u1.isExpr()||u1.getRightExpr().isNonNull())&&(u2.isExpr()&&u2.getRightExpr().isNull())))) {
149                     if (u1.getOpcode()==Opcode.NE)
150                         toremove.add(u1);
151                     else
152                         toremove.add(u2);
153                     continue;
154                 }
155
156                 /* Handle = and != to different constants */
157                 if ((((u1.getOpcode()==Opcode.EQ)&&(u2.getOpcode()==Opcode.NE))||
158                     ((u1.getOpcode()==Opcode.NE)&&(u2.getOpcode()==Opcode.EQ)))&&
159                     (u1.isExpr()&&u1.getRightExpr() instanceof LiteralExpr)&&
160                     (u2.isExpr()&&u2.getRightExpr() instanceof LiteralExpr)&&
161                     !u1.getRightExpr().equals(u2.getRightExpr())) {
162                     if (u1.getOpcode()==Opcode.NE)
163                         toremove.add(u1);
164                     else
165                         toremove.add(u2);
166                     continue;
167                 }
168                 
169                 /* Compatible operations < & <= */
170                 if (((u1.getOpcode()==Opcode.LT)||(u1.getOpcode()==Opcode.LE))&&
171                     ((u2.getOpcode()==Opcode.LT)||(u2.getOpcode()==Opcode.LE)))
172                     continue;
173
174                 /* Compatible operations > & >= */
175                 if (((u1.getOpcode()==Opcode.GT)||(u1.getOpcode()==Opcode.GE))&&
176                     ((u2.getOpcode()==Opcode.GT)||(u2.getOpcode()==Opcode.GE)))
177                     continue;
178                 /* Ranges */
179
180                 //XXXXXX: TODO
181                 /* Equality & Comparisons */
182                 //XXXXXX: TODO
183
184                 return false; /* They interfere */
185             }
186         }
187         updates.removeAll(toremove);
188         return true;
189     }
190
191     public void addBinding(Binding b) {
192         bindings.add(b);
193         binding.put(b.getVar(),b);
194     }
195
196     public int numBindings() {
197         return bindings.size();
198     }
199
200     public Binding getBinding(int i) {
201         return (Binding)bindings.get(i);
202     }
203     
204     public Binding getBinding(VarDescriptor vd) {
205         if (binding.containsKey(vd))
206             return (Binding)binding.get(vd);
207         else
208             return null;
209     }
210
211     public void addInvariant(Expr e) {
212         invariants.add(e);
213     }
214
215     public int numInvariants() {
216         return invariants.size();
217     }
218
219     public Expr getInvariant(int i) {
220         return (Expr)invariants.get(i);
221     }
222
223     public void addUpdate(Updates u) {
224         updates.add(u);
225     }
226
227     public int numUpdates() {
228         return updates.size();
229     }
230     public Updates getUpdate(int i) {
231         return (Updates)updates.get(i);
232     }
233
234     private MultUpdateNode getMultUpdateNode(boolean negate, Descriptor d, RepairGenerator rg) {
235         Termination termination=rg.termination;
236         MultUpdateNode mun=null;
237         GraphNode gn;
238         if (negate)
239             gn=(GraphNode)termination.abstractremove.get(d);
240         else
241             gn=(GraphNode)termination.abstractadd.get(d);
242         TermNode tn=(TermNode)gn.getOwner();
243         for(Iterator edgeit=gn.edges();edgeit.hasNext();) {
244             GraphNode gn2=((GraphNode.Edge) edgeit.next()).getTarget();
245             if (!rg.removed.contains(gn2)) {
246                 TermNode tn2=(TermNode)gn2.getOwner();
247                 if (tn2.getType()==TermNode.UPDATE) {
248                     mun=tn2.getUpdate();
249                     break;
250                 }
251             }
252         }
253         if (mun==null)
254             throw new Error("Can't find update node!");
255         return mun;
256     }
257
258     public void generate_abstract(CodeWriter cr, Updates u, RepairGenerator rg) {
259         State state=rg.state;
260         Expr abstractexpr=u.getLeftExpr();
261         boolean negated=u.negate;
262         Descriptor d=null;
263         Expr left=null;
264         Expr right=null;
265         boolean istuple=false;
266         if (abstractexpr instanceof TupleOfExpr) {
267             TupleOfExpr toe=(TupleOfExpr) abstractexpr;
268             d=toe.relation;
269             left=toe.left;
270             right=toe.right;
271             istuple=true;
272         } else if (abstractexpr instanceof ElementOfExpr) {
273             ElementOfExpr eoe=(ElementOfExpr) abstractexpr;
274             d=eoe.set;
275             left=eoe.element;
276             istuple=false;
277         } else {
278             throw new Error("Unsupported Expr");
279         }
280         MultUpdateNode mun=getMultUpdateNode(negated,d,rg);
281         VarDescriptor leftvar=VarDescriptor.makeNew("leftvar");
282         VarDescriptor rightvar=VarDescriptor.makeNew("rightvar");
283         left.generate(cr, leftvar);
284         if (istuple)
285             right.generate(cr,rightvar);
286
287         if (negated) {
288             if (istuple) {
289                 RelationDescriptor rd=(RelationDescriptor)d;
290                 boolean usageimage=rd.testUsage(RelationDescriptor.IMAGE);
291                 boolean usageinvimage=rd.testUsage(RelationDescriptor.INVIMAGE);
292                 if (usageimage)
293                     cr.outputline(rg.stmodel+"->"+rd.getJustSafeSymbol() + "_hash->remove((int)" + leftvar.getSafeSymbol() + ", (int)" + rightvar.getSafeSymbol() + ");");
294                 if (usageinvimage)
295                     cr.outputline(rg.stmodel+"->"+rd.getJustSafeSymbol() + "_hashinv->remove((int)" + rightvar.getSafeSymbol() + ", (int)" + leftvar.getSafeSymbol() + ");");
296                 
297                 for(int i=0;i<state.vRules.size();i++) {
298                     Rule r=(Rule)state.vRules.get(i);
299                     if (r.getInclusion().getTargetDescriptors().contains(rd)) {
300                         for(int j=0;j<mun.numUpdates();j++) {
301                             UpdateNode un=mun.getUpdate(i);
302                             if (un.getRule()==r) {
303                                 /* Update for rule rule r */
304                                 String name=(String)rg.updatenames.get(un);
305                                 cr.outputline(rg.strepairtable+"->addrelation("+rd.getNum()+","+r.getNum()+","+leftvar.getSafeSymbol()+","+rightvar.getSafeSymbol()+",(int) &"+name+");");
306                             }
307                         }
308                     }
309                 }
310             } else {
311                 SetDescriptor sd=(SetDescriptor) d;
312                 cr.outputline(rg.stmodel+"->"+sd.getJustSafeSymbol() + "_hash->remove((int)" + leftvar.getSafeSymbol() + ", (int)" + leftvar.getSafeSymbol() + ");");
313
314                 for(int i=0;i<state.vRules.size();i++) {
315                     Rule r=(Rule)state.vRules.get(i);
316                     if (r.getInclusion().getTargetDescriptors().contains(sd)) {
317                         for(int j=0;j<mun.numUpdates();j++) {
318                             UpdateNode un=mun.getUpdate(j);
319                             if (un.getRule()==r) {
320                                 /* Update for rule rule r */
321                                 String name=(String)rg.updatenames.get(un);
322                                 cr.outputline(rg.strepairtable+"->addset("+sd.getNum()+","+r.getNum()+","+leftvar.getSafeSymbol()+",(int) &"+name+");");
323                             }
324                         }
325                     }
326                 }
327             }
328         } else {
329             /* Generate update */
330             if (istuple) {
331                 RelationDescriptor rd=(RelationDescriptor) d;
332                 boolean usageimage=rd.testUsage(RelationDescriptor.IMAGE);
333                 boolean usageinvimage=rd.testUsage(RelationDescriptor.INVIMAGE);
334                 if (usageimage)
335                     cr.outputline(rg.stmodel+"->"+rd.getJustSafeSymbol() + "_hash->add((int)" + leftvar.getSafeSymbol() + ", (int)" + rightvar.getSafeSymbol() + ");");
336                 if (usageinvimage)
337                     cr.outputline(rg.stmodel+"->"+rd.getJustSafeSymbol() + "_hashinv->add((int)" + rightvar.getSafeSymbol() + ", (int)" + leftvar.getSafeSymbol() + ");");
338
339                 UpdateNode un=mun.getUpdate(0);
340                 String name=(String)rg.updatenames.get(un);
341                 cr.outputline(name+"(this,"+rg.stmodel+","+rg.strepairtable+","+leftvar.getSafeSymbol()+","+rightvar.getSafeSymbol()+");");
342             } else {
343                 SetDescriptor sd=(SetDescriptor)d;
344                 cr.outputline(rg.stmodel+"->"+sd.getJustSafeSymbol() + "_hash->add((int)" + leftvar.getSafeSymbol() + ", (int)" + leftvar.getSafeSymbol() + ");");
345
346                 UpdateNode un=mun.getUpdate(0);
347                 /* Update for rule rule r */
348                 String name=(String)rg.updatenames.get(un);
349                 cr.outputline(name+"(this,"+rg.stmodel+","+rg.strepairtable+","+leftvar.getSafeSymbol()+");");
350             }
351         }
352         
353     }
354
355     public void generate(CodeWriter cr, boolean removal, boolean modify, String slot0, String slot1, String slot2, RepairGenerator rg) {
356         if (!removal&&!modify)
357             generate_bindings(cr, slot0,slot1);
358         for(int i=0;i<updates.size();i++) {
359             Updates u=(Updates)updates.get(i);
360             VarDescriptor right=VarDescriptor.makeNew("right");
361             if (u.getType()==Updates.ABSTRACT) {
362                 generate_abstract(cr, u, rg);
363                 return;
364             }
365
366             switch(u.getType()) {
367             case Updates.EXPR:
368                 u.getRightExpr().generate(cr,right);
369                 break;
370             case Updates.POSITION:
371             case Updates.ACCESSPATH:
372                 if (u.getRightPos()==0)
373                     cr.outputline("int "+right.getSafeSymbol()+"="+slot0+";");
374                 else if (u.getRightPos()==1)
375                     cr.outputline("int "+right.getSafeSymbol()+"="+slot1+";");
376                 else if (u.getRightPos()==2)
377                     cr.outputline("int "+right.getSafeSymbol()+"="+slot2+";");
378                 else throw new Error("Error w/ Position");
379                 break;
380             default:
381                 throw new Error();
382             }
383
384             if (u.getType()==Updates.ACCESSPATH) {
385                 VarDescriptor newright=VarDescriptor.makeNew("right");
386                 generate_accesspath(cr, right,newright,u);
387                 right=newright;
388             }
389             VarDescriptor left=VarDescriptor.makeNew("left");
390             u.getLeftExpr().generate(cr,left);
391             Opcode op=u.getOpcode();
392             cr.outputline("if (!("+left.getSafeSymbol()+op+right.getSafeSymbol()+"))");
393             cr.startblock();
394
395             if (op==Opcode.GT)
396                 cr.outputline(right.getSafeSymbol()+"++;");
397             else if (op==Opcode.GE)
398                 ;
399             else if (op==Opcode.EQ)
400                 ;
401             else if (op==Opcode.NE)
402                 cr.outputline(right.getSafeSymbol()+"++;");
403             else if (op==Opcode.LT)
404                 cr.outputline(right.getSafeSymbol()+"--;");
405             else if (op==Opcode.LE)
406                 ;
407             else throw new Error();
408             if (u.isGlobal()) {
409                 VarDescriptor vd=((VarExpr)u.getLeftExpr()).getVar();
410                 cr.outputline(vd.getSafeSymbol()+"="+right.getSafeSymbol()+";");
411             } else if (u.isField()) {
412                 /* NEED TO FIX */
413                 Expr subexpr=((DotExpr)u.getLeftExpr()).getExpr();
414                 Expr intindex=((DotExpr)u.getLeftExpr()).getIndex();
415                 VarDescriptor subvd=VarDescriptor.makeNew("subexpr");
416                 VarDescriptor indexvd=VarDescriptor.makeNew("index");
417                 subexpr.generate(cr,subvd);
418                 if (intindex!=null)
419                     intindex.generate(cr,indexvd);
420                 FieldDescriptor fd=(FieldDescriptor)u.getDescriptor();
421                 StructureTypeDescriptor std=(StructureTypeDescriptor)subexpr.getType();
422                 Expr offsetbits = std.getOffsetExpr(fd);
423                 if (fd instanceof ArrayDescriptor) {
424                     fd = ((ArrayDescriptor) fd).getField();
425                 }
426
427                 if (intindex != null) {
428                     Expr basesize = fd.getBaseSizeExpr();
429                     offsetbits = new OpExpr(Opcode.ADD, offsetbits, new OpExpr(Opcode.MULT, basesize, intindex));
430                 }
431                 Expr offsetbytes = new OpExpr(Opcode.SHR, offsetbits,new IntegerLiteralExpr(3));
432                 Expr byteaddress=new OpExpr(Opcode.ADD, offsetbytes, subexpr);
433                 VarDescriptor addr=VarDescriptor.makeNew("byteaddress");
434                 byteaddress.generate(cr,addr);
435
436                 if (fd.getType() instanceof ReservedTypeDescriptor && !fd.getPtr()) {
437                     ReservedTypeDescriptor rtd=(ReservedTypeDescriptor)fd.getType();
438                     if (rtd==ReservedTypeDescriptor.INT) {
439                         cr.outputline("*((int *) "+addr.getSafeSymbol()+")="+right.getSafeSymbol()+";");
440                     } else if (rtd==ReservedTypeDescriptor.SHORT) {
441                         cr.outputline("*((short *) "+addr.getSafeSymbol()+")="+right.getSafeSymbol()+";");
442                     } else if (rtd==ReservedTypeDescriptor.BYTE) {
443                         cr.outputline("*((char *) "+addr.getSafeSymbol()+")="+right.getSafeSymbol()+";");
444                     } else if (rtd==ReservedTypeDescriptor.BIT) {
445                         Expr tmp = new OpExpr(Opcode.SHL, offsetbytes, new IntegerLiteralExpr(3));
446                         Expr offset=new OpExpr(Opcode.SUB, offsetbits, tmp);
447                         Expr mask=new OpExpr(Opcode.SHL, new IntegerLiteralExpr(1), offset);
448                         VarDescriptor maskvar=VarDescriptor.makeNew("mask");
449                         mask.generate(cr,maskvar);
450                         cr.outputline("*((char *) "+addr.getSafeSymbol()+")|="+maskvar.getSafeSymbol()+";");
451                         cr.outputline("if (!"+right.getSafeSymbol()+")");
452                         cr.outputline("*((char *) "+addr.getSafeSymbol()+")^="+maskvar.getSafeSymbol()+";");
453                     } else throw new Error();
454                 } else {
455                     /* Pointer */
456                     cr.outputline("*((int *) "+addr.getSafeSymbol()+")="+right.getSafeSymbol()+";");
457                 }
458             }
459             cr.endblock();
460         }
461     }
462
463
464     private void generate_accesspath(CodeWriter cr, VarDescriptor right, VarDescriptor newright, Updates u) {
465         Vector dotvector=new Vector();
466         Expr ptr=u.getRightExpr();
467         VarExpr rightve=new VarExpr(right);
468         right.td=ReservedTypeDescriptor.INT;
469
470         while(true) {
471             /* Does something other than a dereference? */
472             dotvector.add(ptr);
473             if (ptr instanceof DotExpr)
474                 ptr=((DotExpr)ptr).left;
475             else if (ptr instanceof CastExpr)
476                 ptr=((CastExpr)ptr).getExpr();
477             if (ptr instanceof VarExpr) {
478                 /* Finished constructing vector */
479                 break;
480             }
481         }
482         ArrayAnalysis.AccessPath ap=u.getAccessPath();
483         VarDescriptor init=VarDescriptor.makeNew("init");
484         if (ap.isSet()) {
485             cr.outputline("int "+init.getSafeSymbol()+"="+ap.getSet().getSafeSymbol()+"_hash->firstkey();");
486             init.td=ap.getSet().getType();
487         } else {
488             init=ap.getVar();
489         }
490         Expr newexpr=new VarExpr(init);
491         int apindex=0;
492         for(int i=dotvector.size()-1;i>=0;i--) {
493             Expr e=(Expr)dotvector.get(i);
494             if (e instanceof CastExpr) {
495                 newexpr.td=e.td;
496                 newexpr=new CastExpr(((CastExpr)e).getType(),newexpr);
497             } else if (e instanceof DotExpr) {
498                 DotExpr de=(DotExpr)e;
499                 if (de.getField() instanceof ArrayDescriptor) {
500                     DotExpr de2=new DotExpr(newexpr,de.field,new IntegerLiteralExpr(0));
501                     de2.fd=de.fd;
502                     de2.fieldtype=de.fieldtype;
503                     de2.td=de.td;
504                     OpExpr offset=new OpExpr(Opcode.SUB,rightve,de2);
505                     OpExpr index=new OpExpr(Opcode.DIV,offset,de.fieldtype.getSizeExpr());
506                     if (u.getRightPos()==apindex) {
507                         index.generate(cr,newright);
508                         return;
509                     } else {
510                         DotExpr de3=new DotExpr(newexpr,de.field,index);
511                         de3.fd=de.fd;
512                         de3.td=de.td;
513                         de3.fieldtype=de.fieldtype;
514                         newexpr=de3;
515                     }
516                 } else {
517                     DotExpr de2=new DotExpr(newexpr,de.field,null);
518                     de2.fd=de.fd;
519                     de2.fieldtype=de.fieldtype;
520                     de2.td=de.td;
521                     newexpr=de2;
522                 }
523                 apindex++;
524             } else throw new Error();
525         }
526         throw new Error();
527     }
528
529     private void generate_bindings(CodeWriter cr, String slot0, String slot1) {
530         for(int i=0;i<bindings.size();i++) {
531             Binding b=(Binding)bindings.get(i);
532
533             if (b.getType()==Binding.SEARCH) {
534                 VarDescriptor vd=b.getVar();
535                 cr.outputline(vd.getType().getGenerateType().getSafeSymbol()+" "+vd.getSafeSymbol()+"="+b.getSet().getSafeSymbol()+"_hash->firstkey();");
536             } else if (b.getType()==Binding.CREATE) {
537                 throw new Error("Creation not supported");
538                 //              source.generateSourceAlloc(cr,vd,b.getSet());
539             } else {
540                 VarDescriptor vd=b.getVar();
541                 switch(b.getPosition()) {
542                 case 0:
543                     cr.outputline(vd.getType().getGenerateType().getSafeSymbol()+" "+vd.getSafeSymbol()+"="+slot0+";");
544                     break;
545                 case 1:
546                     cr.outputline(vd.getType().getGenerateType().getSafeSymbol()+" "+vd.getSafeSymbol()+"="+slot1+";");
547                     break;
548                 default:
549                     throw new Error("Slot >1 doesn't exist.");
550                 }
551             }
552         }
553     }
554 }