bug fix
[IRC.git] / Robust / src / IR / Flat / Inliner.java
1 package IR.Flat;
2 import java.util.Hashtable;
3 import java.util.Set;
4 import java.util.HashSet;
5 import java.util.Stack;
6 import java.util.Iterator;
7 import IR.ClassDescriptor;
8 import IR.Operation;
9 import IR.State;
10 import IR.TypeUtil;
11 import IR.MethodDescriptor;
12
13 public class Inliner {
14   public static void inlineAtomic(State state, TypeUtil typeutil, FlatMethod fm, int depth) {
15     Stack<FlatNode> toprocess=new Stack<FlatNode>();
16     HashSet<FlatNode> visited=new HashSet<FlatNode>();
17     Hashtable<FlatNode, Integer> atomictable=new Hashtable<FlatNode, Integer>();
18     HashSet<FlatNode> atomicset=new HashSet<FlatNode>();
19
20     toprocess.push(fm);
21     visited.add(fm);
22     atomictable.put(fm, new Integer(0));
23     while(!toprocess.isEmpty()) {
24       FlatNode fn=toprocess.pop();
25       int atomicval=atomictable.get(fn).intValue();
26       if (fn.kind()==FKind.FlatAtomicEnterNode)
27         atomicval++;
28       else if(fn.kind()==FKind.FlatAtomicExitNode)
29         atomicval--;
30       for(int i=0;i<fn.numNext();i++) {
31         FlatNode fnext=fn.getNext(i);
32         if (!visited.contains(fnext)) {
33           atomictable.put(fnext, new Integer(atomicval));
34           if (atomicval>0)
35             atomicset.add(fnext);
36           visited.add(fnext);
37           toprocess.push(fnext);
38         }
39       }
40     }
41     //make depth 0 be depth infinity
42     if (depth==0)
43       depth=10000000;
44     recursive(state, typeutil, atomicset, depth, new Stack<MethodDescriptor>());
45   }
46   
47
48   public static void recursive(State state, TypeUtil typeutil, Set<FlatNode> fnset, int depth, Stack<MethodDescriptor> toexclude) {
49     for(Iterator<FlatNode> fnit=fnset.iterator();fnit.hasNext();) {
50       FlatNode fn=fnit.next();
51       if (fn.kind()==FKind.FlatCall) {
52         FlatCall fc=(FlatCall)fn;
53         MethodDescriptor md=fc.getMethod();
54
55         if (toexclude.contains(md))
56           continue;
57
58         Set<FlatNode> inlinefnset=inline(fc, typeutil, state);
59         toexclude.push(md);
60         if (depth>1)
61           recursive(state, typeutil, inlinefnset, depth-1, toexclude);
62         toexclude.pop();
63       }
64     }
65   }
66
67   public static Set<FlatNode> inline(FlatCall fc, TypeUtil typeutil, State state) {
68     MethodDescriptor md=fc.getMethod();
69     /* Do we need to do virtual dispatch? */
70     if (md.isStatic()||md.getReturnType()==null||singleCall(typeutil, fc.getThis().getType().getClassDesc(),md)) {
71       //just reuse temps...makes problem with inlining recursion
72       TempMap clonemap=new TempMap();
73       Hashtable<FlatNode, FlatNode> flatmap=new Hashtable<FlatNode, FlatNode>();
74       TempDescriptor rettmp=fc.getReturnTemp();
75       FlatNode aftercallnode=fc.getNext(0);
76       aftercallnode.removePrev(fc);
77
78       FlatMethod fm=state.getMethodFlat(md);
79       //Clone nodes
80       Set<FlatNode> nodeset=fm.getNodeSet();
81       nodeset.remove(fm);
82
83       HashSet<FlatNode> newnodes=new HashSet<FlatNode>();
84
85       //Build the clones
86       for(Iterator<FlatNode> fnit=nodeset.iterator();fnit.hasNext();) {
87         FlatNode fn=fnit.next();
88         if (fn.kind()==FKind.FlatReturnNode) {
89           //Convert FlatReturn node into move
90           TempDescriptor rtmp=((FlatReturnNode)fn).getReturnTemp();
91           if (rtmp!=null) {
92             FlatOpNode fon=new FlatOpNode(rettmp, rtmp, null, new Operation(Operation.ASSIGN));
93             flatmap.put(fn, fon);
94           } else {
95             flatmap.put(fn, aftercallnode);
96           }
97         } else {
98           FlatNode clone=fn.clone(clonemap);
99           newnodes.add(clone);
100           flatmap.put(fn,clone);
101         }
102       }
103       //Build the move chain
104       FlatNode first=new FlatNop();;
105       newnodes.add(first);
106       FlatNode last=first;
107       {
108         int i=0;
109         if (fc.getThis()!=null) {
110           FlatOpNode fon=new FlatOpNode(fm.getParameter(i++), fc.getThis(), null, new Operation(Operation.ASSIGN));
111           newnodes.add(fon);
112           last.addNext(fon);
113           last=fon;
114         }
115         for(int j=0;j<fc.numArgs();i++,j++) {
116           FlatOpNode fon=new FlatOpNode(fm.getParameter(i), fc.getArg(j), null, new Operation(Operation.ASSIGN));
117           newnodes.add(fon);
118           last.addNext(fon);
119           last=fon;
120         }
121       }
122
123       //Add the edges
124       for(Iterator<FlatNode> fnit=nodeset.iterator();fnit.hasNext();) {
125         FlatNode fn=fnit.next();
126         FlatNode fnclone=flatmap.get(fn);
127
128         if (fn.kind()!=FKind.FlatReturnNode) {
129           //don't build old edges out of a flat return node
130           for(int i=0;i<fn.numNext();i++) {
131             FlatNode fnnext=fn.getNext(i);
132             FlatNode fnnextclone=flatmap.get(fnnext);
133             fnclone.setNewNext(i, fnnextclone);
134           }
135         } else {
136           if (fnclone!=aftercallnode)
137             fnclone.addNext(aftercallnode);
138         }
139       }
140
141       //Add edges to beginning of move chain
142       for(int i=0;i<fc.numPrev();i++) {
143         FlatNode fnprev=fc.getPrev(i);
144         for(int j=0;j<fnprev.numNext();j++) {
145           if (fnprev.getNext(j)==fc) {
146             //doing setnewnext to avoid changing the node we are
147             //iterating over
148             fnprev.setNewNext(j, first);
149             break;
150           }
151         }
152       }
153
154       //Add in the edge from move chain to callee
155       last.addNext(flatmap.get(fm.getNext(0)));
156       return newnodes;
157     } else return null;
158   }
159
160   private static boolean singleCall(TypeUtil typeutil, ClassDescriptor thiscd, MethodDescriptor md) {
161     Set subclasses=typeutil.getSubClasses(thiscd);
162     if (subclasses==null)
163       return true;
164     for(Iterator classit=subclasses.iterator(); classit.hasNext();) {
165       ClassDescriptor cd=(ClassDescriptor)classit.next();
166       Set possiblematches=cd.getMethodTable().getSet(md.getSymbol());
167       for(Iterator matchit=possiblematches.iterator(); matchit.hasNext();) {
168         MethodDescriptor matchmd=(MethodDescriptor)matchit.next();
169         if (md.matches(matchmd))
170           return false;
171       }
172     }
173     return true;
174   }
175 }