changes
[IRC.git] / Robust / src / Analysis / Locality / DiscoverConflicts.java
1 package Analysis.Locality;
2
3 import IR.Flat.*;
4 import java.util.Set;
5 import java.util.Arrays;
6 import java.util.HashSet;
7 import java.util.Iterator;
8 import java.util.Hashtable;
9 import IR.State;
10 import IR.Operation;
11 import IR.TypeDescriptor;
12 import IR.MethodDescriptor;
13 import IR.FieldDescriptor;
14
15 public class DiscoverConflicts {
16   Set<FieldDescriptor> fields;
17   Set<TypeDescriptor> arrays;
18   LocalityAnalysis locality;
19   State state;
20   Hashtable<LocalityBinding, Set<FlatNode>> treadmap;
21   Hashtable<LocalityBinding, Set<TempFlatPair>> transreadmap;
22   Hashtable<LocalityBinding, Set<FlatNode>> srcmap;
23     
24   public DiscoverConflicts(LocalityAnalysis locality, State state) {
25     this.locality=locality;
26     this.fields=new HashSet<FieldDescriptor>();
27     this.arrays=new HashSet<TypeDescriptor>();
28     this.state=state;
29     transreadmap=new Hashtable<LocalityBinding, Set<TempFlatPair>>();
30     treadmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
31     srcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
32   }
33   
34   public void doAnalysis() {
35     //Compute fields and arrays for all transactions
36     Set<LocalityBinding> localityset=locality.getLocalityBindings();
37     for(Iterator<LocalityBinding> lb=localityset.iterator();lb.hasNext();) {
38       computeModified(lb.next());
39     }
40     expandTypes();
41     //Compute set of nodes that need transread
42     for(Iterator<LocalityBinding> lb=localityset.iterator();lb.hasNext();) {
43       LocalityBinding l=lb.next();
44       analyzeLocality(l);
45       setNeedReadTrans(l);
46     }
47   }
48
49   public void setNeedReadTrans(LocalityBinding lb) {
50     HashSet<FlatNode> set=new HashSet<FlatNode>();
51     for(Iterator<TempFlatPair> it=transreadmap.get(lb).iterator();it.hasNext();) {
52       TempFlatPair tfp=it.next();
53       set.add(tfp.f);
54     }
55     treadmap.put(lb, set);
56   }
57
58   public void expandTypes() {
59     //FIX ARRAY...compute super/sub sets of each so we can do simple membership test
60   }
61
62   Hashtable<TempDescriptor, Set<TempFlatPair>> doMerge(FlatNode fn, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> tmptofnset) {
63     Hashtable<TempDescriptor, Set<TempFlatPair>> table=new Hashtable<TempDescriptor, Set<TempFlatPair>>();
64     for(int i=0;i<fn.numPrev();i++) {
65       FlatNode fprev=fn.getPrev(i);
66       Hashtable<TempDescriptor, Set<TempFlatPair>> tabset=tmptofnset.get(fprev);
67       if (tabset!=null) {
68         for(Iterator<TempDescriptor> tmpit=tabset.keySet().iterator();tmpit.hasNext();) {
69           TempDescriptor td=tmpit.next();
70           Set<TempFlatPair> fnset=tabset.get(td);
71           if (!table.containsKey(td))
72             table.put(td, new HashSet<TempFlatPair>());
73           table.get(td).addAll(fnset);
74         }
75       }
76     }
77     return table;
78   }
79   
80   public Set<FlatNode> getNeedSrcTrans(LocalityBinding lb) {
81     return srcmap.get(lb);
82   }
83
84   public boolean getNeedSrcTrans(LocalityBinding lb, FlatNode fn) {
85     return srcmap.get(lb).contains(fn);
86   }
87
88   public boolean getNeedTrans(LocalityBinding lb, FlatNode fn) {
89     return treadmap.get(lb).contains(fn);
90   }
91
92   private void analyzeLocality(LocalityBinding lb) {
93     MethodDescriptor md=lb.getMethod();
94     FlatMethod fm=state.getMethodFlat(md);
95     Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> fnmap=computeTempSets(lb);
96     HashSet<TempFlatPair> tfset=computeTranslationSet(lb, fm, fnmap);
97     HashSet<FlatNode> srctrans=new HashSet<FlatNode>();
98     transreadmap.put(lb, tfset);
99     srcmap.put(lb,srctrans);
100     
101     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
102       FlatNode fn=fnit.next();
103       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
104       if (atomictable.get(fn).intValue()>0) {
105         Hashtable<TempDescriptor, Set<TempFlatPair>> tmap=fnmap.get(fn);
106         switch(fn.kind()) {
107         case FKind.FlatSetFieldNode: { 
108           //definitely need to translate these
109           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
110           if (!fsfn.getField().getType().isPtr())
111             break;
112           Set<TempFlatPair> tfpset=tmap.get(fsfn.getSrc());
113           if (tfpset!=null) {
114             for(Iterator<TempFlatPair> tfpit=tfpset.iterator();tfpit.hasNext();) {
115               TempFlatPair tfp=tfpit.next();
116               if (tfset.contains(tfp)) {
117                 srctrans.add(fsfn);
118                 break;
119               }
120             }
121           }
122           break;
123         }
124         case FKind.FlatSetElementNode: { 
125           //definitely need to translate these
126           FlatSetElementNode fsen=(FlatSetElementNode)fn;
127           if (!fsen.getSrc().getType().isPtr())
128             break;
129           Set<TempFlatPair> tfpset=tmap.get(fsen.getSrc());
130           if (tfpset!=null) {
131             for(Iterator<TempFlatPair> tfpit=tfpset.iterator();tfpit.hasNext();) {
132               TempFlatPair tfp=tfpit.next();
133               if (tfset.contains(tfp)) {
134                 srctrans.add(fsen);
135                 break;
136               }
137             }
138           }
139           break;
140         }
141         default:
142         }
143       }
144     }
145   }
146
147   HashSet<TempFlatPair> computeTranslationSet(LocalityBinding lb, FlatMethod fm, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> fnmap) {
148     HashSet<TempFlatPair> tfset=new HashSet<TempFlatPair>();
149     
150     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
151       FlatNode fn=fnit.next();
152       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
153       if (atomictable.get(fn).intValue()>0) {
154         Hashtable<TempDescriptor, Set<TempFlatPair>> tmap=fnmap.get(fn);
155         switch(fn.kind()) {
156         case FKind.FlatElementNode: {
157           FlatElementNode fen=(FlatElementNode)fn;
158           if (arrays.contains(fen.getSrc().getType())) {
159             //this could cause conflict...figure out conflict set
160             Set<TempFlatPair> tfpset=tmap.get(fen.getSrc());
161             if (tfpset!=null)
162               tfset.addAll(tfpset);
163           }
164           break;
165         }
166         case FKind.FlatFieldNode: { 
167           FlatFieldNode ffn=(FlatFieldNode)fn;
168           if (fields.contains(ffn.getField())) {
169             //this could cause conflict...figure out conflict set
170             Set<TempFlatPair> tfpset=tmap.get(ffn.getSrc());
171             if (tfpset!=null)
172               tfset.addAll(tfpset);
173           }
174           break;
175         }
176         case FKind.FlatSetFieldNode: { 
177           //definitely need to translate these
178           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
179           Set<TempFlatPair> tfpset=tmap.get(fsfn.getDst());
180           if (tfpset!=null)
181             tfset.addAll(tfpset);
182           break;
183         }
184         case FKind.FlatSetElementNode: { 
185           //definitely need to translate these
186           FlatSetElementNode fsen=(FlatSetElementNode)fn;
187           Set<TempFlatPair> tfpset=tmap.get(fsen.getDst());
188           if (tfpset!=null)
189             tfset.addAll(tfpset);
190           break;
191         }
192         case FKind.FlatCall: //assume pessimistically that calls do bad things
193         case FKind.FlatReturnNode: {
194           TempDescriptor []readarray=fn.readsTemps();
195           for(int i=0;i<readarray.length;i++) {
196             TempDescriptor rtmp=readarray[i];
197             Set<TempFlatPair> tfpset=tmap.get(rtmp);
198             if (tfpset!=null)
199               tfset.addAll(tfpset);
200           }
201           break;
202         }
203         default:
204           //do nothing
205         }
206       }
207     }   
208     return tfset;
209   }
210
211   Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> computeTempSets(LocalityBinding lb) {
212     Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> tmptofnset=new Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>>();
213     HashSet<FlatNode> discovered=new HashSet<FlatNode>();
214     HashSet<FlatNode> tovisit=new HashSet<FlatNode>();
215     MethodDescriptor md=lb.getMethod();
216     FlatMethod fm=state.getMethodFlat(md);
217     Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
218     Hashtable<FlatNode, Set<TempDescriptor>> livetemps=locality.computeLiveTemps(fm);
219     tovisit.add(fm);
220     discovered.add(fm);
221     
222     while(!tovisit.isEmpty()) {
223       FlatNode fn=tovisit.iterator().next();
224       tovisit.remove(fn);
225       for(int i=0;i<fn.numNext();i++) {
226         FlatNode fnext=fn.getNext(i);
227         if (!discovered.contains(fnext)) {
228           discovered.add(fnext);
229           tovisit.add(fnext);
230         }
231       }
232       Hashtable<TempDescriptor, Set<TempFlatPair>> ttofn=null;
233       if (atomictable.get(fn).intValue()!=0) {
234         if ((fn.numPrev()>0)&&atomictable.get(fn.getPrev(0)).intValue()==0) {
235           //flatatomic enter node...  see what we really need to transread
236           Set<TempDescriptor> liveset=livetemps.get(fn);
237           ttofn=new Hashtable<TempDescriptor, Set<TempFlatPair>>();
238           for(Iterator<TempDescriptor> tmpit=liveset.iterator();tmpit.hasNext();) {
239             TempDescriptor tmp=tmpit.next();
240             if (tmp.getType().isPtr()) {
241               HashSet<TempFlatPair> fnset=new HashSet<TempFlatPair>();
242               fnset.add(new TempFlatPair(tmp, fn));
243               ttofn.put(tmp, fnset);
244             }
245           }
246         } else {
247           ttofn=doMerge(fn, tmptofnset);
248           switch(fn.kind()) {
249           case FKind.FlatFieldNode:
250           case FKind.FlatElementNode: {
251             TempDescriptor[] writes=fn.writesTemps();
252             for(int i=0;i<writes.length;i++) {
253               TempDescriptor wtmp=writes[i];
254               HashSet<TempFlatPair> set=new HashSet<TempFlatPair>();
255               set.add(new TempFlatPair(wtmp, fn));
256               ttofn.put(wtmp, set);
257             }
258             break;
259           }
260           case FKind.FlatCall:
261           case FKind.FlatMethod: {
262             TempDescriptor[] writes=fn.writesTemps();
263             for(int i=0;i<writes.length;i++) {
264               TempDescriptor wtmp=writes[i];
265               HashSet<TempFlatPair> set=new HashSet<TempFlatPair>();
266               set.add(new TempFlatPair(wtmp, fn));
267               ttofn.put(wtmp, set);
268             }
269             break;
270           }
271           case FKind.FlatOpNode: {
272             FlatOpNode fon=(FlatOpNode)fn;
273             if (fon.getOp().getOp()==Operation.ASSIGN&&fon.getDest().getType().isPtr()&&
274                 ttofn.containsKey(fon.getLeft())) {
275               ttofn.put(fon.getDest(), new HashSet<TempFlatPair>(ttofn.get(fon.getLeft())));
276               break;
277             }
278           }
279           default:
280             //Do kill computation
281             TempDescriptor[] writes=fn.writesTemps();
282             for(int i=0;i<writes.length;i++) {
283               TempDescriptor wtmp=writes[i];
284               ttofn.remove(writes[i]);
285             }
286           }
287         }
288         if (ttofn!=null) {
289           if (!tmptofnset.containsKey(fn)||
290               !tmptofnset.get(fn).equals(ttofn)) {
291             //enqueue nodes to process
292             tmptofnset.put(fn, ttofn);
293             for(int i=0;i<fn.numNext();i++) {
294               FlatNode fnext=fn.getNext(i);
295               tovisit.add(fnext);
296             }
297           }
298         }
299       }
300     }
301     return tmptofnset;
302   }
303   
304   public void computeModified(LocalityBinding lb) {
305     MethodDescriptor md=lb.getMethod();
306     FlatMethod fm=state.getMethodFlat(md);
307     Hashtable<FlatNode, Set<TempDescriptor>> oldtemps=computeOldTemps(lb);
308     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
309       FlatNode fn=fnit.next();
310       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
311       if (atomictable.get(fn).intValue()>0) {
312         switch (fn.kind()) {
313         case FKind.FlatSetFieldNode:
314           FlatSetFieldNode fsfn=(FlatSetFieldNode) fn;
315           fields.add(fsfn.getField());
316           break;
317         case FKind.FlatSetElementNode:
318           FlatSetElementNode fsen=(FlatSetElementNode) fn;
319           arrays.add(fsen.getDst().getType());
320           break;
321         default:
322         }
323       }
324     }
325   }
326     
327   Hashtable<FlatNode, Set<TempDescriptor>> computeOldTemps(LocalityBinding lb) {
328     Hashtable<FlatNode, Set<TempDescriptor>> fntooldtmp=new Hashtable<FlatNode, Set<TempDescriptor>>();
329     HashSet<FlatNode> discovered=new HashSet<FlatNode>();
330     HashSet<FlatNode> tovisit=new HashSet<FlatNode>();
331     MethodDescriptor md=lb.getMethod();
332     FlatMethod fm=state.getMethodFlat(md);
333     Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
334     Hashtable<FlatNode, Set<TempDescriptor>> livetemps=locality.computeLiveTemps(fm);
335     tovisit.add(fm);
336     discovered.add(fm);
337     
338     while(!tovisit.isEmpty()) {
339       FlatNode fn=tovisit.iterator().next();
340       tovisit.remove(fn);
341       for(int i=0;i<fn.numNext();i++) {
342         FlatNode fnext=fn.getNext(i);
343         if (!discovered.contains(fnext)) {
344           discovered.add(fnext);
345           tovisit.add(fnext);
346         }
347       }
348       HashSet<TempDescriptor> oldtemps=null;
349       if (atomictable.get(fn).intValue()!=0) {
350         if ((fn.numPrev()>0)&&atomictable.get(fn.getPrev(0)).intValue()==0) {
351           //Everything live is old
352           Set<TempDescriptor> lives=livetemps.get(fn);
353           oldtemps=new HashSet<TempDescriptor>();
354           
355           for(Iterator<TempDescriptor> it=lives.iterator();it.hasNext();) {
356             TempDescriptor tmp=it.next();
357             if (tmp.getType().isPtr()) {
358               oldtemps.add(tmp);
359             }
360           }
361         } else {
362           oldtemps=new HashSet<TempDescriptor>();
363           //Compute union of old temporaries
364           for(int i=0;i<fn.numPrev();i++) {
365             Set<TempDescriptor> pset=fntooldtmp.get(fn.getPrev(i));
366             if (pset!=null)
367               oldtemps.addAll(pset);
368           }
369           
370           switch (fn.kind()) {
371           case FKind.FlatNew:
372             oldtemps.removeAll(Arrays.asList(fn.readsTemps()));
373             break;
374           case FKind.FlatOpNode: {
375             FlatOpNode fon=(FlatOpNode)fn;
376             if (fon.getOp().getOp()==Operation.ASSIGN&&fon.getDest().getType().isPtr()) {
377               if (oldtemps.contains(fon.getLeft()))
378                 oldtemps.add(fon.getDest());
379               else
380                 oldtemps.remove(fon.getDest());
381               break;
382             }
383           }
384           default: {
385             TempDescriptor[] writes=fn.writesTemps();
386             for(int i=0;i<writes.length;i++) {
387               TempDescriptor wtemp=writes[i];
388               if (wtemp.getType().isPtr())
389                 oldtemps.add(wtemp);
390             }
391           }
392           }
393         }
394       }
395       
396       if (oldtemps!=null) {
397         if (!fntooldtmp.containsKey(fn)||!fntooldtmp.get(fn).equals(oldtemps)) {
398           fntooldtmp.put(fn, oldtemps);
399           //propagate changes
400           for(int i=0;i<fn.numNext();i++) {
401             FlatNode fnext=fn.getNext(i);
402             tovisit.add(fnext);
403           }
404         }
405       }
406     }
407     return fntooldtmp;
408   }
409 }
410
411 class TempFlatPair {
412     FlatNode f;
413     TempDescriptor t;
414     TempFlatPair(TempDescriptor t, FlatNode f) {
415         this.t=t;
416         this.f=f;
417     }
418
419     public int hashCode() {
420         return f.hashCode()^t.hashCode();
421     }
422     public boolean equals(Object o) {
423         TempFlatPair tf=(TempFlatPair)o;
424         return t.equals(tf.t)&&f.equals(tf.f);
425     }
426 }