changes
[IRC.git] / Robust / src / Analysis / Locality / DiscoverConflicts.java
1 package Analysis.Locality;
2
3 import IR.Flat.*;
4 import java.util.Set;
5 import java.util.Arrays;
6 import java.util.HashSet;
7 import java.util.Iterator;
8 import java.util.Hashtable;
9 import IR.State;
10 import IR.Operation;
11 import IR.TypeDescriptor;
12 import IR.MethodDescriptor;
13 import IR.FieldDescriptor;
14
15 public class DiscoverConflicts {
16   Set<FieldDescriptor> fields;
17   Set<TypeDescriptor> arrays;
18   LocalityAnalysis locality;
19   State state;
20   Hashtable<LocalityBinding, Set<FlatNode>> treadmap;
21   Hashtable<LocalityBinding, Set<TempFlatPair>> transreadmap;
22   Hashtable<LocalityBinding, Set<FlatNode>> srcmap;
23   Hashtable<LocalityBinding, Set<FlatNode>> leftsrcmap;
24   Hashtable<LocalityBinding, Set<FlatNode>> rightsrcmap;
25   TypeAnalysis typeanalysis;
26   HashSet<FlatNode>cannotdelay;
27
28   public DiscoverConflicts(LocalityAnalysis locality, State state, TypeAnalysis typeanalysis) {
29     this.locality=locality;
30     this.fields=new HashSet<FieldDescriptor>();
31     this.arrays=new HashSet<TypeDescriptor>();
32     this.state=state;
33     this.typeanalysis=typeanalysis;
34     transreadmap=new Hashtable<LocalityBinding, Set<TempFlatPair>>();
35     treadmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
36     srcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
37     leftsrcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
38     rightsrcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
39   }
40
41   public DiscoverConflicts(LocalityAnalysis locality, State state, TypeAnalysis typeanalysis, HashSet<FlatNode> cannotdelay) {
42     this.locality=locality;
43     this.fields=new HashSet<FieldDescriptor>();
44     this.arrays=new HashSet<TypeDescriptor>();
45     this.state=state;
46     this.typeanalysis=typeanalysis;
47     this.cannotdelay=cannotdelay;
48     transreadmap=new Hashtable<LocalityBinding, Set<TempFlatPair>>();
49     treadmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
50     srcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
51     leftsrcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
52     rightsrcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
53   }
54   
55   public void doAnalysis() {
56     //Compute fields and arrays for all transactions.  Note that we
57     //only look at changes to old objects
58
59     Set<LocalityBinding> localityset=locality.getLocalityBindings();
60     for(Iterator<LocalityBinding> lb=localityset.iterator();lb.hasNext();) {
61       computeModified(lb.next());
62     }
63     expandTypes();
64     //Compute set of nodes that need transread
65     for(Iterator<LocalityBinding> lb=localityset.iterator();lb.hasNext();) {
66       LocalityBinding l=lb.next();
67       analyzeLocality(l);
68       setNeedReadTrans(l);
69     }
70   }
71
72   //Change flatnode/temp pairs to just flatnodes that need transactional reads
73
74   public void setNeedReadTrans(LocalityBinding lb) {
75     HashSet<FlatNode> set=new HashSet<FlatNode>();
76     for(Iterator<TempFlatPair> it=transreadmap.get(lb).iterator();it.hasNext();) {
77       TempFlatPair tfp=it.next();
78       set.add(tfp.f);
79     }
80     treadmap.put(lb, set);
81   }
82
83   //We have a set of things we write to, figure out what things this
84   //could effect.
85   public void expandTypes() {
86     Set<TypeDescriptor> expandedarrays=new HashSet<TypeDescriptor>();
87     for(Iterator<TypeDescriptor> it=arrays.iterator();it.hasNext();) {
88       TypeDescriptor td=it.next();
89       expandedarrays.addAll(typeanalysis.expand(td));
90     }
91     arrays=expandedarrays;
92   }
93
94   Hashtable<TempDescriptor, Set<TempFlatPair>> doMerge(FlatNode fn, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> tmptofnset) {
95     Hashtable<TempDescriptor, Set<TempFlatPair>> table=new Hashtable<TempDescriptor, Set<TempFlatPair>>();
96     for(int i=0;i<fn.numPrev();i++) {
97       FlatNode fprev=fn.getPrev(i);
98       Hashtable<TempDescriptor, Set<TempFlatPair>> tabset=tmptofnset.get(fprev);
99       if (tabset!=null) {
100         for(Iterator<TempDescriptor> tmpit=tabset.keySet().iterator();tmpit.hasNext();) {
101           TempDescriptor td=tmpit.next();
102           Set<TempFlatPair> fnset=tabset.get(td);
103           if (!table.containsKey(td))
104             table.put(td, new HashSet<TempFlatPair>());
105           table.get(td).addAll(fnset);
106         }
107       }
108     }
109     return table;
110   }
111   
112   public Set<FlatNode> getNeedSrcTrans(LocalityBinding lb) {
113     return srcmap.get(lb);
114   }
115
116   public boolean getNeedSrcTrans(LocalityBinding lb, FlatNode fn) {
117     return srcmap.get(lb).contains(fn);
118   }
119
120   public boolean getNeedLeftSrcTrans(LocalityBinding lb, FlatNode fn) {
121     return leftsrcmap.get(lb).contains(fn);
122   }
123
124   public boolean getNeedRightSrcTrans(LocalityBinding lb, FlatNode fn) {
125     return rightsrcmap.get(lb).contains(fn);
126   }
127
128   public boolean getNeedTrans(LocalityBinding lb, FlatNode fn) {
129     return treadmap.get(lb).contains(fn);
130   }
131
132   private void analyzeLocality(LocalityBinding lb) {
133     MethodDescriptor md=lb.getMethod();
134     FlatMethod fm=state.getMethodFlat(md);
135     Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> fnmap=computeTempSets(lb);
136     HashSet<TempFlatPair> tfset=computeTranslationSet(lb, fm, fnmap);
137     HashSet<FlatNode> srctrans=new HashSet<FlatNode>();
138     HashSet<FlatNode> leftsrctrans=new HashSet<FlatNode>();
139     HashSet<FlatNode> rightsrctrans=new HashSet<FlatNode>();
140     transreadmap.put(lb, tfset);
141     srcmap.put(lb,srctrans);
142     leftsrcmap.put(lb,leftsrctrans);
143     rightsrcmap.put(lb,rightsrctrans);
144
145     //compute writes that need translation on source
146
147     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
148       FlatNode fn=fnit.next();
149       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
150       if (atomictable.get(fn).intValue()>0) {
151         Hashtable<TempDescriptor, Set<TempFlatPair>> tmap=fnmap.get(fn);
152         switch(fn.kind()) {
153
154           //We might need to translate arguments to pointer comparison
155
156         case FKind.FlatOpNode: { 
157           FlatOpNode fon=(FlatOpNode)fn;
158           if (fon.getOp().getOp()==Operation.EQUAL||
159               fon.getOp().getOp()==Operation.NOTEQUAL) {
160             if (!fon.getLeft().getType().isPtr())
161               break;
162             Set<TempFlatPair> lefttfpset=tmap.get(fon.getLeft());
163             Set<TempFlatPair> righttfpset=tmap.get(fon.getRight());
164             //handle left operand
165             if (lefttfpset!=null) {
166               for(Iterator<TempFlatPair> tfpit=lefttfpset.iterator();tfpit.hasNext();) {
167                 TempFlatPair tfp=tfpit.next();
168                 if (tfset.contains(tfp)||outofscope(tfp)) {
169                   leftsrctrans.add(fon);
170                   break;
171                 }
172               }
173             }
174             //handle right operand
175             if (righttfpset!=null) {
176               for(Iterator<TempFlatPair> tfpit=righttfpset.iterator();tfpit.hasNext();) {
177                 TempFlatPair tfp=tfpit.next();
178                 if (tfset.contains(tfp)||outofscope(tfp)) {
179                   rightsrctrans.add(fon);
180                   break;
181                 }
182               }
183             }
184           }
185           break;
186         }
187
188         case FKind.FlatSetFieldNode: { 
189           //need to translate these if the value we read from may be a
190           //shadow...  check this by seeing if any of the values we
191           //may read are in the transread set or came from our caller
192           //or a method we called
193
194           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
195           if (!fsfn.getField().getType().isPtr())
196             break;
197           Set<TempFlatPair> tfpset=tmap.get(fsfn.getSrc());
198           if (tfpset!=null) {
199             for(Iterator<TempFlatPair> tfpit=tfpset.iterator();tfpit.hasNext();) {
200               TempFlatPair tfp=tfpit.next();
201               if (tfset.contains(tfp)||outofscope(tfp)) {
202                 srctrans.add(fsfn);
203                 break;
204               }
205             }
206           }
207           break;
208         }
209         case FKind.FlatSetElementNode: { 
210           //need to translate these if the value we read from may be a
211           //shadow...  check this by seeing if any of the values we
212           //may read are in the transread set or came from our caller
213           //or a method we called
214
215           FlatSetElementNode fsen=(FlatSetElementNode)fn;
216           if (!fsen.getSrc().getType().isPtr())
217             break;
218           Set<TempFlatPair> tfpset=tmap.get(fsen.getSrc());
219           if (tfpset!=null) {
220             for(Iterator<TempFlatPair> tfpit=tfpset.iterator();tfpit.hasNext();) {
221               TempFlatPair tfp=tfpit.next();
222               if (tfset.contains(tfp)||outofscope(tfp)) {
223                 srctrans.add(fsen);
224                 break;
225               }
226             }
227           }
228           break;
229         }
230         default:
231         }
232       }
233     }
234   }
235
236   public boolean outofscope(TempFlatPair tfp) {
237     FlatNode fn=tfp.f;
238     return fn.kind()==FKind.FlatCall||fn.kind()==FKind.FlatMethod;
239   }
240
241
242   /** Need to figure out which nodes need a transread to make local
243   copies.  Transread conceptually tracks conflicts.  This depends on
244   what fields/elements are accessed We iterate over all flatnodes that
245   access fields...If these accesses could conflict, we mark the source
246   tempflat pair as needing a transread */
247
248   HashSet<TempFlatPair> computeTranslationSet(LocalityBinding lb, FlatMethod fm, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> fnmap) {
249     HashSet<TempFlatPair> tfset=new HashSet<TempFlatPair>();
250     
251     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
252       FlatNode fn=fnit.next();
253       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
254
255       //Check whether this node matters for delayed computation
256       if (cannotdelay!=null&&!cannotdelay.contains(fn))
257         continue;
258
259       if (atomictable.get(fn).intValue()>0) {
260         Hashtable<TempDescriptor, Set<TempFlatPair>> tmap=fnmap.get(fn);
261         switch(fn.kind()) {
262         case FKind.FlatElementNode: {
263           FlatElementNode fen=(FlatElementNode)fn;
264           if (arrays.contains(fen.getSrc().getType())) {
265             //this could cause conflict...figure out conflict set
266             Set<TempFlatPair> tfpset=tmap.get(fen.getSrc());
267             if (tfpset!=null)
268               tfset.addAll(tfpset);
269           }
270           break;
271         }
272         case FKind.FlatFieldNode: { 
273           FlatFieldNode ffn=(FlatFieldNode)fn;
274           if (fields.contains(ffn.getField())) {
275             //this could cause conflict...figure out conflict set
276             Set<TempFlatPair> tfpset=tmap.get(ffn.getSrc());
277             if (tfpset!=null)
278               tfset.addAll(tfpset);
279           }
280           break;
281         }
282         case FKind.FlatSetFieldNode: { 
283           //definitely need to translate these
284           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
285           Set<TempFlatPair> tfpset=tmap.get(fsfn.getDst());
286           if (tfpset!=null)
287             tfset.addAll(tfpset);
288           break;
289         }
290         case FKind.FlatSetElementNode: { 
291           //definitely need to translate these
292           FlatSetElementNode fsen=(FlatSetElementNode)fn;
293           Set<TempFlatPair> tfpset=tmap.get(fsen.getDst());
294           if (tfpset!=null)
295             tfset.addAll(tfpset);
296           break;
297         }
298         case FKind.FlatCall: //assume pessimistically that calls do bad things
299         case FKind.FlatReturnNode: {
300           TempDescriptor []readarray=fn.readsTemps();
301           for(int i=0;i<readarray.length;i++) {
302             TempDescriptor rtmp=readarray[i];
303             Set<TempFlatPair> tfpset=tmap.get(rtmp);
304             if (tfpset!=null)
305               tfset.addAll(tfpset);
306           }
307           break;
308         }
309         default:
310           //do nothing
311         }
312       }
313     }   
314     return tfset;
315   }
316
317
318   //This method generates as output for each node
319   //A map from from temps to a set of temp/flat pairs that the
320   //original temp points to
321   //A temp/flat pair gives the flatnode that the value was created at
322   //and the original temp
323
324   Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> computeTempSets(LocalityBinding lb) {
325     Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> tmptofnset=new Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>>();
326     HashSet<FlatNode> discovered=new HashSet<FlatNode>();
327     HashSet<FlatNode> tovisit=new HashSet<FlatNode>();
328     MethodDescriptor md=lb.getMethod();
329     FlatMethod fm=state.getMethodFlat(md);
330     Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
331     Hashtable<FlatNode, Set<TempDescriptor>> livetemps=locality.computeLiveTemps(fm);
332     tovisit.add(fm);
333     discovered.add(fm);
334     
335     while(!tovisit.isEmpty()) {
336       FlatNode fn=tovisit.iterator().next();
337       tovisit.remove(fn);
338       for(int i=0;i<fn.numNext();i++) {
339         FlatNode fnext=fn.getNext(i);
340         if (!discovered.contains(fnext)) {
341           discovered.add(fnext);
342           tovisit.add(fnext);
343         }
344       }
345       Hashtable<TempDescriptor, Set<TempFlatPair>> ttofn=null;
346       if (atomictable.get(fn).intValue()!=0) {
347         if ((fn.numPrev()>0)&&atomictable.get(fn.getPrev(0)).intValue()==0) {
348           //atomic node, start with new set
349           ttofn=new Hashtable<TempDescriptor, Set<TempFlatPair>>();
350         } else {
351           ttofn=doMerge(fn, tmptofnset);
352           switch(fn.kind()) {
353           case FKind.FlatGlobalConvNode: {
354             FlatGlobalConvNode fgcn=(FlatGlobalConvNode)fn;
355             if (lb==fgcn.getLocality()&&
356                 fgcn.getMakePtr()) {
357               TempDescriptor[] writes=fn.writesTemps();
358               for(int i=0;i<writes.length;i++) {
359                 TempDescriptor wtmp=writes[i];
360                 HashSet<TempFlatPair> set=new HashSet<TempFlatPair>();
361                 set.add(new TempFlatPair(wtmp, fn));
362                 ttofn.put(wtmp, set);
363               }
364             }
365             break;
366           }
367           case FKind.FlatFieldNode:
368           case FKind.FlatElementNode: {
369             TempDescriptor[] writes=fn.writesTemps();
370             for(int i=0;i<writes.length;i++) {
371               TempDescriptor wtmp=writes[i];
372               HashSet<TempFlatPair> set=new HashSet<TempFlatPair>();
373               set.add(new TempFlatPair(wtmp, fn));
374               ttofn.put(wtmp, set);
375             }
376             break;
377           }
378           case FKind.FlatCall:
379           case FKind.FlatMethod: {
380             TempDescriptor[] writes=fn.writesTemps();
381             for(int i=0;i<writes.length;i++) {
382               TempDescriptor wtmp=writes[i];
383               HashSet<TempFlatPair> set=new HashSet<TempFlatPair>();
384               set.add(new TempFlatPair(wtmp, fn));
385               ttofn.put(wtmp, set);
386             }
387             break;
388           }
389           case FKind.FlatOpNode: {
390             FlatOpNode fon=(FlatOpNode)fn;
391             if (fon.getOp().getOp()==Operation.ASSIGN&&fon.getDest().getType().isPtr()&&
392                 ttofn.containsKey(fon.getLeft())) {
393               ttofn.put(fon.getDest(), new HashSet<TempFlatPair>(ttofn.get(fon.getLeft())));
394               break;
395             }
396           }
397           default:
398             //Do kill computation
399             TempDescriptor[] writes=fn.writesTemps();
400             for(int i=0;i<writes.length;i++) {
401               TempDescriptor wtmp=writes[i];
402               ttofn.remove(writes[i]);
403             }
404           }
405         }
406         if (ttofn!=null) {
407           if (!tmptofnset.containsKey(fn)||
408               !tmptofnset.get(fn).equals(ttofn)) {
409             //enqueue nodes to process
410             tmptofnset.put(fn, ttofn);
411             for(int i=0;i<fn.numNext();i++) {
412               FlatNode fnext=fn.getNext(i);
413               tovisit.add(fnext);
414             }
415           }
416         }
417       }
418     }
419     return tmptofnset;
420   }
421   
422   /* See what fields and arrays transactions might modify.  We only
423    * look at changes to old objects. */
424
425   public void computeModified(LocalityBinding lb) {
426     MethodDescriptor md=lb.getMethod();
427     FlatMethod fm=state.getMethodFlat(md);
428     Hashtable<FlatNode, Set<TempDescriptor>> oldtemps=computeOldTemps(lb);
429     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
430       FlatNode fn=fnit.next();
431       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
432       if (atomictable.get(fn).intValue()>0) {
433         switch (fn.kind()) {
434         case FKind.FlatSetFieldNode:
435           FlatSetFieldNode fsfn=(FlatSetFieldNode) fn;
436           fields.add(fsfn.getField());
437           break;
438         case FKind.FlatSetElementNode:
439           FlatSetElementNode fsen=(FlatSetElementNode) fn;
440           arrays.add(fsen.getDst().getType());
441           break;
442         default:
443         }
444       }
445     }
446   }
447     
448
449   //Returns a table that maps a flatnode to a set of temporaries
450   //This set of temporaries is old (meaning they may point to object
451   //allocated before the beginning of the current transaction
452
453   Hashtable<FlatNode, Set<TempDescriptor>> computeOldTemps(LocalityBinding lb) {
454     Hashtable<FlatNode, Set<TempDescriptor>> fntooldtmp=new Hashtable<FlatNode, Set<TempDescriptor>>();
455     HashSet<FlatNode> discovered=new HashSet<FlatNode>();
456     HashSet<FlatNode> tovisit=new HashSet<FlatNode>();
457     MethodDescriptor md=lb.getMethod();
458     FlatMethod fm=state.getMethodFlat(md);
459     Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
460     Hashtable<FlatNode, Set<TempDescriptor>> livetemps=locality.computeLiveTemps(fm);
461     tovisit.add(fm);
462     discovered.add(fm);
463     
464     while(!tovisit.isEmpty()) {
465       FlatNode fn=tovisit.iterator().next();
466       tovisit.remove(fn);
467       for(int i=0;i<fn.numNext();i++) {
468         FlatNode fnext=fn.getNext(i);
469         if (!discovered.contains(fnext)) {
470           discovered.add(fnext);
471           tovisit.add(fnext);
472         }
473       }
474       HashSet<TempDescriptor> oldtemps=null;
475       if (atomictable.get(fn).intValue()!=0) {
476         if ((fn.numPrev()>0)&&atomictable.get(fn.getPrev(0)).intValue()==0) {
477           //Everything live is old
478           Set<TempDescriptor> lives=livetemps.get(fn);
479           oldtemps=new HashSet<TempDescriptor>();
480           
481           for(Iterator<TempDescriptor> it=lives.iterator();it.hasNext();) {
482             TempDescriptor tmp=it.next();
483             if (tmp.getType().isPtr()) {
484               oldtemps.add(tmp);
485             }
486           }
487         } else {
488           oldtemps=new HashSet<TempDescriptor>();
489           //Compute union of old temporaries
490           for(int i=0;i<fn.numPrev();i++) {
491             Set<TempDescriptor> pset=fntooldtmp.get(fn.getPrev(i));
492             if (pset!=null)
493               oldtemps.addAll(pset);
494           }
495           
496           switch (fn.kind()) {
497           case FKind.FlatNew:
498             oldtemps.removeAll(Arrays.asList(fn.readsTemps()));
499             break;
500           case FKind.FlatOpNode: {
501             FlatOpNode fon=(FlatOpNode)fn;
502             if (fon.getOp().getOp()==Operation.ASSIGN&&fon.getDest().getType().isPtr()) {
503               if (oldtemps.contains(fon.getLeft()))
504                 oldtemps.add(fon.getDest());
505               else
506                 oldtemps.remove(fon.getDest());
507               break;
508             }
509           }
510           default: {
511             TempDescriptor[] writes=fn.writesTemps();
512             for(int i=0;i<writes.length;i++) {
513               TempDescriptor wtemp=writes[i];
514               if (wtemp.getType().isPtr())
515                 oldtemps.add(wtemp);
516             }
517           }
518           }
519         }
520       }
521       
522       if (oldtemps!=null) {
523         if (!fntooldtmp.containsKey(fn)||!fntooldtmp.get(fn).equals(oldtemps)) {
524           fntooldtmp.put(fn, oldtemps);
525           //propagate changes
526           for(int i=0;i<fn.numNext();i++) {
527             FlatNode fnext=fn.getNext(i);
528             tovisit.add(fnext);
529           }
530         }
531       }
532     }
533     return fntooldtmp;
534   }
535 }
536
537 class TempFlatPair {
538     FlatNode f;
539     TempDescriptor t;
540     TempFlatPair(TempDescriptor t, FlatNode f) {
541         this.t=t;
542         this.f=f;
543     }
544
545     public int hashCode() {
546         return f.hashCode()^t.hashCode();
547     }
548     public boolean equals(Object o) {
549         TempFlatPair tf=(TempFlatPair)o;
550         return t.equals(tf.t)&&f.equals(tf.f);
551     }
552 }