Revert r230921, "Revert some changes that were made to fix PR20680.", for now.
[oota-llvm.git] / lib / Transforms / Scalar / IndVarSimplify.cpp
index 462c2b6f90d636cf5904855a1d951c17d1089192..f99ebbc453f55bc2a472c46605d778c19259d876 100644 (file)
@@ -1705,15 +1705,51 @@ LinearFunctionTestReplace(Loop *L,
   // compare against the post-incremented value, otherwise we must compare
   // against the preincremented value.
   if (L->getExitingBlock() == L->getLoopLatch()) {
-    // Add one to the "backedge-taken" count to get the trip count.
-    // This addition may overflow, which is valid as long as the comparison is
-    // truncated to BackedgeTakenCount->getType().
-    IVCount = SE->getAddExpr(BackedgeTakenCount,
-                             SE->getConstant(BackedgeTakenCount->getType(), 1));
     // The BackedgeTaken expression contains the number of times that the
     // backedge branches to the loop header.  This is one less than the
     // number of times the loop executes, so use the incremented indvar.
-    CmpIndVar = IndVar->getIncomingValueForBlock(L->getExitingBlock());
+    llvm::Value *IncrementedIndvar =
+        IndVar->getIncomingValueForBlock(L->getExitingBlock());
+    const auto *IncrementedIndvarSCEV =
+        cast<SCEVAddRecExpr>(SE->getSCEV(IncrementedIndvar));
+    // It is unsafe to use the incremented indvar if it has a wrapping flag, we
+    // don't want to compare against a poison value.  Check the SCEV that
+    // corresponds to the incremented indvar, the SCEVExpander will only insert
+    // flags in the IR if the SCEV originally had wrapping flags.
+    // FIXME: In theory, SCEV could drop flags even though they exist in IR.
+    // A more robust solution would involve getting a new expression for
+    // CmpIndVar by applying non-NSW/NUW AddExprs.
+    auto WrappingFlags =
+        ScalarEvolution::setFlags(SCEV::FlagNUW, SCEV::FlagNSW);
+    const SCEV *IVInit = IncrementedIndvarSCEV->getStart();
+    if (SE->getTypeSizeInBits(IVInit->getType()) >
+        SE->getTypeSizeInBits(IVCount->getType()))
+      IVInit = SE->getTruncateExpr(IVInit, IVCount->getType());
+    unsigned BitWidth = SE->getTypeSizeInBits(IVCount->getType());
+    Type *WideTy = IntegerType::get(SE->getContext(), BitWidth + 1);
+    // Check if InitIV + BECount+1 requires sign/zero extension.
+    // If not, clear the corresponding flag from WrappingFlags because it is not
+    // necessary for those flags in the IncrementedIndvarSCEV expression.
+    if (SE->getSignExtendExpr(SE->getAddExpr(IVInit, BackedgeTakenCount),
+                              WideTy) ==
+        SE->getAddExpr(SE->getSignExtendExpr(IVInit, WideTy),
+                       SE->getSignExtendExpr(BackedgeTakenCount, WideTy)))
+      WrappingFlags = ScalarEvolution::clearFlags(WrappingFlags, SCEV::FlagNSW);
+    if (SE->getZeroExtendExpr(SE->getAddExpr(IVInit, BackedgeTakenCount),
+                              WideTy) ==
+        SE->getAddExpr(SE->getZeroExtendExpr(IVInit, WideTy),
+                       SE->getZeroExtendExpr(BackedgeTakenCount, WideTy)))
+      WrappingFlags = ScalarEvolution::clearFlags(WrappingFlags, SCEV::FlagNUW);
+    if (!ScalarEvolution::maskFlags(IncrementedIndvarSCEV->getNoWrapFlags(),
+                                    WrappingFlags)) {
+      // Add one to the "backedge-taken" count to get the trip count.
+      // This addition may overflow, which is valid as long as the comparison is
+      // truncated to BackedgeTakenCount->getType().
+      IVCount =
+          SE->getAddExpr(BackedgeTakenCount,
+                         SE->getConstant(BackedgeTakenCount->getType(), 1));
+      CmpIndVar = IncrementedIndvar;
+    }
   }
 
   Value *ExitCnt = genLoopLimit(IndVar, IVCount, L, Rewriter, SE);