c778b1d69af3443467c8d06bb40b48e63dba27a0
[IRC.git] / Robust / src / IR / Flat / BuildCode.java
1 package IR.Flat;
2 import IR.*;
3 import java.util.*;
4 import java.io.*;
5
6 public class BuildCode {
7     State state;
8     Hashtable temptovar;
9     Hashtable paramstable;
10     Hashtable tempstable;
11     int tag=0;
12     String localsprefix="___locals___";
13     String paramsprefix="___params___";
14     private static final boolean GENERATEPRECISEGC=true;
15
16     public BuildCode(State st, Hashtable temptovar) {
17         state=st;
18         this.temptovar=temptovar;
19         paramstable=new Hashtable();    
20         tempstable=new Hashtable();
21     }
22
23     public void buildCode() {
24         Iterator it=state.getClassSymbolTable().getDescriptorsIterator();
25         PrintWriter outclassdefs=null;
26         PrintWriter outstructs=null;
27         PrintWriter outmethodheader=null;
28         PrintWriter outmethod=null;
29         try {
30             OutputStream str=new FileOutputStream("structdefs.h");
31             outstructs=new java.io.PrintWriter(str, true);
32             str=new FileOutputStream("methodheaders.h");
33             outmethodheader=new java.io.PrintWriter(str, true);
34             str=new FileOutputStream("classdefs.h");
35             outclassdefs=new java.io.PrintWriter(str, true);
36             str=new FileOutputStream("methods.c");
37             outmethod=new java.io.PrintWriter(str, true);
38         } catch (Exception e) {
39             e.printStackTrace();
40             System.exit(-1);
41         }
42         outstructs.println("#include \"classdefs.h\"");
43         outmethodheader.println("#include \"structdefs.h\"");
44
45         // Output the C declarations
46         // These could mutually reference each other
47         while(it.hasNext()) {
48             ClassDescriptor cn=(ClassDescriptor)it.next();
49             outclassdefs.println("struct "+cn.getSafeSymbol()+";");
50         }
51         outclassdefs.println("");
52
53         it=state.getClassSymbolTable().getDescriptorsIterator();
54         while(it.hasNext()) {
55             ClassDescriptor cn=(ClassDescriptor)it.next();
56             generateCallStructs(cn, outclassdefs, outstructs, outmethodheader);
57         }
58         outstructs.close();
59         outmethodheader.close();
60
61         /* Build the actual methods */
62         outmethod.println("#include \"methodheaders.h\"");
63         Iterator classit=state.getClassSymbolTable().getDescriptorsIterator();
64         while(classit.hasNext()) {
65             ClassDescriptor cn=(ClassDescriptor)classit.next();
66             Iterator methodit=cn.getMethods();
67             while(methodit.hasNext()) {
68                 /* Classify parameters */
69                 MethodDescriptor md=(MethodDescriptor)methodit.next();
70                 FlatMethod fm=state.getMethodFlat(md);
71                 generateFlatMethod(fm,outmethod);
72             }
73         }
74         outmethod.close();
75     }
76
77     private void generateTempStructs(FlatMethod fm) {
78         MethodDescriptor md=fm.getMethod();
79         ParamsObject objectparams=new ParamsObject(md,tag++);
80         paramstable.put(md, objectparams);
81         for(int i=0;i<fm.numParameters();i++) {
82             TempDescriptor temp=fm.getParameter(i);
83             TypeDescriptor type=temp.getType();
84             if (type.isPtr()&&GENERATEPRECISEGC)
85                 objectparams.addPtr(temp);
86             else
87                 objectparams.addPrim(temp);
88         }
89
90         TempObject objecttemps=new TempObject(objectparams,md,tag++);
91         tempstable.put(md, objecttemps);
92         for(Iterator nodeit=fm.getNodeSet().iterator();nodeit.hasNext();) {
93             FlatNode fn=(FlatNode)nodeit.next();
94             TempDescriptor[] writes=fn.writesTemps();
95             for(int i=0;i<writes.length;i++) {
96                 TempDescriptor temp=writes[i];
97                 TypeDescriptor type=temp.getType();
98                 if (type.isPtr()&&GENERATEPRECISEGC)
99                     objecttemps.addPtr(temp);
100                 else
101                     objecttemps.addPrim(temp);
102             }
103         }
104     }
105
106     private void generateCallStructs(ClassDescriptor cn, PrintWriter classdefout, PrintWriter output, PrintWriter headersout) {
107         /* Output class structure */
108         Iterator fieldit=cn.getFields();
109         classdefout.println("struct "+cn.getSafeSymbol()+" {");
110         classdefout.println("  int type;");
111         while(fieldit.hasNext()) {
112             FieldDescriptor fd=(FieldDescriptor)fieldit.next();
113             if (fd.getType().isClass())
114                 classdefout.println("  struct "+fd.getType().getSafeSymbol()+" * "+fd.getSafeSymbol()+";");
115             else 
116                 classdefout.println("  "+fd.getType().getSafeSymbol()+" "+fd.getSafeSymbol()+";");
117         }
118         classdefout.println("};\n");
119
120         /* Cycle through methods */
121         Iterator methodit=cn.getMethods();
122         while(methodit.hasNext()) {
123             /* Classify parameters */
124             MethodDescriptor md=(MethodDescriptor)methodit.next();
125             FlatMethod fm=state.getMethodFlat(md);
126             generateTempStructs(fm);
127
128             ParamsObject objectparams=(ParamsObject) paramstable.get(md);
129             TempObject objecttemps=(TempObject) tempstable.get(md);
130
131             /* Output parameter structure */
132             if (GENERATEPRECISEGC) {
133                 output.println("struct "+cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_params {");
134                 output.println("  int type;");
135                 output.println("  void * next;");
136                 for(int i=0;i<objectparams.numPointers();i++) {
137                     TempDescriptor temp=objectparams.getPointer(i);
138                     output.println("  struct "+temp.getType().getSafeSymbol()+" * "+temp.getSafeSymbol()+";");
139                 }
140                 output.println("};\n");
141             }
142
143             /* Output temp structure */
144             if (GENERATEPRECISEGC) {
145                 output.println("struct "+cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_locals {");
146                 output.println("  int type;");
147                 output.println("  void * next;");
148                 for(int i=0;i<objecttemps.numPointers();i++) {
149                     TempDescriptor temp=objecttemps.getPointer(i);
150                     if (temp.getType().isNull())
151                         output.println("  void * "+temp.getSafeSymbol()+";");
152                     else
153                         output.println("  struct "+temp.getType().getSafeSymbol()+" * "+temp.getSafeSymbol()+";");
154                 }
155                 output.println("};\n");
156             }
157             
158             /* Output method declaration */
159             if (md.getReturnType()!=null) {
160                 if (md.getReturnType().isClass())
161                     headersout.print("struct " + md.getReturnType().getSafeSymbol()+" * ");
162                 else
163                     headersout.print(md.getReturnType().getSafeSymbol()+" ");
164             } else 
165                 //catch the constructor case
166                 headersout.print("void ");
167             headersout.print(cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"(");
168             
169             boolean printcomma=false;
170             if (GENERATEPRECISEGC) {
171                 headersout.print("struct "+cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_params * "+paramsprefix);
172                 printcomma=true;
173             }
174             for(int i=0;i<objectparams.numPrimitives();i++) {
175                 TempDescriptor temp=objectparams.getPrimitive(i);
176                 if (printcomma)
177                     headersout.print(", ");
178                 printcomma=true;
179                 headersout.print(temp.getType().getSafeSymbol()+" "+temp.getSafeSymbol());
180             }
181             headersout.println(");\n");
182         }
183     }
184
185     private void generateFlatMethod(FlatMethod fm, PrintWriter output) {
186         MethodDescriptor md=fm.getMethod();
187         ClassDescriptor cn=md.getClassDesc();
188         ParamsObject objectparams=(ParamsObject)paramstable.get(md);
189
190         generateHeader(md,output);
191         /* Print code */
192         output.println(" {");
193         
194         if (GENERATEPRECISEGC) {
195             output.println("   struct "+cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_locals "+localsprefix+";");
196         }
197         TempObject objecttemp=(TempObject) tempstable.get(md);
198         for(int i=0;i<objecttemp.numPrimitives();i++) {
199             TempDescriptor td=objecttemp.getPrimitive(i);
200             TypeDescriptor type=td.getType();
201             if (type.isClass())
202                 output.println("   struct "+type.getSafeSymbol()+" * "+td.getSafeSymbol()+";");
203             else
204                 output.println("   "+type.getSafeSymbol()+" "+td.getSafeSymbol()+";");
205         }
206         
207
208         /* Generate labels first */
209         HashSet tovisit=new HashSet();
210         HashSet visited=new HashSet();
211         int labelindex=0;
212         Hashtable nodetolabel=new Hashtable();
213         tovisit.add(fm.methodEntryNode());
214         FlatNode current_node=null;
215
216         //Assign labels 1st
217         //Node needs a label if it is
218         while(!tovisit.isEmpty()) {
219             FlatNode fn=(FlatNode)tovisit.iterator().next();
220             tovisit.remove(fn);
221             visited.add(fn);
222             for(int i=0;i<fn.numNext();i++) {
223                 FlatNode nn=fn.getNext(i);
224                 if(i>0) {
225                     //1) Edge >1 of node
226                     nodetolabel.put(nn,new Integer(labelindex++));
227                 }
228                 if (!visited.contains(nn)&&!tovisit.contains(nn)) {
229                     tovisit.add(nn);
230                 } else {
231                     //2) Join point
232                     nodetolabel.put(nn,new Integer(labelindex++));
233                 }
234             }
235         }
236
237         //Do the actual code generation
238         tovisit=new HashSet();
239         visited=new HashSet();
240         tovisit.add(fm.methodEntryNode());
241         while(current_node!=null||!tovisit.isEmpty()) {
242             if (current_node==null) {
243                 current_node=(FlatNode)tovisit.iterator().next();
244                 tovisit.remove(current_node);
245             }
246             visited.add(current_node);
247             if (nodetolabel.containsKey(current_node))
248                 output.println("L"+nodetolabel.get(current_node)+":");
249             if (current_node.numNext()==0) {
250                 output.print("   ");
251                 generateFlatNode(fm, current_node, output);
252                 current_node=null;
253             } else if(current_node.numNext()==1) {
254                 output.print("   ");
255                 generateFlatNode(fm, current_node, output);
256                 FlatNode nextnode=current_node.getNext(0);
257                 if (visited.contains(nextnode)) {
258                     output.println("goto L"+nodetolabel.get(nextnode)+";");
259                     current_node=null;
260                 } else
261                     current_node=nextnode;
262             } else if (current_node.numNext()==2) {
263                 /* Branch */
264                 output.print("   ");
265                 generateFlatCondBranch(fm, (FlatCondBranch)current_node, "L"+nodetolabel.get(current_node.getNext(1)), output);
266                 if (!visited.contains(current_node.getNext(1)))
267                     tovisit.add(current_node.getNext(1));
268                 if (visited.contains(current_node.getNext(0))) {
269                     output.println("goto L"+nodetolabel.get(current_node.getNext(0))+";");
270                     current_node=null;
271                 } else
272                     current_node=current_node.getNext(0);
273             } else throw new Error();
274         }
275         output.println("}\n\n");
276     }
277
278     private String generateTemp(FlatMethod fm, TempDescriptor td) {
279         MethodDescriptor md=fm.getMethod();
280         TempObject objecttemps=(TempObject) tempstable.get(md);
281         if (objecttemps.isLocalPrim(td)||objecttemps.isParamPrim(td)) {
282             return td.getSafeSymbol();
283         }
284
285         if (objecttemps.isLocalPtr(td)) {
286             return localsprefix+"."+td.getSafeSymbol();
287         }
288
289         if (objecttemps.isParamPtr(td)) {
290             return paramsprefix+"->"+td.getSafeSymbol();
291         }
292         throw new Error();
293     }
294
295     private void generateFlatNode(FlatMethod fm, FlatNode fn, PrintWriter output) {
296         switch(fn.kind()) {
297         case FKind.FlatCall:
298             generateFlatCall(fm, (FlatCall) fn,output);
299             return;
300         case FKind.FlatFieldNode:
301             generateFlatFieldNode(fm, (FlatFieldNode) fn,output);
302             return;
303         case FKind.FlatSetFieldNode:
304             generateFlatSetFieldNode(fm, (FlatSetFieldNode) fn,output);
305             return;
306         case FKind.FlatNew:
307             generateFlatNew(fm, (FlatNew) fn,output);
308             return;
309         case FKind.FlatOpNode:
310             generateFlatOpNode(fm, (FlatOpNode) fn,output);
311             return;
312         case FKind.FlatCastNode:
313             generateFlatCastNode(fm, (FlatCastNode) fn,output);
314             return;
315         case FKind.FlatLiteralNode:
316             generateFlatLiteralNode(fm, (FlatLiteralNode) fn,output);
317             return;
318         case FKind.FlatReturnNode:
319             generateFlatReturnNode(fm, (FlatReturnNode) fn,output);
320             return;
321         case FKind.FlatNop:
322             output.println("/* nop */");
323             return;
324         }
325         throw new Error();
326
327     }
328
329     private void generateFlatCall(FlatMethod fm, FlatCall fc, PrintWriter output) {
330         MethodDescriptor md=fc.getMethod();
331         ParamsObject objectparams=(ParamsObject) paramstable.get(md);
332         ClassDescriptor cn=md.getClassDesc();
333         output.println("{");
334         if (GENERATEPRECISEGC) {
335             output.print("       struct "+cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_params __parameterlist__={");
336             
337             output.print(objectparams.getUID());
338             output.print(", & "+localsprefix);
339             if (fc.getThis()!=null) {
340                 output.print(", ");
341                 output.print(generateTemp(fm,fc.getThis()));
342             }
343             for(int i=0;i<fc.numArgs();i++) {
344                 VarDescriptor var=md.getParameter(i);
345                 TempDescriptor paramtemp=(TempDescriptor)temptovar.get(var);
346                 if (objectparams.isParamPtr(paramtemp)) {
347                     TempDescriptor targ=fc.getArg(i);
348                     output.print(", ");
349                     output.print(generateTemp(fm, targ));
350                 }
351             }
352             output.println("};");
353         }
354         output.print("       ");
355
356         /* TODO: Virtual dispatch */
357         if (fc.getReturnTemp()!=null)
358             output.print(generateTemp(fm,fc.getReturnTemp())+"=");
359         output.print(cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"(");
360         boolean needcomma=false;
361         if (GENERATEPRECISEGC) {
362             output.print("&__parameterlist__");
363             needcomma=true;
364         }
365         for(int i=0;i<fc.numArgs();i++) {
366             VarDescriptor var=md.getParameter(i);
367             TempDescriptor paramtemp=(TempDescriptor)temptovar.get(var);
368             if (objectparams.isParamPrim(paramtemp)) {
369                 TempDescriptor targ=fc.getArg(i);
370                 if (needcomma)
371                     output.print(", ");
372                 output.print(generateTemp(fm, targ));
373                 needcomma=true;
374             }
375         }
376         output.println(");");
377         output.println("   }");
378     }
379
380     private void generateFlatFieldNode(FlatMethod fm, FlatFieldNode ffn, PrintWriter output) {
381         output.println(generateTemp(fm, ffn.getDst())+"="+ generateTemp(fm,ffn.getSrc())+"->"+ ffn.getField().getSafeSymbol()+";");
382     }
383
384     private void generateFlatSetFieldNode(FlatMethod fm, FlatSetFieldNode fsfn, PrintWriter output) {
385         output.println(generateTemp(fm, fsfn.getDst())+"->"+ fsfn.getField().getSafeSymbol()+"="+ generateTemp(fm,fsfn.getSrc())+";");
386     }
387
388     private void generateFlatNew(FlatMethod fm, FlatNew fn, PrintWriter output) {
389         output.println(generateTemp(fm,fn.getDst())+"=allocate_new("+fn.getType().getClassDesc().getId()+");");
390     }
391
392     private void generateFlatOpNode(FlatMethod fm, FlatOpNode fon, PrintWriter output) {
393
394         if (fon.getRight()!=null)
395             output.println(generateTemp(fm, fon.getDest())+" = "+generateTemp(fm, fon.getLeft())+fon.getOp().toString()+generateTemp(fm,fon.getRight())+";");
396         else if (fon.getOp().getOp()==Operation.ASSIGN)
397             output.println(generateTemp(fm, fon.getDest())+" = "+generateTemp(fm, fon.getLeft())+";");
398         else if (fon.getOp().getOp()==Operation.UNARYPLUS)
399             output.println(generateTemp(fm, fon.getDest())+" = "+generateTemp(fm, fon.getLeft())+";");
400         else if (fon.getOp().getOp()==Operation.UNARYMINUS)
401             output.println(generateTemp(fm, fon.getDest())+" = -"+generateTemp(fm, fon.getLeft())+";");
402         else if (fon.getOp().getOp()==Operation.POSTINC)
403             output.println(generateTemp(fm, fon.getDest())+" = "+generateTemp(fm, fon.getLeft())+"++;");
404         else if (fon.getOp().getOp()==Operation.POSTDEC)
405             output.println(generateTemp(fm, fon.getDest())+" = "+generateTemp(fm, fon.getLeft())+"--;");
406         else if (fon.getOp().getOp()==Operation.PREINC)
407             output.println(generateTemp(fm, fon.getDest())+" = ++"+generateTemp(fm, fon.getLeft())+";");
408         else if (fon.getOp().getOp()==Operation.PREDEC)
409             output.println(generateTemp(fm, fon.getDest())+" = --"+generateTemp(fm, fon.getLeft())+";");
410         else
411             output.println(generateTemp(fm, fon.getDest())+fon.getOp().toString()+generateTemp(fm, fon.getLeft())+";");
412     }
413
414     private void generateFlatCastNode(FlatMethod fm, FlatCastNode fcn, PrintWriter output) {
415         /* TODO: Make call into runtime */
416         output.println(generateTemp(fm,fcn.getDst())+"=("+fcn.getType().getSafeSymbol()+")"+generateTemp(fm,fcn.getSrc())+";");
417     }
418
419     private void generateFlatLiteralNode(FlatMethod fm, FlatLiteralNode fln, PrintWriter output) {
420         if (fln.getValue()==null)
421             output.println(generateTemp(fm, fln.getDst())+"=0;");
422         else if (fln.getType().getSymbol().equals(TypeUtil.StringClass))
423             output.println(generateTemp(fm, fln.getDst())+"=newstring(\""+FlatLiteralNode.escapeString((String)fln.getValue())+"\");");
424         else
425             output.println(generateTemp(fm, fln.getDst())+"="+fln.getValue()+";");
426     }
427
428     private void generateFlatReturnNode(FlatMethod fm, FlatReturnNode frn, PrintWriter output) {
429         output.println("return "+generateTemp(fm, frn.getReturnTemp())+";");
430     }
431
432     private void generateFlatCondBranch(FlatMethod fm, FlatCondBranch fcb, String label, PrintWriter output) {
433         output.println("if (!"+generateTemp(fm, fcb.getTest())+") goto "+label+";");
434     }
435
436     private void generateHeader(MethodDescriptor md, PrintWriter output) {
437         /* Print header */
438         ParamsObject objectparams=(ParamsObject)paramstable.get(md);
439         ClassDescriptor cn=md.getClassDesc();
440         
441         if (md.getReturnType()!=null) {
442             if (md.getReturnType().isClass())
443                 output.print("struct " + md.getReturnType().getSafeSymbol()+" * ");
444             else
445                 output.print(md.getReturnType().getSafeSymbol()+" ");
446         } else 
447             //catch the constructor case
448             output.print("void ");
449
450         output.print(cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"(");
451         
452         boolean printcomma=false;
453         if (GENERATEPRECISEGC) {
454             output.print("struct "+cn.getSafeSymbol()+md.getSafeSymbol()+"_"+md.getSafeMethodDescriptor()+"_params * "+paramsprefix);
455             printcomma=true;
456         }
457         for(int i=0;i<objectparams.numPrimitives();i++) {
458             TempDescriptor temp=objectparams.getPrimitive(i);
459             if (printcomma)
460                 output.print(", ");
461             printcomma=true;
462             output.print(temp.getType().getSafeSymbol()+" "+temp.getSafeSymbol());
463         }
464         output.print(")");
465     }
466 }