[LSR] Generate and use zero extends
authorSanjoy Das <sanjoy@playingwithpointers.com>
Mon, 27 Jul 2015 23:27:51 +0000 (23:27 +0000)
committerSanjoy Das <sanjoy@playingwithpointers.com>
Mon, 27 Jul 2015 23:27:51 +0000 (23:27 +0000)
Summary:
If a scale or a base register can be rewritten as "Zext({A,+,1})" then
LSR will now consider a formula of that form in its normal cost
computation.

Depends on D9180

Reviewers: qcolombet, atrick

Subscribers: llvm-commits

Differential Revision: http://reviews.llvm.org/D9181

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@243348 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/LoopStrengthReduce.cpp
test/Transforms/LoopStrengthReduce/zext-of-scale.ll [new file with mode: 0644]

index 773777ac804f179302fc395bca6bc16a42e58d84..059b10ef73f2357bd95c179de7d8cb81289e1775 100644 (file)
@@ -256,9 +256,22 @@ struct Formula {
   /// live in an add immediate field rather than a register.
   int64_t UnfoldedOffset;
 
+  /// ZeroExtendScaledReg - This formula zero extends the scale register to
+  /// ZeroExtendType before its use.
+  bool ZeroExtendScaledReg;
+
+  /// ZeroExtendBaseReg - This formula zero extends all the base registers to
+  /// ZeroExtendType before their use.
+  bool ZeroExtendBaseReg;
+
+  /// ZeroExtendType - The destination type of the zero extension implied by
+  /// the above two booleans.
+  Type *ZeroExtendType;
+
   Formula()
       : BaseGV(nullptr), BaseOffset(0), HasBaseReg(false), Scale(0),
-        ScaledReg(nullptr), UnfoldedOffset(0) {}
+        ScaledReg(nullptr), UnfoldedOffset(0), ZeroExtendScaledReg(false),
+        ZeroExtendBaseReg(false), ZeroExtendType(nullptr) {}
 
   void InitialMatch(const SCEV *S, Loop *L, ScalarEvolution &SE);
 
@@ -413,10 +426,12 @@ size_t Formula::getNumRegs() const {
 /// getType - Return the type of this formula, if it has one, or null
 /// otherwise. This type is meaningless except for the bit size.
 Type *Formula::getType() const {
-  return !BaseRegs.empty() ? BaseRegs.front()->getType() :
-         ScaledReg ? ScaledReg->getType() :
-         BaseGV ? BaseGV->getType() :
-         nullptr;
+  return ZeroExtendType
+             ? ZeroExtendType
+             : !BaseRegs.empty()
+                   ? BaseRegs.front()->getType()
+                   : ScaledReg ? ScaledReg->getType()
+                               : BaseGV ? BaseGV->getType() : nullptr;
 }
 
 /// DeleteBaseReg - Delete the given base reg from the BaseRegs list.
@@ -457,7 +472,10 @@ void Formula::print(raw_ostream &OS) const {
   }
   for (const SCEV *BaseReg : BaseRegs) {
     if (!First) OS << " + "; else First = false;
-    OS << "reg(" << *BaseReg << ')';
+    if (ZeroExtendBaseReg)
+      OS << "reg(zext " << *BaseReg << " to " << *ZeroExtendType << ')';
+    else
+      OS << "reg(" << *BaseReg << ')';
   }
   if (HasBaseReg && BaseRegs.empty()) {
     if (!First) OS << " + "; else First = false;
@@ -469,9 +487,12 @@ void Formula::print(raw_ostream &OS) const {
   if (Scale != 0) {
     if (!First) OS << " + "; else First = false;
     OS << Scale << "*reg(";
-    if (ScaledReg)
-      OS << *ScaledReg;
-    else
+    if (ScaledReg) {
+      if (ZeroExtendScaledReg)
+        OS << "(zext " << *ScaledReg << " to " << *ZeroExtendType << ')';
+      else
+        OS << *ScaledReg;
+    } else
       OS << "<unknown>";
     OS << ')';
   }
@@ -1732,6 +1753,7 @@ class LSRInstance {
   void GenerateICmpZeroScales(LSRUse &LU, unsigned LUIdx, Formula Base);
   void GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base);
   void GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base);
+  void GenerateZExts(LSRUse &LU, unsigned LUIdx, Formula Base);
   void GenerateCrossUseConstantOffsets();
   void GenerateAllReuseFormulae();
 
@@ -3627,6 +3649,64 @@ void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) {
   }
 }
 
+/// GenerateZExts - If a scale or a base register can be rewritten as
+/// "Zext({A,+,1})" then consider a formula of that form.
+void LSRInstance::GenerateZExts(LSRUse &LU, unsigned LUIdx, Formula Base) {
+  // Don't bother with symbolic values.
+  if (Base.BaseGV)
+    return;
+
+  auto CanBeNarrowed = [&](const SCEV *Reg) -> const SCEV * {
+    // Check if the register is an increment can be rewritten as zext(R) where
+    // the zext is free.
+
+    const auto *RegAR = dyn_cast_or_null<SCEVAddRecExpr>(Reg);
+    if (!RegAR)
+      return nullptr;
+
+    const auto *ZExtStart = dyn_cast<SCEVZeroExtendExpr>(RegAR->getStart());
+    const auto *ConstStep =
+        dyn_cast<SCEVConstant>(RegAR->getStepRecurrence(SE));
+    if (!ZExtStart || !ConstStep || ConstStep->getValue()->getValue() != 1)
+      return nullptr;
+
+    const SCEV *NarrowStart = ZExtStart->getOperand();
+    if (!TTI.isZExtFree(NarrowStart->getType(), ZExtStart->getType()))
+      return nullptr;
+
+    const auto *NarrowAR = dyn_cast<SCEVAddRecExpr>(
+        SE.getAddRecExpr(NarrowStart, SE.getConstant(NarrowStart->getType(), 1),
+                         RegAR->getLoop(), RegAR->getNoWrapFlags()));
+
+    if (!NarrowAR || !NarrowAR->getNoWrapFlags(SCEV::FlagNUW))
+      return nullptr;
+
+    return NarrowAR;
+  };
+
+  if (Base.ScaledReg && !Base.ZeroExtendType)
+    if (const SCEV *S = CanBeNarrowed(Base.ScaledReg)) {
+      Formula F = Base;
+      F.ZeroExtendType = Base.ScaledReg->getType();
+      F.ZeroExtendScaledReg = true;
+      F.ScaledReg = S;
+
+      if (isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, F))
+        InsertFormula(LU, LUIdx, F);
+    }
+
+  if (Base.BaseRegs.size() == 1 && !Base.ZeroExtendType)
+    if (const SCEV *S = CanBeNarrowed(Base.BaseRegs[0])) {
+      Formula F = Base;
+      F.ZeroExtendType = Base.BaseRegs[0]->getType();
+      F.ZeroExtendBaseReg = true;
+      F.BaseRegs[0] = S;
+
+      if (isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, F))
+        InsertFormula(LU, LUIdx, F);
+    }
+}
+
 namespace {
 
 /// WorkItem - Helper class for GenerateCrossUseConstantOffsets. It's used to
@@ -3846,6 +3926,8 @@ LSRInstance::GenerateAllReuseFormulae() {
     LSRUse &LU = Uses[LUIdx];
     for (size_t i = 0, f = LU.Formulae.size(); i != f; ++i)
       GenerateTruncates(LU, LUIdx, LU.Formulae[i]);
+    for (size_t i = 0, f = LU.Formulae.size(); i != f; ++i)
+      GenerateZExts(LU, LUIdx, LU.Formulae[i]);
   }
 
   GenerateCrossUseConstantOffsets();
@@ -4483,13 +4565,28 @@ Value *LSRInstance::Expand(const LSRFixup &LF,
 
     // If we're expanding for a post-inc user, make the post-inc adjustment.
     PostIncLoopSet &Loops = const_cast<PostIncLoopSet &>(LF.PostIncLoops);
-    Reg = TransformForPostIncUse(Denormalize, Reg,
-                                 LF.UserInst, LF.OperandValToReplace,
-                                 Loops, SE, DT);
-
-    Ops.push_back(SE.getUnknown(Rewriter.expandCodeFor(Reg, nullptr, IP)));
+    const SCEV *ExtendedReg =
+        F.ZeroExtendBaseReg ? SE.getZeroExtendExpr(Reg, F.ZeroExtendType) : Reg;
+
+    const SCEV *PostIncReg =
+        TransformForPostIncUse(Denormalize, ExtendedReg, LF.UserInst,
+                               LF.OperandValToReplace, Loops, SE, DT);
+    if (PostIncReg == ExtendedReg) {
+      Value *Expanded = Rewriter.expandCodeFor(Reg, nullptr, IP);
+      if (F.ZeroExtendBaseReg)
+        Expanded = new ZExtInst(Expanded, F.ZeroExtendType, "", IP);
+      Ops.push_back(SE.getUnknown(Expanded));
+    } else {
+      Ops.push_back(
+          SE.getUnknown(Rewriter.expandCodeFor(PostIncReg, nullptr, IP)));
+    }
   }
 
+  // Note on post-inc uses and zero extends -- since the no-wrap behavior for
+  // the post-inc SCEV can be different from the no-wrap behavior of the pre-inc
+  // SCEV, if a post-inc transform is required we do the zero extension on the
+  // pre-inc expression before doing the post-inc transform.
+
   // Expand the ScaledReg portion.
   Value *ICmpScaledV = nullptr;
   if (F.Scale != 0) {
@@ -4497,22 +4594,33 @@ Value *LSRInstance::Expand(const LSRFixup &LF,
 
     // If we're expanding for a post-inc user, make the post-inc adjustment.
     PostIncLoopSet &Loops = const_cast<PostIncLoopSet &>(LF.PostIncLoops);
-    ScaledS = TransformForPostIncUse(Denormalize, ScaledS,
-                                     LF.UserInst, LF.OperandValToReplace,
-                                     Loops, SE, DT);
+    const SCEV *ExtendedScaleS =
+        F.ZeroExtendScaledReg ? SE.getZeroExtendExpr(ScaledS, F.ZeroExtendType)
+                              : ScaledS;
+    const SCEV *PostIncScaleS =
+        TransformForPostIncUse(Denormalize, ExtendedScaleS, LF.UserInst,
+                               LF.OperandValToReplace, Loops, SE, DT);
 
     if (LU.Kind == LSRUse::ICmpZero) {
       // Expand ScaleReg as if it was part of the base regs.
+      Value *Expanded = nullptr;
+      if (PostIncScaleS == ExtendedScaleS) {
+        Expanded = Rewriter.expandCodeFor(ScaledS, nullptr, IP);
+        if (F.ZeroExtendScaledReg)
+          Expanded = new ZExtInst(Expanded, F.ZeroExtendType, "", IP);
+      } else {
+        Expanded = Rewriter.expandCodeFor(PostIncScaleS, nullptr, IP);
+      }
+
       if (F.Scale == 1)
-        Ops.push_back(
-            SE.getUnknown(Rewriter.expandCodeFor(ScaledS, nullptr, IP)));
+        Ops.push_back(SE.getUnknown(Expanded));
       else {
         // An interesting way of "folding" with an icmp is to use a negated
         // scale, which we'll implement by inserting it into the other operand
         // of the icmp.
         assert(F.Scale == -1 &&
                "The only scale supported by ICmpZero uses is -1!");
-        ICmpScaledV = Rewriter.expandCodeFor(ScaledS, nullptr, IP);
+        ICmpScaledV = Expanded;
       }
     } else {
       // Otherwise just expand the scaled register and an explicit scale,
@@ -4526,7 +4634,17 @@ Value *LSRInstance::Expand(const LSRFixup &LF,
         Ops.clear();
         Ops.push_back(SE.getUnknown(FullV));
       }
-      ScaledS = SE.getUnknown(Rewriter.expandCodeFor(ScaledS, nullptr, IP));
+
+      Value *Expanded = nullptr;
+      if (PostIncScaleS == ExtendedScaleS) {
+        Expanded = Rewriter.expandCodeFor(ScaledS, nullptr, IP);
+        if (F.ZeroExtendScaledReg)
+          Expanded = new ZExtInst(Expanded, F.ZeroExtendType, "", IP);
+      } else {
+        Expanded = Rewriter.expandCodeFor(PostIncScaleS, nullptr, IP);
+      }
+
+      ScaledS = SE.getUnknown(Expanded);
       if (F.Scale != 1)
         ScaledS =
             SE.getMulExpr(ScaledS, SE.getConstant(ScaledS->getType(), F.Scale));
diff --git a/test/Transforms/LoopStrengthReduce/zext-of-scale.ll b/test/Transforms/LoopStrengthReduce/zext-of-scale.ll
new file mode 100644 (file)
index 0000000..d0972fe
--- /dev/null
@@ -0,0 +1,70 @@
+; RUN: opt  < %s -S -loop-reduce | FileCheck %s
+
+target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+%struct = type { [8 x i8] }
+
+declare void @use_32(i32)
+declare void @use_64(i64)
+
+define void @f(i32 %tmp156, i32* %length_buf_1, i32* %length_buf_0, %struct* %b,
+                %struct* %c, %struct* %d, %struct* %e, i32* %length_buf_2,
+                i32 %tmp160) {
+; CHECK-LABEL: @f(
+entry:
+  %begin151 = getelementptr inbounds %struct, %struct* %b, i64 0, i32 0, i64 12
+  %tmp21 = bitcast i8* %begin151 to i32*
+  %begin157 = getelementptr inbounds %struct, %struct* %c, i64 0, i32 0, i64 16
+  %tmp23 = bitcast i8* %begin157 to double*
+  %begin163 = getelementptr inbounds %struct, %struct* %d, i64 0, i32 0, i64 16
+  %tmp25 = bitcast i8* %begin163 to double*
+  %length.i820 = load i32, i32* %length_buf_1, align 4, !range !0
+  %enter = icmp ne i32 %tmp156, -1
+  br i1 %enter, label %ok_146, label %block_81_2
+
+ok_146:
+  %var_13 = phi double [ %tmp186, %ok_161 ], [ 0.000000e+00, %entry ]
+  %var_17 = phi i32 [ %tmp187, %ok_161 ], [ %tmp156, %entry ]
+  %tmp174 = zext i32 %var_17 to i64
+  %tmp175 = icmp ult i32 %var_17, %length.i820
+  br i1 %tmp175, label %ok_152, label %block_81_2
+
+ok_152:
+  %tmp176 = getelementptr inbounds i32, i32* %tmp21, i64 %tmp174
+  %tmp177 = load i32, i32* %tmp176, align 4
+  %tmp178 = zext i32 %tmp177 to i64
+  %length.i836 = load i32, i32* %length_buf_2, align 4, !range !0
+  %tmp179 = icmp ult i32 %tmp177, %length.i836
+  br i1 %tmp179, label %ok_158, label %block_81_2
+
+ok_158:
+  %tmp180 = getelementptr inbounds double, double* %tmp23, i64 %tmp178
+  %tmp181 = load double, double* %tmp180, align 8
+  %length.i = load i32, i32* %length_buf_0, align 4, !range !0
+  %tmp182 = icmp slt i32 %var_17, %length.i
+  br i1 %tmp182, label %ok_161, label %block_81_2
+
+ok_161:
+; CHECK-LABEL: ok_161:
+; CHECK: add
+; CHECK-NOT: add
+  %tmp183 = getelementptr inbounds double, double* %tmp25, i64 %tmp174
+  %tmp184 = load double, double* %tmp183, align 8
+  %tmp185 = fmul double %tmp181, %tmp184
+  %tmp186 = fadd double %var_13, %tmp185
+  %tmp187 = add nsw i32 %var_17, 1
+  %tmp188 = icmp slt i32 %tmp187, %tmp160
+; CHECK: br
+  br i1 %tmp188, label %ok_146, label %block_81
+
+block_81:
+  call void @use_64(i64 %tmp174)  ;; pre-inc use
+  call void @use_32(i32 %tmp187)  ;; post-inc use
+  ret void
+
+block_81_2:
+  ret void
+}
+
+!0 = !{i32 0, i32 2147483647}