get compiler side of STM working
[IRC.git] / Robust / src / Analysis / Loops / LoopOptimize.java
1 package Analysis.Loops;
2
3 import IR.Flat.*;
4 import IR.TypeUtil;
5 import IR.MethodDescriptor;
6 import IR.Operation;
7 import java.util.Set;
8 import java.util.Vector;
9 import java.util.Iterator;
10 import java.util.Hashtable;
11
12 public class LoopOptimize {
13   LoopInvariant loopinv;
14   public LoopOptimize(GlobalFieldType gft, TypeUtil typeutil) {
15     loopinv=new LoopInvariant(typeutil,gft);
16   }
17   public void optimize(FlatMethod fm) {
18     loopinv.analyze(fm);
19     dooptimize(fm);
20   } 
21   private void dooptimize(FlatMethod fm) {
22     Loops root=loopinv.root;
23     recurse(root);
24   }
25   private void recurse(Loops parent) {
26     for(Iterator lpit=parent.nestedLoops().iterator();lpit.hasNext();) {
27       Loops child=(Loops)lpit.next();
28       processLoop(child);
29       recurse(child);
30     }
31   }
32   public void processLoop(Loops l) {
33     if (loopinv.tounroll.contains(l)) {
34       unrollLoop(l);
35     } else {
36       hoistOps(l);
37     }
38   }
39   public void hoistOps(Loops l) {
40     Set entrances=l.loopEntrances();
41     assert entrances.size()==1;
42     FlatNode entrance=(FlatNode)entrances.iterator().next();
43     Vector<FlatNode> tohoist=loopinv.table.get(entrance);
44     Set lelements=l.loopIncElements();
45     TempMap t=new TempMap();
46     TempMap tnone=new TempMap();
47     FlatNode first=null;
48     FlatNode last=null;
49     if (tohoist.size()==0)
50       return;
51
52     for(int i=0;i<tohoist.size();i++) {
53       FlatNode fn=tohoist.elementAt(i);
54       TempDescriptor[] writes=fn.writesTemps();
55       FlatNode fnnew=fn.clone(tnone);
56
57       fnnew.rewriteUse(t);
58
59       for(int j=0;j<writes.length;j++) {
60         if (writes[j]!=null) {
61           TempDescriptor cp=writes[j].createNew();
62           t.addPair(writes[j],cp);
63         }
64       }
65       fnnew.rewriteDef(t);
66
67       if (first==null)
68         first=fnnew;
69       else
70         last.addNext(fnnew);
71       last=fnnew;
72       /* Splice out old node */
73       if (writes.length==1) {
74         FlatOpNode fon=new FlatOpNode(writes[0], t.tempMap(writes[0]), null, new Operation(Operation.ASSIGN));
75         fn.replace(fon);
76         if (fn==entrance)
77           entrance=fon;
78       } else if (writes.length>1) {
79         throw new Error();
80       }
81     }
82     /* The chain is built at this point. */
83     
84     FlatNode[] prevarray=new FlatNode[entrance.numPrev()];
85     for(int i=0;i<entrance.numPrev();i++) {
86       prevarray[i]=entrance.getPrev(i);
87     }
88     for(int i=0;i<prevarray.length;i++) {
89       FlatNode prev=prevarray[i];
90
91       if (!lelements.contains(prev)) {
92         //need to fix this edge
93         for(int j=0;j<prev.numNext();j++) {
94           if (prev.getNext(j)==entrance)
95             prev.setNext(j, first);
96         }
97       }
98     }
99     last.addNext(entrance);
100   }
101   public void unrollLoop(Loops l) {
102     assert l.loopEntrances().size()==1;
103     FlatNode entrance=(FlatNode)l.loopEntrances().iterator().next();
104     Set lelements=l.loopIncElements();
105     Set<FlatNode> tohoist=loopinv.hoisted;
106     Hashtable<FlatNode, TempDescriptor> temptable=new Hashtable<FlatNode, TempDescriptor>();
107     Hashtable<FlatNode, FlatNode> copytable=new Hashtable<FlatNode, FlatNode>();
108     Hashtable<FlatNode, FlatNode> copyendtable=new Hashtable<FlatNode, FlatNode>();
109     
110     TempMap t=new TempMap();
111     /* Copy the nodes */
112     for(Iterator it=lelements.iterator();it.hasNext();) {
113       FlatNode fn=(FlatNode)it.next();
114       FlatNode copy=fn.clone(t);
115       FlatNode copyend=copy;
116       if (tohoist.contains(fn)) {
117         TempDescriptor[] writes=fn.writesTemps();
118         TempDescriptor tmp=writes[0];
119         TempDescriptor ntmp=tmp.createNew();
120         temptable.put(fn, ntmp);
121         copyend=new FlatOpNode(ntmp, tmp, null, new Operation(Operation.ASSIGN));
122         copy.addNext(copyend);
123       }
124       copytable.put(fn, copy);
125       copyendtable.put(fn, copyend);
126     }
127     /* Copy the edges */
128     for(Iterator it=lelements.iterator();it.hasNext();) {
129       FlatNode fn=(FlatNode)it.next();
130       FlatNode copyend=copyendtable.get(fn);
131       for(int i=0;i<fn.numNext();i++) {
132         FlatNode nnext=fn.getNext(i);
133         if (nnext==entrance) {
134           /* Back to loop header...point to old graph */
135           copyend.addNext(nnext);
136         } else if (lelements.contains(nnext)) {
137           /* In graph...point to first graph */
138           copyend.addNext(copytable.get(nnext));
139         } else {
140           /* Outside loop */
141           /* Just goto same place as before */
142           copyend.addNext(nnext);
143         }
144       }
145     }
146     /* Splice out loop invariant stuff */
147     for(Iterator it=lelements.iterator();it.hasNext();) {
148       FlatNode fn=(FlatNode)it.next();
149       if (tohoist.contains(fn)) {
150         TempDescriptor[] writes=fn.writesTemps();
151         TempDescriptor tmp=writes[0];
152         FlatOpNode fon=new FlatOpNode(temptable.get(fn),tmp, null, new Operation(Operation.ASSIGN));
153         fn.replace(fon);
154       }
155     }
156   }
157 }