bug fix
[IRC.git] / Robust / src / Analysis / Locality / DiscoverConflicts.java
1 package Analysis.Locality;
2
3 import IR.Flat.*;
4 import java.util.Set;
5 import java.util.Arrays;
6 import java.util.HashSet;
7 import java.util.Iterator;
8 import java.util.Hashtable;
9 import IR.State;
10 import IR.Operation;
11 import IR.TypeDescriptor;
12 import IR.MethodDescriptor;
13 import IR.FieldDescriptor;
14
15 public class DiscoverConflicts {
16   Set<FieldDescriptor> fields;
17   Set<TypeDescriptor> arrays;
18   LocalityAnalysis locality;
19   State state;
20   Hashtable<LocalityBinding, Set<FlatNode>> treadmap;
21   Hashtable<LocalityBinding, Set<TempFlatPair>> transreadmap;
22   Hashtable<LocalityBinding, Set<FlatNode>> srcmap;
23     
24   public DiscoverConflicts(LocalityAnalysis locality, State state) {
25     this.locality=locality;
26     this.fields=new HashSet<FieldDescriptor>();
27     this.arrays=new HashSet<TypeDescriptor>();
28     this.state=state;
29     transreadmap=new Hashtable<LocalityBinding, Set<TempFlatPair>>();
30     treadmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
31     srcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
32   }
33   
34   public void doAnalysis() {
35     //Compute fields and arrays for all transactions
36     Set<LocalityBinding> localityset=locality.getLocalityBindings();
37     for(Iterator<LocalityBinding> lb=localityset.iterator();lb.hasNext();) {
38       computeModified(lb.next());
39     }
40     expandTypes();
41     //Compute set of nodes that need transread
42     for(Iterator<LocalityBinding> lb=localityset.iterator();lb.hasNext();) {
43       LocalityBinding l=lb.next();
44       analyzeLocality(l);
45       setNeedReadTrans(l);
46     }
47   }
48
49   public void setNeedReadTrans(LocalityBinding lb) {
50     HashSet<FlatNode> set=new HashSet<FlatNode>();
51     for(Iterator<TempFlatPair> it=transreadmap.get(lb).iterator();it.hasNext();) {
52       TempFlatPair tfp=it.next();
53       set.add(tfp.f);
54     }
55     treadmap.put(lb, set);
56   }
57
58   public void expandTypes() {
59     //FIX ARRAY...compute super/sub sets of each so we can do simple membership test
60   }
61
62   Hashtable<TempDescriptor, Set<TempFlatPair>> doMerge(FlatNode fn, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> tmptofnset) {
63     Hashtable<TempDescriptor, Set<TempFlatPair>> table=new Hashtable<TempDescriptor, Set<TempFlatPair>>();
64     for(int i=0;i<fn.numPrev();i++) {
65       FlatNode fprev=fn.getPrev(i);
66       Hashtable<TempDescriptor, Set<TempFlatPair>> tabset=tmptofnset.get(fprev);
67       if (tabset!=null) {
68         for(Iterator<TempDescriptor> tmpit=tabset.keySet().iterator();tmpit.hasNext();) {
69           TempDescriptor td=tmpit.next();
70           Set<TempFlatPair> fnset=tabset.get(td);
71           if (!table.containsKey(td))
72             table.put(td, new HashSet<TempFlatPair>());
73           table.get(td).addAll(fnset);
74         }
75       }
76     }
77     return table;
78   }
79   
80   public Set<FlatNode> getNeedSrcTrans(LocalityBinding lb) {
81     return srcmap.get(lb);
82   }
83
84   public boolean getNeedSrcTrans(LocalityBinding lb, FlatNode fn) {
85     return srcmap.get(lb).contains(fn);
86   }
87
88   public boolean getNeedTrans(LocalityBinding lb, FlatNode fn) {
89     return treadmap.get(lb).contains(fn);
90   }
91
92   private void analyzeLocality(LocalityBinding lb) {
93     MethodDescriptor md=lb.getMethod();
94     FlatMethod fm=state.getMethodFlat(md);
95     Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> fnmap=computeTempSets(lb);
96     HashSet<TempFlatPair> tfset=computeTranslationSet(lb, fm, fnmap);
97     HashSet<FlatNode> srctrans=new HashSet<FlatNode>();
98     transreadmap.put(lb, tfset);
99     srcmap.put(lb,srctrans);
100     
101     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
102       FlatNode fn=fnit.next();
103       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
104       if (atomictable.get(fn).intValue()>0) {
105         Hashtable<TempDescriptor, Set<TempFlatPair>> tmap=fnmap.get(fn);
106         switch(fn.kind()) {
107         case FKind.FlatSetFieldNode: { 
108           //definitely need to translate these
109           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
110           if (!fsfn.getField().getType().isPtr())
111             break;
112           Set<TempFlatPair> tfpset=tmap.get(fsfn.getSrc());
113           if (tfpset!=null) {
114             for(Iterator<TempFlatPair> tfpit=tfpset.iterator();tfpit.hasNext();) {
115               TempFlatPair tfp=tfpit.next();
116               if (tfset.contains(tfp)||ok(tfp)) {
117                 srctrans.add(fsfn);
118                 break;
119               }
120             }
121           }
122           break;
123         }
124         case FKind.FlatSetElementNode: { 
125           //definitely need to translate these
126           FlatSetElementNode fsen=(FlatSetElementNode)fn;
127           if (!fsen.getSrc().getType().isPtr())
128             break;
129           Set<TempFlatPair> tfpset=tmap.get(fsen.getSrc());
130           if (tfpset!=null) {
131             for(Iterator<TempFlatPair> tfpit=tfpset.iterator();tfpit.hasNext();) {
132               TempFlatPair tfp=tfpit.next();
133               if (tfset.contains(tfp)||ok(tfp)) {
134                 srctrans.add(fsen);
135                 break;
136               }
137             }
138           }
139           break;
140         }
141         default:
142         }
143       }
144     }
145   }
146
147   public boolean ok(TempFlatPair tfp) {
148     FlatNode fn=tfp.f;
149     return fn.kind()==FKind.FlatCall;
150   }
151
152   HashSet<TempFlatPair> computeTranslationSet(LocalityBinding lb, FlatMethod fm, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> fnmap) {
153     HashSet<TempFlatPair> tfset=new HashSet<TempFlatPair>();
154     
155     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
156       FlatNode fn=fnit.next();
157       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
158       if (atomictable.get(fn).intValue()>0) {
159         Hashtable<TempDescriptor, Set<TempFlatPair>> tmap=fnmap.get(fn);
160         switch(fn.kind()) {
161         case FKind.FlatElementNode: {
162           FlatElementNode fen=(FlatElementNode)fn;
163           if (arrays.contains(fen.getSrc().getType())) {
164             //this could cause conflict...figure out conflict set
165             Set<TempFlatPair> tfpset=tmap.get(fen.getSrc());
166             if (tfpset!=null)
167               tfset.addAll(tfpset);
168           }
169           break;
170         }
171         case FKind.FlatFieldNode: { 
172           FlatFieldNode ffn=(FlatFieldNode)fn;
173           if (fields.contains(ffn.getField())) {
174             //this could cause conflict...figure out conflict set
175             Set<TempFlatPair> tfpset=tmap.get(ffn.getSrc());
176             if (tfpset!=null)
177               tfset.addAll(tfpset);
178           }
179           break;
180         }
181         case FKind.FlatSetFieldNode: { 
182           //definitely need to translate these
183           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
184           Set<TempFlatPair> tfpset=tmap.get(fsfn.getDst());
185           if (tfpset!=null)
186             tfset.addAll(tfpset);
187           break;
188         }
189         case FKind.FlatSetElementNode: { 
190           //definitely need to translate these
191           FlatSetElementNode fsen=(FlatSetElementNode)fn;
192           Set<TempFlatPair> tfpset=tmap.get(fsen.getDst());
193           if (tfpset!=null)
194             tfset.addAll(tfpset);
195           break;
196         }
197         case FKind.FlatCall: //assume pessimistically that calls do bad things
198         case FKind.FlatReturnNode: {
199           TempDescriptor []readarray=fn.readsTemps();
200           for(int i=0;i<readarray.length;i++) {
201             TempDescriptor rtmp=readarray[i];
202             Set<TempFlatPair> tfpset=tmap.get(rtmp);
203             if (tfpset!=null)
204               tfset.addAll(tfpset);
205           }
206           break;
207         }
208         default:
209           //do nothing
210         }
211       }
212     }   
213     return tfset;
214   }
215
216   Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> computeTempSets(LocalityBinding lb) {
217     Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> tmptofnset=new Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>>();
218     HashSet<FlatNode> discovered=new HashSet<FlatNode>();
219     HashSet<FlatNode> tovisit=new HashSet<FlatNode>();
220     MethodDescriptor md=lb.getMethod();
221     FlatMethod fm=state.getMethodFlat(md);
222     Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
223     Hashtable<FlatNode, Set<TempDescriptor>> livetemps=locality.computeLiveTemps(fm);
224     tovisit.add(fm);
225     discovered.add(fm);
226     
227     while(!tovisit.isEmpty()) {
228       FlatNode fn=tovisit.iterator().next();
229       tovisit.remove(fn);
230       for(int i=0;i<fn.numNext();i++) {
231         FlatNode fnext=fn.getNext(i);
232         if (!discovered.contains(fnext)) {
233           discovered.add(fnext);
234           tovisit.add(fnext);
235         }
236       }
237       Hashtable<TempDescriptor, Set<TempFlatPair>> ttofn=null;
238       if (atomictable.get(fn).intValue()!=0) {
239         if ((fn.numPrev()>0)&&atomictable.get(fn.getPrev(0)).intValue()==0) {
240           //flatatomic enter node...  see what we really need to transread
241           Set<TempDescriptor> liveset=livetemps.get(fn);
242           ttofn=new Hashtable<TempDescriptor, Set<TempFlatPair>>();
243           for(Iterator<TempDescriptor> tmpit=liveset.iterator();tmpit.hasNext();) {
244             TempDescriptor tmp=tmpit.next();
245             if (tmp.getType().isPtr()) {
246               HashSet<TempFlatPair> fnset=new HashSet<TempFlatPair>();
247               fnset.add(new TempFlatPair(tmp, fn));
248               ttofn.put(tmp, fnset);
249             }
250           }
251         } else {
252           ttofn=doMerge(fn, tmptofnset);
253           switch(fn.kind()) {
254           case FKind.FlatFieldNode:
255           case FKind.FlatElementNode: {
256             TempDescriptor[] writes=fn.writesTemps();
257             for(int i=0;i<writes.length;i++) {
258               TempDescriptor wtmp=writes[i];
259               HashSet<TempFlatPair> set=new HashSet<TempFlatPair>();
260               set.add(new TempFlatPair(wtmp, fn));
261               ttofn.put(wtmp, set);
262             }
263             break;
264           }
265           case FKind.FlatCall:
266           case FKind.FlatMethod: {
267             TempDescriptor[] writes=fn.writesTemps();
268             for(int i=0;i<writes.length;i++) {
269               TempDescriptor wtmp=writes[i];
270               HashSet<TempFlatPair> set=new HashSet<TempFlatPair>();
271               set.add(new TempFlatPair(wtmp, fn));
272               ttofn.put(wtmp, set);
273             }
274             break;
275           }
276           case FKind.FlatOpNode: {
277             FlatOpNode fon=(FlatOpNode)fn;
278             if (fon.getOp().getOp()==Operation.ASSIGN&&fon.getDest().getType().isPtr()&&
279                 ttofn.containsKey(fon.getLeft())) {
280               ttofn.put(fon.getDest(), new HashSet<TempFlatPair>(ttofn.get(fon.getLeft())));
281               break;
282             }
283           }
284           default:
285             //Do kill computation
286             TempDescriptor[] writes=fn.writesTemps();
287             for(int i=0;i<writes.length;i++) {
288               TempDescriptor wtmp=writes[i];
289               ttofn.remove(writes[i]);
290             }
291           }
292         }
293         if (ttofn!=null) {
294           if (!tmptofnset.containsKey(fn)||
295               !tmptofnset.get(fn).equals(ttofn)) {
296             //enqueue nodes to process
297             tmptofnset.put(fn, ttofn);
298             for(int i=0;i<fn.numNext();i++) {
299               FlatNode fnext=fn.getNext(i);
300               tovisit.add(fnext);
301             }
302           }
303         }
304       }
305     }
306     return tmptofnset;
307   }
308   
309   public void computeModified(LocalityBinding lb) {
310     MethodDescriptor md=lb.getMethod();
311     FlatMethod fm=state.getMethodFlat(md);
312     Hashtable<FlatNode, Set<TempDescriptor>> oldtemps=computeOldTemps(lb);
313     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
314       FlatNode fn=fnit.next();
315       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
316       if (atomictable.get(fn).intValue()>0) {
317         switch (fn.kind()) {
318         case FKind.FlatSetFieldNode:
319           FlatSetFieldNode fsfn=(FlatSetFieldNode) fn;
320           fields.add(fsfn.getField());
321           break;
322         case FKind.FlatSetElementNode:
323           FlatSetElementNode fsen=(FlatSetElementNode) fn;
324           arrays.add(fsen.getDst().getType());
325           break;
326         default:
327         }
328       }
329     }
330   }
331     
332   Hashtable<FlatNode, Set<TempDescriptor>> computeOldTemps(LocalityBinding lb) {
333     Hashtable<FlatNode, Set<TempDescriptor>> fntooldtmp=new Hashtable<FlatNode, Set<TempDescriptor>>();
334     HashSet<FlatNode> discovered=new HashSet<FlatNode>();
335     HashSet<FlatNode> tovisit=new HashSet<FlatNode>();
336     MethodDescriptor md=lb.getMethod();
337     FlatMethod fm=state.getMethodFlat(md);
338     Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
339     Hashtable<FlatNode, Set<TempDescriptor>> livetemps=locality.computeLiveTemps(fm);
340     tovisit.add(fm);
341     discovered.add(fm);
342     
343     while(!tovisit.isEmpty()) {
344       FlatNode fn=tovisit.iterator().next();
345       tovisit.remove(fn);
346       for(int i=0;i<fn.numNext();i++) {
347         FlatNode fnext=fn.getNext(i);
348         if (!discovered.contains(fnext)) {
349           discovered.add(fnext);
350           tovisit.add(fnext);
351         }
352       }
353       HashSet<TempDescriptor> oldtemps=null;
354       if (atomictable.get(fn).intValue()!=0) {
355         if ((fn.numPrev()>0)&&atomictable.get(fn.getPrev(0)).intValue()==0) {
356           //Everything live is old
357           Set<TempDescriptor> lives=livetemps.get(fn);
358           oldtemps=new HashSet<TempDescriptor>();
359           
360           for(Iterator<TempDescriptor> it=lives.iterator();it.hasNext();) {
361             TempDescriptor tmp=it.next();
362             if (tmp.getType().isPtr()) {
363               oldtemps.add(tmp);
364             }
365           }
366         } else {
367           oldtemps=new HashSet<TempDescriptor>();
368           //Compute union of old temporaries
369           for(int i=0;i<fn.numPrev();i++) {
370             Set<TempDescriptor> pset=fntooldtmp.get(fn.getPrev(i));
371             if (pset!=null)
372               oldtemps.addAll(pset);
373           }
374           
375           switch (fn.kind()) {
376           case FKind.FlatNew:
377             oldtemps.removeAll(Arrays.asList(fn.readsTemps()));
378             break;
379           case FKind.FlatOpNode: {
380             FlatOpNode fon=(FlatOpNode)fn;
381             if (fon.getOp().getOp()==Operation.ASSIGN&&fon.getDest().getType().isPtr()) {
382               if (oldtemps.contains(fon.getLeft()))
383                 oldtemps.add(fon.getDest());
384               else
385                 oldtemps.remove(fon.getDest());
386               break;
387             }
388           }
389           default: {
390             TempDescriptor[] writes=fn.writesTemps();
391             for(int i=0;i<writes.length;i++) {
392               TempDescriptor wtemp=writes[i];
393               if (wtemp.getType().isPtr())
394                 oldtemps.add(wtemp);
395             }
396           }
397           }
398         }
399       }
400       
401       if (oldtemps!=null) {
402         if (!fntooldtmp.containsKey(fn)||!fntooldtmp.get(fn).equals(oldtemps)) {
403           fntooldtmp.put(fn, oldtemps);
404           //propagate changes
405           for(int i=0;i<fn.numNext();i++) {
406             FlatNode fnext=fn.getNext(i);
407             tovisit.add(fnext);
408           }
409         }
410       }
411     }
412     return fntooldtmp;
413   }
414 }
415
416 class TempFlatPair {
417     FlatNode f;
418     TempDescriptor t;
419     TempFlatPair(TempDescriptor t, FlatNode f) {
420         this.t=t;
421         this.f=f;
422     }
423
424     public int hashCode() {
425         return f.hashCode()^t.hashCode();
426     }
427     public boolean equals(Object o) {
428         TempFlatPair tf=(TempFlatPair)o;
429         return t.equals(tf.t)&&f.equals(tf.f);
430     }
431 }