From 15465ea1a2de1290e7a4f695a5d57039f7c76796 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Mon, 2 Nov 2015 02:06:01 +0000 Subject: [PATCH] [SCEV] Fix PR25369 Have `getConstantEvolutionLoopExitValue` work correctly with multiple entry loops. As far as I can tell, `getConstantEvolutionLoopExitValue` never did the right thing for multiple entry loops; and before r249712 it would silently return an incorrect answer. r249712 changed SCEV to fail an assert on a multiple entry loop, and this change fixes the underlying issue. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@251770 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Analysis/ScalarEvolution.cpp | 53 ++++++++-------- test/Analysis/ScalarEvolution/pr25369.ll | 78 ++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 27 deletions(-) create mode 100644 test/Analysis/ScalarEvolution/pr25369.ll diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index e7380682d9f..66c5290b8a0 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -5928,6 +5928,30 @@ static Constant *EvaluateExpression(Value *V, const Loop *L, TLI); } + +// If every incoming value to PN except the one for BB is a specific Constant, +// return that, else return nullptr. +static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) { + Constant *IncomingVal = nullptr; + + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + if (PN->getIncomingBlock(i) == BB) + continue; + + auto *CurrentVal = dyn_cast(PN->getIncomingValue(i)); + if (!CurrentVal) + return nullptr; + + if (IncomingVal != CurrentVal) { + if (IncomingVal) + return nullptr; + IncomingVal = CurrentVal; + } + } + + return IncomingVal; +} + /// getConstantEvolutionLoopExitValue - If we know that the specified Phi is /// in the header of its containing loop, we know the loop executes a /// constant number of times, and the PHI node is just a recurrence @@ -5953,25 +5977,10 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, if (!Latch) return nullptr; - // Since the loop has one latch, the PHI node must have two entries. One - // entry must be a constant (coming in from outside of the loop), and the - // second must be derived from the same PHI. - - BasicBlock *NonLatch = Latch == PN->getIncomingBlock(0) - ? PN->getIncomingBlock(1) - : PN->getIncomingBlock(0); - - assert(PN->getNumIncomingValues() == 2 && "Follows from having one latch!"); - - // Note: not all PHI nodes in the same block have to have their incoming - // values in the same order, so we use the basic block to look up the incoming - // value, not an index. - for (auto &I : *Header) { PHINode *PHI = dyn_cast(&I); if (!PHI) break; - auto *StartCST = - dyn_cast(PHI->getIncomingValueForBlock(NonLatch)); + auto *StartCST = getOtherIncomingValue(PHI, Latch); if (!StartCST) continue; CurrentIterVals[PHI] = StartCST; } @@ -6050,21 +6059,11 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, BasicBlock *Latch = L->getLoopLatch(); assert(Latch && "Should follow from NumIncomingValues == 2!"); - // NonLatch is the preheader, or something equivalent. - BasicBlock *NonLatch = Latch == PN->getIncomingBlock(0) - ? PN->getIncomingBlock(1) - : PN->getIncomingBlock(0); - - // Note: not all PHI nodes in the same block have to have their incoming - // values in the same order, so we use the basic block to look up the incoming - // value, not an index. - for (auto &I : *Header) { PHINode *PHI = dyn_cast(&I); if (!PHI) break; - auto *StartCST = - dyn_cast(PHI->getIncomingValueForBlock(NonLatch)); + auto *StartCST = getOtherIncomingValue(PHI, Latch); if (!StartCST) continue; CurrentIterVals[PHI] = StartCST; } diff --git a/test/Analysis/ScalarEvolution/pr25369.ll b/test/Analysis/ScalarEvolution/pr25369.ll new file mode 100644 index 00000000000..10754867a36 --- /dev/null +++ b/test/Analysis/ScalarEvolution/pr25369.ll @@ -0,0 +1,78 @@ +; RUN: opt -analyze -scalar-evolution < %s | FileCheck %s + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +define void @hoge1() { +; CHECK-LABEL: Classifying expressions for: @hoge1 +bb: + br i1 undef, label %bb4, label %bb2 + +bb2: ; preds = %bb2, %bb + br i1 false, label %bb4, label %bb2 + +bb3: ; preds = %bb4 + %tmp = add i32 %tmp10, -1 + br label %bb13 + +bb4: ; preds = %bb4, %bb2, %bb + %tmp5 = phi i64 [ %tmp11, %bb4 ], [ 1, %bb2 ], [ 1, %bb ] + %tmp6 = phi i32 [ %tmp10, %bb4 ], [ 0, %bb2 ], [ 0, %bb ] + %tmp7 = load i32, i32* undef, align 4 + %tmp8 = add i32 %tmp7, %tmp6 + %tmp9 = add i32 undef, %tmp8 + %tmp10 = add i32 undef, %tmp9 + %tmp11 = add nsw i64 %tmp5, 3 + %tmp12 = icmp eq i64 %tmp11, 64 + br i1 %tmp12, label %bb3, label %bb4 + +; CHECK: Loop %bb4: backedge-taken count is 20 +; CHECK: Loop %bb4: max backedge-taken count is 20 + +bb13: ; preds = %bb13, %bb3 + %tmp14 = phi i64 [ 0, %bb3 ], [ %tmp15, %bb13 ] + %tmp15 = add nuw nsw i64 %tmp14, 1 + %tmp16 = trunc i64 %tmp15 to i32 + %tmp17 = icmp eq i32 %tmp16, %tmp + br i1 %tmp17, label %bb18, label %bb13 + +bb18: ; preds = %bb13 + ret void +} + +define void @hoge2() { +; CHECK-LABEL: Classifying expressions for: @hoge2 +bb: + br i1 undef, label %bb4, label %bb2 + +bb2: ; preds = %bb2, %bb + br i1 false, label %bb4, label %bb2 + +bb3: ; preds = %bb4 + %tmp = add i32 %tmp10, -1 + br label %bb13 + +bb4: ; preds = %bb4, %bb2, %bb + %tmp5 = phi i64 [ %tmp11, %bb4 ], [ 1, %bb2 ], [ 3, %bb ] + %tmp6 = phi i32 [ %tmp10, %bb4 ], [ 0, %bb2 ], [ 0, %bb ] + %tmp7 = load i32, i32* undef, align 4 + %tmp8 = add i32 %tmp7, %tmp6 + %tmp9 = add i32 undef, %tmp8 + %tmp10 = add i32 undef, %tmp9 + %tmp11 = add nsw i64 %tmp5, 3 + %tmp12 = icmp eq i64 %tmp11, 64 + br i1 %tmp12, label %bb3, label %bb4 + +; CHECK: Loop %bb4: Unpredictable backedge-taken count. +; CHECK: Loop %bb4: Unpredictable max backedge-taken count. + +bb13: ; preds = %bb13, %bb3 + %tmp14 = phi i64 [ 0, %bb3 ], [ %tmp15, %bb13 ] + %tmp15 = add nuw nsw i64 %tmp14, 1 + %tmp16 = trunc i64 %tmp15 to i32 + %tmp17 = icmp eq i32 %tmp16, %tmp + br i1 %tmp17, label %bb18, label %bb13 + +bb18: ; preds = %bb13 + ret void +} -- 2.34.1