add new features...they don't break the build, but need to check if they work...
[IRC.git] / Robust / src / Analysis / Locality / DiscoverConflicts.java
index e9127ae389d7cc1ba7776cd4b8fb455ec1c3437f..775c87ebaab3c7c4bc151f8a120ca0e9d4e88384 100644 (file)
@@ -12,6 +12,7 @@ import IR.TypeDescriptor;
 import IR.MethodDescriptor;
 import IR.FieldDescriptor;
 import Analysis.Liveness;
+import Analysis.Loops.GlobalFieldType;
 
 public class DiscoverConflicts {
   Set<FieldDescriptor> fields;
@@ -20,15 +21,19 @@ public class DiscoverConflicts {
   State state;
   Hashtable<LocalityBinding, Set<FlatNode>> treadmap;
   Hashtable<LocalityBinding, Set<TempFlatPair>> transreadmap;
+  Hashtable<LocalityBinding, Set<FlatNode>> twritemap;
+  Hashtable<LocalityBinding, Set<TempFlatPair>> writemap;
   Hashtable<LocalityBinding, Set<FlatNode>> srcmap;
   Hashtable<LocalityBinding, Set<FlatNode>> leftsrcmap;
   Hashtable<LocalityBinding, Set<FlatNode>> rightsrcmap;
   TypeAnalysis typeanalysis;
   Hashtable<LocalityBinding, HashSet<FlatNode>>cannotdelaymap;
   Hashtable<LocalityBinding, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>>> lbtofnmap;
+  boolean inclusive=false;
+  boolean normalassign=false;
+  GlobalFieldType gft;
 
-
-  public DiscoverConflicts(LocalityAnalysis locality, State state, TypeAnalysis typeanalysis) {
+  public DiscoverConflicts(LocalityAnalysis locality, State state, TypeAnalysis typeanalysis, GlobalFieldType gft) {
     this.locality=locality;
     this.fields=new HashSet<FieldDescriptor>();
     this.arrays=new HashSet<TypeDescriptor>();
@@ -40,9 +45,14 @@ public class DiscoverConflicts {
     leftsrcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
     rightsrcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
     lbtofnmap=new Hashtable<LocalityBinding, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>>>();
+    if (gft!=null) {
+      twritemap=new Hashtable<LocalityBinding, Set<FlatNode>>();
+      writemap=new Hashtable<LocalityBinding, Set<TempFlatPair>>();
+    }
+    this.gft=gft;
   }
 
-  public DiscoverConflicts(LocalityAnalysis locality, State state, TypeAnalysis typeanalysis, Hashtable<LocalityBinding, HashSet<FlatNode>> cannotdelaymap) {
+  public DiscoverConflicts(LocalityAnalysis locality, State state, TypeAnalysis typeanalysis, Hashtable<LocalityBinding, HashSet<FlatNode>> cannotdelaymap, boolean inclusive, boolean normalassign, GlobalFieldType gft) {
     this.locality=locality;
     this.fields=new HashSet<FieldDescriptor>();
     this.arrays=new HashSet<TypeDescriptor>();
@@ -55,6 +65,13 @@ public class DiscoverConflicts {
     leftsrcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
     rightsrcmap=new Hashtable<LocalityBinding, Set<FlatNode>>();
     lbtofnmap=new Hashtable<LocalityBinding, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>>>();
+    this.inclusive=inclusive;
+    this.normalassign=normalassign;
+    if (gft!=null) {
+      twritemap=new Hashtable<LocalityBinding, Set<FlatNode>>();
+      writemap=new Hashtable<LocalityBinding, Set<TempFlatPair>>();
+    }
+    this.gft=gft;
   }
 
   public Set<FieldDescriptor> getFields() {
@@ -91,6 +108,15 @@ public class DiscoverConflicts {
       set.add(tfp.f);
     }
     treadmap.put(lb, set);
+    if (gft!=null) {
+      //need to translate write map set
+      set=new HashSet<FlatNode>();
+      for(Iterator<TempFlatPair> it=writemap.get(lb).iterator();it.hasNext();) {
+       TempFlatPair tfp=it.next();
+       set.add(tfp.f);
+      }
+      twritemap.put(lb, set);
+    }
   }
 
   //We have a set of things we write to, figure out what things this
@@ -142,6 +168,10 @@ public class DiscoverConflicts {
     return treadmap.get(lb).contains(fn);
   }
 
+  public boolean getNeedWriteTrans(LocalityBinding lb, FlatNode fn) {
+    return twritemap.get(lb).contains(fn);
+  }
+
   public Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> getMap(LocalityBinding lb) {
     return lbtofnmap.get(lb);
   }
@@ -149,9 +179,19 @@ public class DiscoverConflicts {
   private void analyzeLocality(LocalityBinding lb) {
     MethodDescriptor md=lb.getMethod();
     FlatMethod fm=state.getMethodFlat(md);
+
+    //Compute map from flatnode -> (temps -> source of value)
     Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> fnmap=computeTempSets(lb);
     lbtofnmap.put(lb,fnmap);
-    HashSet<TempFlatPair> tfset=computeTranslationSet(lb, fm, fnmap);
+    HashSet<TempFlatPair> writeset=null;
+    if (gft!=null) {
+      writeset=new HashSet<TempFlatPair>();
+    }
+    HashSet<TempFlatPair> tfset=computeTranslationSet(lb, fm, fnmap, writeset);
+    if (gft!=null) {
+      writemap.put(lb, writeset);
+    }
+    
     HashSet<FlatNode> srctrans=new HashSet<FlatNode>();
     HashSet<FlatNode> leftsrctrans=new HashSet<FlatNode>();
     HashSet<FlatNode> rightsrctrans=new HashSet<FlatNode>();
@@ -161,7 +201,6 @@ public class DiscoverConflicts {
     rightsrcmap.put(lb,rightsrctrans);
 
     //compute writes that need translation on source
-
     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
       FlatNode fn=fnit.next();
       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
@@ -281,6 +320,76 @@ public class DiscoverConflicts {
     return fn.kind()==FKind.FlatCall||fn.kind()==FKind.FlatMethod;
   }
 
+  private void computeReadOnly(LocalityBinding lb, Hashtable<FlatNode, Set<TypeDescriptor>> updatedtypemap, Hashtable<FlatNode, Set<FieldDescriptor>> updatedfieldmap) {
+    //inside of transaction, try to convert rw access to ro access
+    MethodDescriptor md=lb.getMethod();
+    FlatMethod fm=state.getMethodFlat(md);
+    Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
+
+    HashSet<FlatNode> toanalyze=new HashSet<FlatNode>();
+    toanalyze.addAll(fm.getNodeSet());
+    
+    while(!toanalyze.isEmpty()) {
+      FlatNode fn=toanalyze.iterator().next();
+      toanalyze.remove(fn);
+      HashSet<TypeDescriptor> updatetypeset=new HashSet<TypeDescriptor>();
+      HashSet<FieldDescriptor> updatefieldset=new HashSet<FieldDescriptor>();
+      
+      //Stop if we aren't in a transaction
+      if (atomictable.get(fn).intValue()==0)
+       continue;
+      
+      //Do merge of all exits
+      for(int i=0;i<fn.numNext();i++) {
+       FlatNode fnnext=fn.getNext(i);
+       if (updatedtypemap.containsKey(fnnext)) {
+         updatetypeset.addAll(updatedtypemap.get(fnnext));
+       }
+       if (updatedfieldmap.containsKey(fnnext)) {
+         updatefieldset.addAll(updatedfieldmap.get(fnnext));
+       }
+      }
+      
+      //process this node
+      if (cannotdelaymap!=null&&cannotdelaymap.containsKey(lb)&&cannotdelaymap.get(lb).contains(fn)!=inclusive) {
+       switch(fn.kind()) {
+       case FKind.FlatSetFieldNode: {
+         FlatSetFieldNode fsfn=(FlatSetFieldNode)fn;
+         updatefieldset.add(fsfn.getField());
+         break;
+       }
+       case FKind.FlatSetElementNode: {
+         FlatSetElementNode fsen=(FlatSetElementNode)fn;
+         updatetypeset.addAll(typeanalysis.expand(fsen.getDst().getType()));
+         break;
+       }
+       case FKind.FlatCall: {
+         FlatCall fcall=(FlatCall)fn;
+         MethodDescriptor mdfc=fcall.getMethod();
+         
+         //get modified fields
+         Set<FieldDescriptor> fields=gft.getFieldsAll(mdfc);
+         updatefieldset.addAll(fields);
+         
+         //get modified arrays
+         Set<TypeDescriptor> arrays=gft.getArraysAll(mdfc);
+         updatetypeset.addAll(typeanalysis.expandSet(arrays));
+         break;
+       }
+       }
+      }
+      
+      if (!updatedtypemap.containsKey(fn)||!updatedfieldmap.containsKey(fn)||
+         !updatedtypemap.get(fn).equals(updatetypeset)||!updatedfieldmap.get(fn).equals(updatefieldset)) {
+       updatedtypemap.put(fn, updatetypeset);
+       updatedfieldmap.put(fn, updatefieldset);
+       for(int i=0;i<fn.numPrev();i++) {
+         toanalyze.add(fn.getPrev(i));
+       }
+      }
+    }
+  }
+
 
   /** Need to figure out which nodes need a transread to make local
   copies.  Transread conceptually tracks conflicts.  This depends on
@@ -288,18 +397,27 @@ public class DiscoverConflicts {
   access fields...If these accesses could conflict, we mark the source
   tempflat pair as needing a transread */
 
-  HashSet<TempFlatPair> computeTranslationSet(LocalityBinding lb, FlatMethod fm, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> fnmap) {
+  
+  HashSet<TempFlatPair> computeTranslationSet(LocalityBinding lb, FlatMethod fm, Hashtable<FlatNode, Hashtable<TempDescriptor, Set<TempFlatPair>>> fnmap, Set<TempFlatPair> writeset) {
     HashSet<TempFlatPair> tfset=new HashSet<TempFlatPair>();
 
+    //Compute maps from flatnodes -> sets of things that may be updated after this node
+    Hashtable<FlatNode, Set<TypeDescriptor>> updatedtypemap=null;
+    Hashtable<FlatNode, Set<FieldDescriptor>> updatedfieldmap=null;
+
+    if (writeset!=null&&!lb.isAtomic()) {
+      updatedtypemap=new Hashtable<FlatNode, Set<TypeDescriptor>>();
+      updatedfieldmap=new Hashtable<FlatNode, Set<FieldDescriptor>>();
+      computeReadOnly(lb, updatedtypemap, updatedfieldmap);
+    }
+
     for(Iterator<FlatNode> fnit=fm.getNodeSet().iterator();fnit.hasNext();) {
       FlatNode fn=fnit.next();
-
-      //Check whether this node matters for delayed computation
-      if (cannotdelaymap!=null&&cannotdelaymap.containsKey(lb)&&!cannotdelaymap.get(lb).contains(fn))
+      //Check whether this node matters for cannot delayed computation
+      if (cannotdelaymap!=null&&cannotdelaymap.containsKey(lb)&&cannotdelaymap.get(lb).contains(fn)==inclusive)
        continue;
 
       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
-
       if (atomictable.get(fn).intValue()>0) {
        Hashtable<TempDescriptor, Set<TempFlatPair>> tmap=fnmap.get(fn);
        switch(fn.kind()) {
@@ -311,6 +429,12 @@ public class DiscoverConflicts {
            if (tfpset!=null)
              tfset.addAll(tfpset);
          }
+         if (updatedtypemap!=null&&updatedtypemap.get(fen).contains(fen.getSrc().getType())) {
+           //this could cause conflict...figure out conflict set
+           Set<TempFlatPair> tfpset=tmap.get(fen.getSrc());
+           if (tfpset!=null)
+             writeset.addAll(tfpset);
+         }
          break;
        }
        case FKind.FlatFieldNode: { 
@@ -321,6 +445,12 @@ public class DiscoverConflicts {
            if (tfpset!=null)
              tfset.addAll(tfpset);
          }
+         if (updatedfieldmap!=null&&updatedfieldmap.get(ffn).contains(ffn.getField())) {
+           //this could cause conflict...figure out conflict set
+           Set<TempFlatPair> tfpset=tmap.get(ffn.getSrc());
+           if (tfpset!=null)
+             writeset.addAll(tfpset);
+         }
          break;
        }
        case FKind.FlatSetFieldNode: { 
@@ -329,6 +459,10 @@ public class DiscoverConflicts {
          Set<TempFlatPair> tfpset=tmap.get(fsfn.getDst());
          if (tfpset!=null)
            tfset.addAll(tfpset);
+         if (writeset!=null) {
+           if (tfpset!=null)
+             writeset.addAll(tfpset);
+         }
          break;
        }
        case FKind.FlatSetElementNode: { 
@@ -337,6 +471,10 @@ public class DiscoverConflicts {
          Set<TempFlatPair> tfpset=tmap.get(fsen.getDst());
          if (tfpset!=null)
            tfset.addAll(tfpset);
+         if (writeset!=null) {
+           if (tfpset!=null)
+             writeset.addAll(tfpset);
+         }
          break;
        }
        case FKind.FlatCall: //assume pessimistically that calls do bad things
@@ -347,6 +485,10 @@ public class DiscoverConflicts {
            Set<TempFlatPair> tfpset=tmap.get(rtmp);
            if (tfpset!=null)
              tfset.addAll(tfpset);
+           if (writeset!=null) {
+             if (tfpset!=null)
+               writeset.addAll(tfpset);
+           }
          }
          break;
        }
@@ -430,14 +572,31 @@ public class DiscoverConflicts {
            }
            break;
          }
-         case FKind.FlatOpNode: {
-           FlatOpNode fon=(FlatOpNode)fn;
-           if (fon.getOp().getOp()==Operation.ASSIGN&&fon.getDest().getType().isPtr()&&
-               ttofn.containsKey(fon.getLeft())) {
-             ttofn.put(fon.getDest(), new HashSet<TempFlatPair>(ttofn.get(fon.getLeft())));
-             break;
+         case FKind.FlatCastNode:
+         case FKind.FlatOpNode: 
+           if (fn.kind()==FKind.FlatCastNode) {
+             FlatCastNode fcn=(FlatCastNode)fn;
+             if (fcn.getDst().getType().isPtr()) {
+               HashSet<TempFlatPair> set=new HashSet<TempFlatPair>();
+               if (ttofn.containsKey(fcn.getSrc()))
+                 set.addAll(ttofn.get(fcn.getSrc()));
+               if (normalassign)
+                 set.add(new TempFlatPair(fcn.getDst(), fn));
+               ttofn.put(fcn.getDst(), set);
+               break;
+             }
+           } else if (fn.kind()==FKind.FlatOpNode) {
+             FlatOpNode fon=(FlatOpNode)fn;
+             if (fon.getOp().getOp()==Operation.ASSIGN&&fon.getDest().getType().isPtr()) {
+               HashSet<TempFlatPair> set=new HashSet<TempFlatPair>();
+               if (ttofn.containsKey(fon.getLeft()))
+                 set.addAll(ttofn.get(fon.getLeft()));
+               if (normalassign)
+                 set.add(new TempFlatPair(fon.getDest(), fn));
+               ttofn.put(fon.getDest(), set);
+               break;
+             }
            }
-         }
          default:
            //Do kill computation
            TempDescriptor[] writes=fn.writesTemps();
@@ -466,6 +625,9 @@ public class DiscoverConflicts {
   /* See what fields and arrays transactions might modify.  We only
    * look at changes to old objects. */
 
+  //Bug fix: original version forget to check if object is new and
+  //could be optimized
+
   public void computeModified(LocalityBinding lb) {
     MethodDescriptor md=lb.getMethod();
     FlatMethod fm=state.getMethodFlat(md);
@@ -474,14 +636,17 @@ public class DiscoverConflicts {
       FlatNode fn=fnit.next();
       Hashtable<FlatNode, Integer> atomictable=locality.getAtomic(lb);
       if (atomictable.get(fn).intValue()>0) {
+       Set<TempDescriptor> oldtemp=oldtemps.get(fn);
        switch (fn.kind()) {
        case FKind.FlatSetFieldNode:
          FlatSetFieldNode fsfn=(FlatSetFieldNode) fn;
-         fields.add(fsfn.getField());
+         if (oldtemp.contains(fsfn.getDst()))
+           fields.add(fsfn.getField());
          break;
        case FKind.FlatSetElementNode:
          FlatSetElementNode fsen=(FlatSetElementNode) fn;
-         arrays.add(fsen.getDst().getType());
+         if (oldtemp.contains(fsen.getDst()))
+           arrays.add(fsen.getDst().getType());
          break;
        default:
        }
@@ -541,16 +706,27 @@ public class DiscoverConflicts {
          case FKind.FlatNew:
            oldtemps.removeAll(Arrays.asList(fn.readsTemps()));
            break;
-         case FKind.FlatOpNode: {
-           FlatOpNode fon=(FlatOpNode)fn;
-           if (fon.getOp().getOp()==Operation.ASSIGN&&fon.getDest().getType().isPtr()) {
-             if (oldtemps.contains(fon.getLeft()))
-               oldtemps.add(fon.getDest());
-             else
-               oldtemps.remove(fon.getDest());
-             break;
+         case FKind.FlatOpNode:
+         case FKind.FlatCastNode: 
+           if (fn.kind()==FKind.FlatCastNode) {
+             FlatCastNode fcn=(FlatCastNode)fn;
+             if (fcn.getDst().getType().isPtr()) {
+               if (oldtemps.contains(fcn.getSrc()))
+                 oldtemps.add(fcn.getDst());
+               else
+                 oldtemps.remove(fcn.getDst());
+               break;
+             }
+           } else if (fn.kind()==FKind.FlatOpNode) {
+             FlatOpNode fon=(FlatOpNode)fn;
+             if (fon.getOp().getOp()==Operation.ASSIGN&&fon.getDest().getType().isPtr()) {
+               if (oldtemps.contains(fon.getLeft()))
+                 oldtemps.add(fon.getDest());
+               else
+                 oldtemps.remove(fon.getDest());
+               break;
+             }
            }
-         }
          default: {
            TempDescriptor[] writes=fn.writesTemps();
            for(int i=0;i<writes.length;i++) {