typo
[IRC.git] / Robust / src / IR / Flat / RuntimeConflictResolver.java
1 package IR.Flat;
2 import java.io.File;
3 import java.io.FileNotFoundException;
4 import java.util.ArrayList;
5 import java.util.HashSet;
6 import java.util.Hashtable;
7 import java.util.Iterator;
8 import java.util.Set;
9 import java.util.Vector;
10 import Util.Pair;
11 import Analysis.Disjoint.*;
12 import Analysis.Pointer.*;
13 import Analysis.Pointer.AllocFactory.AllocNode;
14 import IR.State;
15 import IR.TypeDescriptor;
16 import Analysis.OoOJava.ConflictGraph;
17 import Analysis.OoOJava.ConflictNode;
18 import Analysis.OoOJava.OoOJavaAnalysis;
19 import Util.CodePrinter;
20
21 /* An instance of this class manages all OoOJava coarse-grained runtime conflicts
22  * by generating C-code to either rule out the conflict at runtime or resolve one.
23  * 
24  * How to Use:
25  * 1) Instantiate singleton object (String input is to specify output dir)
26  * 2) Call void close() 
27  */
28 public class RuntimeConflictResolver {
29   private CodePrinter headerFile, cFile;
30   private static final String hashAndQueueCFileDir = "oooJava/";
31   
32   //This keeps track of taints we've traversed to prevent printing duplicate traverse functions
33   //The Integer keeps track of the weakly connected group it's in (used in enumerateHeapRoots)
34   //private Hashtable<Taint, Integer> doneTaints;
35   private Hashtable<Pair, Integer> idMap=new Hashtable<Pair,Integer>();
36   
37   //Keeps track of stallsites that we've generated code for. 
38   protected Hashtable <FlatNode, TempDescriptor> processedStallSites = new Hashtable <FlatNode, TempDescriptor>();
39  
40   public int currentID=1;
41   private int totalWeakGroups;
42   private OoOJavaAnalysis oooa;  
43   private State globalState;
44   
45   // initializing variables can be found in printHeader()
46   private static final String allocSiteInC = "allocsite";
47   private static final String queryAndAddToVistedHashtable = "hashRCRInsert";
48   private static final String enqueueInC = "enqueueRCRQueue(";
49   private static final String dequeueFromQueueInC = "dequeueRCRQueue()";
50   private static final String clearQueue = "resetRCRQueue()";
51   // Make hashtable; hashRCRCreate(unsigned int size, double loadfactor)
52   private static final String mallocVisitedHashtable = "hashRCRCreate(128, 0.75)";
53   private static final String deallocVisitedHashTable = "hashRCRDelete()";
54   private static final String resetVisitedHashTable = "hashRCRreset()";
55
56   public RuntimeConflictResolver( String buildir, 
57                                   OoOJavaAnalysis oooa, 
58                                   State state) 
59   throws FileNotFoundException {
60     this.oooa         = oooa;
61     this.globalState  = state;
62
63     processedStallSites = new Hashtable <FlatNode, TempDescriptor>();
64     BuildStateMachines bsm  = oooa.getBuildStateMachines();
65     totalWeakGroups         = bsm.getTotalNumOfWeakGroups();
66     
67     setupOutputFiles(buildir);
68
69     for( Pair<FlatNode, TempDescriptor> p: bsm.getAllMachineNames() ) {
70       FlatNode                taskOrStallSite      =  p.getFirst();
71       TempDescriptor          var                  =  p.getSecond();
72       StateMachineForEffects  stateMachine         = bsm.getStateMachine( taskOrStallSite, var );
73
74       //prints the traversal code
75       printCMethod( taskOrStallSite, var, stateMachine); 
76     }
77     
78     //IMPORTANT must call .close() elsewhere to finish printing the C files.  
79   }
80   
81   /*
82    * This method generates a C method for every inset variable and rblock. 
83    * 
84    * The C method works by generating a large switch statement that will run the appropriate 
85    * checking code for each object based on the current state. The switch statement is 
86    * surrounded by a while statement which dequeues objects to be checked from a queue. An
87    * object is added to a queue only if it contains a conflict (in itself or in its referencees)
88    * and we came across it while checking through it's referencer. Because of this property, 
89    * conflicts will be signaled by the referencer; the only exception is the inset variable which can 
90    * signal a conflict within itself. 
91    */
92   
93   private void printCMethod( FlatNode               taskOrStallSite,
94                              TempDescriptor         var,
95                              StateMachineForEffects smfe) {
96
97     // collect info for code gen
98     FlatSESEEnterNode task          = null;
99     String            inVar         = var.getSafeSymbol();
100     SMFEState         initialState  = smfe.getInitialState();
101     boolean           isStallSite   = !(taskOrStallSite instanceof FlatSESEEnterNode);    
102     int               weakID        = smfe.getWeaklyConnectedGroupID(taskOrStallSite);
103     
104     String blockName;    
105     //No need generate code for empty traverser
106     if (smfe.isEmpty())
107       return;
108
109     if( isStallSite ) {
110       blockName = taskOrStallSite.toString();
111       processedStallSites.put(taskOrStallSite, var);
112     } else {
113       task = (FlatSESEEnterNode) taskOrStallSite;
114       
115       //if the task is the main task, there's no traverser
116       if(task.isMainSESE)
117         return;
118       
119       blockName = task.getPrettyIdentifier();
120     }
121
122
123     
124     String methodName = "void traverse___" + inVar + removeInvalidChars(blockName) + "___(void * InVar, ";
125     int    index      = -1;
126
127     if( isStallSite ) {
128       methodName += "SESEstall *record)";
129     } else {
130       methodName += task.getSESErecordName() +" *record)";
131       //TODO check that this HACK is correct (i.e. adding and then polling immediately afterwards)
132       task.addInVarForDynamicCoarseConflictResolution(var);
133       index = task.getInVarsForDynamicCoarseConflictResolution().indexOf( var );
134     }
135     
136     cFile.println( methodName + " {");
137     headerFile.println( methodName + ";" );
138
139     cFile.println(  "  int totalcount = RUNBIAS;");      
140     if( isStallSite ) {
141       cFile.println("  record->rcrRecords[0].count = RUNBIAS;");
142     } else {
143       cFile.println("  record->rcrRecords["+index+"].count = RUNBIAS;");
144     }
145
146     //clears queue and hashtable that keeps track of where we've been. 
147     cFile.println(clearQueue + ";");
148     cFile.println(resetVisitedHashTable + ";"); 
149     cFile.println("  RCRQueueEntry * queueEntry; //needed for dequeuing");
150     
151     cFile.println("  int traverserState = "+initialState.getID()+";");
152
153     //generic cast to ___Object___ to access ptr->allocsite field. 
154     cFile.println("  struct ___Object___ * ptr = (struct ___Object___ *) InVar;");
155     cFile.println("  if (InVar != NULL) {");
156     cFile.println("    " + queryAndAddToVistedHashtable + "(ptr, "+initialState.getID()+");");
157     cFile.println("    do {");
158
159     if( !isStallSite ) {
160       cFile.println("      if(unlikely(record->common.doneExecuting)) {");
161       cFile.println("        record->common.rcrstatus=0;");
162       cFile.println("        return;");
163       cFile.println("      }");
164     }
165
166     
167     // Traverse the StateMachineForEffects (a graph)
168     // that serves as a plan for building the heap examiner code.
169     // SWITCH on the states in the state machine, THEN
170     //   SWITCH on the concrete object's allocation site THEN
171     //     consider conflicts, enqueue more work, inline more SWITCHES, etc.
172       
173     boolean needswitch=smfe.getStates().size()>1;
174
175     if (needswitch) {
176       cFile.println("  switch( traverserState ) {");
177     }
178     for(SMFEState state: smfe.getStates()) {
179
180       if(state.getRefCount() != 1 || initialState == state) {
181         if (needswitch) {
182           cFile.println("    case "+state.getID()+":");
183         } else {
184           cFile.println("  if(traverserState=="+state.getID()+") {");
185         }
186         
187         printAllocChecksInsideState("ptr->allocsite", state, taskOrStallSite, var, "ptr", 0, weakID);
188         
189         cFile.println("      break;");
190       }
191     }
192     
193     if (needswitch) {
194       cFile.println("        default: break;");
195     }
196     cFile.println("      } // end switch on traverser state");
197     cFile.println("      queueEntry = " + dequeueFromQueueInC + ";");
198     cFile.println("      if(queueEntry == NULL) {");
199     cFile.println("        break;");
200     cFile.println("      }");
201     cFile.println("      ptr = queueEntry->object;");
202     cFile.println("      traverserState = queueEntry->traverserState;");
203     cFile.println("    } while(ptr != NULL);");
204     cFile.println("  } // end if inVar not null");
205    
206
207     if( isStallSite ) {
208       cFile.println("  if(atomic_sub_and_test(totalcount,&(record->rcrRecords[0].count))) {");
209       cFile.println("    psem_give_tag(record->common.parentsStallSem, record->tag);");
210       cFile.println("    BARRIER();");
211       cFile.println("  }");
212     } else {
213       cFile.println("  if(atomic_sub_and_test(totalcount,&(record->rcrRecords["+index+"].count))) {");
214       cFile.println("    int flag=LOCKXCHG32(&(record->rcrRecords["+index+"].flag),0);");
215       cFile.println("    if(flag) {");
216       //we have resolved a heap root...see if this was the last dependence
217       cFile.println("      if(atomic_sub_and_test(1, &(record->common.unresolvedDependencies))) workScheduleSubmit((void *)record);");
218       cFile.println("    }");
219       cFile.println("  }");
220     }
221
222     cFile.println("}");
223     cFile.flush();
224   }
225   
226   public void printAllocChecksInsideState(String input, SMFEState state, FlatNode fn, TempDescriptor tmp, String prefix, int depth, int weakID) {
227     EffectsTable et = new EffectsTable(state);
228     boolean needswitch=et.getAllAllocs().size()>1;
229     if (needswitch) {
230       cFile.println("      switch(" + input + ") {");
231     }
232
233     //we assume that all allocs given in the effects are starting locs. 
234     for(Alloc a: et.getAllAllocs()) {
235       if (needswitch) {
236         cFile.println("    case "+a.getUniqueAllocSiteID()+": {");
237       } else {
238         cFile.println("     if("+input+"=="+a.getUniqueAllocSiteID()+") {");
239       }
240       addChecker(a, fn, tmp, state, et, "ptr", 0, weakID);
241       if (needswitch) {
242         cFile.println("       }");
243         cFile.println("       break;");
244       }
245     }
246     if (needswitch) {
247       cFile.println("      default:");
248       cFile.println("        break;");
249     }
250     cFile.println("      }");
251   }
252   
253   public void addChecker(Alloc a, FlatNode fn, TempDescriptor tmp, SMFEState state, EffectsTable et, String prefix, int depth, int weakID) {
254     insertEntriesIntoHashStructureNew(fn, tmp, et, a, prefix, depth, weakID);
255     
256     int pdepth = depth+1;
257     
258     if(a.getType().isArray()) {
259       String childPtr = "((struct ___Object___ **)(((char *) &(((struct ArrayObject *)"+ prefix+")->___length___))+sizeof(int)))[i]";
260       String currPtr = "arrayElement" + pdepth;
261       
262       cFile.println("  int i;");
263       cFile.println("  struct ___Object___ * "+currPtr+";");
264       cFile.println("  for(i = 0; i<((struct ArrayObject *) " + prefix + " )->___length___; i++ ) {");
265       
266       for(Effect e: et.getEffects(a)) {
267         if (!state.transitionsTo(e).isEmpty()) {
268           printRefSwitch(fn, tmp, pdepth, childPtr, currPtr, state.transitionsTo(e), weakID);
269         }
270       }
271       cFile.println("}");
272     }  else {
273       //All other cases
274       String currPtr = "myPtr" + pdepth;
275       cFile.println("    struct ___Object___ * "+currPtr+";");
276       
277       for(Effect e: et.getEffects(a)) {
278         if (!state.transitionsTo(e).isEmpty()) {
279           String childPtr = "((struct "+a.getType().getSafeSymbol()+" *)"+prefix +")->" + e.getField().getSafeSymbol();
280           printRefSwitch(fn, tmp, pdepth, childPtr, currPtr, state.transitionsTo(e), weakID);
281         }
282       }
283     }
284   }
285
286   private void printRefSwitch(FlatNode fn, TempDescriptor tmp, int pdepth, String childPtr, String currPtr, Set<SMFEState> transitions, int weakID) {    
287     
288     for(SMFEState tr: transitions) {
289       if(tr.getRefCount() == 1) {       //in-lineable case
290         //Don't need to update state counter since we don't care really if it's inlined...
291         cFile.println("    "+currPtr+"= (struct ___Object___ * ) " + childPtr + ";");
292         cFile.println("    if (" + currPtr + " != NULL) { ");
293         
294         printAllocChecksInsideState(currPtr+"->"+allocSiteInC, tr, fn, tmp, currPtr, pdepth+1, weakID);
295         
296         cFile.println("    }"); //break for internal switch and if
297       } else {                          //non-inlineable cases
298         cFile.println("    " + enqueueInC + childPtr + ", "+tr.getID()+");");
299       } 
300     }
301   }
302   
303   
304   //FlatNode and TempDescriptor are what are used to make the taint
305   private void insertEntriesIntoHashStructureNew(FlatNode fn, TempDescriptor tmp, EffectsTable et, Alloc a, String prefix, int depth, int weakID) {
306     int index = 0;
307     boolean isRblock = (fn instanceof FlatSESEEnterNode);
308     if (isRblock) {
309       FlatSESEEnterNode fsese = (FlatSESEEnterNode) fn;
310       index = fsese.getInVarsForDynamicCoarseConflictResolution().indexOf(tmp);
311     }
312     
313     String strrcr = isRblock ? "&record->rcrRecords[" + index + "], " : "NULL, ";
314     String tasksrc =isRblock ? "(SESEcommon *) record, ":"(SESEcommon *)(((INTPTR)record)|1LL), ";
315
316     if(et.hasWriteConflict(a)) {
317       cFile.append("    int tmpkey" + depth + " = rcr_generateKey(" + prefix + ");\n");
318       if (et.leadsToConflict(a))
319         cFile.append("    int tmpvar" + depth + " = rcr_WTWRITEBINCASE(allHashStructures[" + weakID + "], tmpkey" + depth + ", " + tasksrc + strrcr + index + ");\n");
320       else
321         cFile.append("    int tmpvar" + depth + " = rcr_WRITEBINCASE(allHashStructures["+ weakID + "], tmpkey" + depth + ", " + tasksrc + strrcr + index + ");\n");
322     } else  if(et.hasReadConflict(a)) { 
323       cFile.append("    int tmpkey" + depth + " = rcr_generateKey(" + prefix + ");\n");
324       if (et.leadsToConflict(a))
325         cFile.append("    int tmpvar" + depth + " = rcr_WTREADBINCASE(allHashStructures[" + weakID + "], tmpkey" + depth + ", " + tasksrc + strrcr + index + ");\n");
326       else
327         cFile.append("    int tmpvar" + depth + " = rcr_READBINCASE(allHashStructures["+ weakID + "], tmpkey" + depth + ", " + tasksrc + strrcr + index + ");\n");
328     }
329
330     if (et.hasReadConflict(a) || et.hasWriteConflict(a)) {
331       cFile.append("if (!(tmpvar" + depth + "&READYMASK)) totalcount--;\n");
332     }
333   }
334
335   private void setupOutputFiles(String buildir) throws FileNotFoundException {
336     cFile = new CodePrinter(new File(buildir + "RuntimeConflictResolver" + ".c"));
337     headerFile = new CodePrinter(new File(buildir + "RuntimeConflictResolver" + ".h"));
338     
339     cFile.println("#include \"" + hashAndQueueCFileDir + "hashRCR.h\"\n#include \""
340         + hashAndQueueCFileDir + "Queue_RCR.h\"\n#include <stdlib.h>");
341     cFile.println("#include \"classdefs.h\"");
342     cFile.println("#include \"structdefs.h\"");
343     cFile.println("#include \"mlp_runtime.h\"");
344     cFile.println("#include \"RuntimeConflictResolver.h\"");
345     cFile.println("#include \"hashStructure.h\"");
346     
347     headerFile.println("#ifndef __3_RCR_H_");
348     headerFile.println("#define __3_RCR_H_");
349   }
350   
351   //The official way to generate the name for a traverser call
352   public String getTraverserInvocation(TempDescriptor invar, String varString, FlatNode fn) {
353     String flatname;
354     if(fn instanceof FlatSESEEnterNode) {  //is SESE block
355       flatname = ((FlatSESEEnterNode) fn).getPrettyIdentifier();
356     } else {  //is stallsite
357       flatname = fn.toString();
358     }
359     
360     return "traverse___" + invar.getSafeSymbol() + removeInvalidChars(flatname) + "___("+varString+");";
361   }
362   
363   public String removeInvalidChars(String in) {
364     StringBuilder s = new StringBuilder(in);
365     for(int i = 0; i < s.length(); i++) {
366       if(s.charAt(i) == ' ' || 
367          s.charAt(i) == '.' || 
368          s.charAt(i) == '=' ||
369          s.charAt(i) == '[' ||
370          s.charAt(i) == ']'    ) {
371
372         s.deleteCharAt(i);
373         i--;
374       }
375     }
376     return s.toString();
377   }
378
379   public int getTraverserID(TempDescriptor invar, FlatNode fn) {
380     Pair<TempDescriptor, FlatNode> t = new Pair<TempDescriptor, FlatNode>(invar, fn);
381     if (idMap.containsKey(t)) {
382       return idMap.get(t).intValue();
383     }
384     int value=currentID++;
385     idMap.put(t, new Integer(value));
386     return value;
387   }
388   
389   public void close() {
390     //Prints out the master traverser Invocation that'll call all other traversers
391     //based on traverserID
392     printMasterTraverserInvocation();    
393     createMasterHashTableArray();
394     
395     // Adds Extra supporting methods
396     cFile.println("void initializeStructsRCR() {\n  " + mallocVisitedHashtable + ";\n  " + clearQueue + ";\n}");
397     cFile.println("void destroyRCR() {\n  " + deallocVisitedHashTable + ";\n}");
398     
399     headerFile.println("void initializeStructsRCR();\nvoid destroyRCR();");
400     headerFile.println("#endif\n");
401
402     cFile.close();
403     headerFile.close();
404   }
405
406   private void printMasterTraverserInvocation() {
407     headerFile.println("int tasktraverse(SESEcommon * record);");
408     cFile.println("int tasktraverse(SESEcommon * record) {");
409     cFile.println("  if(!CAS(&record->rcrstatus,1,2)) {");
410
411     //release traverser reference...no traversal necessary
412     cFile.println("#ifndef OOO_DISABLE_TASKMEMPOOL");
413     cFile.println("    RELEASE_REFERENCE_TO(record);");
414     cFile.println("#endif");
415
416     cFile.println("    return;");
417     cFile.println("  }");
418     cFile.println("  switch(record->classID) {");
419     
420     for(Iterator<FlatSESEEnterNode> seseit=oooa.getAllSESEs().iterator();seseit.hasNext();) {
421       FlatSESEEnterNode fsen=seseit.next();
422       cFile.println(    "    /* "+fsen.getPrettyIdentifier()+" */");
423       cFile.println(    "    case "+fsen.getIdentifier()+": {");
424       cFile.println(    "      "+fsen.getSESErecordName()+" * rec=("+fsen.getSESErecordName()+" *) record;");
425       Vector<TempDescriptor> invars=fsen.getInVarsForDynamicCoarseConflictResolution();
426       for(int i=0;i<invars.size();i++) {
427         TempDescriptor tmp=invars.get(i);
428         
429         /* In some cases we don't want to a dynamic traversal if it is
430          * unlikely to increase parallelism...these are cases where we
431          * are just enabling a stall site to possible clear faster*/
432
433         boolean isValidToPrune=true;
434         for( FlatSESEEnterNode parentSESE: fsen.getParents() ) {
435           ConflictGraph     graph      = oooa.getConflictGraph(parentSESE);
436           String            id         = tmp + "_sese" + fsen.getPrettyIdentifier();
437           ConflictNode      node       = graph.getId2cn().get(id);
438           isValidToPrune &= node.IsValidToPrune();
439         }
440         if (i!=0) {
441           cFile.println("      if (record->rcrstatus!=0)");
442         }
443         
444         if(globalState.NOSTALLTR && isValidToPrune) {
445           cFile.println("    /*  " + getTraverserInvocation(tmp, "rec->"+tmp+", rec", fsen)+"*/");
446         } else {
447           cFile.println("      " + getTraverserInvocation(tmp, "rec->"+tmp+", rec", fsen));
448         }
449       }
450       //release traverser reference...traversal finished...
451       //executing thread will clean bins for us
452       cFile.println("     record->rcrstatus=0;");
453       cFile.println("#ifndef OOO_DISABLE_TASKMEMPOOL");
454       cFile.println("    RELEASE_REFERENCE_TO(record);");
455       cFile.println("#endif");
456       cFile.println(    "    }");
457       cFile.println(    "    break;");
458     }
459     
460     for(FlatNode stallsite: processedStallSites.keySet()) {
461       TempDescriptor var = processedStallSites.get(stallsite);
462       
463       cFile.println(    "    case -" + getTraverserID(var, stallsite)+ ": {");
464       cFile.println(    "      SESEstall * rec=(SESEstall*) record;");
465       cFile.println(    "      " + getTraverserInvocation(var, "rec->___obj___, rec", stallsite)+";");
466       cFile.println(    "     record->rcrstatus=0;");
467       cFile.println(    "    }");
468       cFile.println("    break;");
469     }
470
471     cFile.println("    default:");
472     cFile.println("      printf(\"Invalid SESE ID was passed in: %d.\\n\",record->classID);");
473     cFile.println("      break;");
474     cFile.println("  }");
475     cFile.println("}");
476   }
477   
478   private void createMasterHashTableArray() {
479     headerFile.println("struct Hashtable_rcr ** createAndFillMasterHashStructureArray();");
480     cFile.println("struct Hashtable_rcr ** createAndFillMasterHashStructureArray() {");
481
482     cFile.println("  struct Hashtable_rcr **table=rcr_createMasterHashTableArray("+totalWeakGroups + ");");
483     
484     for(int i = 0; i < totalWeakGroups; i++) {
485       cFile.println("  table["+i+"] = (struct Hashtable_rcr *) rcr_createHashtable();");
486     }
487     cFile.println("  return table;");
488     cFile.println("}");
489   }
490
491   public int getWeakID(TempDescriptor invar, FlatNode fn) {
492     //return weakMap.get(new Pair(invar, fn)).intValue();
493     return 0;
494   }
495
496
497   public boolean hasEmptyTraversers(FlatSESEEnterNode fsen) {
498     boolean hasEmpty = true;
499     
500     Set<FlatSESEEnterNode> children = fsen.getChildren();
501     for (Iterator<FlatSESEEnterNode> iterator = children.iterator(); iterator.hasNext();) {
502       FlatSESEEnterNode child = (FlatSESEEnterNode) iterator.next();
503       hasEmpty &= child.getInVarsForDynamicCoarseConflictResolution().size() == 0;
504     }
505     return hasEmpty;
506   }  
507
508   
509   //Simply rehashes and combines all effects for a AffectedAllocSite + Field.
510   private class EffectsTable {
511     private Hashtable<Alloc,Set<Effect>> table;
512     SMFEState state;
513
514     public EffectsTable(SMFEState state) {
515       table = new Hashtable<Alloc, Set<Effect>>();
516       this.state=state;
517       for(Effect e: state.getEffectsAllowed()) {
518         Set<Effect> eg;
519         if((eg = table.get(e.getAffectedAllocSite())) == null) {
520           eg = new HashSet<Effect>();
521           table.put(e.getAffectedAllocSite(), eg);
522         }
523         eg.add(e);
524       }
525     }
526     
527     public boolean leadsToConflict(Alloc a) {
528       for(Effect e:getEffects(a)) {
529         if (!state.transitionsTo(e).isEmpty())
530           return true;
531       }
532       return false;
533     }
534
535     public boolean hasWriteConflict(Alloc a) {
536       for(Effect e:getEffects(a)) {
537         if (e.isWrite() && state.getConflicts().contains(e))
538           return true;
539       }
540       return false;
541     }
542
543     public boolean hasReadConflict(Alloc a) {
544       for(Effect e:getEffects(a)) {
545         if (e.isRead() && state.getConflicts().contains(e))
546           return true;
547       }
548       return false;
549     }
550
551     public Set<Effect> getEffects(Alloc a) {
552       return table.get(a);
553     }
554
555     public Set<Alloc> getAllAllocs() {
556       return table.keySet();
557     }
558   }
559 }