changes on the loop termination analysis: associate a labeled statement with a corres...
[IRC.git] / Robust / src / Analysis / Loops / LoopTerminate.java
1 package Analysis.Loops;
2
3 import java.util.HashMap;
4 import java.util.HashSet;
5 import java.util.Iterator;
6 import java.util.Set;
7
8 import Analysis.SSJava.SSJavaAnalysis;
9 import IR.Operation;
10 import IR.State;
11 import IR.Flat.FKind;
12 import IR.Flat.FlatCondBranch;
13 import IR.Flat.FlatMethod;
14 import IR.Flat.FlatNode;
15 import IR.Flat.FlatOpNode;
16 import IR.Flat.TempDescriptor;
17
18 public class LoopTerminate {
19
20   private FlatMethod fm;
21   private LoopInvariant loopInv;
22   private Set<TempDescriptor> inductionSet;
23   // mapping from Induction Variable TempDescriptor to Flat Node that defines
24   // it
25   private HashMap<TempDescriptor, FlatNode> inductionVar2DefNode;
26
27   // mapping from Derived Induction Variable TempDescriptor to its root
28   // induction variable TempDescriptor
29   private HashMap<TempDescriptor, TempDescriptor> derivedVar2basicInduction;
30
31   Set<FlatNode> computed;
32
33   State state;
34   SSJavaAnalysis ssjava;
35
36   /**
37    * Constructor for Loop Termination Analysis
38    */
39   public LoopTerminate(SSJavaAnalysis ssjava, State state) {
40     this.ssjava = ssjava;
41     this.state = state;
42     this.inductionSet = new HashSet<TempDescriptor>();
43     this.inductionVar2DefNode = new HashMap<TempDescriptor, FlatNode>();
44     this.derivedVar2basicInduction = new HashMap<TempDescriptor, TempDescriptor>();
45     this.computed = new HashSet<FlatNode>();
46   }
47
48   /**
49    * starts loop termination analysis
50    * 
51    * @param fm
52    *          FlatMethod for termination analysis
53    * @param loopInv
54    *          LoopInvariants for given method
55    */
56   public void terminateAnalysis(FlatMethod fm, LoopInvariant loopInv) {
57     this.fm = fm;
58     this.loopInv = loopInv;
59     Loops loopFinder = loopInv.root;
60     recurse(fm, loopFinder);
61   }
62
63   /**
64    * iterate over the current level of loops and spawn analysis for its child
65    * 
66    * @param fm
67    *          FlatMethod for loop analysis
68    * @param parent
69    *          the current level of loop
70    */
71   private void recurse(FlatMethod fm, Loops parent) {
72     for (Iterator lpit = parent.nestedLoops().iterator(); lpit.hasNext();) {
73       Loops child = (Loops) lpit.next();
74       processLoop(fm, child);
75       recurse(fm, child);
76     }
77   }
78
79   /**
80    * initialize internal data structure
81    */
82   private void init() {
83     inductionSet.clear();
84     inductionVar2DefNode.clear();
85     derivedVar2basicInduction.clear();
86   }
87
88   /**
89    * analysis loop for termination property
90    * 
91    * @param fm
92    *          FlatMethod that contains loop l
93    * @param l
94    *          analysis target loop l
95    */
96   private void processLoop(FlatMethod fm, Loops l) {
97
98     Set loopElements = l.loopIncElements(); // loop body elements
99     Set loopEntrances = l.loopEntrances(); // loop entrance
100     assert loopEntrances.size() == 1;
101     FlatNode loopEntrance = (FlatNode) loopEntrances.iterator().next();
102
103     String loopLabel = (String) state.fn2labelMap.get(loopEntrance);
104
105     if (loopLabel == null || !loopLabel.startsWith(ssjava.TERMINATE)) {
106       init();
107       detectBasicInductionVar(loopElements);
108       detectDerivedInductionVar(loopElements);
109       checkConditionBranch(loopEntrance, loopElements);
110     }
111
112   }
113
114   /**
115    * check if condition branch node satisfies loop condition
116    * 
117    * @param loopEntrance
118    *          loop entrance flat node
119    * @param loopElements
120    *          elements of current loop and all nested loop
121    */
122   private void checkConditionBranch(FlatNode loopEntrance, Set loopElements) {
123     // In the loop, every guard condition of the loop must be composed by
124     // induction & invariants
125     // every guard condition of the if-statement that leads it to the exit must
126     // be composed by induction&invariants
127
128     Set<FlatNode> tovisit = new HashSet<FlatNode>();
129     Set<FlatNode> visited = new HashSet<FlatNode>();
130     tovisit.add(loopEntrance);
131
132     int numMustTerminateGuardCondtion = 0;
133     int numLoop = 0;
134     while (!tovisit.isEmpty()) {
135       FlatNode fnvisit = tovisit.iterator().next();
136       tovisit.remove(fnvisit);
137       visited.add(fnvisit);
138
139       if (fnvisit.kind() == FKind.FlatCondBranch) {
140         FlatCondBranch fcb = (FlatCondBranch) fnvisit;
141
142         if (fcb.isLoopBranch()) {
143           numLoop++;
144         }
145
146         if (fcb.isLoopBranch() || hasLoopExitNode(fcb, true, loopEntrance, loopElements)) {
147           // current FlatCondBranch can introduce loop exits
148           // in this case, guard condition of it should be composed only by loop
149           // invariants and induction variables
150           Set<FlatNode> condSet = getDefinitionInLoop(fnvisit, fcb.getTest(), loopElements);
151           assert condSet.size() == 1;
152           FlatNode condFn = condSet.iterator().next();
153           // assume that guard condition node is always a conditional inequality
154           if (condFn instanceof FlatOpNode) {
155             FlatOpNode condOp = (FlatOpNode) condFn;
156             // check if guard condition is composed only with induction
157             // variables
158             if (checkConditionNode(condOp, fcb.isLoopBranch(), loopElements)) {
159               numMustTerminateGuardCondtion++;
160             } else {
161               if (!fcb.isLoopBranch()) {
162                 // if it is if-condition and it leads us to loop exit,
163                 // corresponding guard condition should be composed by induction
164                 // and invariants
165                 throw new Error("Loop may never terminate at "
166                     + fm.getMethod().getClassDesc().getSourceFileName() + "::"
167                     + loopEntrance.numLine);
168               }
169             }
170           }
171         }
172       }
173
174       for (int i = 0; i < fnvisit.numNext(); i++) {
175         FlatNode next = fnvisit.getNext(i);
176         if (loopElements.contains(next) && !visited.contains(next)) {
177           tovisit.add(next);
178         }
179       }
180
181     }
182
183     // # of must-terminate loop condition must be equal to or larger than # of
184     // loop
185     if (numMustTerminateGuardCondtion < numLoop) {
186       throw new Error("Loop may never terminate at "
187           + fm.getMethod().getClassDesc().getSourceFileName() + "::" + loopEntrance.numLine);
188     }
189
190   }
191
192   /**
193    * detect derived induction variable
194    * 
195    * @param loopElements
196    *          elements of current loop and all nested loop
197    */
198   private void detectDerivedInductionVar(Set loopElements) {
199     // 2) detect derived induction variables
200     // variable k is a derived induction variable if
201     // 1) there is only one definition of k within the loop, of the form k=j*c
202     // or k=j+d where j is induction variable, c, d are loop-invariant
203     // 2) and if j is a derived induction variable in the family of i, then:
204     // (a) the only definition of j that reaches k is the one in the loop
205     // (b) and there is no definition of i on any path between the definition of
206     // j and the definition of k
207
208     boolean changed = true;
209     Set<TempDescriptor> basicInductionSet = new HashSet<TempDescriptor>();
210     basicInductionSet.addAll(inductionSet);
211
212     while (changed) {
213       changed = false;
214       nextfn: for (Iterator elit = loopElements.iterator(); elit.hasNext();) {
215         FlatNode fn = (FlatNode) elit.next();
216         if (!computed.contains(fn)) {
217           if (fn.kind() == FKind.FlatOpNode) {
218             FlatOpNode fon = (FlatOpNode) fn;
219             int op = fon.getOp().getOp();
220             if (op == Operation.ADD || op == Operation.MULT) {
221               TempDescriptor tdLeft = fon.getLeft();
222               TempDescriptor tdRight = fon.getRight();
223               TempDescriptor tdDest = fon.getDest();
224
225               boolean isLeftLoopInvariant = isLoopInvariantVar(fn, tdLeft, loopElements);
226               boolean isRightLoopInvariant = isLoopInvariantVar(fn, tdRight, loopElements);
227
228               if (isLeftLoopInvariant ^ isRightLoopInvariant) {
229                 TempDescriptor inductionOp;
230                 if (isLeftLoopInvariant) {
231                   inductionOp = tdRight;
232                 } else {
233                   inductionOp = tdLeft;
234                 }
235
236                 if (inductionSet.contains(inductionOp)) {
237                   // find new derived one k
238
239                   if (!basicInductionSet.contains(inductionOp)) {
240                     // in this case, induction variable 'j' is derived from
241                     // basic induction var
242
243                     // check if only definition of j that reaches k is the one
244                     // in the loop
245
246                     Set<FlatNode> defSet = getDefinitionInLoop(fn, inductionOp, loopElements);
247                     if (defSet.size() == 1) {
248                       // check if there is no def of i on any path bet' def of j
249                       // and def of k
250
251                       TempDescriptor originInduc = derivedVar2basicInduction.get(inductionOp);
252                       FlatNode defI = inductionVar2DefNode.get(originInduc);
253                       FlatNode defJ = inductionVar2DefNode.get(inductionOp);
254                       FlatNode defk = fn;
255
256                       if (!checkPath(defI, defJ, defk)) {
257                         continue nextfn;
258                       }
259
260                     }
261                   }
262                   // add new induction var
263
264                   // when tdDest has the form of srctmp(tdDest) = inductionOp +
265                   // loopInvariant
266                   // want to have the definition of srctmp
267                   Set<FlatNode> setUseNode = loopInv.usedef.useMap(fn, tdDest);
268                   assert setUseNode.size() == 1;
269                   assert setUseNode.iterator().next().writesTemps().length == 1;
270
271                   FlatNode srcDefNode = setUseNode.iterator().next();
272                   if (srcDefNode instanceof FlatOpNode) {
273                     if (((FlatOpNode) srcDefNode).getOp().getOp() == Operation.ASSIGN) {
274                       TempDescriptor derivedIndVar = setUseNode.iterator().next().writesTemps()[0];
275                       FlatNode defNode = setUseNode.iterator().next();
276
277                       computed.add(fn);
278                       computed.add(defNode);
279                       inductionSet.add(derivedIndVar);
280                       inductionVar2DefNode.put(derivedIndVar, defNode);
281                       derivedVar2basicInduction.put(derivedIndVar, inductionOp);
282                       changed = true;
283                     }
284                   }
285
286                 }
287
288               }
289
290             }
291
292           }
293         }
294
295       }
296     }
297
298   }
299
300   /**
301    * detect basic induction variable
302    * 
303    * @param loopElements
304    *          elements of current loop and all nested loop
305    */
306   private void detectBasicInductionVar(Set loopElements) {
307     // 1) find out basic induction variable
308     // variable i is a basic induction variable in loop if the only definitions
309     // of i within L are of the form i=i+c where c is loop invariant
310     for (Iterator elit = loopElements.iterator(); elit.hasNext();) {
311       FlatNode fn = (FlatNode) elit.next();
312       if (fn.kind() == FKind.FlatOpNode) {
313         FlatOpNode fon = (FlatOpNode) fn;
314         int op = fon.getOp().getOp();
315         if (op == Operation.ADD) {
316           TempDescriptor tdLeft = fon.getLeft();
317           TempDescriptor tdRight = fon.getRight();
318
319           boolean isLeftLoopInvariant = isLoopInvariantVar(fn, tdLeft, loopElements);
320           boolean isRightLoopInvariant = isLoopInvariantVar(fn, tdRight, loopElements);
321
322           if (isLeftLoopInvariant ^ isRightLoopInvariant) {
323
324             TempDescriptor candidateTemp;
325
326             if (isLeftLoopInvariant) {
327               candidateTemp = tdRight;
328             } else {
329               candidateTemp = tdLeft;
330             }
331
332             Set<FlatNode> defSetOfLoop = getDefinitionInLoop(fn, candidateTemp, loopElements);
333             // basic induction variable must have only one definition within the
334             // loop
335             if (defSetOfLoop.size() == 1) {
336               // find out definition of induction var, form of Flat
337               // Assign:inductionVar = candidateTemp
338               FlatNode indNode = defSetOfLoop.iterator().next();
339               assert indNode.readsTemps().length == 1;
340               TempDescriptor readTemp = indNode.readsTemps()[0];
341               if (readTemp.equals(fon.getDest())) {
342                 inductionVar2DefNode.put(candidateTemp, defSetOfLoop.iterator().next());
343                 inductionVar2DefNode.put(readTemp, defSetOfLoop.iterator().next());
344                 inductionSet.add(fon.getDest());
345                 inductionSet.add(candidateTemp);
346                 computed.add(fn);
347               }
348
349             }
350
351           }
352
353         }
354       }
355     }
356
357   }
358
359   /**
360    * check whether there is no definition node 'def' on any path between 'start'
361    * node and 'end' node
362    * 
363    * @param def
364    * @param start
365    * @param end
366    * @return true if there is no def in-bet start and end
367    */
368   private boolean checkPath(FlatNode def, FlatNode start, FlatNode end) {
369     Set<FlatNode> endSet = new HashSet<FlatNode>();
370     endSet.add(end);
371     return !(start.getReachableSet(endSet)).contains(def);
372   }
373
374   /**
375    * check condition node satisfies termination condition
376    * 
377    * @param fon
378    *          condition node FlatOpNode
379    * @param isLoopCondition
380    *          true if condition is loop condition
381    * @param loopElements
382    *          elements of current loop and all nested loop
383    * @return true if it satisfies termination condition
384    */
385   private boolean checkConditionNode(FlatOpNode fon, boolean isLoopCondition, Set loopElements) {
386     // check flatOpNode that computes loop guard condition
387     // currently we assume that induction variable is always getting bigger
388     // and guard variable is constant
389     // so need to check (1) one of operand should be induction variable
390     // (2) another operand should be constant or loop invariant
391
392     TempDescriptor induction = null;
393     TempDescriptor guard = null;
394
395     int op = fon.getOp().getOp();
396     if (op == Operation.LT || op == Operation.LTE) {
397       if (isLoopCondition) {
398         // loop condition is inductionVar <= loop invariant
399         induction = fon.getLeft();
400         guard = fon.getRight();
401       } else {
402         // if-statement condition is loop invariant <= inductionVar since
403         // inductionVar is getting biggier each iteration
404         induction = fon.getRight();
405         guard = fon.getLeft();
406       }
407     } else if (op == Operation.GT || op == Operation.GTE) {
408       if (isLoopCondition) {
409         // condition is loop invariant >= inductionVar
410         induction = fon.getRight();
411         guard = fon.getLeft();
412       } else {
413         // if-statement condition is loop inductionVar >= invariant
414         induction = fon.getLeft();
415         guard = fon.getRight();
416       }
417     } else {
418       return false;
419     }
420
421     // here, verify that guard operand is an induction variable
422     if (!hasInductionVar(fon, induction, loopElements)) {
423       return false;
424     }
425
426     if (guard != null) {
427       Set guardDefSet = getDefinitionInLoop(fon, guard, loopElements);
428       for (Iterator iterator = guardDefSet.iterator(); iterator.hasNext();) {
429         FlatNode guardDef = (FlatNode) iterator.next();
430         if (!(guardDef.kind() == FKind.FlatLiteralNode) && !loopInv.hoisted.contains(guardDef)) {
431           return false;
432         }
433       }
434     }
435
436     return true;
437   }
438
439   /**
440    * check if TempDescriptor td has at least one induction variable and is
441    * composed only by induction vars +loop invariants
442    * 
443    * @param fn
444    *          FlatNode that contains TempDescriptor 'td'
445    * @param td
446    *          TempDescriptor representing target variable
447    * @param loopElements
448    *          elements of current loop and all nested loop
449    * @return true if 'td' is induction variable
450    */
451   private boolean hasInductionVar(FlatNode fn, TempDescriptor td, Set loopElements) {
452
453     if (inductionSet.contains(td)) {
454       return true;
455     } else {
456       // check if td is composed by induction variables or loop invariants
457       Set<FlatNode> defSet = getDefinitionInLoop(fn, td, loopElements);
458       for (Iterator iterator = defSet.iterator(); iterator.hasNext();) {
459         FlatNode defNode = (FlatNode) iterator.next();
460
461         int inductionVarCount = 0;
462         TempDescriptor[] readTemps = defNode.readsTemps();
463         for (int i = 0; i < readTemps.length; i++) {
464           if (!hasInductionVar(defNode, readTemps[i], loopElements)) {
465             if (!isLoopInvariantVar(defNode, readTemps[i], loopElements)) {
466               return false;
467             }
468           } else {
469             inductionVarCount++;
470           }
471         }
472
473         // check definition of td has at least one induction var
474         if (inductionVarCount > 0) {
475           return true;
476         }
477
478       }
479
480       return false;
481     }
482
483   }
484
485   /**
486    * check if TempDescriptor td is loop invariant variable or constant value wrt
487    * the current loop
488    * 
489    * @param fn
490    *          FlatNode that contains TempDescriptor 'td'
491    * @param td
492    *          TempDescriptor representing target variable
493    * @param loopElements
494    *          elements of current loop and all nested loop
495    * @return true if 'td' is loop invariant variable
496    */
497   private boolean isLoopInvariantVar(FlatNode fn, TempDescriptor td, Set loopElements) {
498
499     Set<FlatNode> defset = loopInv.usedef.defMap(fn, td);
500
501     Set<FlatNode> defSetOfLoop = new HashSet<FlatNode>();
502     for (Iterator<FlatNode> defit = defset.iterator(); defit.hasNext();) {
503       FlatNode def = defit.next();
504       if (loopElements.contains(def)) {
505         defSetOfLoop.add(def);
506       }
507     }
508
509     if (defSetOfLoop.size() == 0) {
510       // all definition comes from outside the loop
511       // so it is loop invariant
512       return true;
513     } else if (defSetOfLoop.size() == 1) {
514       // check if def is 1) constant node or 2) loop invariant
515       FlatNode defFlatNode = defSetOfLoop.iterator().next();
516       if (defFlatNode.kind() == FKind.FlatLiteralNode || loopInv.hoisted.contains(defFlatNode)) {
517         return true;
518       }
519     }
520
521     return false;
522
523   }
524
525   /**
526    * compute the set of definitions of variable 'td' inside of the loop
527    * 
528    * @param fn
529    *          FlatNode that uses 'td'
530    * @param td
531    *          target node that we want to have the set of definitions
532    * @param loopElements
533    *          elements of current loop and all nested loop
534    * @return the set of definition nodes for 'td' in the current loop
535    */
536   private Set<FlatNode> getDefinitionInLoop(FlatNode fn, TempDescriptor td, Set loopElements) {
537
538     Set<FlatNode> defSetOfLoop = new HashSet<FlatNode>();
539
540     Set defSet = loopInv.usedef.defMap(fn, td);
541     for (Iterator iterator = defSet.iterator(); iterator.hasNext();) {
542       FlatNode defFlatNode = (FlatNode) iterator.next();
543       if (loopElements.contains(defFlatNode)) {
544         defSetOfLoop.add(defFlatNode);
545       }
546     }
547
548     return defSetOfLoop;
549
550   }
551
552   /**
553    * check whether FlatCondBranch introduces loop exit
554    * 
555    * @param fcb
556    *          target node
557    * @param fromTrueBlock
558    *          specify which block is possible to have loop exit
559    * @param loopHeader
560    *          loop header of current loop
561    * @param loopElements
562    *          elements of current loop and all nested loop
563    * @return true if input 'fcb' intrroduces loop exit
564    */
565   private boolean hasLoopExitNode(FlatCondBranch fcb, boolean fromTrueBlock, FlatNode loopHeader,
566       Set loopElements) {
567     // return true if fcb possibly introduces loop exit
568
569     FlatNode next;
570     if (fromTrueBlock) {
571       next = fcb.getNext(0);
572     } else {
573       next = fcb.getNext(1);
574     }
575
576     return hasLoopExitNode(loopHeader, next, loopElements);
577
578   }
579
580   /**
581    * check whether start node reaches loop exit
582    * 
583    * @param loopHeader
584    * @param start
585    * @param loopElements
586    * @return true if a path exist from start to loop exit
587    */
588   private boolean hasLoopExitNode(FlatNode loopHeader, FlatNode start, Set loopElements) {
589
590     Set<FlatNode> tovisit = new HashSet<FlatNode>();
591     Set<FlatNode> visited = new HashSet<FlatNode>();
592     tovisit.add(start);
593
594     while (!tovisit.isEmpty()) {
595
596       FlatNode fn = tovisit.iterator().next();
597       tovisit.remove(fn);
598       visited.add(fn);
599
600       for (int i = 0; i < fn.numNext(); i++) {
601         FlatNode next = fn.getNext(i);
602         if (!visited.contains(next)) {
603           // check that if-body statment introduces loop exit.
604           if (!loopElements.contains(next)) {
605             return true;
606           }
607
608           if (loopInv.domtree.idom(next).equals(fn)) {
609             // add next node only if current node is immediate dominator of the
610             // next node
611             tovisit.add(next);
612           }
613         }
614       }
615
616     }
617
618     return false;
619
620   }
621 }