Fix tabbing.... Please fix your editors so they do tabbing correctly!!! (Spaces...
[IRC.git] / Robust / src / Analysis / Disjoint / ProcessStateMachines.java
1 package Analysis.Disjoint;
2 import java.util.*;
3
4 import Analysis.OoOJava.*;
5 import IR.FieldDescriptor;
6 import IR.Flat.*;
7 import Util.Pair;
8
9 public class ProcessStateMachines {
10   protected HashMap<FlatSESEEnterNode, Set<StateMachineForEffects>> groupMap;
11   protected BuildStateMachines bsm;
12   protected RBlockRelationAnalysis taskAnalysis;
13
14   public ProcessStateMachines(BuildStateMachines bsm, RBlockRelationAnalysis taskAnalysis) {
15     this.bsm=bsm;
16     this.taskAnalysis=taskAnalysis;
17     groupMap=new HashMap<FlatSESEEnterNode, Set<StateMachineForEffects>>();
18   }
19
20   public void doProcess() {
21     groupStateMachines();
22     computeConflictEffects();
23     prune();
24     merge();
25     protectAgainstEvilTasks();
26   }
27
28   private void merge() {
29     for(Pair<FlatNode, TempDescriptor> machinepair : bsm.getAllMachineNames()) {
30       StateMachineForEffects sm=bsm.getStateMachine(machinepair);
31       merge(sm);
32     }
33   }
34
35
36   private void merge(StateMachineForEffects sm) {
37     HashMap<SMFEState, Set<Pair<SMFEState, FieldDescriptor>>> backMap=buildBackMap(sm);
38     boolean mergeAgain=false;
39     HashSet<SMFEState> removedStates=new HashSet<SMFEState>();
40     do {
41       mergeAgain=false;
42       HashMap<Pair<SMFEState, FieldDescriptor>, Set<SMFEState>> revMap=buildReverse(backMap);
43       for(Map.Entry<Pair<SMFEState,FieldDescriptor>, Set<SMFEState>> entry : revMap.entrySet()) {
44         if (entry.getValue().size()>1) {
45           SMFEState first=null;
46           for(SMFEState state : entry.getValue()) {
47             if (removedStates.contains(state))
48               continue;
49             if (first==null) {
50               first=state;
51             } else {
52               mergeAgain=true;
53               System.out.println("MERGING:"+first+" and "+state);
54               //Make sure we don't merge the initial state someplace else
55               if (state==sm.initialState) {
56                 state=first;
57                 first=sm.initialState;
58               }
59               mergeTwoStates(first, state, backMap);
60               removedStates.add(state);
61               sm.fn2state.remove(state.whereDefined);
62             }
63           }
64         }
65       }
66     } while(mergeAgain);
67   }
68
69
70   private HashMap<Pair<SMFEState, FieldDescriptor>, Set<SMFEState>> buildReverse(HashMap<SMFEState, Set<Pair<SMFEState, FieldDescriptor>>> backMap) {
71     HashMap<Pair<SMFEState, FieldDescriptor>, Set<SMFEState>> revMap=new HashMap<Pair<SMFEState, FieldDescriptor>, Set<SMFEState>>();
72     for(Map.Entry<SMFEState, Set<Pair<SMFEState, FieldDescriptor>>>entry : backMap.entrySet()) {
73       SMFEState state=entry.getKey();
74       for(Pair<SMFEState, FieldDescriptor> pair : entry.getValue()) {
75         if (!revMap.containsKey(pair))
76           revMap.put(pair, new HashSet<SMFEState>());
77         revMap.get(pair).add(state);
78       }
79     }
80     return revMap;
81   }
82
83   private void mergeTwoStates(SMFEState state1, SMFEState state2, HashMap<SMFEState, Set<Pair<SMFEState, FieldDescriptor>>> backMap) {
84     //Merge effects and conflicts
85     state1.effects.addAll(state2.effects);
86     state1.conflicts.addAll(state2.conflicts);
87
88     //fix up our backmap
89     backMap.get(state1).addAll(backMap.get(state2));
90
91     //merge outgoing transitions
92     for(Map.Entry<Effect, Set<SMFEState>> entry : state2.e2states.entrySet()) {
93       Effect e=entry.getKey();
94       Set<SMFEState> states=entry.getValue();
95       if (state1.e2states.containsKey(e)) {
96         for(SMFEState statetoadd : states) {
97           if (!state1.e2states.get(e).add(statetoadd)) {
98             //already added...reduce reference count
99             statetoadd.refCount--;
100           }
101         }
102       } else {
103         state1.e2states.put(e, states);
104       }
105       Set<SMFEState> states1=state1.e2states.get(e);
106
107       //move now-self edges
108       if (states1.contains(state2)) {
109         states1.remove(state2);
110         states1.add(state1);
111       }
112
113       //fix up the backmap of the edges we point to
114       for(SMFEState st : states1) {
115         HashSet<Pair<SMFEState, FieldDescriptor>> toRemove=new HashSet<Pair<SMFEState, FieldDescriptor>>();
116         HashSet<Pair<SMFEState, FieldDescriptor>> toAdd=new HashSet<Pair<SMFEState, FieldDescriptor>>();
117         for(Pair<SMFEState, FieldDescriptor> backpair : backMap.get(st)) {
118           if (backpair.getFirst()==state2) {
119             Pair<SMFEState, FieldDescriptor> newpair=new Pair<SMFEState, FieldDescriptor>(state1, backpair.getSecond());
120             toRemove.add(backpair);
121             toAdd.add(newpair);
122           }
123         }
124         backMap.get(st).removeAll(toRemove);
125         backMap.get(st).addAll(toAdd);
126       }
127     }
128
129     //Fix up our new incoming edges
130     for(Pair<SMFEState,FieldDescriptor> fromStatePair : backMap.get(state2)) {
131       SMFEState fromState=fromStatePair.getFirst();
132       for(Map.Entry<Effect, Set<SMFEState>> fromEntry : fromState.e2states.entrySet()) {
133         Effect e=fromEntry.getKey();
134         Set<SMFEState> states=fromEntry.getValue();
135         if (states.contains(state2)) {
136           states.remove(state2);
137           if(states.add(state1) && !fromState.equals(state2)) {
138             state1.refCount++;
139           }
140         }
141       }
142     }
143     //Clear out unreachable state's backmap
144     backMap.remove(state2);
145   }
146
147
148   private void prune() {
149     for(Pair<FlatNode, TempDescriptor> machinepair : bsm.getAllMachineNames()) {
150       StateMachineForEffects sm=bsm.getStateMachine(machinepair);
151       pruneNonConflictingStates(sm);
152       pruneEffects(sm);
153     }
154   }
155
156   private void pruneEffects(StateMachineForEffects sm) {
157     for(Iterator<FlatNode> fnit=sm.fn2state.keySet().iterator(); fnit.hasNext(); ) {
158       FlatNode fn=fnit.next();
159       SMFEState state=sm.fn2state.get(fn);
160       for(Iterator<Effect> efit=state.effects.iterator(); efit.hasNext(); ) {
161         Effect e=efit.next();
162         //Is it a conflicting effecting
163         if (state.getConflicts().contains(e))
164           continue;
165         //Does it still have transitions
166         if (state.e2states.containsKey(e))
167           continue;
168         //If no to both, remove it
169         efit.remove();
170       }
171     }
172   }
173
174   private void pruneNonConflictingStates(StateMachineForEffects sm) {
175     Set<SMFEState> canReachConflicts=buildConflictsAndMap(sm);
176     for(Iterator<FlatNode> fnit=sm.fn2state.keySet().iterator(); fnit.hasNext(); ) {
177       FlatNode fn=fnit.next();
178       SMFEState state=sm.fn2state.get(fn);
179       if (canReachConflicts.contains(state)) {
180         for(Iterator<Effect> efit=state.e2states.keySet().iterator(); efit.hasNext(); ) {
181           Effect e=efit.next();
182           Set<SMFEState> stateset=state.e2states.get(e);
183           for(Iterator<SMFEState> stit=stateset.iterator(); stit.hasNext(); ) {
184             SMFEState tostate=stit.next();
185             if(!canReachConflicts.contains(tostate))
186               stit.remove();
187           }
188           if (stateset.isEmpty())
189             efit.remove();
190         }
191       } else {
192         fnit.remove();
193       }
194     }
195   }
196
197   private HashMap<SMFEState, Set<Pair<SMFEState, FieldDescriptor>>> buildBackMap(StateMachineForEffects sm) {
198     return buildBackMap(sm, null);
199   }
200
201   private HashMap<SMFEState, Set<Pair<SMFEState, FieldDescriptor>>> buildBackMap(StateMachineForEffects sm, Set<SMFEState> conflictStates) {
202     Stack<SMFEState> toprocess=new Stack<SMFEState>();
203     HashMap<SMFEState, Set<Pair<SMFEState, FieldDescriptor>>> backMap=new HashMap<SMFEState, Set<Pair<SMFEState,FieldDescriptor>>>();
204     toprocess.add(sm.initialState);
205     backMap.put(sm.initialState, new HashSet<Pair<SMFEState, FieldDescriptor>>());
206     while(!toprocess.isEmpty()) {
207       SMFEState state=toprocess.pop();
208       if (!state.getConflicts().isEmpty()&&conflictStates!=null) {
209         conflictStates.add(state);
210       }
211       for(Effect e : state.getEffectsAllowed()) {
212         for(SMFEState stateout : state.transitionsTo(e)) {
213           if (!backMap.containsKey(stateout)) {
214             toprocess.add(stateout);
215             backMap.put(stateout, new HashSet<Pair<SMFEState,FieldDescriptor>>());
216           }
217           Pair<SMFEState, FieldDescriptor> p=new Pair<SMFEState, FieldDescriptor>(state, e.getField());
218           backMap.get(stateout).add(p);
219         }
220       }
221     }
222     return backMap;
223   }
224
225
226   private Set<SMFEState> buildConflictsAndMap(StateMachineForEffects sm) {
227     Set<SMFEState> conflictStates=new HashSet<SMFEState>();
228     HashMap<SMFEState, Set<Pair<SMFEState,FieldDescriptor>>> backMap=buildBackMap(sm, conflictStates);
229
230     Stack<SMFEState> toprocess=new Stack<SMFEState>();
231     Set<SMFEState> canReachConflicts=new HashSet<SMFEState>();
232     toprocess.addAll(conflictStates);
233     canReachConflicts.addAll(conflictStates);
234     while(!toprocess.isEmpty()) {
235       SMFEState state=toprocess.pop();
236
237       for(Pair<SMFEState,FieldDescriptor> instatepair : backMap.get(state)) {
238         SMFEState instate=instatepair.getFirst();
239         if (!canReachConflicts.contains(instate)) {
240           toprocess.add(instate);
241           canReachConflicts.add(instate);
242         }
243       }
244     }
245     return canReachConflicts;
246   }
247
248   private void groupStateMachines() {
249     for(Pair<FlatNode, TempDescriptor> machinePair : bsm.getAllMachineNames()) {
250       FlatNode fn=machinePair.getFirst();
251       StateMachineForEffects sm=bsm.getStateMachine(machinePair);
252       Set<FlatSESEEnterNode> taskSet=taskAnalysis.getPossibleExecutingRBlocks(fn);
253       for(FlatSESEEnterNode sese : taskSet) {
254         if (!groupMap.containsKey(sese))
255           groupMap.put(sese, new HashSet<StateMachineForEffects>());
256         groupMap.get(sese).add(sm);
257       }
258     }
259   }
260
261   private void computeConflictEffects() {
262     //Loop through all state machines
263     for(Pair<FlatNode, TempDescriptor> machinePair : bsm.getAllMachineNames()) {
264       FlatNode fn=machinePair.getFirst();
265       StateMachineForEffects sm=bsm.getStateMachine(machinePair);
266       Set<FlatSESEEnterNode> taskSet=taskAnalysis.getPossibleExecutingRBlocks(fn);
267       for(FlatSESEEnterNode sese : taskSet) {
268         Set<StateMachineForEffects> smgroup=groupMap.get(sese);
269         computeConflictingEffects(sm, smgroup);
270       }
271     }
272   }
273
274   private void computeConflictingEffects(StateMachineForEffects sm, Set<StateMachineForEffects> smgroup) {
275     boolean isStall=sm.getStallorSESE().kind()!=FKind.FlatSESEEnterNode;
276     for(SMFEState state : sm.getStates()) {
277       for(Effect e : state.getEffectsAllowed()) {
278         Alloc a=e.getAffectedAllocSite();
279         FieldDescriptor fd=e.getField();
280         int type=e.getType();
281         boolean hasConflict=false;
282         if (!isStall&&Effect.isWrite(type)) {
283           hasConflict=true;
284         } else {
285           for(StateMachineForEffects othersm : smgroup) {
286             boolean otherIsStall=othersm.getStallorSESE().kind()!=FKind.FlatSESEEnterNode;
287             //Stall sites can't conflict with each other
288             if (isStall&&otherIsStall) continue;
289
290             int effectType=othersm.getEffects(a, fd);
291             if (Effect.isWrite(type)&&effectType!=0) {
292               //Original effect is a write and we have some effect on the same field/alloc site
293               hasConflict=true;
294               break;
295             }
296             if (Effect.isWrite(effectType)) {
297               //We are a write
298               hasConflict=true;
299               break;
300             }
301           }
302         }
303         if (hasConflict) {
304           state.addConflict(e);
305         }
306       }
307     }
308   }
309
310
311   private void protectAgainstEvilTasks() {
312     for( Pair<FlatNode, TempDescriptor> machinepair : bsm.getAllMachineNames() ) {
313       StateMachineForEffects sm = bsm.getStateMachine(machinepair);
314       protectAgainstEvilTasks(sm);
315     }
316   }
317
318   private void protectAgainstEvilTasks(StateMachineForEffects sm) {
319     // first identify the set of <Alloc, Field> pairs for which this
320     // traverser will both read and write, remember the read effect
321     Set<Effect> allocAndFieldRW = new HashSet<Effect>();
322     for( Pair<Alloc, FieldDescriptor> af : sm.effectsMap.keySet() ) {
323       Integer effectType = sm.effectsMap.get(af);
324       if( (effectType & Effect.read)  != 0 &&
325           (effectType & Effect.write) != 0
326           ) {
327         allocAndFieldRW.add(new Effect(af.getFirst(),
328                                        Effect.read,
329                                        af.getSecond()
330                                        )
331                             );
332       }
333     }
334
335     // next check the state machine: if an effect that initiates
336     // a transition is in the allocAndFieldRW set, then mark it
337     // as... POSSIBLY EVIL!!!!!
338     for( SMFEState state : sm.getStates() ) {
339       for( Effect effect : state.getTransitionEffects() ) {
340         if( allocAndFieldRW.contains(effect) ) {
341           sm.addPossiblyEvilEffect(effect);
342         }
343       }
344     }
345   }
346 }