96375a4ee79fa8f5ef24743a0eae0fac44c98adb
[IRC.git] / Robust / src / Analysis / Loops / LoopTerminate.java
1 package Analysis.Loops;
2
3 import java.util.HashSet;
4 import java.util.Hashtable;
5 import java.util.Iterator;
6 import java.util.Set;
7
8 import IR.Operation;
9 import IR.Flat.FKind;
10 import IR.Flat.FlatCondBranch;
11 import IR.Flat.FlatLiteralNode;
12 import IR.Flat.FlatMethod;
13 import IR.Flat.FlatNode;
14 import IR.Flat.FlatOpNode;
15 import IR.Flat.TempDescriptor;
16
17 public class LoopTerminate {
18
19   FlatMethod fm;
20   LoopInvariant loopInv;
21   Set<TempDescriptor> inductionSet;
22
23   public LoopTerminate(FlatMethod fm, LoopInvariant loopInv) {
24     this.fm = fm;
25     this.loopInv = loopInv;
26     this.inductionSet = new HashSet<TempDescriptor>();
27   }
28
29   public void terminateAnalysis() {
30     Loops loopFinder = loopInv.root;
31     if (loopFinder.nestedLoops().size() > 0) {
32       for (Iterator lpit = loopFinder.nestedLoops().iterator(); lpit.hasNext();) {
33         Loops loop = (Loops) lpit.next();
34         Set entrances = loop.loopEntrances();
35         processLoop(fm, loop, loopInv);
36       }
37     }
38   }
39
40   public void processLoop(FlatMethod fm, Loops l, LoopInvariant loopInv) {
41
42     boolean changed = true;
43
44     Set elements = l.loopIncElements(); // loop body elements
45     Set entrances = l.loopEntrances(); // loop entrance
46     assert entrances.size() == 1;
47     FlatNode entrance = (FlatNode) entrances.iterator().next();
48
49     // mapping from Induction Variable TempDescriptor to Definiton Flat Node
50     Hashtable<TempDescriptor, FlatNode> inductionVar2DefNode =
51         new Hashtable<TempDescriptor, FlatNode>();
52
53     // mapping from Derived Induction Variable TempDescriptor to its root
54     // induction variable TempDescriptor
55     Hashtable<TempDescriptor, TempDescriptor> derivedVar2basicInduction =
56         new Hashtable<TempDescriptor, TempDescriptor>();
57
58     Set<FlatNode> computed = new HashSet<FlatNode>();
59
60     // #1 find out basic induction variable
61     // variable i is a basic induction variable in loop if the only definitions
62     // of i within L are of the form i=i+c where c is loop invariant
63     for (Iterator elit = elements.iterator(); elit.hasNext();) {
64       FlatNode fn = (FlatNode) elit.next();
65       if (fn.kind() == FKind.FlatOpNode) {
66         FlatOpNode fon = (FlatOpNode) fn;
67         int op = fon.getOp().getOp();
68         if (op == Operation.ADD /* || op == Operation.SUB */) {
69           TempDescriptor tdLeft = fon.getLeft();
70           TempDescriptor tdRight = fon.getRight();
71
72           boolean isLeftLoopInvariant = isLoopInvariantVar(l, fn, tdLeft);
73           boolean isRightLoopInvariant = isLoopInvariantVar(l, fn, tdRight);
74
75           if (isLeftLoopInvariant ^ isRightLoopInvariant) {
76
77             TempDescriptor candidateTemp;
78
79             if (isLeftLoopInvariant) {
80               candidateTemp = tdRight;
81             } else {
82               candidateTemp = tdLeft;
83             }
84
85             Set<FlatNode> defSetOfLoop = getDefinitionInsideLoop(l, fn, candidateTemp);
86             // basic induction variable must have only one definition within the
87             // loop
88             if (defSetOfLoop.size() == 1) {
89               FlatNode indNode = defSetOfLoop.iterator().next();
90               assert indNode.readsTemps().length == 1;
91               TempDescriptor readTemp = indNode.readsTemps()[0];
92               if (readTemp.equals(fon.getDest())) {
93                 inductionVar2DefNode.put(candidateTemp, defSetOfLoop.iterator().next());
94                 inductionVar2DefNode.put(readTemp, defSetOfLoop.iterator().next());
95                 inductionSet.add(readTemp);
96                 inductionSet.add(candidateTemp);
97                 computed.add(fn);
98               }
99
100             }
101
102           }
103
104         }
105       }
106     }
107
108     // #2 detect derived induction variables
109     // variable k is a derived induction variable if
110     // 1) there is only one definition of k within the loop, of the form k=j*c
111     // or k=j+d where j is induction variable, c, d are loop-invariant
112     // 2) and if j is a derived induction variable in the family of i, then:
113     // (a) the only definition of j that reaches k is the one in the loop
114     // (b) and there is no definition of i on any path between the definition of
115     // j and the definition of k
116
117     Set<TempDescriptor> basicInductionSet = new HashSet<TempDescriptor>();
118     basicInductionSet.addAll(inductionSet);
119
120     while (changed) {
121       changed = false;
122       for (Iterator elit = elements.iterator(); elit.hasNext();) {
123         FlatNode fn = (FlatNode) elit.next();
124         if (!computed.contains(fn)) {
125           if (fn.kind() == FKind.FlatOpNode) {
126             FlatOpNode fon = (FlatOpNode) fn;
127             int op = fon.getOp().getOp();
128             if (op == Operation.ADD || op == Operation.MULT) {
129               TempDescriptor tdLeft = fon.getLeft();
130               TempDescriptor tdRight = fon.getRight();
131               TempDescriptor tdDest = fon.getDest();
132
133               boolean isLeftLoopInvariant = isLoopInvariantVar(l, fn, tdLeft);
134               boolean isRightLoopInvariant = isLoopInvariantVar(l, fn, tdRight);
135
136               if (isLeftLoopInvariant ^ isRightLoopInvariant) {
137                 TempDescriptor inductionOp;
138                 if (isLeftLoopInvariant) {
139                   inductionOp = tdRight;
140                 } else {
141                   inductionOp = tdLeft;
142                 }
143
144                 if (inductionSet.contains(inductionOp)) {
145                   // find new derived one k
146
147                   if (!basicInductionSet.contains(inductionOp)) {
148                     // check if only definition of j that reaches k is the one
149                     // in the loop
150                     Set defSet = getDefinitionInsideLoop(l, fn, inductionOp);
151                     if (defSet.size() == 1) {
152                       // check if there is no def of i on any path bet' def of j
153                       // and def of k
154
155                       TempDescriptor originInduc = derivedVar2basicInduction.get(inductionOp);
156                       FlatNode defI = inductionVar2DefNode.get(originInduc);
157                       FlatNode defJ = inductionVar2DefNode.get(inductionOp);
158                       FlatNode defk = fn;
159
160                       if (!checkPath(defI, defJ, defk)) {
161                         continue;
162                       }
163
164                     }
165                   }
166                   // add new induction var
167
168                   Set<FlatNode> setUseNode = loopInv.usedef.useMap(fn, tdDest);
169                   assert setUseNode.size() == 1;
170                   assert setUseNode.iterator().next().writesTemps().length == 1;
171
172                   TempDescriptor derivedInd = setUseNode.iterator().next().writesTemps()[0];
173                   FlatNode defNode = setUseNode.iterator().next();
174
175                   computed.add(fn);
176                   computed.add(defNode);
177                   inductionSet.add(derivedInd);
178                   inductionVar2DefNode.put(derivedInd, defNode);
179                   derivedVar2basicInduction.put(derivedInd, inductionOp);
180                   changed = true;
181                 }
182
183               }
184
185             }
186
187           }
188         }
189
190       }
191     }
192
193     // #3 check condition branch
194     // In the loop, every guard condition of the loop must be composed by
195     // induction & invariants
196     // every guard condition of the if-statement that leads it to the exit must
197     // be composed by induction&invariants
198
199     Set<FlatNode> tovisit = new HashSet<FlatNode>();
200     Set<FlatNode> visited = new HashSet<FlatNode>();
201     tovisit.add(entrance);
202
203     int countLoopGuardCondtion = 0;
204     int countLoop = 0;
205     while (!tovisit.isEmpty()) {
206       FlatNode fnvisit = tovisit.iterator().next();
207       tovisit.remove(fnvisit);
208       visited.add(fnvisit);
209
210       if (fnvisit.kind() == FKind.FlatCondBranch) {
211         FlatCondBranch fcb = (FlatCondBranch) fnvisit;
212
213         if (fcb.isLoopBranch()) {
214           countLoop++;
215         }
216
217         if (fcb.isLoopBranch() || hasLoopExitNode(fcb, true, entrance, elements)) {
218           // current FlatCondBranch can introduce loop exits
219           // in this case, guard condition of it should be composed only by loop
220           // invariants and induction variables
221           Set<FlatNode> condSet = getDefinitionInsideLoop(l, fnvisit, fcb.getTest());
222           assert condSet.size() == 1;
223           FlatNode condFn = condSet.iterator().next();
224           // assume that guard condition node is always a conditional inequality
225           if (condFn instanceof FlatOpNode) {
226             FlatOpNode condOp = (FlatOpNode) condFn;
227             // check if guard condition is composed only with induction
228             // variables
229             if (checkConditionNode(l, condOp, fcb.isLoopBranch())) {
230               countLoopGuardCondtion++;
231             }
232           }
233         }
234       }
235
236       for (int i = 0; i < fnvisit.numNext(); i++) {
237         FlatNode next = fnvisit.getNext(i);
238         if (!visited.contains(next)) {
239           tovisit.add(next);
240         }
241       }
242
243     }
244
245     // # of must-terminate loop condition must be equal to or larger than # of
246     // loop
247     if (countLoopGuardCondtion < countLoop) {
248       throw new Error("Loop may never terminate at "
249           + fm.getMethod().getClassDesc().getSourceFileName() + "::" + entrance.numLine);
250     }
251
252   }
253
254   private boolean checkPath(FlatNode def, FlatNode start, FlatNode end) {
255
256     // return true if there is no def in-bet start and end
257
258     Set<FlatNode> endSet = new HashSet<FlatNode>();
259     endSet.add(end);
260     if ((start.getReachableSet(endSet)).contains(def)) {
261       return false;
262     }
263
264     return true;
265   }
266
267   private boolean checkConditionNode(Loops l, FlatOpNode fon, boolean isLoopCondition) {
268     // check flatOpNode that computes loop guard condition
269     // currently we assume that induction variable is always getting bigger
270     // and guard variable is constant
271     // so need to check (1) one of operand should be induction variable
272     // (2) another operand should be constant or loop invariant
273
274     TempDescriptor induction = null;
275     TempDescriptor guard = null;
276
277     int op = fon.getOp().getOp();
278     if (op == Operation.LT || op == Operation.LTE) {
279       if (isLoopCondition) {
280         // loop condition is inductionVar <= loop invariant
281         induction = fon.getLeft();
282         guard = fon.getRight();
283       } else {
284         // if-statement condition is loop invariant <= inductionVar since
285         // inductionVar is getting biggier each iteration
286         induction = fon.getRight();
287         guard = fon.getLeft();
288       }
289     } else if (op == Operation.GT || op == Operation.GTE) {
290       if (isLoopCondition) {
291         // condition is loop invariant >= inductionVar
292         induction = fon.getRight();
293         guard = fon.getLeft();
294       } else {
295         // if-statement condition is loop inductionVar >= invariant
296         induction = fon.getLeft();
297         guard = fon.getRight();
298       }
299     } else {
300       return false;
301     }
302
303     // here, verify that guard operand is an induction variable
304     if (!hasInductionVar(l, fon, induction)) {
305       return false;
306     }
307
308     if (guard != null) {
309       Set guardDefSet = getDefinitionInsideLoop(l, fon, guard);
310       for (Iterator iterator = guardDefSet.iterator(); iterator.hasNext();) {
311         FlatNode guardDef = (FlatNode) iterator.next();
312         if (!(guardDef instanceof FlatLiteralNode) && !loopInv.hoisted.contains(guardDef)) {
313           return false;
314         }
315       }
316     }
317
318     return true;
319   }
320
321   private boolean hasInductionVar(Loops l, FlatNode fn, TempDescriptor td) {
322     // check if TempDescriptor td has at least one induction variable and is
323     // composed only by induction vars +loop invariants
324
325     if (inductionSet.contains(td)) {
326       return true;
327     } else {
328       // check if td is composed by induction variables or loop invariants
329       Set<FlatNode> defSet = getDefinitionInsideLoop(l, fn, td);
330       for (Iterator iterator = defSet.iterator(); iterator.hasNext();) {
331         FlatNode defNode = (FlatNode) iterator.next();
332
333         int inductionVarCount = 0;
334         TempDescriptor[] readTemps = defNode.readsTemps();
335         for (int i = 0; i < readTemps.length; i++) {
336           if (!hasInductionVar(l, defNode, readTemps[i])) {
337             if (!isLoopInvariantVar(l, defNode, readTemps[i])) {
338               return false;
339             }
340           } else {
341             inductionVarCount++;
342           }
343         }
344
345         // check definition of td has at least one induction var
346         if (inductionVarCount > 0) {
347           return true;
348         }
349
350       }
351
352       return false;
353     }
354
355   }
356
357   private boolean isLoopInvariantVar(Loops l, FlatNode fn, TempDescriptor td) {
358
359     Set elements = l.loopIncElements();
360     Set<FlatNode> defset = loopInv.usedef.defMap(fn, td);
361
362     Set<FlatNode> defSetOfLoop = new HashSet<FlatNode>();
363     for (Iterator<FlatNode> defit = defset.iterator(); defit.hasNext();) {
364       FlatNode def = defit.next();
365       if (elements.contains(def)) {
366         defSetOfLoop.add(def);
367       }
368     }
369
370     if (defSetOfLoop.size() == 0) {
371       // all definition comes from outside the loop
372       // so it is loop invariant
373       return true;
374     } else if (defSetOfLoop.size() == 1) {
375       // check if def is 1) constant node or 2) loop invariant
376       FlatNode defFlatNode = defSetOfLoop.iterator().next();
377       if (defFlatNode instanceof FlatLiteralNode || loopInv.hoisted.contains(defFlatNode)) {
378         return true;
379       }
380     }
381
382     return false;
383
384   }
385
386   private Set<FlatNode> getUseSetOfLoop(FlatNode fn, TempDescriptor td, Set loopElements) {
387
388     Set<FlatNode> useSetOfLoop = new HashSet<FlatNode>();
389
390     Set useSet = loopInv.usedef.useMap(fn, td);
391     for (Iterator iterator = useSet.iterator(); iterator.hasNext();) {
392       FlatNode defFlatNode = (FlatNode) iterator.next();
393       if (loopElements.contains(defFlatNode)) {
394         useSetOfLoop.add(defFlatNode);
395       }
396     }
397
398     return useSetOfLoop;
399
400   }
401
402   private Set<FlatNode> getDefinitionInsideLoop(Loops l, FlatNode fn, TempDescriptor td) {
403
404     Set<FlatNode> defSetOfLoop = new HashSet<FlatNode>();
405     Set loopElements = l.loopIncElements();
406
407     Set defSet = loopInv.usedef.defMap(fn, td);
408     for (Iterator iterator = defSet.iterator(); iterator.hasNext();) {
409       FlatNode defFlatNode = (FlatNode) iterator.next();
410       if (loopElements.contains(defFlatNode)) {
411         defSetOfLoop.add(defFlatNode);
412       }
413     }
414
415     return defSetOfLoop;
416
417   }
418
419   private boolean hasLoopExitNode(FlatCondBranch fcb, boolean fromTrueBlock, FlatNode loopHeader,
420       Set loopElements) {
421
422     if (!fromTrueBlock) {
423       // in this case, FlatCondBranch must have two next flat node, one for true
424       // block and one for false block
425       assert fcb.next.size() == 2;
426     }
427
428     FlatNode next;
429     if (fromTrueBlock) {
430       next = fcb.getNext(0);
431     } else {
432       next = fcb.getNext(1);
433     }
434
435     if (hasLoopExitNode(loopHeader, next, loopElements)) {
436       return true;
437     } else {
438       return false;
439     }
440
441   }
442
443   private boolean hasLoopExitNode(FlatNode loopHeader, FlatNode start, Set loopElements) {
444
445     Set<FlatNode> tovisit = new HashSet<FlatNode>();
446     Set<FlatNode> visited = new HashSet<FlatNode>();
447     tovisit.add(start);
448
449     while (!tovisit.isEmpty()) {
450
451       FlatNode fn = tovisit.iterator().next();
452       tovisit.remove(fn);
453       visited.add(fn);
454
455       // check if this loop exit is derived from start node
456       // if not, it has an exit in regarding to the loop header
457       if (!loopElements.contains(fn)) {
458         return true;
459       }
460
461       for (int i = 0; i < fn.numNext(); i++) {
462         FlatNode next = fn.getNext(i);
463         if (!visited.contains(next)) {
464           if (loopInv.domtree.idom(next).equals(fn)) {
465             // add next node only if current node is immediate dominator of the
466             // next node
467             tovisit.add(next);
468           }
469         }
470       }
471
472     }
473
474     return false;
475
476   }
477 }