bug fix for barriers...should be able to use joptimize with barriers now
[IRC.git] / Robust / src / Analysis / Loops / CSE.java
1 package Analysis.Loops;
2
3 import IR.Flat.*;
4 import IR.TypeUtil;
5 import IR.Operation;
6 import IR.FieldDescriptor;
7 import IR.MethodDescriptor;
8 import IR.TypeDescriptor;
9 import java.util.Map;
10 import java.util.Iterator;
11 import java.util.Hashtable;
12 import java.util.HashSet;
13 import java.util.Set;
14
15 public class CSE {
16   GlobalFieldType gft;
17   TypeUtil typeutil;
18   public CSE(GlobalFieldType gft, TypeUtil typeutil) {
19     this.gft=gft;
20     this.typeutil=typeutil;
21   }
22
23   public void doAnalysis(FlatMethod fm) {
24     Hashtable<FlatNode,Hashtable<Expression, TempDescriptor>> availexpr=new Hashtable<FlatNode,Hashtable<Expression, TempDescriptor>>();
25     HashSet toprocess=new HashSet();
26     HashSet discovered=new HashSet();
27     toprocess.add(fm);
28     discovered.add(fm);
29     while(!toprocess.isEmpty()) {
30       FlatNode fn=(FlatNode)toprocess.iterator().next();
31       toprocess.remove(fn);
32       for(int i=0;i<fn.numNext();i++) {
33         FlatNode nnext=fn.getNext(i);
34         if (!discovered.contains(nnext)) {
35           toprocess.add(nnext);
36           discovered.add(nnext);
37         }
38       }
39       Hashtable<Expression, TempDescriptor> tab=computeIntersection(fn, availexpr);
40
41       //Do kills of expression/variable mappings
42       TempDescriptor[] write=fn.writesTemps();
43       for(int i=0;i<write.length;i++) {
44         if (tab.containsKey(write[i]))
45           tab.remove(write[i]);
46       }
47       
48       switch(fn.kind()) {
49       case FKind.FlatAtomicEnterNode:
50         {
51           killexpressions(tab, null, null, true);
52           break;
53         }
54       case FKind.FlatCall:
55         {
56           FlatCall fc=(FlatCall) fn;
57           MethodDescriptor md=fc.getMethod();
58           Set<FieldDescriptor> fields=gft.getFields(md);
59           Set<TypeDescriptor> arrays=gft.getArrays(md);
60           killexpressions(tab, fields, arrays, gft.containsAtomic(md)||gft.containsBarrier(md));
61           break;
62         }
63       case FKind.FlatOpNode:
64         {
65           FlatOpNode fon=(FlatOpNode) fn;
66           Expression e=new Expression(fon.getLeft(), fon.getRight(), fon.getOp());
67           tab.put(e, fon.getDest());
68           break;
69         }
70       case FKind.FlatSetFieldNode:
71         {
72           FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
73           Set<FieldDescriptor> fields=new HashSet<FieldDescriptor>();
74           fields.add(fsfn.getField());
75           killexpressions(tab, fields, null, false);
76           Expression e=new Expression(fsfn.getDst(), fsfn.getField());
77           tab.put(e, fsfn.getSrc());
78           break;
79         }
80       case FKind.FlatFieldNode:
81         {
82           FlatFieldNode ffn=(FlatFieldNode)fn;
83           Expression e=new Expression(ffn.getSrc(), ffn.getField());
84           tab.put(e, ffn.getDst());
85           break;
86         }
87       case FKind.FlatSetElementNode:
88         {
89           FlatSetElementNode fsen=(FlatSetElementNode)fn;
90           Expression e=new Expression(fsen.getDst(),fsen.getIndex());
91           tab.put(e, fsen.getSrc());
92           break;
93         }
94       case FKind.FlatElementNode:
95         {
96           FlatElementNode fen=(FlatElementNode)fn;
97           Expression e=new Expression(fen.getSrc(),fen.getIndex());
98           tab.put(e, fen.getDst());
99           break;
100         }
101       default:
102       }
103       
104       if (write.length==1) {
105         TempDescriptor w=write[0];
106         for(Iterator it=tab.entrySet().iterator();it.hasNext();) {
107           Map.Entry m=(Map.Entry)it.next();
108           Expression e=(Expression)m.getKey();
109           if (e.a==w||e.b==w)
110             it.remove();
111         }
112       }
113       if (!availexpr.containsKey(fn)||!availexpr.get(fn).equals(tab)) {
114         availexpr.put(fn, tab);
115         for(int i=0;i<fn.numNext();i++) {
116           FlatNode nnext=fn.getNext(i);
117           toprocess.add(nnext);
118         }
119       }
120     }
121
122     doOptimize(fm, availexpr);
123   }
124     
125   public void doOptimize(FlatMethod fm, Hashtable<FlatNode,Hashtable<Expression, TempDescriptor>> availexpr) {
126     Hashtable<FlatNode, FlatNode> replacetable=new Hashtable<FlatNode, FlatNode>();
127     for(Iterator<FlatNode> it=fm.getNodeSet().iterator();it.hasNext();) {
128       FlatNode fn=it.next();
129       Hashtable<Expression, TempDescriptor> tab=computeIntersection(fn, availexpr);
130       switch(fn.kind()) {
131       case FKind.FlatOpNode:
132         {
133           FlatOpNode fon=(FlatOpNode) fn;
134           Expression e=new Expression(fon.getLeft(), fon.getRight(),fon.getOp());
135           if (tab.containsKey(e)) {
136             TempDescriptor t=tab.get(e);
137             FlatNode newfon=new FlatOpNode(fon.getDest(),t,null,new Operation(Operation.ASSIGN));
138             replacetable.put(fon,newfon);
139           }
140           break;
141         }
142       case FKind.FlatFieldNode:
143         {
144           FlatFieldNode ffn=(FlatFieldNode)fn;
145           Expression e=new Expression(ffn.getSrc(), ffn.getField());
146           if (tab.containsKey(e)) {
147             TempDescriptor t=tab.get(e);
148             FlatNode newfon=new FlatOpNode(ffn.getDst(),t,null,new Operation(Operation.ASSIGN));
149             replacetable.put(ffn,newfon);
150           }
151           break;
152         }
153       case FKind.FlatElementNode:
154         {
155           FlatElementNode fen=(FlatElementNode)fn;
156           Expression e=new Expression(fen.getSrc(),fen.getIndex());
157           if (tab.containsKey(e)) {
158             TempDescriptor t=tab.get(e);
159             FlatNode newfon=new FlatOpNode(fen.getDst(),t,null,new Operation(Operation.ASSIGN));
160             replacetable.put(fen,newfon);
161           }
162           break;
163         }
164       default: 
165       }
166     }
167     for(Iterator<FlatNode> it=replacetable.keySet().iterator();it.hasNext();) {
168       FlatNode fn=it.next();
169       FlatNode newfn=replacetable.get(fn);
170       fn.replace(newfn);
171     }
172   }
173   
174   public Hashtable<Expression, TempDescriptor> computeIntersection(FlatNode fn, Hashtable<FlatNode,Hashtable<Expression, TempDescriptor>> availexpr) {
175     Hashtable<Expression, TempDescriptor> tab=new Hashtable<Expression, TempDescriptor>();
176     boolean first=true;
177     
178     //compute intersection
179     for(int i=0;i<fn.numPrev();i++) {
180       FlatNode prev=fn.getPrev(i);
181       if (first) {
182         if (availexpr.containsKey(prev)) {
183           tab.putAll(availexpr.get(prev));
184           first=false;
185         }
186       } else {
187         if (availexpr.containsKey(prev)) {
188           Hashtable<Expression, TempDescriptor> table=availexpr.get(prev);
189           for(Iterator mapit=tab.entrySet().iterator();mapit.hasNext();) {
190             Object entry=mapit.next();
191             if (!table.contains(entry))
192               mapit.remove();
193           }
194         }
195       }
196     }
197     return tab;
198   }
199
200   public void killexpressions(Hashtable<Expression, TempDescriptor> tab, Set<FieldDescriptor> fields, Set<TypeDescriptor> arrays, boolean killall) {
201     for(Iterator it=tab.entrySet().iterator();it.hasNext();) {
202       Map.Entry m=(Map.Entry)it.next();
203       Expression e=(Expression)m.getKey();
204       if (killall&&(e.f!=null||e.a!=null))
205         it.remove();
206       else if (e.f!=null&&fields!=null&&fields.contains(e.f)) 
207         it.remove();
208       else if ((e.a!=null)&&(arrays!=null)) {
209         for(Iterator<TypeDescriptor> arit=arrays.iterator();arit.hasNext();) {
210           TypeDescriptor artd=arit.next();
211           if (typeutil.isSuperorType(artd,e.a.getType())||
212               typeutil.isSuperorType(e.a.getType(),artd)) {
213             it.remove();
214             break;
215           }
216         }
217       }
218     }
219   }
220 }
221
222 class Expression {
223   Operation op;
224   TempDescriptor a;
225   TempDescriptor b;
226   FieldDescriptor f;
227   Expression(TempDescriptor a, TempDescriptor b, Operation op) {
228     this.a=a;
229     this.b=b;
230     this.op=op;
231   }
232   Expression(TempDescriptor a, FieldDescriptor f) {
233     this.a=a;
234     this.f=f;
235   }
236   Expression(TempDescriptor a, TempDescriptor index) {
237     this.a=a;
238     this.b=index;
239   }
240   public int hashCode() {
241     int h=0;
242     h^=a.hashCode();
243     if (op!=null)
244       h^=op.getOp();
245     if (b!=null)
246       h^=b.hashCode();
247     if (f!=null)
248       h^=f.hashCode();
249     return h;
250   }
251   public boolean equals(Object o) {
252     Expression e=(Expression)o;
253     if (a!=e.a||f!=e.f||b!=e.b)
254       return false;
255     if (op!=null)
256       return op.getOp()==e.op.getOp();
257     return true;
258   }
259 }