bug fix for refCount in SMFEStates. I'm not sure if this fix covers all cases (it...
[IRC.git] / Robust / src / Analysis / Disjoint / ProcessStateMachines.java
1 package Analysis.Disjoint;
2 import java.util.*;
3
4 import Analysis.OoOJava.*;
5 import IR.FieldDescriptor;
6 import IR.Flat.*;
7 import Util.Pair;
8
9 public class ProcessStateMachines {
10   protected HashMap<FlatSESEEnterNode, Set<StateMachineForEffects>> groupMap;
11   protected BuildStateMachines bsm;
12   protected RBlockRelationAnalysis taskAnalysis;
13
14   public ProcessStateMachines(BuildStateMachines bsm, RBlockRelationAnalysis taskAnalysis) {
15     this.bsm=bsm;
16     this.taskAnalysis=taskAnalysis;
17     groupMap=new HashMap<FlatSESEEnterNode, Set<StateMachineForEffects>>();
18   }
19
20   public void doProcess() {
21     groupStateMachines();
22     computeConflictEffects();
23     prune();
24     merge();
25   }
26
27   private void merge() {
28     for(Pair<FlatNode, TempDescriptor> machinepair: bsm.getAllMachineNames()) {
29       StateMachineForEffects sm=bsm.getStateMachine(machinepair);
30       merge(sm);
31     }
32   }
33
34
35   private void merge(StateMachineForEffects sm) {
36     HashMap<SMFEState, Set<Pair<SMFEState, FieldDescriptor>>> backMap=buildBackMap(sm);
37     boolean mergeAgain=false;
38     HashSet<SMFEState> removedStates=new HashSet<SMFEState>();
39     do {
40       mergeAgain=false;
41       HashMap<Pair<SMFEState, FieldDescriptor>, Set<SMFEState>> revMap=buildReverse(backMap);
42       for(Map.Entry<Pair<SMFEState,FieldDescriptor>, Set<SMFEState>> entry:revMap.entrySet()) {
43         if (entry.getValue().size()>1) {
44           SMFEState first=null;
45           for(SMFEState state:entry.getValue()) {
46             if (removedStates.contains(state))
47               continue;
48             if (first==null) {
49               first=state;
50             } else {
51               mergeAgain=true;
52               System.out.println("MERGING:"+first+" and "+state);
53               //Make sure we don't merge the initial state someplace else
54               if (state==sm.initialState) {
55                 state=first;
56                 first=sm.initialState;
57               }
58               mergeTwoStates(first, state, backMap);
59               removedStates.add(state);
60               sm.fn2state.remove(state.whereDefined);
61             }
62           }
63         }
64       }
65     } while(mergeAgain);
66   }
67
68
69   private HashMap<Pair<SMFEState, FieldDescriptor>, Set<SMFEState>> buildReverse(HashMap<SMFEState, Set<Pair<SMFEState, FieldDescriptor>>> backMap) {
70     HashMap<Pair<SMFEState, FieldDescriptor>, Set<SMFEState>> revMap=new HashMap<Pair<SMFEState, FieldDescriptor>, Set<SMFEState>>();
71     for(Map.Entry<SMFEState, Set<Pair<SMFEState, FieldDescriptor>>>entry:backMap.entrySet()) {
72       SMFEState state=entry.getKey();
73       for(Pair<SMFEState, FieldDescriptor> pair:entry.getValue()) {
74         if (!revMap.containsKey(pair))
75           revMap.put(pair, new HashSet<SMFEState>());
76         revMap.get(pair).add(state);
77       }
78     }
79     return revMap;
80   }
81
82   private void mergeTwoStates(SMFEState state1, SMFEState state2, HashMap<SMFEState, Set<Pair<SMFEState, FieldDescriptor>>> backMap) {
83     //Merge effects and conflicts
84     state1.effects.addAll(state2.effects);
85     state1.conflicts.addAll(state2.conflicts);
86
87     //fix up our backmap
88     backMap.get(state1).addAll(backMap.get(state2));
89
90     //merge outgoing transitions
91     for(Map.Entry<Effect, Set<SMFEState>> entry:state2.e2states.entrySet()) {
92       Effect e=entry.getKey();
93       Set<SMFEState> states=entry.getValue();
94       if (state1.e2states.containsKey(e)) {
95         for(SMFEState statetoadd:states) {
96           if (!state1.e2states.get(e).add(statetoadd)) {
97             //already added...reduce reference count
98             statetoadd.refCount--;
99           }
100         }
101       } else {
102         state1.e2states.put(e, states);
103       }
104       Set<SMFEState> states1=state1.e2states.get(e);
105
106       //move now-self edges
107       if (states1.contains(state2)) {
108         states1.remove(state2);
109         states1.add(state1);
110       }
111
112       //fix up the backmap of the edges we point to
113       for(SMFEState st:states1) {
114         HashSet<Pair<SMFEState, FieldDescriptor>> toRemove=new HashSet<Pair<SMFEState, FieldDescriptor>>();
115         HashSet<Pair<SMFEState, FieldDescriptor>> toAdd=new HashSet<Pair<SMFEState, FieldDescriptor>>();
116         for(Pair<SMFEState, FieldDescriptor> backpair:backMap.get(st)) {
117           if (backpair.getFirst()==state2) {
118             Pair<SMFEState, FieldDescriptor> newpair=new Pair<SMFEState, FieldDescriptor>(state1, backpair.getSecond());
119             toRemove.add(backpair);
120             toAdd.add(newpair);
121           }
122         }
123         backMap.get(st).removeAll(toRemove);
124         backMap.get(st).addAll(toAdd);
125       }
126     }
127
128     //Fix up our new incoming edges
129     for(Pair<SMFEState,FieldDescriptor> fromStatePair:backMap.get(state2)) {
130       SMFEState fromState=fromStatePair.getFirst();
131       for(Map.Entry<Effect, Set<SMFEState>> fromEntry:fromState.e2states.entrySet()) {
132         Effect e=fromEntry.getKey();
133         Set<SMFEState> states=fromEntry.getValue();
134         if (states.contains(state2)) {
135           states.remove(state2);
136     if(states.add(state1) && !fromState.equals(state2)) {
137       state1.refCount++; 
138     }
139         }
140       }
141     }
142     //Clear out unreachable state's backmap
143     backMap.remove(state2);
144   }
145
146
147   private void prune() {
148     for(Pair<FlatNode, TempDescriptor> machinepair: bsm.getAllMachineNames()) {
149       StateMachineForEffects sm=bsm.getStateMachine(machinepair);
150       pruneNonConflictingStates(sm);
151       pruneEffects(sm);
152     }
153   }
154
155   private void pruneEffects(StateMachineForEffects sm) {
156     for(Iterator<FlatNode> fnit=sm.fn2state.keySet().iterator(); fnit.hasNext();) {
157       FlatNode fn=fnit.next();
158       SMFEState state=sm.fn2state.get(fn);
159       for(Iterator<Effect> efit=state.effects.iterator();efit.hasNext();) {
160         Effect e=efit.next();
161         //Is it a conflicting effecting
162         if (state.getConflicts().contains(e))
163           continue;
164         //Does it still have transitions
165         if (state.e2states.containsKey(e))
166           continue;
167         //If no to both, remove it
168         efit.remove();
169       }
170     }
171   }
172
173   private void pruneNonConflictingStates(StateMachineForEffects sm) {
174     Set<SMFEState> canReachConflicts=buildConflictsAndMap(sm);
175     for(Iterator<FlatNode> fnit=sm.fn2state.keySet().iterator(); fnit.hasNext();) {
176       FlatNode fn=fnit.next();
177       SMFEState state=sm.fn2state.get(fn);
178       if (canReachConflicts.contains(state)) {
179         for(Iterator<Effect> efit=state.e2states.keySet().iterator(); efit.hasNext();) {
180           Effect e=efit.next();
181           Set<SMFEState> stateset=state.e2states.get(e);
182           for(Iterator<SMFEState> stit=stateset.iterator(); stit.hasNext();) {
183             SMFEState tostate=stit.next();
184             if(!canReachConflicts.contains(tostate))
185               stit.remove();
186           }
187           if (stateset.isEmpty())
188             efit.remove();
189         }
190       } else {
191         fnit.remove();
192       }
193     }
194   }
195   
196   private HashMap<SMFEState, Set<Pair<SMFEState, FieldDescriptor>>> buildBackMap(StateMachineForEffects sm) {
197     return buildBackMap(sm, null);
198   }
199
200   private HashMap<SMFEState, Set<Pair<SMFEState, FieldDescriptor>>> buildBackMap(StateMachineForEffects sm, Set<SMFEState> conflictStates) {
201     Stack<SMFEState> toprocess=new Stack<SMFEState>();
202     HashMap<SMFEState, Set<Pair<SMFEState, FieldDescriptor>>> backMap=new HashMap<SMFEState, Set<Pair<SMFEState,FieldDescriptor>>>();
203     toprocess.add(sm.initialState);
204     backMap.put(sm.initialState, new HashSet<Pair<SMFEState, FieldDescriptor>>());
205     while(!toprocess.isEmpty()) {
206       SMFEState state=toprocess.pop();
207       if (!state.getConflicts().isEmpty()&&conflictStates!=null) {
208         conflictStates.add(state);
209       }
210       for(Effect e:state.getEffectsAllowed()) {
211         for(SMFEState stateout:state.transitionsTo(e)) {
212           if (!backMap.containsKey(stateout)) {
213             toprocess.add(stateout);
214             backMap.put(stateout, new HashSet<Pair<SMFEState,FieldDescriptor>>());
215           }
216           Pair<SMFEState, FieldDescriptor> p=new Pair<SMFEState, FieldDescriptor>(state, e.getField());
217           backMap.get(stateout).add(p);
218         }
219       }
220     }
221     return backMap;
222   }
223
224   
225   private Set<SMFEState> buildConflictsAndMap(StateMachineForEffects sm) {
226     Set<SMFEState> conflictStates=new HashSet<SMFEState>();
227     HashMap<SMFEState, Set<Pair<SMFEState,FieldDescriptor>>> backMap=buildBackMap(sm, conflictStates);
228
229     Stack<SMFEState> toprocess=new Stack<SMFEState>();
230     Set<SMFEState> canReachConflicts=new HashSet<SMFEState>();
231     toprocess.addAll(conflictStates);
232     canReachConflicts.addAll(conflictStates);
233     while(!toprocess.isEmpty()) {
234       SMFEState state=toprocess.pop();
235
236       for(Pair<SMFEState,FieldDescriptor> instatepair:backMap.get(state)) {
237         SMFEState instate=instatepair.getFirst();
238         if (!canReachConflicts.contains(instate)) {
239           toprocess.add(instate);
240           canReachConflicts.add(instate);
241         }
242       }
243     }
244     return canReachConflicts;
245   }
246   
247   private void groupStateMachines() {
248     for(Pair<FlatNode, TempDescriptor> machinePair: bsm.getAllMachineNames()) {
249       FlatNode fn=machinePair.getFirst();
250       StateMachineForEffects sm=bsm.getStateMachine(machinePair);
251       Set<FlatSESEEnterNode> taskSet=taskAnalysis.getPossibleExecutingRBlocks(fn);
252       for(FlatSESEEnterNode sese:taskSet) {
253         if (!groupMap.containsKey(sese))
254           groupMap.put(sese, new HashSet<StateMachineForEffects>());
255         groupMap.get(sese).add(sm);
256       }
257     }
258   }
259
260   private void computeConflictEffects() {
261     //Loop through all state machines
262     for(Pair<FlatNode, TempDescriptor> machinePair: bsm.getAllMachineNames()) {
263       FlatNode fn=machinePair.getFirst();
264       StateMachineForEffects sm=bsm.getStateMachine(machinePair);
265       Set<FlatSESEEnterNode> taskSet=taskAnalysis.getPossibleExecutingRBlocks(fn);
266       for(FlatSESEEnterNode sese:taskSet) {
267         Set<StateMachineForEffects> smgroup=groupMap.get(sese);
268         computeConflictingEffects(sm, smgroup);
269       }
270     }
271   }
272   
273   private void computeConflictingEffects(StateMachineForEffects sm, Set<StateMachineForEffects> smgroup) {
274     boolean isStall=sm.getStallorSESE().kind()!=FKind.FlatSESEEnterNode;
275     for(SMFEState state:sm.getStates()) {
276       for(Effect e:state.getEffectsAllowed()) {
277         Alloc a=e.getAffectedAllocSite();
278         FieldDescriptor fd=e.getField();
279         int type=e.getType();
280         boolean hasConflict=false;
281         if (!isStall&&Effect.isWrite(type)) {
282           hasConflict=true;
283         } else {
284           for(StateMachineForEffects othersm:smgroup) {
285             boolean otherIsStall=othersm.getStallorSESE().kind()!=FKind.FlatSESEEnterNode;
286             //Stall sites can't conflict with each other
287             if (isStall&&otherIsStall) continue;
288
289             int effectType=othersm.getEffects(a, fd);
290             if (Effect.isWrite(type)&&effectType!=0) {
291               //Original effect is a write and we have some effect on the same field/alloc site
292               hasConflict=true;
293               break;
294             }
295             if (Effect.isWrite(effectType)) {
296               //We are a write
297               hasConflict=true;
298               break;
299             }
300           }
301         }
302         if (hasConflict) {
303           state.addConflict(e);
304         }
305       }
306     }
307   }
308 }