e875918950177744cfdade17b7cee65dff397deb
[IRC.git] / Robust / src / Analysis / Loops / LoopOptimize.java
1 package Analysis.Loops;
2
3 import IR.Flat.*;
4 import IR.TypeUtil;
5 import IR.Operation;
6 import java.util.Set;
7 import java.util.Vector;
8 import java.util.Iterator;
9 import java.util.Hashtable;
10
11 public class LoopOptimize {
12   LoopInvariant loopinv;
13   public LoopOptimize(TypeUtil typeutil) {
14     loopinv=new LoopInvariant(typeutil);
15   }
16   public void optimize(FlatMethod fm) {
17     loopinv.analyze(fm);
18     dooptimize(fm);
19   } 
20   private void dooptimize(FlatMethod fm) {
21     Loops root=loopinv.root;
22     recurse(root);
23   }
24   private void recurse(Loops parent) {
25     for(Iterator lpit=parent.nestedLoops().iterator();lpit.hasNext();) {
26       Loops child=(Loops)lpit.next();
27       processLoop(child);
28       recurse(child);
29     }
30   }
31   public void processLoop(Loops l) {
32     if (loopinv.tounroll.contains(l)) {
33       unrollLoop(l);
34     } else {
35       hoistOps(l);
36     }
37   }
38   public void hoistOps(Loops l) {
39     Set entrances=l.loopEntrances();
40     assert entrances.size()==1;
41     FlatNode entrance=(FlatNode)entrances.iterator().next();
42     Vector<FlatNode> tohoist=loopinv.table.get(entrance);
43     Set lelements=l.loopIncElements();
44     TempMap t=new TempMap();
45     TempMap tnone=new TempMap();
46     FlatNode first=null;
47     FlatNode last=null;
48     if (tohoist.size()==0)
49       return;
50
51     for(int i=0;i<tohoist.size();i++) {
52       FlatNode fn=tohoist.elementAt(i);
53       TempDescriptor[] writes=fn.writesTemps();
54       FlatNode fnnew=fn.clone(tnone);
55
56       fnnew.rewriteUse(t);
57
58       for(int j=0;j<writes.length;j++) {
59         if (writes[j]!=null) {
60           TempDescriptor cp=writes[j].createNew();
61           t.addPair(writes[j],cp);
62         }
63       }
64       fnnew.rewriteDef(t);
65
66       if (first==null)
67         first=fnnew;
68       else
69         last.addNext(fnnew);
70       last=fnnew;
71       /* Splice out old node */
72       if (writes.length==1) {
73         FlatOpNode fon=new FlatOpNode(writes[0], t.tempMap(writes[0]), null, new Operation(Operation.ASSIGN));
74         fn.replace(fon);
75         if (fn==entrance)
76           entrance=fon;
77       } else if (writes.length>1) {
78         throw new Error();
79       }
80     }
81     /* The chain is built at this point. */
82     
83     FlatNode[] prevarray=new FlatNode[entrance.numPrev()];
84     for(int i=0;i<entrance.numPrev();i++) {
85       prevarray[i]=entrance.getPrev(i);
86     }
87     for(int i=0;i<prevarray.length;i++) {
88       FlatNode prev=prevarray[i];
89
90       if (!lelements.contains(prev)) {
91         //need to fix this edge
92         for(int j=0;j<prev.numNext();j++) {
93           if (prev.getNext(j)==entrance)
94             prev.setNext(j, first);
95         }
96       }
97     }
98     last.addNext(entrance);
99   }
100   public void unrollLoop(Loops l) {
101     assert l.loopEntrances().size()==1;
102     FlatNode entrance=(FlatNode)l.loopEntrances().iterator().next();
103     Set lelements=l.loopIncElements();
104     Set<FlatNode> tohoist=loopinv.hoisted;
105     Hashtable<FlatNode, TempDescriptor> temptable=new Hashtable<FlatNode, TempDescriptor>();
106     Hashtable<FlatNode, FlatNode> copytable=new Hashtable<FlatNode, FlatNode>();
107     Hashtable<FlatNode, FlatNode> copyendtable=new Hashtable<FlatNode, FlatNode>();
108     
109     TempMap t=new TempMap();
110     /* Copy the nodes */
111     for(Iterator it=lelements.iterator();it.hasNext();) {
112       FlatNode fn=(FlatNode)it.next();
113       FlatNode copy=fn.clone(t);
114       FlatNode copyend=copy;
115       if (tohoist.contains(fn)) {
116         TempDescriptor[] writes=fn.writesTemps();
117         TempDescriptor tmp=writes[0];
118         TempDescriptor ntmp=tmp.createNew();
119         temptable.put(fn, ntmp);
120         copyend=new FlatOpNode(ntmp, tmp, null, new Operation(Operation.ASSIGN));
121         copy.addNext(copyend);
122       }
123       copytable.put(fn, copy);
124       copyendtable.put(fn, copyend);
125     }
126     /* Copy the edges */
127     for(Iterator it=lelements.iterator();it.hasNext();) {
128       FlatNode fn=(FlatNode)it.next();
129       FlatNode copyend=copyendtable.get(fn);
130       for(int i=0;i<fn.numNext();i++) {
131         FlatNode nnext=fn.getNext(i);
132         if (nnext==entrance) {
133           /* Back to loop header...point to old graph */
134           copyend.addNext(nnext);
135         } else if (lelements.contains(nnext)) {
136           /* In graph...point to first graph */
137           copyend.addNext(copytable.get(nnext));
138         } else {
139           /* Outside loop */
140           /* Just goto same place as before */
141           copyend.addNext(nnext);
142         }
143       }
144     }
145     /* Splice out loop invariant stuff */
146     for(Iterator it=lelements.iterator();it.hasNext();) {
147       FlatNode fn=(FlatNode)it.next();
148       if (tohoist.contains(fn)) {
149         TempDescriptor[] writes=fn.writesTemps();
150         TempDescriptor tmp=writes[0];
151         FlatOpNode fon=new FlatOpNode(temptable.get(fn),tmp, null, new Operation(Operation.ASSIGN));
152         fn.replace(fon);
153       }
154     }
155   }
156 }