changes to make sure that i don't step on stephen's work on imports...
[IRC.git] / Robust / src / IR / Tree / JavaBuilder.java
1 package IR.Tree;
2 import IR.*;
3 import IR.Tree.*;
4 import IR.Flat.*;
5 import java.util.*;
6 import java.io.*;
7 import Util.Pair;
8
9 public class JavaBuilder {
10   State state;
11   HashSet<Descriptor> checkedDesc=new HashSet<Descriptor>();
12   HashMap<ClassDescriptor, Integer> classStatus=new HashMap<ClassDescriptor, Integer>();
13   public final int CDNONE=0;
14   public final int CDINIT=1;
15   public final int CDINSTANTIATED=2;
16   BuildIR bir;
17   TypeUtil tu;
18   SemanticCheck sc;
19   BuildFlat bf;
20   Stack<MethodDescriptor> toprocess=new Stack<MethodDescriptor>();
21   HashSet<MethodDescriptor> discovered=new HashSet<MethodDescriptor>();
22
23   /* Maps class/interfaces to all instantiated classes that extend or
24    * implement those classes or interfaces */
25
26   HashMap<ClassDescriptor, Set<ClassDescriptor>> implementationMap=new HashMap<ClassDescriptor, Set<ClassDescriptor>>();
27
28   /* Maps methods to the methods they call */
29   
30   HashMap<MethodDescriptor, Set<MethodDescriptor>> callMap=new HashMap<MethodDescriptor, Set<MethodDescriptor>>();
31
32   /* Invocation map */
33   HashMap<ClassDescriptor, Set<Pair<MethodDescriptor, MethodDescriptor>>> invocationMap=new HashMap<ClassDescriptor, Set<Pair<MethodDescriptor, MethodDescriptor>>>();
34   
35
36   public JavaBuilder(State state) {
37     this.state=state;
38     bir=new BuildIR(state);
39     tu=new TypeUtil(state, bir);
40     sc=new SemanticCheck(state, tu, false);
41     bf=new BuildFlat(state,tu);
42   }
43
44   public TypeUtil getTypeUtil() {
45     return tu;
46   }
47
48   public BuildFlat getBuildFlat() {
49     return bf;
50   }
51
52   public void build() {
53     ClassDescriptor mainClass=sc.getClass(null, state.main, SemanticCheck.INIT);
54     MethodDescriptor mainMethod=tu.getMain();
55     toprocess.push(mainMethod);
56     computeFixPoint();
57   }
58
59   void checkMethod(MethodDescriptor md) {
60     try {
61       sc.checkMethodBody(md.getClassDesc(), md);
62     } catch( Error e ) {
63       System.out.println( "Error in "+md );
64       throw e;
65     }
66   }
67   
68   void initClassDesc(ClassDescriptor cd) {
69     if (classStatus.get(cd)==null) {
70       classStatus.put(cd, CDINIT);
71       //TODO...LOOK FOR STATIC INITIALIZERS
72     }
73   }
74   
75   void computeFixPoint() {
76     while(!toprocess.isEmpty()) {
77       MethodDescriptor md=toprocess.pop();
78       checkMethod(md);
79       initClassDesc(md.getClassDesc());
80       bf.flattenMethod(md.getClassDesc(), md);
81       processFlatMethod(md);
82     }
83   }
84   
85   void processCall(MethodDescriptor md, FlatCall fcall) {
86     MethodDescriptor callmd=fcall.getMethod();
87
88     //First handle easy cases...
89     if (callmd.isStatic()||callmd.isConstructor()) {
90       if (!discovered.contains(callmd)) {
91         discovered.add(callmd);
92         toprocess.push(callmd);
93       }
94       callMap.get(md).add(callmd);
95       return;
96     }
97
98     //Otherwise, handle virtual dispatch...
99     ClassDescriptor cn=callmd.getClassDesc();
100     Set<ClassDescriptor> impSet=implementationMap.get(cn);
101
102     if (!invocationMap.containsKey(cn))
103       invocationMap.put(cn, new HashSet<Pair<MethodDescriptor,MethodDescriptor>>());
104     invocationMap.get(cn).add(new Pair<MethodDescriptor, MethodDescriptor>(md, callmd));
105
106     for(ClassDescriptor cdactual:impSet) {
107       searchimp:
108       while(cdactual!=null) {
109         Set possiblematches=cdactual.getMethodTable().getSetFromSameScope(callmd.getSymbol());
110
111         for(Iterator matchit=possiblematches.iterator(); matchit.hasNext();) {
112           MethodDescriptor matchmd=(MethodDescriptor)matchit.next();
113           if (callmd.matches(matchmd)) {
114             //Found the method that will be called
115             if (!discovered.contains(matchmd)) {
116               discovered.add(matchmd);
117               toprocess.push(matchmd);
118             }
119             callMap.get(md).add(matchmd);
120             
121             break searchimp;
122           }
123         }
124
125         //Didn't find method...look in super class
126         cdactual=cdactual.getSuperDesc();
127       }
128     }
129   }
130
131   void processNew(FlatNew fnew) {
132     TypeDescriptor tdnew=fnew.getType();
133     if (!tdnew.isClass())
134       return;
135     ClassDescriptor cdnew=tdnew.getClassDesc();
136     Stack<ClassDescriptor> tovisit=new Stack<ClassDescriptor>();
137     tovisit.add(cdnew);
138     
139     while(!tovisit.isEmpty()) {
140       ClassDescriptor cdcurr=tovisit.pop();
141       if (!implementationMap.containsKey(cdcurr))
142         implementationMap.put(cdcurr, new HashSet<ClassDescriptor>());
143       if (implementationMap.get(cdcurr).add(cdnew)) {
144         //new implementation...see if it affects implementationmap
145         if (invocationMap.containsKey(cdcurr)) {
146           for(Pair<MethodDescriptor, MethodDescriptor> mdpair:invocationMap.get(cdcurr)) {
147             MethodDescriptor md=mdpair.getFirst();
148             MethodDescriptor callmd=mdpair.getSecond();
149             ClassDescriptor cdactual=cdnew;
150             
151             searchimp:
152             while(cdactual!=null) {
153               Set possiblematches=cdactual.getMethodTable().getSetFromSameScope(callmd.getSymbol());
154               for(Iterator matchit=possiblematches.iterator(); matchit.hasNext();) {
155                 MethodDescriptor matchmd=(MethodDescriptor)matchit.next();
156                 if (callmd.matches(matchmd)) {
157                   //Found the method that will be called
158                   if (!discovered.contains(matchmd)) {
159                     discovered.add(matchmd);
160                     toprocess.push(matchmd);
161                   }
162                   callMap.get(md).add(matchmd);
163                   break searchimp;
164                 }
165               }
166               
167               //Didn't find method...look in super class
168               cdactual=cdactual.getSuperDesc();
169             }
170           }
171         }
172       }
173       if (cdcurr.getSuperDesc()!=null)
174         tovisit.push(cdcurr.getSuperDesc());
175       for(Iterator interit=cdcurr.getSuperInterfaces();interit.hasNext();) {
176         ClassDescriptor cdinter=(ClassDescriptor) interit.next();
177         tovisit.push(cdinter);
178       }
179     }
180   }
181
182   void processFlatMethod(MethodDescriptor md) {
183     if (!callMap.containsKey(md))
184       callMap.put(md, new HashSet<MethodDescriptor>());
185     
186     FlatMethod fm=state.getMethodFlat(md);
187     for(FlatNode fn:fm.getNodeSet()) {
188       switch(fn.kind()) {
189       case FKind.FlatCall: {
190         FlatCall fcall=(FlatCall)fn;
191         processCall(md, fcall);
192         break;
193       }
194       case FKind.FlatNew: {
195         FlatNew fnew=(FlatNew)fn;
196         processNew(fnew);
197         break;
198       }
199       }
200     }
201   }
202
203   public static ParseNode readSourceFile(State state, String sourcefile) {
204     try {
205       Reader fr= new BufferedReader(new FileReader(sourcefile));
206       Lex.Lexer l = new Lex.Lexer(fr);
207       java_cup.runtime.lr_parser g;
208       g = new Parse.Parser(l);
209       ParseNode p=null;
210       try {
211         p=(ParseNode) g./*debug_*/parse().value;
212       } catch (Exception e) {
213         System.err.println("Error parsing file:"+sourcefile);
214         e.printStackTrace();
215         System.exit(-1);
216       }
217       state.addParseNode(p);
218       if (l.numErrors()!=0) {
219         System.out.println("Error parsing "+sourcefile);
220         System.exit(l.numErrors());
221       }
222       state.lines+=l.line_num;
223       return p;
224
225     } catch (Exception e) {
226       throw new Error(e);
227     }
228   }
229
230   public void loadClass(BuildIR bir, String sourcefile) {
231     try {
232       ParseNode pn=readSourceFile(state, sourcefile);
233       bir.buildtree(pn, null,sourcefile);
234     } catch (Exception e) {
235       System.out.println("Error in sourcefile:"+sourcefile);
236       e.printStackTrace();
237       System.exit(-1);
238     } catch (Error e) {
239       System.out.println("Error in sourcefile:"+sourcefile);
240       e.printStackTrace();
241       System.exit(-1);
242     }
243   }
244 }