switch to spaces only..
[IRC.git] / Robust / src / Analysis / Locality / BranchAnalysis.java
1 package Analysis.Locality;
2 import IR.State;
3 import IR.Flat.*;
4 import java.util.*;
5 import java.io.*;
6
7 public class BranchAnalysis {
8   LocalityAnalysis locality;
9   State state;
10   public BranchAnalysis(LocalityAnalysis locality, LocalityBinding lb, Set<FlatNode> nodeset, Set<FlatNode> storeset, State state) {
11     this.locality=locality;
12     this.state=state;
13     doAnalysis(lb, nodeset, storeset);
14   }
15
16   Hashtable<Set<FlatNode>, Vector<FlatNode>> table=new Hashtable<Set<FlatNode>, Vector<FlatNode>>();
17   Hashtable<FlatNode, FlatNode[]> fnmap;
18   Hashtable<FlatNode, Set<FlatNode>> groupmap;
19
20   public int jumpValue(FlatNode fn, int i) {
21     FlatNode next=fnmap.get(fn)[i];
22     Set<FlatNode> group=groupmap.get(fn);
23     if (group==null)
24       return -1;
25     while (next.numNext()==1&&group.contains(next)) {
26       next=fnmap.get(next)[0];
27     }
28     if (group.contains(next))
29       return -1;
30     Vector<FlatNode> exits=table.get(group);
31     int exit=exits.indexOf(next);
32     if (exit<0)
33       throw new Error();
34     return exit;
35   }
36
37   public int numJumps(FlatNode fn) {
38     Set<FlatNode> group=groupmap.get(fn);
39     if (group==null)
40       return -1;
41     Vector<FlatNode> exits=table.get(group);
42     return exits.size();
43   }
44
45   public Vector<FlatNode> getJumps(FlatNode fn) {
46     Set<FlatNode> group=groupmap.get(fn);
47     if (group==null)
48       throw new Error();
49     Vector<FlatNode> exits=table.get(group);
50     return exits;
51   }
52
53   public Set<FlatNode> getTargets() {
54     HashSet<FlatNode> targets=new HashSet<FlatNode>();
55     Collection<Set<FlatNode>> groups=groupmap.values();
56     for(Iterator<Set<FlatNode>> setit=groups.iterator(); setit.hasNext(); ) {
57       Set<FlatNode> group=setit.next();
58       targets.addAll(table.get(group));
59     }
60     return targets;
61   }
62
63   int grouplabelindex=0;
64
65   public boolean hasGroup(FlatNode fn) {
66     return groupmap.contains(fn);
67   }
68
69   Hashtable<Set<FlatNode>, String> grouplabel=new Hashtable<Set<FlatNode>, String>();
70
71   private boolean seenGroup(FlatNode fn) {
72     return grouplabel.containsKey(groupmap.get(fn));
73   }
74
75   private String getGroup(FlatNode fn) {
76     if (!grouplabel.containsKey(groupmap.get(fn)))
77       grouplabel.put(groupmap.get(fn), new String("LG"+(grouplabelindex++)));
78     return grouplabel.get(groupmap.get(fn));
79   }
80
81   public void generateGroupCode(FlatNode fn, PrintWriter output, Hashtable<FlatNode, Integer> nodetolabels) {
82     if (seenGroup(fn)) {
83       String label=getGroup(fn);
84       output.println("goto "+label+";");
85     } else {
86       String label=getGroup(fn);
87       output.println(label+":");
88       if (numJumps(fn)==1) {
89         FlatNode fndst=getJumps(fn).get(0);
90         output.println("goto L"+nodetolabels.get(fndst)+";");
91       } else if (numJumps(fn)==2) {
92         Vector<FlatNode> exits=getJumps(fn);
93         output.println("if(RESTOREBRANCH())");
94         output.println("goto L"+nodetolabels.get(exits.get(1))+";");
95         output.println("else");
96         output.println("goto L"+nodetolabels.get(exits.get(0))+";");
97       } else {
98         Vector<FlatNode> exits=getJumps(fn);
99         output.println("switch(RESTOREBRANCH()) {");
100         for(int i=0; i<exits.size(); i++) {
101           output.println("case "+i+":");
102           output.println("goto L"+nodetolabels.get(exits.get(i))+";");
103         }
104         output.println("}");
105       }
106     }
107   }
108
109   public void doAnalysis(LocalityBinding lb, Set<FlatNode> nodeset, Set<FlatNode> storeset) {
110     Set<FlatNode> transset=computeTransSet(lb);
111     fnmap=computeMap(transset, nodeset, storeset);
112     groupmap=new Hashtable<FlatNode, Set<FlatNode>>();
113
114     for(Iterator<FlatNode> fnit=transset.iterator(); fnit.hasNext(); ) {
115       FlatNode fn=fnit.next();
116       if ((fn.numNext()>1&&storeset.contains(fn))||fn.kind()==FKind.FlatBackEdge||fn.kind()==FKind.FlatNop) {
117         FlatNode[] children=fnmap.get(fn);
118         if (children==null)
119           continue;
120         if (!groupmap.containsKey(fn)) {
121           groupmap.put(fn, new HashSet<FlatNode>());
122           groupmap.get(fn).add(fn);
123         }
124         for(int i=0; i<children.length; i++) {
125           FlatNode child=children[i];
126           if ((child.numNext()>1&&storeset.contains(child))||child.kind()==FKind.FlatBackEdge||child.kind()==FKind.FlatNop) {
127             mergegroups(fn, child, groupmap);
128           }
129         }
130       }
131     }
132     //now we have groupings...
133     Collection<Set<FlatNode>> groups=groupmap.values();
134     for(Iterator<Set<FlatNode>> setit=groups.iterator(); setit.hasNext(); ) {
135       Set<FlatNode> group=setit.next();
136       Vector<FlatNode> exits=new Vector<FlatNode>();
137       table.put(group, exits);
138       for(Iterator<FlatNode> fnit=group.iterator(); fnit.hasNext(); ) {
139         FlatNode fn=fnit.next();
140         FlatNode[] nextnodes=fnmap.get(fn);
141         for(int i=0; i<nextnodes.length; i++) {
142           FlatNode nextnode=nextnodes[i];
143           if (!group.contains(nextnode)) {
144             //outside edge
145             if (!exits.contains(nextnode)) {
146               exits.add(nextnode);
147             }
148           }
149         }
150       }
151     }
152   }
153
154   public void mergegroups(FlatNode fn1, FlatNode fn2, Hashtable<FlatNode, Set<FlatNode>> groupmap) {
155     if (!groupmap.containsKey(fn1)) {
156       groupmap.put(fn1, new HashSet<FlatNode>());
157       groupmap.get(fn1).add(fn1);
158     }
159     if (!groupmap.containsKey(fn2)) {
160       groupmap.put(fn2, new HashSet<FlatNode>());
161       groupmap.get(fn2).add(fn2);
162     }
163     if (groupmap.get(fn1)!=groupmap.get(fn2)) {
164       groupmap.get(fn1).addAll(groupmap.get(fn2));
165       for(Iterator<FlatNode> fnit=groupmap.get(fn2).iterator(); fnit.hasNext(); ) {
166         FlatNode fn3=fnit.next();
167         groupmap.put(fn3, groupmap.get(fn1));
168       }
169     }
170   }
171
172   public Hashtable<FlatNode, FlatNode[]> computeMap(Set<FlatNode> transset, Set<FlatNode> nodeset, Set<FlatNode> storeset) {
173     Set<FlatNode> toprocess=new HashSet<FlatNode>();
174     toprocess.addAll(transset);
175     Hashtable<FlatNode, Set<Object[]>> fntotuple=new Hashtable<FlatNode, Set<Object[]>>();
176     Hashtable<FlatNode, FlatNode[]> fnmap=new Hashtable<FlatNode, FlatNode[]>();
177     while(!toprocess.isEmpty()) {
178       FlatNode fn=toprocess.iterator().next();
179       toprocess.remove(fn);
180       Set<Object[]> incomingtuples=new HashSet<Object[]>();
181
182       for(int i=0; i<fn.numPrev(); i++) {
183         FlatNode fprev=fn.getPrev(i);
184         if (nodeset.contains(fprev)||storeset.contains(fprev)) {
185           for(int j=0; j<fprev.numNext(); j++) {
186             if (fprev.getNext(j)==fn) {
187               Object[] pair=new Object[2];
188               pair[0]=new Integer(j); pair[1]=fprev;
189               incomingtuples.add(pair);
190             }
191           }
192         } else {
193           Set<Object[]> tuple=fntotuple.get(fprev);
194           if (tuple!=null)
195             incomingtuples.addAll(tuple);
196         }
197       }
198
199       if (nodeset.contains(fn)||storeset.contains(fn)||fn.kind()==FKind.FlatAtomicExitNode) {
200         //nodeset contains this node
201         for(Iterator<Object[]> it=incomingtuples.iterator(); it.hasNext(); ) {
202           Object[] pair=it.next();
203           int index=((Integer)pair[0]).intValue();
204           FlatNode node=(FlatNode)pair[1];
205           if (!fnmap.containsKey(node))
206             fnmap.put(node, new FlatNode[node.numNext()]);
207           fnmap.get(node)[index]=fn;
208         }
209         incomingtuples=new HashSet<Object[]>();
210       }
211
212       //add if we need to update
213       if (!fntotuple.containsKey(fn)||
214           !fntotuple.get(fn).equals(incomingtuples)) {
215         fntotuple.put(fn,incomingtuples);
216         for(int i=0; i<fn.numNext(); i++) {
217           if (transset.contains(fn.getNext(i)))
218             toprocess.add(fn.getNext(i));
219         }
220       }
221     }
222     return fnmap;
223   }
224
225
226   public Set<FlatNode> computeTransSet(LocalityBinding lb) {
227     Set<FlatNode> transset=new HashSet();
228     Set<FlatNode> tovisit=new HashSet();
229     tovisit.addAll(state.getMethodFlat(lb.getMethod()).getNodeSet());
230     while(!tovisit.isEmpty()) {
231       FlatNode fn=tovisit.iterator().next();
232       tovisit.remove(fn);
233       if (locality.getAtomic(lb).get(fn).intValue()>0||fn.kind()==FKind.FlatAtomicExitNode)
234         transset.add(fn);
235     }
236     return transset;
237   }
238 }