integrate typeanalysis
[IRC.git] / Robust / src / Analysis / Locality / DiscoverConflicts.java
1 package Analysis.Locality;
2
3 import IR.Flat.*;
4 import java.util.Set;
5 import java.util.Arrays;
6 import java.util.HashSet;
7 import java.util.Iterator;
8 import java.util.Hashtable;
9 import IR.State;
10 import IR.Operation;
11 import IR.TypeDescriptor;
12 import IR.MethodDescriptor;
13 import IR.FieldDescriptor;
14
15 public class DiscoverConflicts {
16   Set<FieldDescriptor> fields;
17   Set<TypeDescriptor> arrays;
18   LocalityAnalysis locality;
19   State state;
20   Hashtable<LocalityBinding, Set<FlatNode>> treadmap;
21   Hashtable<LocalityBinding, Set<TempFlatPair>> transreadmap;
22   Hashtable<LocalityBinding, Set<FlatNode>> srcmap;
23   TypeAnalysis typeanalysis;
24   
25   public DiscoverConflicts(LocalityAnalysis locality, State state, TypeAnalysis typeanalysis) {
26     this.locality=locality;
27     this.fields=new HashSet<FieldDescriptor>();
28     this.arrays=new HashSet<TypeDescriptor>();
29     this.state=state;
30     this.typeanalysis=typeanalysis;
31     transreadmap=new Hashtable<LocalityBinding, Set<TempFlatPair>>();
32     treadmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
33     srcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
34   }
35   
36   public void doAnalysis() {
37     //Compute fields and arrays for all transactions
38     Set<LocalityBinding> localityset=locality.getLocalityBindings();
39     for(Iterator<LocalityBinding> lb=localityset.iterator();lb.hasNext();) {
40       computeModified(lb.next());
41     }
42     expandTypes();
43     //Compute set of nodes that need transread
44     for(Iterator<LocalityBinding> lb=localityset.iterator();lb.hasNext();) {
45       LocalityBinding l=lb.next();
46       analyzeLocality(l);
47       setNeedReadTrans(l);
48     }
49   }
50
51   public void setNeedReadTrans(LocalityBinding lb) {
52     HashSet<FlatNode> set=new HashSet<FlatNode>();
53     for(Iterator<TempFlatPair> it=transreadmap.get(lb).iterator();it.hasNext();) {
54       TempFlatPair tfp=it.next();
55       set.add(tfp.f);
56     }
57     treadmap.put(lb, set);
58   }
59
60   public void expandTypes() {
61     Set<TypeDescriptor> expandedarrays=new HashSet<TypeDescriptor>();
62     for(Iterator<TypeDescriptor> it=arrays.iterator();it.hasNext();) {
63       TypeDescriptor td=it.next();
64       expandedarrays.addAll(typeanalysis.expand(td));
65     }
66     arrays=expandedarrays;
67   }
68
69   Hashtable<TempDescriptor, Set<TempFlatPair>> doMerge(FlatNode fn, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> tmptofnset) {
70     Hashtable<TempDescriptor, Set<TempFlatPair>> table=new Hashtable<TempDescriptor, Set<TempFlatPair>>();
71     for(int i=0;i<fn.numPrev();i++) {
72       FlatNode fprev=fn.getPrev(i);
73       Hashtable<TempDescriptor, Set<TempFlatPair>> tabset=tmptofnset.get(fprev);
74       if (tabset!=null) {
75         for(Iterator<TempDescriptor> tmpit=tabset.keySet().iterator();tmpit.hasNext();) {
76           TempDescriptor td=tmpit.next();
77           Set<TempFlatPair> fnset=tabset.get(td);
78           if (!table.containsKey(td))
79             table.put(td, new HashSet<TempFlatPair>());
80           table.get(td).addAll(fnset);
81         }
82       }
83     }
84     return table;
85   }
86   
87   public Set<FlatNode> getNeedSrcTrans(LocalityBinding lb) {
88     return srcmap.get(lb);
89   }
90
91   public boolean getNeedSrcTrans(LocalityBinding lb, FlatNode fn) {
92     return srcmap.get(lb).contains(fn);
93   }
94
95   public boolean getNeedTrans(LocalityBinding lb, FlatNode fn) {
96     return treadmap.get(lb).contains(fn);
97   }
98
99   private void analyzeLocality(LocalityBinding lb) {
100     MethodDescriptor md=lb.getMethod();
101     FlatMethod fm=state.getMethodFlat(md);
102     Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> fnmap=computeTempSets(lb);
103     HashSet<TempFlatPair> tfset=computeTranslationSet(lb, fm, fnmap);
104     HashSet<FlatNode> srctrans=new HashSet<FlatNode>();
105     transreadmap.put(lb, tfset);
106     srcmap.put(lb,srctrans);
107     
108     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
109       FlatNode fn=fnit.next();
110       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
111       if (atomictable.get(fn).intValue()>0) {
112         Hashtable<TempDescriptor, Set<TempFlatPair>> tmap=fnmap.get(fn);
113         switch(fn.kind()) {
114         case FKind.FlatSetFieldNode: { 
115           //definitely need to translate these
116           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
117           if (!fsfn.getField().getType().isPtr())
118             break;
119           Set<TempFlatPair> tfpset=tmap.get(fsfn.getSrc());
120           if (tfpset!=null) {
121             for(Iterator<TempFlatPair> tfpit=tfpset.iterator();tfpit.hasNext();) {
122               TempFlatPair tfp=tfpit.next();
123               if (tfset.contains(tfp)||ok(tfp)) {
124                 srctrans.add(fsfn);
125                 break;
126               }
127             }
128           }
129           break;
130         }
131         case FKind.FlatSetElementNode: { 
132           //definitely need to translate these
133           FlatSetElementNode fsen=(FlatSetElementNode)fn;
134           if (!fsen.getSrc().getType().isPtr())
135             break;
136           Set<TempFlatPair> tfpset=tmap.get(fsen.getSrc());
137           if (tfpset!=null) {
138             for(Iterator<TempFlatPair> tfpit=tfpset.iterator();tfpit.hasNext();) {
139               TempFlatPair tfp=tfpit.next();
140               if (tfset.contains(tfp)||ok(tfp)) {
141                 srctrans.add(fsen);
142                 break;
143               }
144             }
145           }
146           break;
147         }
148         default:
149         }
150       }
151     }
152   }
153
154   public boolean ok(TempFlatPair tfp) {
155     FlatNode fn=tfp.f;
156     return fn.kind()==FKind.FlatCall;
157   }
158
159   HashSet<TempFlatPair> computeTranslationSet(LocalityBinding lb, FlatMethod fm, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> fnmap) {
160     HashSet<TempFlatPair> tfset=new HashSet<TempFlatPair>();
161     
162     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
163       FlatNode fn=fnit.next();
164       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
165       if (atomictable.get(fn).intValue()>0) {
166         Hashtable<TempDescriptor, Set<TempFlatPair>> tmap=fnmap.get(fn);
167         switch(fn.kind()) {
168         case FKind.FlatElementNode: {
169           FlatElementNode fen=(FlatElementNode)fn;
170           if (arrays.contains(fen.getSrc().getType())) {
171             //this could cause conflict...figure out conflict set
172             Set<TempFlatPair> tfpset=tmap.get(fen.getSrc());
173             if (tfpset!=null)
174               tfset.addAll(tfpset);
175           }
176           break;
177         }
178         case FKind.FlatFieldNode: { 
179           FlatFieldNode ffn=(FlatFieldNode)fn;
180           if (fields.contains(ffn.getField())) {
181             //this could cause conflict...figure out conflict set
182             Set<TempFlatPair> tfpset=tmap.get(ffn.getSrc());
183             if (tfpset!=null)
184               tfset.addAll(tfpset);
185           }
186           break;
187         }
188         case FKind.FlatSetFieldNode: { 
189           //definitely need to translate these
190           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
191           Set<TempFlatPair> tfpset=tmap.get(fsfn.getDst());
192           if (tfpset!=null)
193             tfset.addAll(tfpset);
194           break;
195         }
196         case FKind.FlatSetElementNode: { 
197           //definitely need to translate these
198           FlatSetElementNode fsen=(FlatSetElementNode)fn;
199           Set<TempFlatPair> tfpset=tmap.get(fsen.getDst());
200           if (tfpset!=null)
201             tfset.addAll(tfpset);
202           break;
203         }
204         case FKind.FlatCall: //assume pessimistically that calls do bad things
205         case FKind.FlatReturnNode: {
206           TempDescriptor []readarray=fn.readsTemps();
207           for(int i=0;i<readarray.length;i++) {
208             TempDescriptor rtmp=readarray[i];
209             Set<TempFlatPair> tfpset=tmap.get(rtmp);
210             if (tfpset!=null)
211               tfset.addAll(tfpset);
212           }
213           break;
214         }
215         default:
216           //do nothing
217         }
218       }
219     }   
220     return tfset;
221   }
222
223   Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> computeTempSets(LocalityBinding lb) {
224     Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> tmptofnset=new Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>>();
225     HashSet<FlatNode> discovered=new HashSet<FlatNode>();
226     HashSet<FlatNode> tovisit=new HashSet<FlatNode>();
227     MethodDescriptor md=lb.getMethod();
228     FlatMethod fm=state.getMethodFlat(md);
229     Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
230     Hashtable<FlatNode, Set<TempDescriptor>> livetemps=locality.computeLiveTemps(fm);
231     tovisit.add(fm);
232     discovered.add(fm);
233     
234     while(!tovisit.isEmpty()) {
235       FlatNode fn=tovisit.iterator().next();
236       tovisit.remove(fn);
237       for(int i=0;i<fn.numNext();i++) {
238         FlatNode fnext=fn.getNext(i);
239         if (!discovered.contains(fnext)) {
240           discovered.add(fnext);
241           tovisit.add(fnext);
242         }
243       }
244       Hashtable<TempDescriptor, Set<TempFlatPair>> ttofn=null;
245       if (atomictable.get(fn).intValue()!=0) {
246         if ((fn.numPrev()>0)&&atomictable.get(fn.getPrev(0)).intValue()==0) {
247           //flatatomic enter node...  see what we really need to transread
248           Set<TempDescriptor> liveset=livetemps.get(fn);
249           ttofn=new Hashtable<TempDescriptor, Set<TempFlatPair>>();
250           for(Iterator<TempDescriptor> tmpit=liveset.iterator();tmpit.hasNext();) {
251             TempDescriptor tmp=tmpit.next();
252             if (tmp.getType().isPtr()) {
253               HashSet<TempFlatPair> fnset=new HashSet<TempFlatPair>();
254               fnset.add(new TempFlatPair(tmp, fn));
255               ttofn.put(tmp, fnset);
256             }
257           }
258         } else {
259           ttofn=doMerge(fn, tmptofnset);
260           switch(fn.kind()) {
261           case FKind.FlatFieldNode:
262           case FKind.FlatElementNode: {
263             TempDescriptor[] writes=fn.writesTemps();
264             for(int i=0;i<writes.length;i++) {
265               TempDescriptor wtmp=writes[i];
266               HashSet<TempFlatPair> set=new HashSet<TempFlatPair>();
267               set.add(new TempFlatPair(wtmp, fn));
268               ttofn.put(wtmp, set);
269             }
270             break;
271           }
272           case FKind.FlatCall:
273           case FKind.FlatMethod: {
274             TempDescriptor[] writes=fn.writesTemps();
275             for(int i=0;i<writes.length;i++) {
276               TempDescriptor wtmp=writes[i];
277               HashSet<TempFlatPair> set=new HashSet<TempFlatPair>();
278               set.add(new TempFlatPair(wtmp, fn));
279               ttofn.put(wtmp, set);
280             }
281             break;
282           }
283           case FKind.FlatOpNode: {
284             FlatOpNode fon=(FlatOpNode)fn;
285             if (fon.getOp().getOp()==Operation.ASSIGN&&fon.getDest().getType().isPtr()&&
286                 ttofn.containsKey(fon.getLeft())) {
287               ttofn.put(fon.getDest(), new HashSet<TempFlatPair>(ttofn.get(fon.getLeft())));
288               break;
289             }
290           }
291           default:
292             //Do kill computation
293             TempDescriptor[] writes=fn.writesTemps();
294             for(int i=0;i<writes.length;i++) {
295               TempDescriptor wtmp=writes[i];
296               ttofn.remove(writes[i]);
297             }
298           }
299         }
300         if (ttofn!=null) {
301           if (!tmptofnset.containsKey(fn)||
302               !tmptofnset.get(fn).equals(ttofn)) {
303             //enqueue nodes to process
304             tmptofnset.put(fn, ttofn);
305             for(int i=0;i<fn.numNext();i++) {
306               FlatNode fnext=fn.getNext(i);
307               tovisit.add(fnext);
308             }
309           }
310         }
311       }
312     }
313     return tmptofnset;
314   }
315   
316   public void computeModified(LocalityBinding lb) {
317     MethodDescriptor md=lb.getMethod();
318     FlatMethod fm=state.getMethodFlat(md);
319     Hashtable<FlatNode, Set<TempDescriptor>> oldtemps=computeOldTemps(lb);
320     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
321       FlatNode fn=fnit.next();
322       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
323       if (atomictable.get(fn).intValue()>0) {
324         switch (fn.kind()) {
325         case FKind.FlatSetFieldNode:
326           FlatSetFieldNode fsfn=(FlatSetFieldNode) fn;
327           fields.add(fsfn.getField());
328           break;
329         case FKind.FlatSetElementNode:
330           FlatSetElementNode fsen=(FlatSetElementNode) fn;
331           arrays.add(fsen.getDst().getType());
332           break;
333         default:
334         }
335       }
336     }
337   }
338     
339   Hashtable<FlatNode, Set<TempDescriptor>> computeOldTemps(LocalityBinding lb) {
340     Hashtable<FlatNode, Set<TempDescriptor>> fntooldtmp=new Hashtable<FlatNode, Set<TempDescriptor>>();
341     HashSet<FlatNode> discovered=new HashSet<FlatNode>();
342     HashSet<FlatNode> tovisit=new HashSet<FlatNode>();
343     MethodDescriptor md=lb.getMethod();
344     FlatMethod fm=state.getMethodFlat(md);
345     Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
346     Hashtable<FlatNode, Set<TempDescriptor>> livetemps=locality.computeLiveTemps(fm);
347     tovisit.add(fm);
348     discovered.add(fm);
349     
350     while(!tovisit.isEmpty()) {
351       FlatNode fn=tovisit.iterator().next();
352       tovisit.remove(fn);
353       for(int i=0;i<fn.numNext();i++) {
354         FlatNode fnext=fn.getNext(i);
355         if (!discovered.contains(fnext)) {
356           discovered.add(fnext);
357           tovisit.add(fnext);
358         }
359       }
360       HashSet<TempDescriptor> oldtemps=null;
361       if (atomictable.get(fn).intValue()!=0) {
362         if ((fn.numPrev()>0)&&atomictable.get(fn.getPrev(0)).intValue()==0) {
363           //Everything live is old
364           Set<TempDescriptor> lives=livetemps.get(fn);
365           oldtemps=new HashSet<TempDescriptor>();
366           
367           for(Iterator<TempDescriptor> it=lives.iterator();it.hasNext();) {
368             TempDescriptor tmp=it.next();
369             if (tmp.getType().isPtr()) {
370               oldtemps.add(tmp);
371             }
372           }
373         } else {
374           oldtemps=new HashSet<TempDescriptor>();
375           //Compute union of old temporaries
376           for(int i=0;i<fn.numPrev();i++) {
377             Set<TempDescriptor> pset=fntooldtmp.get(fn.getPrev(i));
378             if (pset!=null)
379               oldtemps.addAll(pset);
380           }
381           
382           switch (fn.kind()) {
383           case FKind.FlatNew:
384             oldtemps.removeAll(Arrays.asList(fn.readsTemps()));
385             break;
386           case FKind.FlatOpNode: {
387             FlatOpNode fon=(FlatOpNode)fn;
388             if (fon.getOp().getOp()==Operation.ASSIGN&&fon.getDest().getType().isPtr()) {
389               if (oldtemps.contains(fon.getLeft()))
390                 oldtemps.add(fon.getDest());
391               else
392                 oldtemps.remove(fon.getDest());
393               break;
394             }
395           }
396           default: {
397             TempDescriptor[] writes=fn.writesTemps();
398             for(int i=0;i<writes.length;i++) {
399               TempDescriptor wtemp=writes[i];
400               if (wtemp.getType().isPtr())
401                 oldtemps.add(wtemp);
402             }
403           }
404           }
405         }
406       }
407       
408       if (oldtemps!=null) {
409         if (!fntooldtmp.containsKey(fn)||!fntooldtmp.get(fn).equals(oldtemps)) {
410           fntooldtmp.put(fn, oldtemps);
411           //propagate changes
412           for(int i=0;i<fn.numNext();i++) {
413             FlatNode fnext=fn.getNext(i);
414             tovisit.add(fnext);
415           }
416         }
417       }
418     }
419     return fntooldtmp;
420   }
421 }
422
423 class TempFlatPair {
424     FlatNode f;
425     TempDescriptor t;
426     TempFlatPair(TempDescriptor t, FlatNode f) {
427         this.t=t;
428         this.f=f;
429     }
430
431     public int hashCode() {
432         return f.hashCode()^t.hashCode();
433     }
434     public boolean equals(Object o) {
435         TempFlatPair tf=(TempFlatPair)o;
436         return t.equals(tf.t)&&f.equals(tf.f);
437     }
438 }