Extend ScalarEvolution's multiple-exit support to compute exact
[oota-llvm.git] / lib / Analysis / ScalarEvolutionExpander.cpp
index c5591d702730774b142a265e18bbe0719819c10e..6d7abc02ebea2d74a369b49871cdc5833eb4736f 100644 (file)
@@ -51,21 +51,26 @@ Value *SCEVExpander::InsertCastOfTo(Instruction::CastOps opcode, Value *V,
   if (Argument *A = dyn_cast<Argument>(V)) {
     // Check to see if there is already a cast!
     for (Value::use_iterator UI = A->use_begin(), E = A->use_end();
-         UI != E; ++UI) {
+         UI != E; ++UI)
       if ((*UI)->getType() == Ty)
         if (CastInst *CI = dyn_cast<CastInst>(cast<Instruction>(*UI)))
           if (CI->getOpcode() == opcode) {
             // If the cast isn't the first instruction of the function, move it.
-            if (BasicBlock::iterator(CI) != 
+            if (BasicBlock::iterator(CI) !=
                 A->getParent()->getEntryBlock().begin()) {
-              // If the CastInst is the insert point, change the insert point.
-              if (CI == InsertPt) ++InsertPt;
-              // Splice the cast at the beginning of the entry block.
-              CI->moveBefore(A->getParent()->getEntryBlock().begin());
+              // Recreate the cast at the beginning of the entry block.
+              // The old cast is left in place in case it is being used
+              // as an insert point.
+              Instruction *NewCI =
+                CastInst::Create(opcode, V, Ty, "",
+                                 A->getParent()->getEntryBlock().begin());
+              NewCI->takeName(CI);
+              CI->replaceAllUsesWith(NewCI);
+              return NewCI;
             }
             return CI;
           }
-    }
+
     Instruction *I = CastInst::Create(opcode, V, Ty, V->getName(),
                                       A->getParent()->getEntryBlock().begin());
     InsertedValues.insert(I);
@@ -85,10 +90,13 @@ Value *SCEVExpander::InsertCastOfTo(Instruction::CastOps opcode, Value *V,
             It = cast<InvokeInst>(I)->getNormalDest()->begin();
           while (isa<PHINode>(It)) ++It;
           if (It != BasicBlock::iterator(CI)) {
-            // If the CastInst is the insert point, change the insert point.
-            if (CI == InsertPt) ++InsertPt;
-            // Splice the cast immediately after the operand in question.
-            CI->moveBefore(It);
+            // Recreate the cast at the beginning of the entry block.
+            // The old cast is left in place in case it is being used
+            // as an insert point.
+            Instruction *NewCI = CastInst::Create(opcode, V, Ty, "", It);
+            NewCI->takeName(CI);
+            CI->replaceAllUsesWith(NewCI);
+            return NewCI;
           }
           return CI;
         }
@@ -497,8 +505,9 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
       }
     }
 
-    Value *RestV = expand(Rest);
-    return expand(SE.getAddExpr(S->getStart(), SE.getUnknown(RestV)));
+    // Just do a normal add. Pre-expand the operands to suppress folding.
+    return expand(SE.getAddExpr(SE.getUnknown(expand(S->getStart())),
+                                SE.getUnknown(expand(Rest))));
   }
 
   // {0,+,1} --> Insert a canonical induction variable into the loop!
@@ -546,36 +555,13 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
              getOrInsertCanonicalInductionVariable(L, Ty);
 
   // If this is a simple linear addrec, emit it now as a special case.
-  if (S->isAffine()) {   // {0,+,F} --> i*F
-    Value *F = expandCodeFor(S->getOperand(1), Ty);
-
-    // If the insert point is directly inside of the loop, emit the multiply at
-    // the insert point.  Otherwise, L is a loop that is a parent of the insert
-    // point loop.  If we can, move the multiply to the outer most loop that it
-    // is safe to be in.
-    BasicBlock::iterator MulInsertPt = getInsertionPoint();
-    Loop *InsertPtLoop = SE.LI->getLoopFor(MulInsertPt->getParent());
-    if (InsertPtLoop != L && InsertPtLoop &&
-        L->contains(InsertPtLoop->getHeader())) {
-      do {
-        // If we cannot hoist the multiply out of this loop, don't.
-        if (!InsertPtLoop->isLoopInvariant(F)) break;
-
-        BasicBlock *InsertPtLoopPH = InsertPtLoop->getLoopPreheader();
-
-        // If this loop hasn't got a preheader, we aren't able to hoist the
-        // multiply.
-        if (!InsertPtLoopPH)
-          break;
-
-        // Otherwise, move the insert point to the preheader.
-        MulInsertPt = InsertPtLoopPH->getTerminator();
-        InsertPtLoop = InsertPtLoop->getParentLoop();
-      } while (InsertPtLoop != L);
-    }
-    
-    return InsertBinop(Instruction::Mul, I, F, MulInsertPt);
-  }
+  if (S->isAffine())    // {0,+,F} --> i*F
+    return
+      expand(SE.getTruncateOrNoop(
+        SE.getMulExpr(SE.getUnknown(I),
+                      SE.getNoopOrAnyExtend(S->getOperand(1),
+                                            I->getType())),
+        Ty));
 
   // If this is a chain of recurrences, turn it into a closed form, using the
   // folders, then expandCodeFor the closed form.  This allows the folders to
@@ -671,8 +657,31 @@ Value *SCEVExpander::expand(const SCEV *S) {
     InsertedExpressions.find(S);
   if (I != InsertedExpressions.end())
     return I->second;
-  
+
+  // Compute an insertion point for this SCEV object. Hoist the instructions
+  // as far out in the loop nest as possible.
+  BasicBlock::iterator InsertPt = getInsertionPoint();
+  BasicBlock::iterator SaveInsertPt = InsertPt;
+  for (Loop *L = SE.LI->getLoopFor(InsertPt->getParent()); ;
+       L = L->getParentLoop())
+    if (S->isLoopInvariant(L)) {
+      if (!L) break;
+      if (BasicBlock *Preheader = L->getLoopPreheader())
+        InsertPt = Preheader->getTerminator();
+    } else {
+      // If the SCEV is computable at this level, insert it into the header
+      // after the PHIs (and after any other instructions that we've inserted
+      // there) so that it is guaranteed to dominate any user inside the loop.
+      if (L && S->hasComputableLoopEvolution(L))
+        InsertPt = L->getHeader()->getFirstNonPHI();
+      while (isInsertedInstruction(InsertPt)) ++InsertPt;
+      break;
+    }
+  setInsertionPoint(InsertPt);
+
   Value *V = visit(S);
+
+  setInsertionPoint(SaveInsertPt);
   InsertedExpressions[S] = V;
   return V;
 }
@@ -686,6 +695,9 @@ SCEVExpander::getOrInsertCanonicalInductionVariable(const Loop *L,
                                                     const Type *Ty) {
   assert(Ty->isInteger() && "Can only insert integer induction variables!");
   const SCEV* H = SE.getAddRecExpr(SE.getIntegerSCEV(0, Ty),
-                                  SE.getIntegerSCEV(1, Ty), L);
-  return expand(H);
+                                   SE.getIntegerSCEV(1, Ty), L);
+  BasicBlock::iterator SaveInsertPt = getInsertionPoint();
+  Value *V = expandCodeFor(H, 0, L->getHeader()->begin());
+  setInsertionPoint(SaveInsertPt);
+  return V;
 }