more code towards transaction optimizations
[IRC.git] / Robust / src / Analysis / Locality / DelayComputation.java
1 package Analysis.Locality;
2 import IR.State;
3 import IR.MethodDescriptor;
4 import IR.TypeDescriptor;
5 import IR.FieldDescriptor;
6 import IR.Flat.*;
7 import Analysis.Loops.GlobalFieldType;
8 import java.util.HashSet;
9 import java.util.Hashtable;
10 import java.util.Set;
11 import java.util.Stack;
12 import java.util.Iterator;
13
14 public class DelayComputation {
15   State state;
16   LocalityAnalysis locality;
17   TypeAnalysis typeanalysis;
18   GlobalFieldType gft;
19   DiscoverConflicts dcopts;
20   Hashtable<LocalityBinding, HashSet<FlatNode>> notreadymap;
21   Hashtable<LocalityBinding, HashSet<FlatNode>> cannotdelaymap;
22   Hashtable<LocalityBinding, HashSet<FlatNode>> othermap;
23
24   public DelayComputation(LocalityAnalysis locality, State state, TypeAnalysis typeanalysis, GlobalFieldType gft) {
25     this.locality=locality;
26     this.state=state;
27     this.typeanalysis=typeanalysis;
28     this.gft=gft;
29     this.notreadymap=new Hashtable<LocalityBinding, HashSet<FlatNode>>();
30     this.cannotdelaymap=new Hashtable<LocalityBinding, HashSet<FlatNode>>();
31     this.othermap=new Hashtable<LocalityBinding, HashSet<FlatNode>>();
32   }
33
34   public DiscoverConflicts getConflicts() {
35     return dcopts;
36   }
37
38   public void doAnalysis() {
39     Set<LocalityBinding> localityset=locality.getLocalityBindings();
40     for(Iterator<LocalityBinding> lb=localityset.iterator();lb.hasNext();) {
41       analyzeMethod(lb.next());
42     }
43   }
44
45   public HashSet<FlatNode> getNotReady(LocalityBinding lb) {
46     return notreadymap.get(lb);
47   }
48
49   public HashSet<FlatNode> getCannotDelay(LocalityBinding lb) {
50     return cannotdelaymap.get(lb);
51   }
52
53   public HashSet<FlatNode> getOther(LocalityBinding lb) {
54     return othermap.get(lb);
55   }
56
57   //This method computes which nodes from the first part of the
58   //transaction must store their output for the second part
59   //Note that many nodes don't need to...
60
61   public Set<FlatNode> livecode(LocalityBinding lb) {
62     if (!othermap.containsKey(lb))
63       return null;
64     HashSet<FlatNode> delayedset=notreadymap.get(lb);
65     MethodDescriptor md=lb.getMethod();
66     FlatMethod fm=state.getMethodFlat(md);
67     Hashtable<FlatNode, Hashtable<TempDescriptor, HashSet<FlatNode>>> map=new Hashtable<FlatNode, Hashtable<TempDescriptor, HashSet<FlatNode>>>();
68
69     HashSet<FlatNode> toanalyze=new HashSet<FlatNode>();
70     toanalyze.add(fm);
71     
72     HashSet<FlatNode> livenodes=new HashSet<FlatNode>();
73
74     while(!toanalyze.isEmpty()) {
75       FlatNode fn=toanalyze.iterator().next();
76       toanalyze.remove(fn);
77       Hashtable<TempDescriptor, HashSet<FlatNode>> tmptofn=new Hashtable<TempDescriptor, HashSet<FlatNode>>();
78       
79       //Do merge on incoming edges
80       for(int i=0;i<fn.numPrev();i++) {
81         FlatNode fnprev=fn.getPrev(i);
82         Hashtable<TempDescriptor, HashSet<FlatNode>> prevmap=map.get(fnprev);
83
84         for(Iterator<TempDescriptor> tmpit=prevmap.keySet().iterator();tmpit.hasNext();) {
85           TempDescriptor tmp=tmpit.next();
86           if (!tmptofn.containsKey(tmp))
87             tmptofn.put(tmp, new HashSet<FlatNode>());
88           tmptofn.get(tmp).addAll(prevmap.get(tmp));
89         }
90       }
91
92       if (delayedset.contains(fn)) {
93         //Check our readset
94         TempDescriptor readset[]=fn.readsTemps();
95         for(int i=0;i<readset.length;i++) {
96           TempDescriptor tmp=readset[i];
97           if (tmptofn.containsKey(tmp))
98             livenodes.addAll(tmptofn.get(tmp)); // add live nodes
99         }
100
101         //Do kills
102         TempDescriptor writeset[]=fn.writesTemps();
103         for(int i=0;i<writeset.length;i++) {
104           TempDescriptor tmp=writeset[i];
105           tmptofn.remove(tmp);
106         }
107       } else {
108         //We write -- our reads are done
109         TempDescriptor writeset[]=fn.writesTemps();
110         for(int i=0;i<writeset.length;i++) {
111           TempDescriptor tmp=writeset[i];
112           HashSet<FlatNode> set=new HashSet<FlatNode>();
113           set.add(fn);
114           tmptofn.put(tmp,set);
115         }
116         if (fn.numNext()>1) {
117           //We have a conditional branch...need to handle this carefully
118           Set<FlatNode> set0=getNext(fn, 0, delayedset);
119           Set<FlatNode> set1=getNext(fn, 1, delayedset);
120           if (!set0.equals(set1)||set0.size()>1) {
121             //This branch is important--need to remember how it goes
122             livenodes.add(fn);
123           }
124         }
125       }
126       if (!map.containsKey(fn)||!map.get(fn).equals(tmptofn)) {
127         map.put(fn, tmptofn);
128         //enqueue next ndoes
129         for(int i=0;i<fn.numNext();i++)
130           toanalyze.add(fn.getNext(i));
131       }
132     }
133     return livenodes;
134   }
135   
136   //Returns null if more than one possible next
137
138   public static Set<FlatNode> getNext(FlatNode fn, int i, HashSet<FlatNode> delayset) {
139     FlatNode fnnext=fn.getNext(i);
140     HashSet<FlatNode> reachable=new HashSet<FlatNode>();    
141
142     if (delayset.contains(fnnext)) {
143       reachable.add(fnnext);
144       return reachable;
145     }
146     Stack<FlatNode> nodes=new Stack<FlatNode>();
147     HashSet<FlatNode> visited=new HashSet<FlatNode>();
148     nodes.push(fnnext);
149
150     while(!nodes.isEmpty()) {
151       FlatNode fn2=nodes.pop();
152       if (visited.contains(fn2)) 
153         continue;
154       visited.add(fn2);
155       for (int j=0;j<fn2.numNext();j++) {
156         FlatNode fn2next=fn2.getNext(j);
157         if (delayset.contains(fn2next)) {
158           reachable.add(fn2next);
159         } else
160           nodes.push(fn2next);
161       }
162     }
163     return reachable;
164   }
165
166   public void analyzeMethod(LocalityBinding lb) {
167     MethodDescriptor md=lb.getMethod();
168     FlatMethod fm=state.getMethodFlat(md);
169     System.out.println("Analyzing "+md);
170     HashSet<FlatNode> cannotdelay=new HashSet<FlatNode>();
171     Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
172     if (lb.isAtomic()) {
173       //We are in a transaction already...
174       //skip past this method or something
175       return;
176     }
177
178     HashSet<FlatNode> toanalyze=new HashSet<FlatNode>();
179     toanalyze.addAll(fm.getNodeSet());
180
181     //Build the hashtables
182     Hashtable<FlatNode, HashSet<TempDescriptor>> nodelaytemps=new Hashtable<FlatNode, HashSet<TempDescriptor>>();
183     Hashtable<FlatNode, HashSet<FieldDescriptor>> nodelayfieldswr=new Hashtable<FlatNode, HashSet<FieldDescriptor>>();
184     Hashtable<FlatNode, HashSet<TypeDescriptor>> nodelayarrayswr=new Hashtable<FlatNode, HashSet<TypeDescriptor>>();
185     Hashtable<FlatNode, HashSet<FieldDescriptor>> nodelayfieldsrd=new Hashtable<FlatNode, HashSet<FieldDescriptor>>();
186     Hashtable<FlatNode, HashSet<TypeDescriptor>> nodelayarraysrd=new Hashtable<FlatNode, HashSet<TypeDescriptor>>();
187     
188     //Effect of adding something to nodelay set is to move it up past everything in delay set
189     //Have to make sure we can do this commute
190
191     while(!toanalyze.isEmpty()) {
192       FlatNode fn=toanalyze.iterator().next();
193       toanalyze.remove(fn);
194       
195       boolean isatomic=atomictable.get(fn).intValue()>0;
196
197       if (!isatomic)
198         continue;
199       boolean isnodelay=false;
200
201       /* Compute incoming nodelay sets */
202       HashSet<TempDescriptor> nodelaytempset=new HashSet<TempDescriptor>();
203       HashSet<FieldDescriptor> nodelayfieldwrset=new HashSet<FieldDescriptor>();
204       HashSet<TypeDescriptor> nodelayarraywrset=new HashSet<TypeDescriptor>();
205       HashSet<FieldDescriptor> nodelayfieldrdset=new HashSet<FieldDescriptor>();
206       HashSet<TypeDescriptor> nodelayarrayrdset=new HashSet<TypeDescriptor>();
207       for(int i=0;i<fn.numNext();i++) {
208         if (nodelaytemps.containsKey(fn.getNext(i)))
209           nodelaytempset.addAll(nodelaytemps.get(fn.getNext(i)));
210         //do field/array write sets
211         if (nodelayfieldswr.containsKey(fn.getNext(i)))
212           nodelayfieldwrset.addAll(nodelayfieldswr.get(fn.getNext(i)));   
213         if (nodelayarrayswr.containsKey(fn.getNext(i)))
214           nodelayarraywrset.addAll(nodelayarrayswr.get(fn.getNext(i)));   
215         //do read sets
216         if (nodelayfieldsrd.containsKey(fn.getNext(i)))
217           nodelayfieldrdset.addAll(nodelayfieldsrd.get(fn.getNext(i)));   
218         if (nodelayarrayswr.containsKey(fn.getNext(i)))
219           nodelayarraywrset.addAll(nodelayarrayswr.get(fn.getNext(i)));   
220       }
221       
222       /* Check our temp write set */
223
224       TempDescriptor writeset[]=fn.writesTemps();
225       for(int i=0;i<writeset.length;i++) {
226         TempDescriptor tmp=writeset[i];
227         if (nodelaytempset.contains(tmp)) {
228           //We are writing to a nodelay temp
229           //Therefore we are nodelay
230           isnodelay=true;
231           //Kill temp we wrote to
232           nodelaytempset.remove(tmp);
233         }
234       }
235       
236       //See if flatnode is definitely no delay
237       if (fn.kind()==FKind.FlatCall) {
238         isnodelay=true;
239         //Have to deal with fields/arrays
240         FlatCall fcall=(FlatCall)fn;
241         MethodDescriptor mdcall=fcall.getMethod();
242         nodelayfieldwrset.addAll(gft.getFieldsAll(mdcall));
243         nodelayarraywrset.addAll(typeanalysis.expandSet(gft.getArraysAll(mdcall)));
244         //Have to deal with field/array reads
245         nodelayfieldrdset.addAll(gft.getFieldsRdAll(mdcall));
246         nodelayarrayrdset.addAll(typeanalysis.expandSet(gft.getArraysRdAll(mdcall)));
247       }
248       
249       // Can't delay branches
250       if (fn.kind()==FKind.FlatCondBranch) {
251         isnodelay=true;
252       }
253
254       //Check for field conflicts
255       if (fn.kind()==FKind.FlatSetFieldNode) {
256         FieldDescriptor fd=((FlatSetFieldNode)fn).getField();
257         //write conflicts
258         if (nodelayfieldwrset.contains(fd))
259           isnodelay=true;
260         //read 
261         if (nodelayfieldrdset.contains(fd))
262           isnodelay=true;
263       }
264
265       if (fn.kind()==FKind.FlatFieldNode) {
266         FieldDescriptor fd=((FlatFieldNode)fn).getField();
267         //write conflicts
268         if (nodelayfieldwrset.contains(fd))
269           isnodelay=true;
270       }
271
272       //Check for array conflicts
273       if (fn.kind()==FKind.FlatSetElementNode) {
274         TypeDescriptor td=((FlatSetElementNode)fn).getDst().getType();
275         //check for write conflicts
276         if (nodelayarraywrset.contains(td))
277           isnodelay=true;
278         //check for read conflicts
279         if (nodelayarrayrdset.contains(td))
280           isnodelay=true;
281       }
282       if (fn.kind()==FKind.FlatElementNode) {
283         TypeDescriptor td=((FlatElementNode)fn).getSrc().getType();
284         //check for write conflicts
285         if (nodelayarraywrset.contains(td))
286           isnodelay=true;
287       }
288       
289       //If we are no delay, then the temps we read are no delay
290       if (isnodelay) {
291         /* Add our read set */
292         TempDescriptor readset[]=fn.readsTemps();
293         for(int i=0;i<readset.length;i++) {
294           TempDescriptor tmp=readset[i];
295           nodelaytempset.add(tmp);
296         }
297         cannotdelay.add(fn);
298
299         /* Do we write to fields */
300         if (fn.kind()==FKind.FlatSetFieldNode) {
301           nodelayfieldwrset.add(((FlatSetFieldNode)fn).getField());
302         }
303         /* Do we read from fields */
304         if (fn.kind()==FKind.FlatFieldNode) {
305           nodelayfieldrdset.add(((FlatFieldNode)fn).getField());
306         }
307
308         /* Do we write to arrays */
309         if (fn.kind()==FKind.FlatSetElementNode) {
310           //have to do expansion
311           nodelayarraywrset.addAll(typeanalysis.expand(((FlatSetElementNode)fn).getDst().getType()));     
312         }
313         /* Do we read from arrays */
314         if (fn.kind()==FKind.FlatElementNode) {
315           //have to do expansion
316           nodelayarrayrdset.addAll(typeanalysis.expand(((FlatElementNode)fn).getSrc().getType()));        
317         }
318       } else {
319         //Need to know which objects to lock on
320         switch(fn.kind()) {
321         case FKind.FlatSetFieldNode: {
322           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
323           nodelaytempset.add(fsfn.getDst());
324           break;
325         }
326         case FKind.FlatSetElementNode: {
327           FlatSetElementNode fsen=(FlatSetElementNode)fn;
328           nodelaytempset.add(fsen.getDst());
329           break;
330         }
331         case FKind.FlatFieldNode: {
332           FlatFieldNode ffn=(FlatFieldNode)fn;
333           nodelaytempset.add(ffn.getSrc());
334           break;
335         }
336         case FKind.FlatElementNode: {
337           FlatElementNode fen=(FlatElementNode)fn;
338           nodelaytempset.add(fen.getSrc());
339           break;
340         }
341         }
342       }
343       
344       boolean changed=false;
345       //See if we need to propagate changes
346       if (!nodelaytemps.containsKey(fn)||
347           !nodelaytemps.get(fn).equals(nodelaytempset)) {
348         nodelaytemps.put(fn, nodelaytempset);
349         changed=true;
350       }
351
352       //See if we need to propagate changes
353       if (!nodelayfieldswr.containsKey(fn)||
354           !nodelayfieldswr.get(fn).equals(nodelayfieldwrset)) {
355         nodelayfieldswr.put(fn, nodelayfieldwrset);
356         changed=true;
357       }
358
359       //See if we need to propagate changes
360       if (!nodelayfieldsrd.containsKey(fn)||
361           !nodelayfieldsrd.get(fn).equals(nodelayfieldrdset)) {
362         nodelayfieldsrd.put(fn, nodelayfieldrdset);
363         changed=true;
364       }
365
366       //See if we need to propagate changes
367       if (!nodelayarrayswr.containsKey(fn)||
368           !nodelayarrayswr.get(fn).equals(nodelayarraywrset)) {
369         nodelayarrayswr.put(fn, nodelayarraywrset);
370         changed=true;
371       }
372
373       //See if we need to propagate changes
374       if (!nodelayarraysrd.containsKey(fn)||
375           !nodelayarraysrd.get(fn).equals(nodelayarrayrdset)) {
376         nodelayarraysrd.put(fn, nodelayarrayrdset);
377         changed=true;
378       }
379
380       if (changed)
381         for(int i=0;i<fn.numPrev();i++)
382           toanalyze.add(fn.getPrev(i));
383     }//end of while loop
384     HashSet<FlatNode> notreadyset=computeNotReadySet(lb, cannotdelay);
385     HashSet<FlatNode> otherset=new HashSet<FlatNode>();
386     otherset.addAll(fm.getNodeSet());
387     if (lb.getHasAtomic()) {
388       otherset.removeAll(notreadyset);
389       otherset.removeAll(cannotdelay);
390       notreadymap.put(lb, notreadyset);
391       cannotdelaymap.put(lb, cannotdelay);
392       othermap.put(lb, otherset);
393     }
394
395     //We now have:
396     //(1) Cannot delay set -- stuff that must be done before commit
397     //(2) Not ready set -- stuff that must wait until commit
398     //(3) everything else -- stuff that should be done before commit
399   } //end of method
400
401   //Problems:
402   //1) we acquire locks too early to object we don't need to yet
403   //2) we don't realize that certain operations have side effects
404
405   public HashSet<FlatNode> computeNotReadySet(LocalityBinding lb, HashSet<FlatNode> cannotdelay) {
406     //You are in not ready set if:
407     //I. You read a not ready temp
408     //II. You access a field or element and
409     //(A). You are not in the cannot delay set
410     //(B). You read a field/element in the transactional set
411     //(C). The source didn't have a transactional read on it
412
413     dcopts=new DiscoverConflicts(locality, state, typeanalysis);
414     dcopts.doAnalysis();
415     MethodDescriptor md=lb.getMethod();
416     FlatMethod fm=state.getMethodFlat(md);
417     Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
418
419     HashSet<FlatNode> notreadynodes=new HashSet<FlatNode>();
420     HashSet<FlatNode> toanalyze=new HashSet<FlatNode>();
421     toanalyze.addAll(fm.getNodeSet());
422     Hashtable<FlatNode, HashSet<TempDescriptor>> notreadymap=new Hashtable<FlatNode, HashSet<TempDescriptor>>();
423     
424     while(!toanalyze.isEmpty()) {
425       FlatNode fn=toanalyze.iterator().next();
426       toanalyze.remove(fn);
427       boolean isatomic=atomictable.get(fn).intValue()>0;
428
429       if (!isatomic)
430         continue;
431
432       //Compute initial notready set
433       HashSet<TempDescriptor> notreadyset=new HashSet<TempDescriptor>();
434       for(int i=0;i<fn.numPrev();i++) {
435         if (notreadymap.containsKey(fn.getPrev(i)))
436           notreadyset.addAll(notreadymap.get(fn.getPrev(i)));
437       }
438       
439       //Are we ready
440       boolean notready=false;
441
442       //Test our read set first
443       TempDescriptor readset[]=fn.readsTemps();
444       for(int i=0;i<readset.length;i++) {
445         TempDescriptor tmp=readset[i];
446         if (notreadyset.contains(tmp)) {
447           notready=true;
448           break;
449         }
450       }
451
452       if (!notready&&!cannotdelay.contains(fn)) {
453         switch(fn.kind()) {
454         case FKind.FlatFieldNode: {
455           FlatFieldNode ffn=(FlatFieldNode)fn;
456           if (!dcopts.getFields().contains(ffn.getField())) {
457             break;
458           }
459           TempDescriptor tmp=ffn.getSrc();
460           Set<TempFlatPair> tfpset=dcopts.getMap(lb).get(fn).get(tmp);
461           if (tfpset!=null) {
462             for(Iterator<TempFlatPair> tfpit=tfpset.iterator();tfpit.hasNext();) {
463               TempFlatPair tfp=tfpit.next();
464               if (!dcopts.getNeedSrcTrans(lb, tfp.f)) {
465                 //if a source didn't need a translation and we are
466                 //accessing it, it did...so therefore we are note
467                 //ready
468                 notready=true;
469                 break;
470               }
471             }
472           }
473           break;
474         }
475         case FKind.FlatSetFieldNode: {
476           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
477           TempDescriptor tmp=fsfn.getDst();
478           Hashtable<TempDescriptor, Set<TempFlatPair>> tmpmap=dcopts.getMap(lb).get(fn);
479           Set<TempFlatPair> tfpset=tmpmap!=null?tmpmap.get(tmp):null;
480
481           if (tfpset!=null) {
482             for(Iterator<TempFlatPair> tfpit=tfpset.iterator();tfpit.hasNext();) {
483               TempFlatPair tfp=tfpit.next();
484               if (!dcopts.getNeedSrcTrans(lb, tfp.f)) {
485                 //if a source didn't need a translation and we are
486                 //accessing it, it did...so therefore we are note
487                 //ready
488                 notready=true;
489                 break;
490               }
491             }
492           }
493           break;
494         }
495         case FKind.FlatElementNode: {
496           FlatElementNode fen=(FlatElementNode)fn;
497           if (!dcopts.getArrays().contains(fen.getSrc().getType())) {
498             break;
499           }
500           TempDescriptor tmp=fen.getSrc();
501           Set<TempFlatPair> tfpset=dcopts.getMap(lb).get(fn).get(tmp);
502           if (tfpset!=null) {
503             for(Iterator<TempFlatPair> tfpit=tfpset.iterator();tfpit.hasNext();) {
504               TempFlatPair tfp=tfpit.next();
505               if (!dcopts.getNeedSrcTrans(lb, tfp.f)) {
506                 //if a source didn't need a translation and we are
507                 //accessing it, it did...so therefore we are note
508                 //ready
509                 notready=true;
510                 break;
511               }
512             }
513           }
514           break;
515         }
516         case FKind.FlatSetElementNode: {
517           FlatSetElementNode fsen=(FlatSetElementNode)fn;
518           TempDescriptor tmp=fsen.getDst();
519           Set<TempFlatPair> tfpset=dcopts.getMap(lb).get(fn).get(tmp);
520           if (tfpset!=null) {
521             for(Iterator<TempFlatPair> tfpit=tfpset.iterator();tfpit.hasNext();) {
522               TempFlatPair tfp=tfpit.next();
523               if (!dcopts.getNeedSrcTrans(lb, tfp.f)) {
524                 //if a source didn't need a translation and we are
525                 //accessing it, it did...so therefore we are note
526                 //ready
527                 notready=true;
528                 break;
529               }
530             }
531           }
532           break;
533         }
534         }
535       }
536
537       //Fix up things based on our status
538       if (notready) {
539         //add us to the list
540         notreadynodes.add(fn);
541         //Add our writes
542         TempDescriptor writeset[]=fn.writesTemps();
543         for(int i=0;i<writeset.length;i++) {
544           TempDescriptor tmp=writeset[i];
545           notreadyset.add(tmp);
546         }
547       } else {
548         //Kill our writes
549         TempDescriptor writeset[]=fn.writesTemps();
550         for(int i=0;i<writeset.length;i++) {
551           TempDescriptor tmp=writeset[i];
552           notreadyset.remove(tmp);
553         }
554       }
555       
556       //See if we need to propagate changes
557       if (!notreadymap.containsKey(fn)||
558           !notreadymap.get(fn).equals(notreadyset)) {
559         notreadymap.put(fn, notreadyset);
560         for(int i=0;i<fn.numNext();i++)
561           toanalyze.add(fn.getNext(i));
562       }
563     } //end of while
564     return notreadynodes;
565   } //end of computeNotReadySet
566 } //end of class