Changes to build code
[IRC.git] / Robust / src / IR / Flat / BuildCode.java
1 package IR.Flat;
2 import IR.*;
3 import java.util.*;
4 import java.io.*;
5
6 public class BuildCode {
7     State state;
8     Hashtable temptovar;
9     Hashtable paramstable;
10     Hashtable tempstable;
11     int tag=0;
12     String localsprefix="__locals__";
13     String paramsprefix="__params__";
14     private static final boolean GENERATEPRECISEGC=true;
15
16     public BuildCode(State st, Hashtable temptovar) {
17         state=st;
18         this.temptovar=temptovar;
19         paramstable=new Hashtable();    
20         tempstable=new Hashtable();
21     }
22
23     public void buildCode() {
24         Iterator it=state.getClassSymbolTable().getDescriptorsIterator();
25         PrintWriter outclassdefs=null;
26         PrintWriter outstructs=null;
27         PrintWriter outmethodheader=null;
28         PrintWriter outmethod=null;
29         try {
30             OutputStream str=new FileOutputStream("structdefs.h");
31             outstructs=new java.io.PrintWriter(str, true);
32             str=new FileOutputStream("methodheaders.h");
33             outmethodheader=new java.io.PrintWriter(str, true);
34             str=new FileOutputStream("classdefs.h");
35             outclassdefs=new java.io.PrintWriter(str, true);
36             str=new FileOutputStream("methods.c");
37             outmethod=new java.io.PrintWriter(str, true);
38         } catch (Exception e) {
39             e.printStackTrace();
40             System.exit(-1);
41         }
42         outstructs.println("#include \"classdefs.h\"");
43         outmethodheader.println("#include \"structdefs.h\"");
44         while(it.hasNext()) {
45             ClassDescriptor cn=(ClassDescriptor)it.next();
46             generateCallStructs(cn, outclassdefs, outstructs, outmethodheader);
47         }
48         outstructs.close();
49         outmethodheader.close();
50
51         /* Build the actual methods */
52         outmethod.println("#include \"methodheaders.h\"");
53         Iterator classit=state.getClassSymbolTable().getDescriptorsIterator();
54         while(classit.hasNext()) {
55             ClassDescriptor cn=(ClassDescriptor)classit.next();
56             generateCallStructs(cn, outclassdefs, outstructs, outmethodheader);
57             Iterator methodit=cn.getMethods();
58             while(methodit.hasNext()) {
59                 /* Classify parameters */
60                 MethodDescriptor md=(MethodDescriptor)methodit.next();
61                 FlatMethod fm=state.getMethodFlat(md);
62                 generateFlatMethod(fm,outmethod);
63             }
64         }
65         outmethod.close();
66     }
67
68     private void generateTempStructs(FlatMethod fm) {
69         MethodDescriptor md=fm.getMethod();
70         ParamsObject objectparams=new ParamsObject(md,tag++);
71         paramstable.put(md, objectparams);
72         for(int i=0;i<fm.numParameters();i++) {
73             TempDescriptor temp=fm.getParameter(i);
74             TypeDescriptor type=temp.getType();
75             if (type.isPtr()&&GENERATEPRECISEGC)
76                 objectparams.addPtr(temp);
77             else
78                 objectparams.addPrim(temp);
79         }
80
81         TempObject objecttemps=new TempObject(objectparams,md,tag++);
82         tempstable.put(md, objecttemps);
83         for(Iterator nodeit=fm.getNodeSet().iterator();nodeit.hasNext();) {
84             FlatNode fn=(FlatNode)nodeit.next();
85             TempDescriptor[] writes=fn.writesTemps();
86             for(int i=0;i<writes.length;i++) {
87                 TempDescriptor temp=writes[i];
88                 TypeDescriptor type=temp.getType();
89                 if (type.isPtr()&&GENERATEPRECISEGC)
90                     objecttemps.addPtr(temp);
91                 else
92                     objecttemps.addPrim(temp);          
93             }
94         }
95     }
96
97     private void generateCallStructs(ClassDescriptor cn, PrintWriter classdefout, PrintWriter output, PrintWriter headersout) {
98         /* Output class structure */
99         Iterator fieldit=cn.getFields();
100         classdefout.println("struct "+cn.getSafeSymbol()+" {");
101         classdefout.println("  int type;");
102         while(fieldit.hasNext()) {
103             FieldDescriptor fd=(FieldDescriptor)fieldit.next();
104             classdefout.println("  "+fd.getType().getSafeSymbol()+" "+fd.getSafeSymbol()+";");
105         }
106         classdefout.println("};\n");
107
108         /* Cycle through methods */
109         Iterator methodit=cn.getMethods();
110         while(methodit.hasNext()) {
111             /* Classify parameters */
112             MethodDescriptor md=(MethodDescriptor)methodit.next();
113             FlatMethod fm=state.getMethodFlat(md);
114             generateTempStructs(fm);
115
116             ParamsObject objectparams=(ParamsObject) paramstable.get(md);
117             TempObject objecttemps=(TempObject) tempstable.get(md);
118
119             /* Output parameter structure */
120             if (GENERATEPRECISEGC) {
121                 output.println("struct "+cn.getSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_params {");
122                 output.println("  int type;");
123                 for(int i=0;i<objectparams.numPointers();i++) {
124                     TempDescriptor temp=objectparams.getPointer(i);
125                     output.println("  struct "+temp.getType().getSafeSymbol()+" * "+temp.getSymbol()+";");
126                 }
127                 output.println("  void * next;");
128                 output.println("};\n");
129             }
130
131             /* Output temp structure */
132             if (GENERATEPRECISEGC) {
133                 output.println("struct "+cn.getSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_temps {");
134                 output.println("  int type;");
135                 for(int i=0;i<objecttemps.numPointers();i++) {
136                     TempDescriptor temp=objecttemps.getPointer(i);
137                     if (temp.getType().isNull())
138                         output.println("  void * "+temp.getSymbol()+";");
139                     else 
140                         output.println("  struct "+temp.getType().getSafeSymbol()+" * "+temp.getSymbol()+";");
141                 }
142                 output.println("  void * next;");
143                 output.println("};\n");
144             }
145             
146             /* Output method declaration */
147             if (md.getReturnType()!=null)
148                 headersout.print(md.getReturnType().getSafeSymbol()+" ");
149             headersout.print(cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"(");
150             
151             boolean printcomma=false;
152             if (GENERATEPRECISEGC) {
153                 headersout.print("struct "+cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_params * "+paramsprefix);
154                 printcomma=true;
155             }
156             for(int i=0;i<objectparams.numPrimitives();i++) {
157                 TempDescriptor temp=objectparams.getPrimitive(i);
158                 if (printcomma)
159                     headersout.print(", ");
160                 printcomma=true;
161                 headersout.print(temp.getType().getSafeSymbol()+" "+temp.getSafeSymbol());
162             }
163             headersout.println(");\n");
164         }
165     }
166
167     private void generateFlatMethod(FlatMethod fm, PrintWriter output) {
168         MethodDescriptor md=fm.getMethod();
169         ClassDescriptor cn=md.getClassDesc();
170         ParamsObject objectparams=(ParamsObject)paramstable.get(md);
171
172         generateHeader(md,output);
173         /* Print code */
174         output.println(" {");
175         
176         if (GENERATEPRECISEGC) {
177             output.println("   struct "+cn.getSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_temps "+localsprefix+";");
178         }
179         TempObject objecttemp=(TempObject) tempstable.get(md);
180         for(int i=0;i<objecttemp.numPrimitives();i++) {
181             TempDescriptor td=objecttemp.getPrimitive(i);
182             TypeDescriptor type=td.getType();
183             if (type.isClass())
184                 output.println("   struct "+type.getSafeSymbol()+" * "+td.getSafeSymbol()+";");
185             else
186                 output.println("   "+type.getSafeSymbol()+" "+td.getSafeSymbol()+";");
187         }
188         
189
190         /* Generate labels first */
191         HashSet tovisit=new HashSet();
192         HashSet visited=new HashSet();
193         int labelindex=0;
194         Hashtable nodetolabel=new Hashtable();
195         tovisit.add(fm.methodEntryNode());
196         FlatNode current_node=null;
197
198
199         //Assign labels 1st
200         //Node needs a label if it is
201         while(!tovisit.isEmpty()) {
202             FlatNode fn=(FlatNode)tovisit.iterator().next();
203             tovisit.remove(fn);
204             visited.add(fn);
205             for(int i=0;i<fn.numNext();i++) {
206                 FlatNode nn=fn.getNext(i);
207                 if(i>0) {
208                     //1) Edge >1 of node
209                     nodetolabel.put(nn,new Integer(labelindex++));
210                 }
211                 if (!visited.contains(nn)) {
212                     tovisit.add(nn);
213                 } else {
214                     //2) Join point
215                     nodetolabel.put(nn,new Integer(labelindex++));
216                 }
217             }
218         }
219
220         //Do the actual code generation
221         tovisit=new HashSet();
222         visited=new HashSet();
223         tovisit.add(fm.methodEntryNode());
224         while(current_node!=null||!tovisit.isEmpty()) {
225             if (current_node==null) {
226                 current_node=(FlatNode)tovisit.iterator().next();
227                 tovisit.remove(current_node);
228             }
229             visited.add(current_node);
230             if (nodetolabel.containsKey(current_node))
231                 output.println("L"+nodetolabel.get(current_node)+":");
232             if (current_node.numNext()==0) {
233                 output.print("   ");
234                 generateFlatNode(fm, current_node, output);
235                 current_node=null;
236             } else if(current_node.numNext()==1) {
237                 output.print("   ");
238                 generateFlatNode(fm, current_node, output);
239                 FlatNode nextnode=current_node.getNext(0);
240                 if (visited.contains(nextnode)) {
241                     output.println("goto L"+nodetolabel.get(nextnode));
242                     current_node=null;
243                 } else
244                     current_node=nextnode;
245             } else if (current_node.numNext()==2) {
246                 /* Branch */
247                 output.print("   ");
248                 generateFlatCondBranch(fm, (FlatCondBranch)current_node, "L"+nodetolabel.get(current_node.getNext(1)), output);
249                 if (!visited.contains(current_node.getNext(1)))
250                     tovisit.add(current_node.getNext(1));
251                 if (visited.contains(current_node.getNext(0))) {
252                     output.println("goto L"+nodetolabel.get(current_node.getNext(0)));
253                     current_node=null;
254                 } else
255                     current_node=current_node.getNext(0);
256             } else throw new Error();
257         }
258         output.println("}\n\n");
259     }
260
261     private String generateTemp(FlatMethod fm, TempDescriptor td) {
262         MethodDescriptor md=fm.getMethod();
263         TempObject objecttemps=(TempObject) tempstable.get(md);
264         if (objecttemps.isLocalPrim(td)||objecttemps.isParamPrim(td)) {
265             return td.getSafeSymbol();
266         }
267
268         if (objecttemps.isLocalPtr(td)) {
269             return localsprefix+"."+td.getSafeSymbol();
270         }
271
272         if (objecttemps.isParamPtr(td)) {
273             return paramsprefix+"->"+td.getSafeSymbol();
274         }
275         throw new Error();
276     }
277
278     private void generateFlatNode(FlatMethod fm, FlatNode fn, PrintWriter output) {
279         switch(fn.kind()) {
280         case FKind.FlatCall:
281             generateFlatCall(fm, (FlatCall) fn,output);
282             return;
283         case FKind.FlatFieldNode:
284             generateFlatFieldNode(fm, (FlatFieldNode) fn,output);
285             return;
286         case FKind.FlatSetFieldNode:
287             generateFlatSetFieldNode(fm, (FlatSetFieldNode) fn,output);
288             return;
289         case FKind.FlatNew:
290             generateFlatNew(fm, (FlatNew) fn,output);
291             return;
292         case FKind.FlatOpNode:
293             generateFlatOpNode(fm, (FlatOpNode) fn,output);
294             return;
295         case FKind.FlatCastNode:
296             generateFlatCastNode(fm, (FlatCastNode) fn,output);
297             return;
298         case FKind.FlatLiteralNode:
299             generateFlatLiteralNode(fm, (FlatLiteralNode) fn,output);
300             return;
301         case FKind.FlatReturnNode:
302             generateFlatReturnNode(fm, (FlatReturnNode) fn,output);
303             return;
304         case FKind.FlatNop:
305             output.println("/* nop */");
306             return;
307         }
308         throw new Error();
309
310     }
311
312     private void generateFlatCall(FlatMethod fm, FlatCall fc, PrintWriter output) {
313         MethodDescriptor md=fm.getMethod();
314         ClassDescriptor cn=md.getClassDesc();
315         output.println("   {");
316         boolean needcomma=false;
317         if (GENERATEPRECISEGC) {
318             output.print("       struct "+cn.getSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_params __paramlist__={");
319             if (fc.getThis()!=null) {
320                 output.print(generateTemp(fm,fc.getThis()));
321                 needcomma=true;
322             }
323             output.println("};");
324         }
325         output.print("       ");
326
327         /* TODO: Virtual dispatch */
328         if (fc.getReturnType()!=null)
329             output.print(generateTemp(fm,fc.getReturnType())+"=");
330         output.print(cn.getSafeSymbol()+md.getSafeSymbol()++"_"+md.getSafeMethodDescriptor()+"(");
331         needcomma=false;
332         if (GENERATEPRECISEGC) {
333             output.println("__parameterlist__");
334             needcomma=true;
335         }
336         output.println(");");
337         output.println("   }");
338     }
339
340     private void generateFlatFieldNode(FlatMethod fm, FlatFieldNode ffn, PrintWriter output) {
341         output.println(generateTemp(fm, ffn.getDst())+"="+ generateTemp(fm,ffn.getSrc())+"->"+ ffn.getField().getSafeSymbol()+";");
342     }
343
344     private void generateFlatSetFieldNode(FlatMethod fm, FlatSetFieldNode fsfn, PrintWriter output) {
345         output.println(generateTemp(fm, fsfn.getDst())+"->"+ fsfn.getField().getSafeSymbol()+"="+ generateTemp(fm,fsfn.getSrc())+";");
346     }
347
348     private void generateFlatNew(FlatMethod fm, FlatNew fn, PrintWriter output) {
349     }
350
351     private void generateFlatOpNode(FlatMethod fm, FlatOpNode fon, PrintWriter output) {
352         if (fon.getOp().getOp()==Operation.ASSIGN)
353             output.println(generateTemp(fm, fon.getDest())+" = "+generateTemp(fm, fon.getLeft())+";");
354         else if (fon.getRight()!=null)
355             output.println(generateTemp(fm, fon.getDest())+" = "+generateTemp(fm, fon.getLeft())+fon.getOp().toString()+generateTemp(fm,fon.getRight())+";");
356         else
357             output.println(generateTemp(fm, fon.getDest())+fon.getOp().toString()+generateTemp(fm, fon.getLeft())+";");
358     }
359
360     private void generateFlatCastNode(FlatMethod fm, FlatCastNode fcn, PrintWriter output) {
361         /* TODO: Make call into runtime */
362         output.println(fcn.getDst()+"=("+fcn.getType().getSafeSymbol()+")"+fcn.getSrc()+";");
363     }
364
365     private void generateFlatLiteralNode(FlatMethod fm, FlatLiteralNode fln, PrintWriter output) {
366     }
367
368     private void generateFlatReturnNode(FlatMethod fm, FlatReturnNode frn, PrintWriter output) {
369         output.println("return "+generateTemp(fm, frn.getReturnTemp())+";");
370     }
371
372     private void generateFlatCondBranch(FlatMethod fm, FlatCondBranch fcb, String label, PrintWriter output) {
373         output.println("if (!"+generateTemp(fm, fcb.getTest())+") goto "+label+";");
374     }
375
376     private void generateHeader(MethodDescriptor md, PrintWriter output) {
377         /* Print header */
378         ParamsObject objectparams=(ParamsObject)paramstable.get(md);
379         ClassDescriptor cn=md.getClassDesc();
380         
381         if (md.getReturnType()!=null)
382             output.print(md.getReturnType().getSafeSymbol()+" ");
383         output.print(cn.getSafeSymbol()+md.getSafeSymbol()++"_"+md.getSafeMethodDescriptor()+"(");
384         
385         boolean printcomma=false;
386         if (GENERATEPRECISEGC) {
387             output.print("struct "+cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_params * "+paramsprefix);
388             printcomma=true;
389         }
390         for(int i=0;i<objectparams.numPrimitives();i++) {
391             TempDescriptor temp=objectparams.getPrimitive(i);
392             if (printcomma)
393                 output.print(", ");
394             printcomma=true;
395             output.print(temp.getType().getSafeSymbol()+" "+temp.getSafeSymbol());
396         }
397         output.print(")");
398     }
399 }