adding a test case
[IRC.git] / Robust / src / IR / Tree / JavaBuilder.java
1 package IR.Tree;
2 import IR.*;
3 import IR.Tree.*;
4 import IR.Flat.*;
5 import java.util.*;
6 import java.io.*;
7 import Util.Pair;
8 import Analysis.CallGraph.CallGraph;
9
10 public class JavaBuilder implements CallGraph {
11   State state;
12   HashSet<Descriptor> checkedDesc=new HashSet<Descriptor>();
13   HashMap<ClassDescriptor, Integer> classStatus=new HashMap<ClassDescriptor, Integer>();
14   public final int CDNONE=0;
15   public final int CDINIT=1;
16   public final int CDINSTANTIATED=2;
17   BuildIR bir;
18   TypeUtil tu;
19   SemanticCheck sc;
20   BuildFlat bf;
21   Stack<MethodDescriptor> toprocess=new Stack<MethodDescriptor>();
22   HashSet<MethodDescriptor> discovered=new HashSet<MethodDescriptor>();
23   HashMap<MethodDescriptor, Set<MethodDescriptor>> canCall=new HashMap<MethodDescriptor, Set<MethodDescriptor>>();
24   MethodDescriptor mainMethod;
25
26   /* Maps class/interfaces to all instantiated classes that extend or
27    * implement those classes or interfaces */
28
29   HashMap<ClassDescriptor, Set<ClassDescriptor>> implementationMap=new HashMap<ClassDescriptor, Set<ClassDescriptor>>();
30
31   /* Maps methods to the methods they call */
32
33   HashMap<MethodDescriptor, Set<MethodDescriptor>> callMap=new HashMap<MethodDescriptor, Set<MethodDescriptor>>();
34
35   HashMap<MethodDescriptor, Set<MethodDescriptor>> revCallMap=new HashMap<MethodDescriptor, Set<MethodDescriptor>>();
36
37   /* Invocation map */
38   HashMap<ClassDescriptor, Set<Pair<MethodDescriptor, MethodDescriptor>>> invocationMap=new HashMap<ClassDescriptor, Set<Pair<MethodDescriptor, MethodDescriptor>>>();
39
40   public Set getAllMethods(Descriptor d) {
41     HashSet tovisit=new HashSet();
42     tovisit.add(d);
43     HashSet callable=new HashSet();
44     while(!tovisit.isEmpty()) {
45       Descriptor md=(Descriptor)tovisit.iterator().next();
46       tovisit.remove(md);
47       Set s=getCalleeSet(md);
48
49       if (s!=null) {
50         for(Iterator it=s.iterator(); it.hasNext(); ) {
51           MethodDescriptor md2=(MethodDescriptor)it.next();
52           if( !callable.contains(md2) ) {
53             callable.add(md2);
54             tovisit.add(md2);
55           }
56         }
57       }
58     }
59     return callable;
60   }
61
62   public Set getMethods(MethodDescriptor md, TypeDescriptor type) {
63     if (canCall.containsKey(md))
64       return canCall.get(md);
65     else
66       return new HashSet();
67   }
68
69   public Set getMethods(MethodDescriptor md) {
70     return getMethods(md, null);
71   }
72
73   public Set getMethodCalls(Descriptor d) {
74     Set set=getAllMethods(d);
75     set.add(d);
76     return set;
77   }
78
79   /* Returns whether there is a reachable call to this method descriptor...Not whether the implementation is called */
80
81   public boolean isCalled(MethodDescriptor md) {
82     return canCall.containsKey(md);
83   }
84
85   public boolean isCallable(MethodDescriptor md) {
86     return !getCallerSet(md).isEmpty()||md==mainMethod;
87   }
88
89   public Set getCalleeSet(Descriptor d) {
90     Set calleeset=callMap.get((MethodDescriptor)d);
91     if (calleeset==null)
92       return new HashSet();
93     else
94       return calleeset;
95   }
96
97   public Set getCallerSet(MethodDescriptor md) {
98     Set callerset=revCallMap.get(md);
99     if (callerset==null)
100       return new HashSet();
101     else
102       return callerset;
103   }
104
105   public Set getFirstReachableMethodContainingSESE(Descriptor d,
106                                                    Set<MethodDescriptor> methodsContainingSESEs) {
107     throw new Error("");
108   }
109
110   public boolean hasLayout(ClassDescriptor cd) {
111     return sc.hasLayout(cd);
112   }
113
114   public JavaBuilder(State state) {
115     this.state=state;
116     bir=new BuildIR(state);
117     tu=new TypeUtil(state, bir);
118     sc=new SemanticCheck(state, tu, false);
119     bf=new BuildFlat(state,tu);
120   }
121
122   public TypeUtil getTypeUtil() {
123     return tu;
124   }
125
126   public BuildFlat getBuildFlat() {
127     return bf;
128   }
129
130   public void build() {
131     //Initialize Strings to keep runtime happy
132     ClassDescriptor stringClass=sc.getClass(null, TypeUtil.StringClass, SemanticCheck.INIT);
133     instantiateClass(stringClass);
134
135     ClassDescriptor mainClass=sc.getClass(null, state.main, SemanticCheck.INIT);
136     mainMethod=tu.getMain();
137
138     canCall.put(mainMethod, new HashSet<MethodDescriptor>());
139     canCall.get(mainMethod).add(mainMethod);
140
141     toprocess.push(mainMethod);
142     computeFixPoint();
143     tu.createFullTable();
144   }
145
146   void checkMethod(MethodDescriptor md) {
147     try {
148       sc.checkMethodBody(md.getClassDesc(), md);
149     } catch( Error e ) {
150       System.out.println("Error in "+md);
151       throw e;
152     }
153   }
154
155   public boolean isInit(ClassDescriptor cd) {
156     return classStatus.get(cd)!=null&&classStatus.get(cd)>=CDINIT;
157   }
158
159   void initClassDesc(ClassDescriptor cd, int init) {
160     if (classStatus.get(cd)==null||classStatus.get(cd)<init) {
161       if (classStatus.get(cd)==null) {
162         MethodDescriptor mdstaticinit = (MethodDescriptor)cd.getMethodTable().get("staticblocks");
163         if (mdstaticinit!=null) {
164           discovered.add(mdstaticinit);
165           toprocess.push(mdstaticinit);
166         }
167       }
168       classStatus.put(cd, init);
169     }
170   }
171
172   void computeFixPoint() {
173     while(!toprocess.isEmpty()) {
174       MethodDescriptor md=toprocess.pop();
175       checkMethod(md);
176       initClassDesc(md.getClassDesc(), CDINIT);
177       bf.flattenMethod(md.getClassDesc(), md);
178       processFlatMethod(md);
179     }
180
181     //make sure every called method descriptor has a flat method
182     for(MethodDescriptor callmd : canCall.keySet())
183       bf.addJustFlatMethod(callmd);
184   }
185
186   void processCall(MethodDescriptor md, FlatCall fcall) {
187     MethodDescriptor callmd=fcall.getMethod();
188     //make sure we have a FlatMethod for the base method...
189     if (!canCall.containsKey(callmd))
190       canCall.put(callmd, new HashSet<MethodDescriptor>());
191
192     //First handle easy cases...
193     if (callmd.isStatic()||callmd.isConstructor()) {
194       if (!discovered.contains(callmd)) {
195         discovered.add(callmd);
196         toprocess.push(callmd);
197       }
198       if (!revCallMap.containsKey(callmd))
199         revCallMap.put(callmd, new HashSet<MethodDescriptor>());
200       revCallMap.get(callmd).add(md);
201       callMap.get(md).add(callmd);
202       canCall.get(callmd).add(callmd);
203       return;
204     }
205
206     //Otherwise, handle virtual dispatch...
207     ClassDescriptor cn=callmd.getClassDesc();
208     Set<ClassDescriptor> impSet=implementationMap.get(cn);
209
210     if (!invocationMap.containsKey(cn))
211       invocationMap.put(cn, new HashSet<Pair<MethodDescriptor,MethodDescriptor>>());
212     invocationMap.get(cn).add(new Pair<MethodDescriptor, MethodDescriptor>(md, callmd));
213
214     if (impSet!=null) {
215       for(ClassDescriptor cdactual : impSet) {
216         searchimp :
217         while(cdactual!=null) {
218           Set possiblematches=cdactual.getMethodTable().getSetFromSameScope(callmd.getSymbol());
219
220           for(Iterator matchit=possiblematches.iterator(); matchit.hasNext(); ) {
221             MethodDescriptor matchmd=(MethodDescriptor)matchit.next();
222             if (callmd.matches(matchmd)) {
223               //Found the method that will be called
224               if (!discovered.contains(matchmd)) {
225                 discovered.add(matchmd);
226                 toprocess.push(matchmd);
227               }
228
229               if (!revCallMap.containsKey(matchmd))
230                 revCallMap.put(matchmd, new HashSet<MethodDescriptor>());
231               revCallMap.get(matchmd).add(md);
232
233               callMap.get(md).add(matchmd);
234               canCall.get(callmd).add(matchmd);
235               break searchimp;
236             }
237           }
238
239           //Didn't find method...look in super class
240           cdactual=cdactual.getSuperDesc();
241         }
242       }
243     }
244   }
245
246   void processNew(FlatNew fnew) {
247     TypeDescriptor tdnew=fnew.getType();
248     if (!tdnew.isClass())
249       return;
250     ClassDescriptor cdnew=tdnew.getClassDesc();
251     //Make sure class is fully initialized
252     sc.checkClass(cdnew, SemanticCheck.INIT);
253     instantiateClass(cdnew);
254   }
255
256   void instantiateClass(ClassDescriptor cdnew) {
257     if (classStatus.containsKey(cdnew)&&classStatus.get(cdnew)==CDINSTANTIATED)
258       return;
259     initClassDesc(cdnew, CDINSTANTIATED);
260
261     Stack<ClassDescriptor> tovisit=new Stack<ClassDescriptor>();
262     tovisit.add(cdnew);
263
264     while(!tovisit.isEmpty()) {
265       ClassDescriptor cdcurr=tovisit.pop();
266       if (!implementationMap.containsKey(cdcurr))
267         implementationMap.put(cdcurr, new HashSet<ClassDescriptor>());
268       if (implementationMap.get(cdcurr).add(cdnew)) {
269         //new implementation...see if it affects implementationmap
270         if (invocationMap.containsKey(cdcurr)) {
271           for(Pair<MethodDescriptor, MethodDescriptor> mdpair : invocationMap.get(cdcurr)) {
272             MethodDescriptor md=mdpair.getFirst();
273             MethodDescriptor callmd=mdpair.getSecond();
274             ClassDescriptor cdactual=cdnew;
275
276 searchimp:
277             while(cdactual!=null) {
278               Set possiblematches=cdactual.getMethodTable().getSetFromSameScope(callmd.getSymbol());
279               for(Iterator matchit=possiblematches.iterator(); matchit.hasNext(); ) {
280                 MethodDescriptor matchmd=(MethodDescriptor)matchit.next();
281                 if (callmd.matches(matchmd)) {
282                   //Found the method that will be called
283                   if (!discovered.contains(matchmd)) {
284                     discovered.add(matchmd);
285                     toprocess.push(matchmd);
286                   }
287                   if (!revCallMap.containsKey(matchmd))
288                     revCallMap.put(matchmd, new HashSet<MethodDescriptor>());
289                   revCallMap.get(matchmd).add(md);
290                   callMap.get(md).add(matchmd);
291                   canCall.get(callmd).add(matchmd);
292                   break searchimp;
293                 }
294               }
295
296               //Didn't find method...look in super class
297               cdactual=cdactual.getSuperDesc();
298             }
299           }
300         }
301       }
302       if (cdcurr.getSuperDesc()!=null)
303         tovisit.push(cdcurr.getSuperDesc());
304       for(Iterator interit=cdcurr.getSuperInterfaces(); interit.hasNext(); ) {
305         ClassDescriptor cdinter=(ClassDescriptor) interit.next();
306         tovisit.push(cdinter);
307       }
308     }
309   }
310
311   void processFlatMethod(MethodDescriptor md) {
312     if (!callMap.containsKey(md))
313       callMap.put(md, new HashSet<MethodDescriptor>());
314
315     FlatMethod fm=state.getMethodFlat(md);
316     for(FlatNode fn: fm.getNodeSet()) {
317       switch(fn.kind()) {
318       case FKind.FlatFieldNode: {
319         FieldDescriptor fd=((FlatFieldNode)fn).getField();
320         if (fd.isStatic()) {
321           ClassDescriptor cd=fd.getClassDescriptor();
322           initClassDesc(cd, CDINIT);
323         }
324         break;
325       }
326
327       case FKind.FlatSetFieldNode: {
328         FieldDescriptor fd=((FlatSetFieldNode)fn).getField();
329         if (fd.isStatic()) {
330           ClassDescriptor cd=fd.getClassDescriptor();
331           initClassDesc(cd, CDINIT);
332         }
333         break;
334       }
335
336       case FKind.FlatCall: {
337         FlatCall fcall=(FlatCall)fn;
338         processCall(md, fcall);
339         break;
340       }
341
342       case FKind.FlatNew: {
343         FlatNew fnew=(FlatNew)fn;
344         processNew(fnew);
345         break;
346       }
347       }
348     }
349   }
350 }