IRCE: use SCEVs instead of llvm::Value's for intermediate
[oota-llvm.git] / lib / Transforms / Scalar / InductiveRangeCheckElimination.cpp
index 809e9ee99c12ecffca5cddd4e4a9c4c8aea90cf9..86a00b1590e5144de8424f96a3380324eb5370eb 100644 (file)
@@ -143,17 +143,17 @@ public:
   /// R.getEnd() sle R.getBegin(), then R denotes the empty range.
 
   class Range {
-    Value *Begin;
-    Value *End;
+    const SCEV *Begin;
+    const SCEV *End;
 
   public:
-    Range(Value *Begin, Value *End) : Begin(Begin), End(End) {
+    Range(const SCEV *Begin, const SCEV *End) : Begin(Begin), End(End) {
       assert(Begin->getType() == End->getType() && "ill-typed range!");
     }
 
     Type *getType() const { return Begin->getType(); }
-    Value *getBegin() const { return Begin; }
-    Value *getEnd() const { return End; }
+    const SCEV *getBegin() const { return Begin; }
+    const SCEV *getEnd() const { return End; }
   };
 
   typedef SpecificBumpPtrAllocator<InductiveRangeCheck> AllocatorTy;
@@ -394,21 +394,6 @@ InductiveRangeCheck::create(InductiveRangeCheck::AllocatorTy &A, BranchInst *BI,
   return IRC;
 }
 
-static Value *MaybeSimplify(Value *V) {
-  if (Instruction *I = dyn_cast<Instruction>(V))
-    if (Value *Simplified = SimplifyInstruction(I))
-      return Simplified;
-  return V;
-}
-
-static Value *ConstructSMinOf(Value *X, Value *Y, IRBuilder<> &B) {
-  return MaybeSimplify(B.CreateSelect(B.CreateICmpSLT(X, Y), X, Y));
-}
-
-static Value *ConstructSMaxOf(Value *X, Value *Y, IRBuilder<> &B) {
-  return MaybeSimplify(B.CreateSelect(B.CreateICmpSGT(X, Y), X, Y));
-}
-
 namespace {
 
 /// This class is used to constrain loops to run within a given iteration space.
@@ -738,35 +723,36 @@ LoopConstrainer::calculateSubRanges(Value *&HeaderCountOut) const {
   SCEVExpander Expander(SE, "irce");
   Instruction *InsertPt = OriginalPreheader->getTerminator();
 
-  Value *LatchCountV =
-      MaybeSimplify(Expander.expandCodeFor(LatchTakenCount, Ty, InsertPt));
-
-  IRBuilder<> B(InsertPt);
-
   LoopConstrainer::SubRanges Result;
 
   // I think we can be more aggressive here and make this nuw / nsw if the
   // addition that feeds into the icmp for the latch's terminating branch is nuw
   // / nsw.  In any case, a wrapping 2's complement addition is safe.
   ConstantInt *One = ConstantInt::get(Ty, 1);
-  HeaderCountOut = MaybeSimplify(B.CreateAdd(LatchCountV, One, "header.count"));
+  const SCEV *HeaderCountSCEV = SE.getAddExpr(LatchTakenCount, SE.getSCEV(One));
+  HeaderCountOut = Expander.expandCodeFor(HeaderCountSCEV, Ty, InsertPt);
 
-  const SCEV *RangeBegin = SE.getSCEV(Range.getBegin());
-  const SCEV *RangeEnd = SE.getSCEV(Range.getEnd());
-  const SCEV *HeaderCountSCEV = SE.getSCEV(HeaderCountOut);
   const SCEV *Zero = SE.getConstant(Ty, 0);
 
   // In some cases we can prove that we don't need a pre or post loop
 
   bool ProvablyNoPreloop =
-      SE.isKnownPredicate(ICmpInst::ICMP_SLE, RangeBegin, Zero);
-  if (!ProvablyNoPreloop)
-    Result.ExitPreLoopAt = ConstructSMinOf(HeaderCountOut, Range.getBegin(), B);
+    SE.isKnownPredicate(ICmpInst::ICMP_SLE, Range.getBegin(), Zero);
+  if (!ProvablyNoPreloop) {
+    const SCEV *ExitPreLoopAtSCEV =
+      SE.getSMinExpr(HeaderCountSCEV, Range.getBegin());
+    Result.ExitPreLoopAt =
+      Expander.expandCodeFor(ExitPreLoopAtSCEV, Ty, InsertPt);
+  }
 
   bool ProvablyNoPostLoop =
-      SE.isKnownPredicate(ICmpInst::ICMP_SLE, HeaderCountSCEV, RangeEnd);
-  if (!ProvablyNoPostLoop)
-    Result.ExitMainLoopAt = ConstructSMinOf(HeaderCountOut, Range.getEnd(), B);
+    SE.isKnownPredicate(ICmpInst::ICMP_SLE, HeaderCountSCEV, Range.getEnd());
+  if (!ProvablyNoPostLoop) {
+    const SCEV *ExitMainLoopAtSCEV =
+      SE.getSMinExpr(HeaderCountSCEV, Range.getEnd());
+    Result.ExitMainLoopAt =
+      Expander.expandCodeFor(ExitMainLoopAtSCEV, Ty, InsertPt);
+  }
 
   return Result;
 }
@@ -1131,18 +1117,15 @@ InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE,
     return None;
   }
 
-  Value *OffsetV = SCEVExpander(SE, "safe.itr.space").expandCodeFor(
-      getOffset(), getOffset()->getType(), B.GetInsertPoint());
-  OffsetV = MaybeSimplify(OffsetV);
-
-  Value *Begin = MaybeSimplify(B.CreateNeg(OffsetV));
-  Value *End = MaybeSimplify(B.CreateSub(getLength(), OffsetV));
+  const SCEV *Begin = SE.getNegativeSCEV(getOffset());
+  const SCEV *End = SE.getMinusSCEV(SE.getSCEV(getLength()), getOffset());
 
   return InductiveRangeCheck::Range(Begin, End);
 }
 
 static Optional<InductiveRangeCheck::Range>
