finish branch elimination optimization for fission code
authorbdemsky <bdemsky>
Fri, 30 Oct 2009 00:14:10 +0000 (00:14 +0000)
committerbdemsky <bdemsky>
Fri, 30 Oct 2009 00:14:10 +0000 (00:14 +0000)
Robust/src/Analysis/Locality/BranchAnalysis.java

index 16ef0d20f0259ec610524e5dfc97dccfcb33eb0a..31a1402666573e12f6f1aa49e8226ae555766457 100644 (file)
@@ -1,12 +1,16 @@
 package Analysis.Locality;
+import IR.State;
+import IR.Flat.*;
+import java.util.*;
+import java.io.*;
 
 public class BranchAnalysis {
   LocalityAnalysis locality;
   State state;
-  public BranchAnalysis(Locality locality, LocalityAnalysis lb, Set<FlatNode> nodeset, State state) {
+  public BranchAnalysis(LocalityAnalysis locality, LocalityBinding lb, Set<FlatNode> nodeset, Set<FlatNode> storeset, State state) {
     this.locality=locality;
     this.state=state;
-    doAnalysis(lb, nodeset);
+    doAnalysis(lb, nodeset, storeset);
   }
 
   Hashtable<Set<FlatNode>, Vector<FlatNode>> table=new Hashtable<Set<FlatNode>, Vector<FlatNode>>();
@@ -35,14 +39,78 @@ public class BranchAnalysis {
     return exits.size();
   }
 
-  public void doAnalysis(LocalityAnalysis lb, Set<FlatNode> nodeset) {
+  public Vector<FlatNode> getJumps(FlatNode fn) {
+    Set<FlatNode> group=groupmap.get(fn);
+    if (group==null)
+      throw new Error();
+    Vector<FlatNode> exits=table.get(group);
+    return exits;
+  }
+
+  public Set<FlatNode> getTargets() {
+    HashSet<FlatNode> targets=new HashSet<FlatNode>();
+    Collection<Set<FlatNode>> groups=groupmap.values();
+    for(Iterator<Set<FlatNode>> setit=groups.iterator();setit.hasNext();) {
+      Set<FlatNode> group=setit.next();
+      targets.addAll(table.get(group));
+    }
+    return targets;
+  }
+
+  int grouplabelindex=0;
+
+  public boolean hasGroup(FlatNode fn) {
+    return groupmap.contains(fn);
+  }
+
+  Hashtable<Set<FlatNode>, String> grouplabel=new Hashtable<Set<FlatNode>, String>();
+
+  private boolean seenGroup(FlatNode fn) {
+    return grouplabel.containsKey(groupmap.get(fn));
+  }
+
+  private String getGroup(FlatNode fn) {
+    if (!grouplabel.containsKey(groupmap.get(fn)))
+      grouplabel.put(groupmap.get(fn), new String("LG"+(grouplabelindex++)));
+    return grouplabel.get(groupmap.get(fn));
+  }
+
+  public void generateGroupCode(FlatNode fn, PrintWriter output, Hashtable<FlatNode, Integer> nodetolabels) {
+    if (seenGroup(fn)) {
+      String label=getGroup(fn);
+      output.println("goto "+label+";");
+    } else {
+      String label=getGroup(fn);
+      output.println(label+":");
+      if (numJumps(fn)==1) {
+       FlatNode fndst=getJumps(fn).get(0);
+       output.println("goto "+nodetolabels.get(fndst)+";");
+      } else if (numJumps(fn)==2) {
+       Vector<FlatNode> exits=getJumps(fn);
+       output.println("if(RESTOREBRANCH())");
+       output.println("goto L"+nodetolabels.get(exits.get(1))+";");
+       output.println("else");
+       output.println("goto L"+nodetolabels.get(exits.get(0))+";");
+      } else {
+       Vector<FlatNode> exits=getJumps(fn);
+       output.println("switch(RESTOREBRANCH()) {");
+       for(int i=0;i<exits.size();i++) {
+         output.println("case "+i+":");
+         output.println("goto L"+nodetolabels.get(exits.get(i))+";");
+       }
+       output.println("}");
+      }
+    }
+  }
+
+  public void doAnalysis(LocalityBinding lb, Set<FlatNode> nodeset, Set<FlatNode> storeset) {
     Set<FlatNode> transset=computeTransSet(lb);
-    fnmap=computeMap(transset, nodeset);
+    fnmap=computeMap(transset, nodeset, storeset);
     groupmap=new Hashtable<FlatNode, Set<FlatNode>>();
 
     for(Iterator<FlatNode> fnit=transset.iterator();fnit.hasNext();) {
       FlatNode fn=fnit.next();
-      if (fn.numNext()>1) {
+      if (fn.numNext()>1&&storeset.contains(fn)) {
        FlatNode[] children=fnmap.get(fn);
        if (!groupmap.containsKey(fn)) {
          groupmap.put(fn, new HashSet<FlatNode>());
@@ -50,9 +118,9 @@ public class BranchAnalysis {
        }
        for(int i=0;i<children.length;i++) {
          FlatNode child=children[i];
-         if (child.numNext()>1)
+         if (child.numNext()>1&&storeset.contains(child))
            mergegroups(fn, child, groupmap);
-       }
+         }
       }
     }
     //now we have groupings...
@@ -95,7 +163,7 @@ public class BranchAnalysis {
     }
   }
 
-  public Hashtable<FlatNode, FlatNode[]> computeMap(Set<FlatNode> transset, Set<FlatNode> nodeset) {
+  public Hashtable<FlatNode, FlatNode[]> computeMap(Set<FlatNode> transset, Set<FlatNode> nodeset, Set<FlatNode> storeset) {
     Set<FlatNode> toprocess=new HashSet<FlatNode>();
     toprocess.addAll(transset);
     Hashtable<FlatNode, Set<Object[]>> fntotuple=new Hashtable<FlatNode, Set<Object[]>>();
@@ -107,21 +175,22 @@ public class BranchAnalysis {
 
       for(int i=0;i<fn.numPrev();i++) {
        FlatNode fprev=fn.getPrev(i);
-       if (nodeset.contains(fprev)) {
+       if (nodeset.contains(fprev)||storeset.contains(fprev)) {
          for(int j=0;j<fprev.numNext();j++) {
            if (fprev.getNext(j)==fn) {
              Object[] pair=new Object[2];
-             pair[0]=new Integer(j);pair[1]=fn;
+             pair[0]=new Integer(j);pair[1]=fprev;
              incomingtuples.add(pair);
            }
          }
        } else {
          Set<Object[]> tuple=fntotuple.get(fprev);
-         incomingtuples.addAll(tuple);
+         if (tuple!=null)
+           incomingtuples.addAll(tuple);
        }
       }
 
-      if (nodeset.contains(fn)) {
+      if (nodeset.contains(fn)||storeset.contains(fn)||fn.kind()==FKind.FlatAtomicExitNode) {
        //nodeset contains this node
        for(Iterator<Object[]> it=incomingtuples.iterator();it.hasNext();) {
          Object[] pair=it.next();
@@ -137,7 +206,7 @@ public class BranchAnalysis {
       //add if we need to update
       if (!fntotuple.containsKey(fn)||
          !fntotuple.get(fn).equals(incomingtuples)) {
-       tntotuple.put(fn,incomingtuples);
+       fntotuple.put(fn,incomingtuples);
        for(int i=0;i<fn.numNext();i++) {
          if (transset.contains(fn.getNext(i)))
            toprocess.add(fn.getNext(i));
@@ -148,14 +217,14 @@ public class BranchAnalysis {
   }
 
 
-  public Set<FlatNode> computeTransSet(LocalityAnalysis lb) {
+  public Set<FlatNode> computeTransSet(LocalityBinding lb) {
     Set<FlatNode> transset=new HashSet();
     Set<FlatNode> tovisit=new HashSet();
     tovisit.addAll(state.getMethodFlat(lb.getMethod()).getNodeSet());
     while(!tovisit.isEmpty()) {
       FlatNode fn=tovisit.iterator().next();
       tovisit.remove(fn);
-      if (locality.getAtomic(lb).get(fn).intValue()>0)
+      if (locality.getAtomic(lb).get(fn).intValue()>0||fn.kind()==FKind.FlatAtomicExitNode)
        transset.add(fn);
     }
     return transset;