3cfbb3d0f1c1c202bb3af75b21aa286c21523526
[IRC.git] / Robust / src / Analysis / Loops / LoopTerminate.java
1 package Analysis.Loops;
2
3 import java.util.HashMap;
4 import java.util.HashSet;
5 import java.util.Iterator;
6 import java.util.Set;
7
8 import IR.Operation;
9 import IR.Flat.FKind;
10 import IR.Flat.FlatCondBranch;
11 import IR.Flat.FlatMethod;
12 import IR.Flat.FlatNode;
13 import IR.Flat.FlatOpNode;
14 import IR.Flat.TempDescriptor;
15
16 public class LoopTerminate {
17
18   private FlatMethod fm;
19   private LoopInvariant loopInv;
20   private Set<TempDescriptor> inductionSet;
21   // mapping from Induction Variable TempDescriptor to Flat Node that defines
22   // it
23   private HashMap<TempDescriptor, FlatNode> inductionVar2DefNode;
24
25   // mapping from Derived Induction Variable TempDescriptor to its root
26   // induction variable TempDescriptor
27   private HashMap<TempDescriptor, TempDescriptor> derivedVar2basicInduction;
28
29   Set<FlatNode> computed;
30
31   /**
32    * Constructor for Loop Termination Analysis
33    */
34   public LoopTerminate() {
35     this.inductionSet = new HashSet<TempDescriptor>();
36     this.inductionVar2DefNode = new HashMap<TempDescriptor, FlatNode>();
37     this.derivedVar2basicInduction = new HashMap<TempDescriptor, TempDescriptor>();
38     this.computed = new HashSet<FlatNode>();
39   }
40
41   /**
42    * starts loop termination analysis
43    * 
44    * @param fm
45    *          FlatMethod for termination analysis
46    * @param loopInv
47    *          LoopInvariants for given method
48    */
49   public void terminateAnalysis(FlatMethod fm, LoopInvariant loopInv) {
50     this.fm = fm;
51     this.loopInv = loopInv;
52     Loops loopFinder = loopInv.root;
53     recurse(fm, loopFinder);
54   }
55
56   /**
57    * iterate over the current level of loops and spawn analysis for its child
58    * 
59    * @param fm
60    *          FlatMethod for loop analysis
61    * @param parent
62    *          the current level of loop
63    */
64   private void recurse(FlatMethod fm, Loops parent) {
65     for (Iterator lpit = parent.nestedLoops().iterator(); lpit.hasNext();) {
66       Loops child = (Loops) lpit.next();
67       processLoop(fm, child);
68       recurse(fm, child);
69     }
70   }
71
72   /**
73    * initialize internal data structure
74    */
75   private void init() {
76     inductionSet.clear();
77     inductionVar2DefNode.clear();
78     derivedVar2basicInduction.clear();
79   }
80
81   /**
82    * analysis loop for termination property
83    * 
84    * @param fm
85    *          FlatMethod that contains loop l
86    * @param l
87    *          analysis target loop l
88    */
89   private void processLoop(FlatMethod fm, Loops l) {
90
91     Set loopElements = l.loopIncElements(); // loop body elements
92     Set loopEntrances = l.loopEntrances(); // loop entrance
93     assert loopEntrances.size() == 1;
94     FlatNode loopEntrance = (FlatNode) loopEntrances.iterator().next();
95
96     init();
97     detectBasicInductionVar(loopElements);
98     detectDerivedInductionVar(loopElements);
99     checkConditionBranch(loopEntrance, loopElements);
100
101   }
102
103   /**
104    * check if condition branch node satisfies loop condition
105    * 
106    * @param loopEntrance
107    *          loop entrance flat node
108    * @param loopElements
109    *          elements of current loop and all nested loop
110    */
111   private void checkConditionBranch(FlatNode loopEntrance, Set loopElements) {
112     // In the loop, every guard condition of the loop must be composed by
113     // induction & invariants
114     // every guard condition of the if-statement that leads it to the exit must
115     // be composed by induction&invariants
116
117     Set<FlatNode> tovisit = new HashSet<FlatNode>();
118     Set<FlatNode> visited = new HashSet<FlatNode>();
119     tovisit.add(loopEntrance);
120
121     int numMustTerminateGuardCondtion = 0;
122     int numLoop = 0;
123     while (!tovisit.isEmpty()) {
124       FlatNode fnvisit = tovisit.iterator().next();
125       tovisit.remove(fnvisit);
126       visited.add(fnvisit);
127
128       if (fnvisit.kind() == FKind.FlatCondBranch) {
129         FlatCondBranch fcb = (FlatCondBranch) fnvisit;
130
131         if (fcb.isLoopBranch()) {
132           numLoop++;
133         }
134
135         if (fcb.isLoopBranch() || hasLoopExitNode(fcb, true, loopEntrance, loopElements)) {
136           // current FlatCondBranch can introduce loop exits
137           // in this case, guard condition of it should be composed only by loop
138           // invariants and induction variables
139           Set<FlatNode> condSet = getDefinitionInLoop(fnvisit, fcb.getTest(), loopElements);
140           assert condSet.size() == 1;
141           FlatNode condFn = condSet.iterator().next();
142           // assume that guard condition node is always a conditional inequality
143           if (condFn instanceof FlatOpNode) {
144             FlatOpNode condOp = (FlatOpNode) condFn;
145             // check if guard condition is composed only with induction
146             // variables
147             if (checkConditionNode(condOp, fcb.isLoopBranch(), loopElements)) {
148               numMustTerminateGuardCondtion++;
149             } else {
150               if (!fcb.isLoopBranch()) {
151                 // if it is if-condition and it leads us to loop exit,
152                 // corresponding guard condition should be composed by induction
153                 // and invariants
154                 throw new Error("Loop may never terminate at "
155                     + fm.getMethod().getClassDesc().getSourceFileName() + "::"
156                     + loopEntrance.numLine);
157               }
158             }
159           }
160         }
161       }
162
163       for (int i = 0; i < fnvisit.numNext(); i++) {
164         FlatNode next = fnvisit.getNext(i);
165         if (loopElements.contains(next) && !visited.contains(next)) {
166           tovisit.add(next);
167         }
168       }
169
170     }
171
172     // # of must-terminate loop condition must be equal to or larger than # of
173     // loop
174     if (numMustTerminateGuardCondtion < numLoop) {
175       throw new Error("Loop may never terminate at "
176           + fm.getMethod().getClassDesc().getSourceFileName() + "::" + loopEntrance.numLine);
177     }
178
179   }
180
181   /**
182    * detect derived induction variable
183    * 
184    * @param loopElements
185    *          elements of current loop and all nested loop
186    */
187   private void detectDerivedInductionVar(Set loopElements) {
188     // 2) detect derived induction variables
189     // variable k is a derived induction variable if
190     // 1) there is only one definition of k within the loop, of the form k=j*c
191     // or k=j+d where j is induction variable, c, d are loop-invariant
192     // 2) and if j is a derived induction variable in the family of i, then:
193     // (a) the only definition of j that reaches k is the one in the loop
194     // (b) and there is no definition of i on any path between the definition of
195     // j and the definition of k
196
197     boolean changed = true;
198     Set<TempDescriptor> basicInductionSet = new HashSet<TempDescriptor>();
199     basicInductionSet.addAll(inductionSet);
200
201     while (changed) {
202       changed = false;
203       nextfn: for (Iterator elit = loopElements.iterator(); elit.hasNext();) {
204         FlatNode fn = (FlatNode) elit.next();
205         if (!computed.contains(fn)) {
206           if (fn.kind() == FKind.FlatOpNode) {
207             FlatOpNode fon = (FlatOpNode) fn;
208             int op = fon.getOp().getOp();
209             if (op == Operation.ADD || op == Operation.MULT) {
210               TempDescriptor tdLeft = fon.getLeft();
211               TempDescriptor tdRight = fon.getRight();
212               TempDescriptor tdDest = fon.getDest();
213
214               boolean isLeftLoopInvariant = isLoopInvariantVar(fn, tdLeft, loopElements);
215               boolean isRightLoopInvariant = isLoopInvariantVar(fn, tdRight, loopElements);
216
217               if (isLeftLoopInvariant ^ isRightLoopInvariant) {
218                 TempDescriptor inductionOp;
219                 if (isLeftLoopInvariant) {
220                   inductionOp = tdRight;
221                 } else {
222                   inductionOp = tdLeft;
223                 }
224
225                 if (inductionSet.contains(inductionOp)) {
226                   // find new derived one k
227
228                   if (!basicInductionSet.contains(inductionOp)) {
229                     // in this case, induction variable 'j' is derived from
230                     // basic induction var
231
232                     // check if only definition of j that reaches k is the one
233                     // in the loop
234
235                     Set<FlatNode> defSet = getDefinitionInLoop(fn, inductionOp, loopElements);
236                     if (defSet.size() == 1) {
237                       // check if there is no def of i on any path bet' def of j
238                       // and def of k
239
240                       TempDescriptor originInduc = derivedVar2basicInduction.get(inductionOp);
241                       FlatNode defI = inductionVar2DefNode.get(originInduc);
242                       FlatNode defJ = inductionVar2DefNode.get(inductionOp);
243                       FlatNode defk = fn;
244
245                       if (!checkPath(defI, defJ, defk)) {
246                         continue nextfn;
247                       }
248
249                     }
250                   }
251                   // add new induction var
252
253                   // when tdDest has the form of srctmp(tdDest) = inductionOp +
254                   // loopInvariant
255                   // want to have the definition of srctmp
256                   Set<FlatNode> setUseNode = loopInv.usedef.useMap(fn, tdDest);
257                   assert setUseNode.size() == 1;
258                   assert setUseNode.iterator().next().writesTemps().length == 1;
259
260                   FlatNode srcDefNode = setUseNode.iterator().next();
261                   if (srcDefNode instanceof FlatOpNode) {
262                     if (((FlatOpNode) srcDefNode).getOp().getOp() == Operation.ASSIGN) {
263                       TempDescriptor derivedIndVar = setUseNode.iterator().next().writesTemps()[0];
264                       FlatNode defNode = setUseNode.iterator().next();
265
266                       computed.add(fn);
267                       computed.add(defNode);
268                       inductionSet.add(derivedIndVar);
269                       inductionVar2DefNode.put(derivedIndVar, defNode);
270                       derivedVar2basicInduction.put(derivedIndVar, inductionOp);
271                       changed = true;
272                     }
273                   }
274
275                 }
276
277               }
278
279             }
280
281           }
282         }
283
284       }
285     }
286
287   }
288
289   /**
290    * detect basic induction variable
291    * 
292    * @param loopElements
293    *          elements of current loop and all nested loop
294    */
295   private void detectBasicInductionVar(Set loopElements) {
296     // 1) find out basic induction variable
297     // variable i is a basic induction variable in loop if the only definitions
298     // of i within L are of the form i=i+c where c is loop invariant
299     for (Iterator elit = loopElements.iterator(); elit.hasNext();) {
300       FlatNode fn = (FlatNode) elit.next();
301       if (fn.kind() == FKind.FlatOpNode) {
302         FlatOpNode fon = (FlatOpNode) fn;
303         int op = fon.getOp().getOp();
304         if (op == Operation.ADD) {
305           TempDescriptor tdLeft = fon.getLeft();
306           TempDescriptor tdRight = fon.getRight();
307
308           boolean isLeftLoopInvariant = isLoopInvariantVar(fn, tdLeft, loopElements);
309           boolean isRightLoopInvariant = isLoopInvariantVar(fn, tdRight, loopElements);
310
311           if (isLeftLoopInvariant ^ isRightLoopInvariant) {
312
313             TempDescriptor candidateTemp;
314
315             if (isLeftLoopInvariant) {
316               candidateTemp = tdRight;
317             } else {
318               candidateTemp = tdLeft;
319             }
320
321             Set<FlatNode> defSetOfLoop = getDefinitionInLoop(fn, candidateTemp, loopElements);
322             // basic induction variable must have only one definition within the
323             // loop
324             if (defSetOfLoop.size() == 1) {
325               // find out definition of induction var, form of Flat
326               // Assign:inductionVar = candidateTemp
327               FlatNode indNode = defSetOfLoop.iterator().next();
328               assert indNode.readsTemps().length == 1;
329               TempDescriptor readTemp = indNode.readsTemps()[0];
330               if (readTemp.equals(fon.getDest())) {
331                 inductionVar2DefNode.put(candidateTemp, defSetOfLoop.iterator().next());
332                 inductionVar2DefNode.put(readTemp, defSetOfLoop.iterator().next());
333                 inductionSet.add(fon.getDest());
334                 inductionSet.add(candidateTemp);
335                 computed.add(fn);
336               }
337
338             }
339
340           }
341
342         }
343       }
344     }
345
346   }
347
348   /**
349    * check whether there is no definition node 'def' on any path between 'start'
350    * node and 'end' node
351    * 
352    * @param def
353    * @param start
354    * @param end
355    * @return true if there is no def in-bet start and end
356    */
357   private boolean checkPath(FlatNode def, FlatNode start, FlatNode end) {
358     Set<FlatNode> endSet = new HashSet<FlatNode>();
359     endSet.add(end);
360     return !(start.getReachableSet(endSet)).contains(def);
361   }
362
363   /**
364    * check condition node satisfies termination condition
365    * 
366    * @param fon
367    *          condition node FlatOpNode
368    * @param isLoopCondition
369    *          true if condition is loop condition
370    * @param loopElements
371    *          elements of current loop and all nested loop
372    * @return true if it satisfies termination condition
373    */
374   private boolean checkConditionNode(FlatOpNode fon, boolean isLoopCondition, Set loopElements) {
375     // check flatOpNode that computes loop guard condition
376     // currently we assume that induction variable is always getting bigger
377     // and guard variable is constant
378     // so need to check (1) one of operand should be induction variable
379     // (2) another operand should be constant or loop invariant
380
381     TempDescriptor induction = null;
382     TempDescriptor guard = null;
383
384     int op = fon.getOp().getOp();
385     if (op == Operation.LT || op == Operation.LTE) {
386       if (isLoopCondition) {
387         // loop condition is inductionVar <= loop invariant
388         induction = fon.getLeft();
389         guard = fon.getRight();
390       } else {
391         // if-statement condition is loop invariant <= inductionVar since
392         // inductionVar is getting biggier each iteration
393         induction = fon.getRight();
394         guard = fon.getLeft();
395       }
396     } else if (op == Operation.GT || op == Operation.GTE) {
397       if (isLoopCondition) {
398         // condition is loop invariant >= inductionVar
399         induction = fon.getRight();
400         guard = fon.getLeft();
401       } else {
402         // if-statement condition is loop inductionVar >= invariant
403         induction = fon.getLeft();
404         guard = fon.getRight();
405       }
406     } else {
407       return false;
408     }
409
410     // here, verify that guard operand is an induction variable
411     if (!hasInductionVar(fon, induction, loopElements)) {
412       return false;
413     }
414
415     if (guard != null) {
416       Set guardDefSet = getDefinitionInLoop(fon, guard, loopElements);
417       for (Iterator iterator = guardDefSet.iterator(); iterator.hasNext();) {
418         FlatNode guardDef = (FlatNode) iterator.next();
419         if (!(guardDef.kind() == FKind.FlatLiteralNode) && !loopInv.hoisted.contains(guardDef)) {
420           return false;
421         }
422       }
423     }
424
425     return true;
426   }
427
428   /**
429    * check if TempDescriptor td has at least one induction variable and is
430    * composed only by induction vars +loop invariants
431    * 
432    * @param fn
433    *          FlatNode that contains TempDescriptor 'td'
434    * @param td
435    *          TempDescriptor representing target variable
436    * @param loopElements
437    *          elements of current loop and all nested loop
438    * @return true if 'td' is induction variable
439    */
440   private boolean hasInductionVar(FlatNode fn, TempDescriptor td, Set loopElements) {
441
442     if (inductionSet.contains(td)) {
443       return true;
444     } else {
445       // check if td is composed by induction variables or loop invariants
446       Set<FlatNode> defSet = getDefinitionInLoop(fn, td, loopElements);
447       for (Iterator iterator = defSet.iterator(); iterator.hasNext();) {
448         FlatNode defNode = (FlatNode) iterator.next();
449
450         int inductionVarCount = 0;
451         TempDescriptor[] readTemps = defNode.readsTemps();
452         for (int i = 0; i < readTemps.length; i++) {
453           if (!hasInductionVar(defNode, readTemps[i], loopElements)) {
454             if (!isLoopInvariantVar(defNode, readTemps[i], loopElements)) {
455               return false;
456             }
457           } else {
458             inductionVarCount++;
459           }
460         }
461
462         // check definition of td has at least one induction var
463         if (inductionVarCount > 0) {
464           return true;
465         }
466
467       }
468
469       return false;
470     }
471
472   }
473
474   /**
475    * check if TempDescriptor td is loop invariant variable or constant value wrt
476    * the current loop
477    * 
478    * @param fn
479    *          FlatNode that contains TempDescriptor 'td'
480    * @param td
481    *          TempDescriptor representing target variable
482    * @param loopElements
483    *          elements of current loop and all nested loop
484    * @return true if 'td' is loop invariant variable
485    */
486   private boolean isLoopInvariantVar(FlatNode fn, TempDescriptor td, Set loopElements) {
487
488     Set<FlatNode> defset = loopInv.usedef.defMap(fn, td);
489
490     Set<FlatNode> defSetOfLoop = new HashSet<FlatNode>();
491     for (Iterator<FlatNode> defit = defset.iterator(); defit.hasNext();) {
492       FlatNode def = defit.next();
493       if (loopElements.contains(def)) {
494         defSetOfLoop.add(def);
495       }
496     }
497
498     if (defSetOfLoop.size() == 0) {
499       // all definition comes from outside the loop
500       // so it is loop invariant
501       return true;
502     } else if (defSetOfLoop.size() == 1) {
503       // check if def is 1) constant node or 2) loop invariant
504       FlatNode defFlatNode = defSetOfLoop.iterator().next();
505       if (defFlatNode.kind() == FKind.FlatLiteralNode || loopInv.hoisted.contains(defFlatNode)) {
506         return true;
507       }
508     }
509
510     return false;
511
512   }
513
514   /**
515    * compute the set of definitions of variable 'td' inside of the loop
516    * 
517    * @param fn
518    *          FlatNode that uses 'td'
519    * @param td
520    *          target node that we want to have the set of definitions
521    * @param loopElements
522    *          elements of current loop and all nested loop
523    * @return the set of definition nodes for 'td' in the current loop
524    */
525   private Set<FlatNode> getDefinitionInLoop(FlatNode fn, TempDescriptor td, Set loopElements) {
526
527     Set<FlatNode> defSetOfLoop = new HashSet<FlatNode>();
528
529     Set defSet = loopInv.usedef.defMap(fn, td);
530     for (Iterator iterator = defSet.iterator(); iterator.hasNext();) {
531       FlatNode defFlatNode = (FlatNode) iterator.next();
532       if (loopElements.contains(defFlatNode)) {
533         defSetOfLoop.add(defFlatNode);
534       }
535     }
536
537     return defSetOfLoop;
538
539   }
540
541   /**
542    * check whether FlatCondBranch introduces loop exit
543    * 
544    * @param fcb
545    *          target node
546    * @param fromTrueBlock
547    *          specify which block is possible to have loop exit
548    * @param loopHeader
549    *          loop header of current loop
550    * @param loopElements
551    *          elements of current loop and all nested loop
552    * @return true if input 'fcb' intrroduces loop exit
553    */
554   private boolean hasLoopExitNode(FlatCondBranch fcb, boolean fromTrueBlock, FlatNode loopHeader,
555       Set loopElements) {
556     // return true if fcb possibly introduces loop exit
557
558     FlatNode next;
559     if (fromTrueBlock) {
560       next = fcb.getNext(0);
561     } else {
562       next = fcb.getNext(1);
563     }
564
565     return hasLoopExitNode(loopHeader, next, loopElements);
566
567   }
568
569   /**
570    * check whether start node reaches loop exit
571    * 
572    * @param loopHeader
573    * @param start
574    * @param loopElements
575    * @return true if a path exist from start to loop exit
576    */
577   private boolean hasLoopExitNode(FlatNode loopHeader, FlatNode start, Set loopElements) {
578
579     Set<FlatNode> tovisit = new HashSet<FlatNode>();
580     Set<FlatNode> visited = new HashSet<FlatNode>();
581     tovisit.add(start);
582
583     while (!tovisit.isEmpty()) {
584
585       FlatNode fn = tovisit.iterator().next();
586       tovisit.remove(fn);
587       visited.add(fn);
588
589       for (int i = 0; i < fn.numNext(); i++) {
590         FlatNode next = fn.getNext(i);
591         if (!visited.contains(next)) {
592           // check that if-body statment introduces loop exit.
593           if (!loopElements.contains(next)) {
594             return true;
595           }
596
597           if (loopInv.domtree.idom(next).equals(fn)) {
598             // add next node only if current node is immediate dominator of the
599             // next node
600             tovisit.add(next);
601           }
602         }
603       }
604
605     }
606
607     return false;
608
609   }
610 }