-IntersectRange(const Optional<InductiveRangeCheck::Range> &R1,
+IntersectRange(ScalarEvolution &SE,
+               const Optional<InductiveRangeCheck::Range> &R1,
                const InductiveRangeCheck::Range &R2, IRBuilder<> &B) {
   if (!R1.hasValue())
     return R2;
@@ -1153,9 +1136,10 @@ IntersectRange(const Optional<InductiveRangeCheck::Range> &R1,
   if (R1Value.getType() != R2.getType())
     return None;
 
-  Value *NewMin = ConstructSMaxOf(R1Value.getBegin(), R2.getBegin(), B);
-  Value *NewMax = ConstructSMinOf(R1Value.getEnd(), R2.getEnd(), B);
-  return InductiveRangeCheck::Range(NewMin, NewMax);
+  const SCEV *NewBegin = SE.getSMaxExpr(R1Value.getBegin(), R2.getBegin());
+  const SCEV *NewEnd = SE.getSMinExpr(R1Value.getEnd(), R2.getEnd());
+
+  return InductiveRangeCheck::Range(NewBegin, NewEnd);
 }
 
 bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) {
@@ -1202,7 +1186,7 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) {
     auto Result = IRC->computeSafeIterationSpace(SE, B);
     if (Result.hasValue()) {
       auto MaybeSafeIterRange =
-        IntersectRange(SafeIterRange, Result.getValue(), B);
+        IntersectRange(SE, SafeIterRange, Result.getValue(), B);
       if (MaybeSafeIterRange.hasValue()) {
         RangeChecksToEliminate.push_back(IRC);
         SafeIterRange = MaybeSafeIterRange.getValue();