[SCEV] Generalize the SCEV algorithm for creating expressions for PHI nodes
authorSilviu Baranga <silviu.baranga@arm.com>
Fri, 30 Oct 2015 15:02:28 +0000 (15:02 +0000)
committerSilviu Baranga <silviu.baranga@arm.com>
Fri, 30 Oct 2015 15:02:28 +0000 (15:02 +0000)
Summary:
When forming expressions for phi nodes having an incoming value from
outside the loop A and a value coming from the previous iteration B
we were forming an AddRec if:
  - B was an AddRec
  - the value A was equal to the value for B at iteration -1 (or equal
    to the value of B shifted by one iteration, at iteration 0)

In this case, we were computing the expression to be the expression of
B, shifted by one iteration.

This changes generalizes the logic above by removing the restriction that
B needs to be an AddRec. For this we introduce two expression rewriters
that allow us to
  - shift an expression by one iteration
  - get the value of an expression at iteration 0

This allows us to get SCEV expressions for PHI nodes when these expressions
are not AddRecExprs.

Reviewers: sanjoy

Subscribers: llvm-commits, sanjoy

Differential Revision: http://reviews.llvm.org/D14175

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@251700 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Analysis/ScalarEvolution.cpp
test/Analysis/ScalarEvolution/non-IV-phi.ll [new file with mode: 0644]

index e0658c210506f42bad811b53a5c003d8c47046a5..441a5a1da6c8af22dc17cabc07c489a7fbe720a7 100644 (file)
@@ -3629,6 +3629,71 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
   }
 }
 
+class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
+public:
+  static const SCEV *rewrite(const SCEV *Scev, const Loop *L,
+                             ScalarEvolution &SE) {
+    SCEVInitRewriter Rewriter(L, SE);
+    const SCEV *Result = Rewriter.visit(Scev);
+    return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
+  }
+
+  SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
+      : SCEVRewriteVisitor(SE), L(L), Valid(true) {}
+
+  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+    if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant))
+      Valid = false;
+    return Expr;
+  }
+
+  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+    // Only allow AddRecExprs for this loop.
+    if (Expr->getLoop() == L)
+      return Expr->getStart();
+    Valid = false;
+    return Expr;
+  }
+
+  bool isValid() { return Valid; }
+
+private:
+  const Loop *L;
+  bool Valid;
+};
+
+class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
+public:
+  static const SCEV *rewrite(const SCEV *Scev, const Loop *L,
+                             ScalarEvolution &SE) {
+    SCEVShiftRewriter Rewriter(L, SE);
+    const SCEV *Result = Rewriter.visit(Scev);
+    return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
+  }
+
+  SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
+      : SCEVRewriteVisitor(SE), L(L), Valid(true) {}
+
+  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+    // Only allow AddRecExprs for this loop.
+    if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant))
+      Valid = false;
+    return Expr;
+  }
+
+  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+    if (Expr->getLoop() == L && Expr->isAffine())
+      return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
+    Valid = false;
+    return Expr;
+  }
+  bool isValid() { return Valid; }
+
+private:
+  const Loop *L;
+  bool Valid;
+};
+
 const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
   const Loop *L = LI.getLoopFor(PN->getParent());
   if (!L || L->getHeader() != PN->getParent())
@@ -3741,30 +3806,28 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
           return PHISCEV;
         }
       }
-    } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(BEValue)) {
+    } else {
       // Otherwise, this could be a loop like this:
       //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
       // In this case, j = {1,+,1}  and BEValue is j.
       // Because the other in-value of i (0) fits the evolution of BEValue
       // i really is an addrec evolution.
-      if (AddRec->getLoop() == L && AddRec->isAffine()) {
+      //
+      // We can generalize this saying that i is the shifted value of BEValue
+      // by one iteration:
+      //   PHI(f(0), f({1,+,1})) --> f({0,+,1})
+      const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
+      const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this);
+      if (Shifted != getCouldNotCompute() &&
+          Start != getCouldNotCompute()) {
         const SCEV *StartVal = getSCEV(StartValueV);
-
-        // If StartVal = j.start - j.stride, we can use StartVal as the
-        // initial step of the addrec evolution.
-        if (StartVal ==
-            getMinusSCEV(AddRec->getOperand(0), AddRec->getOperand(1))) {
-          // FIXME: For constant StartVal, we should be able to infer
-          // no-wrap flags.
-          const SCEV *PHISCEV = getAddRecExpr(StartVal, AddRec->getOperand(1),
-                                              L, SCEV::FlagAnyWrap);
-
+        if (Start == StartVal) {
           // Okay, for the entire analysis of this edge we assumed the PHI
           // to be symbolic.  We now need to go back and purge all of the
           // entries for the scalars that use the symbolic expression.
           ForgetSymbolicName(PN, SymbolicName);
-          ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
-          return PHISCEV;
+          ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted;
+          return Shifted;
         }
       }
     }
diff --git a/test/Analysis/ScalarEvolution/non-IV-phi.ll b/test/Analysis/ScalarEvolution/non-IV-phi.ll
new file mode 100644 (file)
index 0000000..f0d6c2f
--- /dev/null
@@ -0,0 +1,59 @@
+; RUN: opt -scalar-evolution -analyze < %s | FileCheck %s
+
+define void @test1(i8 %t, i32 %len) {
+; CHECK-LABEL: test1
+; CHECK: %sphi = phi i32 [ %ext, %entry ], [ %idx.inc.ext, %loop ]
+; CHECK-NEXT:  -->  (zext i8 {%t,+,1}<%loop> to i32)
+
+ entry:
+  %st = zext i8 %t to i16
+  %ext = zext i8 %t to i32
+  %ecmp = icmp ult i16 %st, 42
+  br i1 %ecmp, label %loop, label %exit
+
+ loop:
+
+  %idx = phi i8 [ %t, %entry ], [ %idx.inc, %loop ]
+  %sphi = phi i32 [ %ext, %entry ], [%idx.inc.ext, %loop]
+
+  %idx.inc = add i8 %idx, 1
+  %idx.inc.ext = zext i8 %idx.inc to i32
+  %idx.ext = zext i8 %idx to i32
+
+  %c = icmp ult i32 %idx.inc.ext, %len
+  br i1 %c, label %loop, label %exit
+
+ exit:
+  ret void
+}
+
+define void @test2(i8 %t, i32 %len) {
+; CHECK-LABEL: test2
+; CHECK: %sphi = phi i32 [ %ext.mul, %entry ], [ %mul, %loop ]
+; CHECK-NEXT:  -->  (4 * (zext i8 {%t,+,1}<%loop> to i32))
+
+ entry:
+  %st = zext i8 %t to i16
+  %ext = zext i8 %t to i32
+  %ext.mul = mul i32 %ext, 4
+
+  %ecmp = icmp ult i16 %st, 42
+  br i1 %ecmp, label %loop, label %exit
+
+ loop:
+
+  %idx = phi i8 [ %t, %entry ], [ %idx.inc, %loop ]
+  %sphi = phi i32 [ %ext.mul, %entry ], [%mul, %loop]
+
+  %idx.inc = add i8 %idx, 1
+  %idx.inc.ext = zext i8 %idx.inc to i32
+  %mul = mul i32 %idx.inc.ext, 4
+
+  %idx.ext = zext i8 %idx to i32
+
+  %c = icmp ult i32 %idx.inc.ext, %len
+  br i1 %c, label %loop, label %exit
+
+ exit:
+  ret void
+}