bug fix for comparisons
[IRC.git] / Robust / src / Analysis / Locality / DiscoverConflicts.java
1 package Analysis.Locality;
2
3 import IR.Flat.*;
4 import java.util.Set;
5 import java.util.Arrays;
6 import java.util.HashSet;
7 import java.util.Iterator;
8 import java.util.Hashtable;
9 import IR.State;
10 import IR.Operation;
11 import IR.TypeDescriptor;
12 import IR.MethodDescriptor;
13 import IR.FieldDescriptor;
14
15 public class DiscoverConflicts {
16   Set<FieldDescriptor> fields;
17   Set<TypeDescriptor> arrays;
18   LocalityAnalysis locality;
19   State state;
20   Hashtable<LocalityBinding, Set<FlatNode>> treadmap;
21   Hashtable<LocalityBinding, Set<TempFlatPair>> transreadmap;
22   Hashtable<LocalityBinding, Set<FlatNode>> srcmap;
23   Hashtable<LocalityBinding, Set<FlatNode>> leftsrcmap;
24   Hashtable<LocalityBinding, Set<FlatNode>> rightsrcmap;
25   TypeAnalysis typeanalysis;
26   
27   public DiscoverConflicts(LocalityAnalysis locality, State state, TypeAnalysis typeanalysis) {
28     this.locality=locality;
29     this.fields=new HashSet<FieldDescriptor>();
30     this.arrays=new HashSet<TypeDescriptor>();
31     this.state=state;
32     this.typeanalysis=typeanalysis;
33     transreadmap=new Hashtable<LocalityBinding, Set<TempFlatPair>>();
34     treadmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
35     srcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
36     leftsrcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
37     rightsrcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
38   }
39   
40   public void doAnalysis() {
41     //Compute fields and arrays for all transactions
42     Set<LocalityBinding> localityset=locality.getLocalityBindings();
43     for(Iterator<LocalityBinding> lb=localityset.iterator();lb.hasNext();) {
44       computeModified(lb.next());
45     }
46     expandTypes();
47     //Compute set of nodes that need transread
48     for(Iterator<LocalityBinding> lb=localityset.iterator();lb.hasNext();) {
49       LocalityBinding l=lb.next();
50       analyzeLocality(l);
51       setNeedReadTrans(l);
52     }
53   }
54
55   public void setNeedReadTrans(LocalityBinding lb) {
56     HashSet<FlatNode> set=new HashSet<FlatNode>();
57     for(Iterator<TempFlatPair> it=transreadmap.get(lb).iterator();it.hasNext();) {
58       TempFlatPair tfp=it.next();
59       set.add(tfp.f);
60     }
61     treadmap.put(lb, set);
62   }
63
64   //We have a set of things we write to, figure out what things this
65   //could effect.
66   public void expandTypes() {
67     Set<TypeDescriptor> expandedarrays=new HashSet<TypeDescriptor>();
68     for(Iterator<TypeDescriptor> it=arrays.iterator();it.hasNext();) {
69       TypeDescriptor td=it.next();
70       expandedarrays.addAll(typeanalysis.expand(td));
71     }
72     arrays=expandedarrays;
73   }
74
75   Hashtable<TempDescriptor, Set<TempFlatPair>> doMerge(FlatNode fn, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> tmptofnset) {
76     Hashtable<TempDescriptor, Set<TempFlatPair>> table=new Hashtable<TempDescriptor, Set<TempFlatPair>>();
77     for(int i=0;i<fn.numPrev();i++) {
78       FlatNode fprev=fn.getPrev(i);
79       Hashtable<TempDescriptor, Set<TempFlatPair>> tabset=tmptofnset.get(fprev);
80       if (tabset!=null) {
81         for(Iterator<TempDescriptor> tmpit=tabset.keySet().iterator();tmpit.hasNext();) {
82           TempDescriptor td=tmpit.next();
83           Set<TempFlatPair> fnset=tabset.get(td);
84           if (!table.containsKey(td))
85             table.put(td, new HashSet<TempFlatPair>());
86           table.get(td).addAll(fnset);
87         }
88       }
89     }
90     return table;
91   }
92   
93   public Set<FlatNode> getNeedSrcTrans(LocalityBinding lb) {
94     return srcmap.get(lb);
95   }
96
97   public boolean getNeedSrcTrans(LocalityBinding lb, FlatNode fn) {
98     return srcmap.get(lb).contains(fn);
99   }
100
101   public boolean getNeedLeftSrcTrans(LocalityBinding lb, FlatNode fn) {
102     return leftsrcmap.get(lb).contains(fn);
103   }
104
105   public boolean getNeedRightSrcTrans(LocalityBinding lb, FlatNode fn) {
106     return rightsrcmap.get(lb).contains(fn);
107   }
108
109   public boolean getNeedTrans(LocalityBinding lb, FlatNode fn) {
110     return treadmap.get(lb).contains(fn);
111   }
112
113   private void analyzeLocality(LocalityBinding lb) {
114     MethodDescriptor md=lb.getMethod();
115     FlatMethod fm=state.getMethodFlat(md);
116     Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> fnmap=computeTempSets(lb);
117     HashSet<TempFlatPair> tfset=computeTranslationSet(lb, fm, fnmap);
118     HashSet<FlatNode> srctrans=new HashSet<FlatNode>();
119     HashSet<FlatNode> leftsrctrans=new HashSet<FlatNode>();
120     HashSet<FlatNode> rightsrctrans=new HashSet<FlatNode>();
121     transreadmap.put(lb, tfset);
122     srcmap.put(lb,srctrans);
123     leftsrcmap.put(lb,leftsrctrans);
124     rightsrcmap.put(lb,rightsrctrans);
125
126     //compute writes that need translation on source
127
128     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
129       FlatNode fn=fnit.next();
130       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
131       if (atomictable.get(fn).intValue()>0) {
132         Hashtable<TempDescriptor, Set<TempFlatPair>> tmap=fnmap.get(fn);
133         switch(fn.kind()) {
134         case FKind.FlatOpNode: { 
135           FlatOpNode fon=(FlatOpNode)fn;
136           if (fon.getOp().getOp()==Operation.EQUAL||
137               fon.getOp().getOp()==Operation.NOTEQUAL) {
138             if (!fon.getLeft().getType().isPtr())
139               break;
140             Set<TempFlatPair> lefttfpset=tmap.get(fon.getLeft());
141             Set<TempFlatPair> righttfpset=tmap.get(fon.getRight());
142             //handle left operand
143             if (lefttfpset!=null) {
144               for(Iterator<TempFlatPair> tfpit=lefttfpset.iterator();tfpit.hasNext();) {
145                 TempFlatPair tfp=tfpit.next();
146                 if (tfset.contains(tfp)||ok(tfp)) {
147                   leftsrctrans.add(fon);
148                   break;
149                 }
150               }
151             }
152             //handle right operand
153             if (righttfpset!=null) {
154               for(Iterator<TempFlatPair> tfpit=righttfpset.iterator();tfpit.hasNext();) {
155                 TempFlatPair tfp=tfpit.next();
156                 if (tfset.contains(tfp)||ok(tfp)) {
157                   rightsrctrans.add(fon);
158                   break;
159                 }
160               }
161             }
162           }
163           break;
164         }
165         case FKind.FlatSetFieldNode: { 
166           //definitely need to translate these
167           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
168           if (!fsfn.getField().getType().isPtr())
169             break;
170           Set<TempFlatPair> tfpset=tmap.get(fsfn.getSrc());
171           if (tfpset!=null) {
172             for(Iterator<TempFlatPair> tfpit=tfpset.iterator();tfpit.hasNext();) {
173               TempFlatPair tfp=tfpit.next();
174               if (tfset.contains(tfp)||ok(tfp)) {
175                 srctrans.add(fsfn);
176                 break;
177               }
178             }
179           }
180           break;
181         }
182         case FKind.FlatSetElementNode: { 
183           //definitely need to translate these
184           FlatSetElementNode fsen=(FlatSetElementNode)fn;
185           if (!fsen.getSrc().getType().isPtr())
186             break;
187           Set<TempFlatPair> tfpset=tmap.get(fsen.getSrc());
188           if (tfpset!=null) {
189             for(Iterator<TempFlatPair> tfpit=tfpset.iterator();tfpit.hasNext();) {
190               TempFlatPair tfp=tfpit.next();
191               if (tfset.contains(tfp)||ok(tfp)) {
192                 srctrans.add(fsen);
193                 break;
194               }
195             }
196           }
197           break;
198         }
199         default:
200         }
201       }
202     }
203   }
204
205   public boolean ok(TempFlatPair tfp) {
206     FlatNode fn=tfp.f;
207     return fn.kind()==FKind.FlatCall||fn.kind()==FKind.FlatMethod;
208   }
209
210   //compute set of nodes that need transread on their output
211   HashSet<TempFlatPair> computeTranslationSet(LocalityBinding lb, FlatMethod fm, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> fnmap) {
212     HashSet<TempFlatPair> tfset=new HashSet<TempFlatPair>();
213     
214     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
215       FlatNode fn=fnit.next();
216       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
217       if (atomictable.get(fn).intValue()>0) {
218         Hashtable<TempDescriptor, Set<TempFlatPair>> tmap=fnmap.get(fn);
219         switch(fn.kind()) {
220         case FKind.FlatElementNode: {
221           FlatElementNode fen=(FlatElementNode)fn;
222           if (arrays.contains(fen.getSrc().getType())) {
223             //this could cause conflict...figure out conflict set
224             Set<TempFlatPair> tfpset=tmap.get(fen.getSrc());
225             if (tfpset!=null)
226               tfset.addAll(tfpset);
227           }
228           break;
229         }
230         case FKind.FlatFieldNode: { 
231           FlatFieldNode ffn=(FlatFieldNode)fn;
232           if (fields.contains(ffn.getField())) {
233             //this could cause conflict...figure out conflict set
234             Set<TempFlatPair> tfpset=tmap.get(ffn.getSrc());
235             if (tfpset!=null)
236               tfset.addAll(tfpset);
237           }
238           break;
239         }
240         case FKind.FlatSetFieldNode: { 
241           //definitely need to translate these
242           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
243           Set<TempFlatPair> tfpset=tmap.get(fsfn.getDst());
244           if (tfpset!=null)
245             tfset.addAll(tfpset);
246           break;
247         }
248         case FKind.FlatSetElementNode: { 
249           //definitely need to translate these
250           FlatSetElementNode fsen=(FlatSetElementNode)fn;
251           Set<TempFlatPair> tfpset=tmap.get(fsen.getDst());
252           if (tfpset!=null)
253             tfset.addAll(tfpset);
254           break;
255         }
256         case FKind.FlatCall: //assume pessimistically that calls do bad things
257         case FKind.FlatReturnNode: {
258           TempDescriptor []readarray=fn.readsTemps();
259           for(int i=0;i<readarray.length;i++) {
260             TempDescriptor rtmp=readarray[i];
261             Set<TempFlatPair> tfpset=tmap.get(rtmp);
262             if (tfpset!=null)
263               tfset.addAll(tfpset);
264           }
265           break;
266         }
267         default:
268           //do nothing
269         }
270       }
271     }   
272     return tfset;
273   }
274
275   //Map for each node from temps to the flatnode and (the temp) that
276   //first did a field/element read for this value
277
278   Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> computeTempSets(LocalityBinding lb) {
279     Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> tmptofnset=new Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>>();
280     HashSet<FlatNode> discovered=new HashSet<FlatNode>();
281     HashSet<FlatNode> tovisit=new HashSet<FlatNode>();
282     MethodDescriptor md=lb.getMethod();
283     FlatMethod fm=state.getMethodFlat(md);
284     Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
285     Hashtable<FlatNode, Set<TempDescriptor>> livetemps=locality.computeLiveTemps(fm);
286     tovisit.add(fm);
287     discovered.add(fm);
288     
289     while(!tovisit.isEmpty()) {
290       FlatNode fn=tovisit.iterator().next();
291       tovisit.remove(fn);
292       for(int i=0;i<fn.numNext();i++) {
293         FlatNode fnext=fn.getNext(i);
294         if (!discovered.contains(fnext)) {
295           discovered.add(fnext);
296           tovisit.add(fnext);
297         }
298       }
299       Hashtable<TempDescriptor, Set<TempFlatPair>> ttofn=null;
300       if (atomictable.get(fn).intValue()!=0) {
301         if ((fn.numPrev()>0)&&atomictable.get(fn.getPrev(0)).intValue()==0) {
302           //flatatomic enter node...  see what we really need to transread
303           //Set<TempDescriptor> liveset=livetemps.get(fn);
304           ttofn=new Hashtable<TempDescriptor, Set<TempFlatPair>>();
305           /*
306           for(Iterator<TempDescriptor> tmpit=liveset.iterator();tmpit.hasNext();) {
307             TempDescriptor tmp=tmpit.next();
308             if (tmp.getType().isPtr()) {
309               HashSet<TempFlatPair> fnset=new HashSet<TempFlatPair>();
310               fnset.add(new TempFlatPair(tmp, fn));
311               ttofn.put(tmp, fnset);
312             }
313             }*/
314         } else {
315           ttofn=doMerge(fn, tmptofnset);
316           switch(fn.kind()) {
317           case FKind.FlatGlobalConvNode: {
318             FlatGlobalConvNode fgcn=(FlatGlobalConvNode)fn;
319             if (lb==fgcn.getLocality()&&
320                 fgcn.getMakePtr()) {
321               TempDescriptor[] writes=fn.writesTemps();
322               for(int i=0;i<writes.length;i++) {
323                 TempDescriptor wtmp=writes[i];
324                 HashSet<TempFlatPair> set=new HashSet<TempFlatPair>();
325                 set.add(new TempFlatPair(wtmp, fn));
326                 ttofn.put(wtmp, set);
327               }
328             }
329             break;
330           }
331           case FKind.FlatFieldNode:
332           case FKind.FlatElementNode: {
333             TempDescriptor[] writes=fn.writesTemps();
334             for(int i=0;i<writes.length;i++) {
335               TempDescriptor wtmp=writes[i];
336               HashSet<TempFlatPair> set=new HashSet<TempFlatPair>();
337               set.add(new TempFlatPair(wtmp, fn));
338               ttofn.put(wtmp, set);
339             }
340             break;
341           }
342           case FKind.FlatCall:
343           case FKind.FlatMethod: {
344             TempDescriptor[] writes=fn.writesTemps();
345             for(int i=0;i<writes.length;i++) {
346               TempDescriptor wtmp=writes[i];
347               HashSet<TempFlatPair> set=new HashSet<TempFlatPair>();
348               set.add(new TempFlatPair(wtmp, fn));
349               ttofn.put(wtmp, set);
350             }
351             break;
352           }
353           case FKind.FlatOpNode: {
354             FlatOpNode fon=(FlatOpNode)fn;
355             if (fon.getOp().getOp()==Operation.ASSIGN&&fon.getDest().getType().isPtr()&&
356                 ttofn.containsKey(fon.getLeft())) {
357               ttofn.put(fon.getDest(), new HashSet<TempFlatPair>(ttofn.get(fon.getLeft())));
358               break;
359             }
360           }
361           default:
362             //Do kill computation
363             TempDescriptor[] writes=fn.writesTemps();
364             for(int i=0;i<writes.length;i++) {
365               TempDescriptor wtmp=writes[i];
366               ttofn.remove(writes[i]);
367             }
368           }
369         }
370         if (ttofn!=null) {
371           if (!tmptofnset.containsKey(fn)||
372               !tmptofnset.get(fn).equals(ttofn)) {
373             //enqueue nodes to process
374             tmptofnset.put(fn, ttofn);
375             for(int i=0;i<fn.numNext();i++) {
376               FlatNode fnext=fn.getNext(i);
377               tovisit.add(fnext);
378             }
379           }
380         }
381       }
382     }
383     return tmptofnset;
384   }
385   
386   /* See what fields and arrays transactions might modify. */
387
388   public void computeModified(LocalityBinding lb) {
389     MethodDescriptor md=lb.getMethod();
390     FlatMethod fm=state.getMethodFlat(md);
391     Hashtable<FlatNode, Set<TempDescriptor>> oldtemps=computeOldTemps(lb);
392     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
393       FlatNode fn=fnit.next();
394       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
395       if (atomictable.get(fn).intValue()>0) {
396         switch (fn.kind()) {
397         case FKind.FlatSetFieldNode:
398           FlatSetFieldNode fsfn=(FlatSetFieldNode) fn;
399           fields.add(fsfn.getField());
400           break;
401         case FKind.FlatSetElementNode:
402           FlatSetElementNode fsen=(FlatSetElementNode) fn;
403           arrays.add(fsen.getDst().getType());
404           break;
405         default:
406         }
407       }
408     }
409   }
410     
411   Hashtable<FlatNode, Set<TempDescriptor>> computeOldTemps(LocalityBinding lb) {
412     Hashtable<FlatNode, Set<TempDescriptor>> fntooldtmp=new Hashtable<FlatNode, Set<TempDescriptor>>();
413     HashSet<FlatNode> discovered=new HashSet<FlatNode>();
414     HashSet<FlatNode> tovisit=new HashSet<FlatNode>();
415     MethodDescriptor md=lb.getMethod();
416     FlatMethod fm=state.getMethodFlat(md);
417     Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
418     Hashtable<FlatNode, Set<TempDescriptor>> livetemps=locality.computeLiveTemps(fm);
419     tovisit.add(fm);
420     discovered.add(fm);
421     
422     while(!tovisit.isEmpty()) {
423       FlatNode fn=tovisit.iterator().next();
424       tovisit.remove(fn);
425       for(int i=0;i<fn.numNext();i++) {
426         FlatNode fnext=fn.getNext(i);
427         if (!discovered.contains(fnext)) {
428           discovered.add(fnext);
429           tovisit.add(fnext);
430         }
431       }
432       HashSet<TempDescriptor> oldtemps=null;
433       if (atomictable.get(fn).intValue()!=0) {
434         if ((fn.numPrev()>0)&&atomictable.get(fn.getPrev(0)).intValue()==0) {
435           //Everything live is old
436           Set<TempDescriptor> lives=livetemps.get(fn);
437           oldtemps=new HashSet<TempDescriptor>();
438           
439           for(Iterator<TempDescriptor> it=lives.iterator();it.hasNext();) {
440             TempDescriptor tmp=it.next();
441             if (tmp.getType().isPtr()) {
442               oldtemps.add(tmp);
443             }
444           }
445         } else {
446           oldtemps=new HashSet<TempDescriptor>();
447           //Compute union of old temporaries
448           for(int i=0;i<fn.numPrev();i++) {
449             Set<TempDescriptor> pset=fntooldtmp.get(fn.getPrev(i));
450             if (pset!=null)
451               oldtemps.addAll(pset);
452           }
453           
454           switch (fn.kind()) {
455           case FKind.FlatNew:
456             oldtemps.removeAll(Arrays.asList(fn.readsTemps()));
457             break;
458           case FKind.FlatOpNode: {
459             FlatOpNode fon=(FlatOpNode)fn;
460             if (fon.getOp().getOp()==Operation.ASSIGN&&fon.getDest().getType().isPtr()) {
461               if (oldtemps.contains(fon.getLeft()))
462                 oldtemps.add(fon.getDest());
463               else
464                 oldtemps.remove(fon.getDest());
465               break;
466             }
467           }
468           default: {
469             TempDescriptor[] writes=fn.writesTemps();
470             for(int i=0;i<writes.length;i++) {
471               TempDescriptor wtemp=writes[i];
472               if (wtemp.getType().isPtr())
473                 oldtemps.add(wtemp);
474             }
475           }
476           }
477         }
478       }
479       
480       if (oldtemps!=null) {
481         if (!fntooldtmp.containsKey(fn)||!fntooldtmp.get(fn).equals(oldtemps)) {
482           fntooldtmp.put(fn, oldtemps);
483           //propagate changes
484           for(int i=0;i<fn.numNext();i++) {
485             FlatNode fnext=fn.getNext(i);
486             tovisit.add(fnext);
487           }
488         }
489       }
490     }
491     return fntooldtmp;
492   }
493 }
494
495 class TempFlatPair {
496     FlatNode f;
497     TempDescriptor t;
498     TempFlatPair(TempDescriptor t, FlatNode f) {
499         this.t=t;
500         this.f=f;
501     }
502
503     public int hashCode() {
504         return f.hashCode()^t.hashCode();
505     }
506     public boolean equals(Object o) {
507         TempFlatPair tf=(TempFlatPair)o;
508         return t.equals(tf.t)&&f.equals(tf.f);
509     }
510 }