[SLP] Be more aggressive about reduction width selection.
authorCharlie Turner <charlie.turner@arm.com>
Tue, 27 Oct 2015 17:59:03 +0000 (17:59 +0000)
committerCharlie Turner <charlie.turner@arm.com>
Tue, 27 Oct 2015 17:59:03 +0000 (17:59 +0000)
Summary:
This change could be way off-piste, I'm looking for any feedback on whether it's an acceptable approach.

It never seems to be a problem to gobble up as many reduction values as can be found, and then to attempt to reduce the resulting tree. Some of the workloads I'm looking at have been aggressively unrolled by hand, and by selecting reduction widths that are not constrained by a vector register size, it becomes possible to profitably vectorize. My test case shows such an unrolling which SLP was not vectorizing (on neither ARM nor X86) before this patch, but with it does vectorize.

I measure no significant compile time impact of this change when combined with D13949 and D14063. There are also no significant performance regressions on ARM/AArch64 in SPEC or LNT.

The more principled approach I thought of was to generate several candidate tree's and use the cost model to pick the cheapest one. That seemed like quite a big design change (the algorithms seem very much one-shot), and would likely be a costly thing for compile time. This seemed to do the job at very little cost, but I'm worried I've misunderstood something!

Reviewers: nadav, jmolloy

Subscribers: mssimpso, llvm-commits, aemerson

Differential Revision: http://reviews.llvm.org/D14116

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@251428 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Vectorize/SLPVectorizer.cpp
test/Transforms/SLPVectorizer/AArch64/horizontal.ll

index dcaa24008ffcc61de6598370fd1ffee0119070b4..14768df8bd2e77aea176a8e47c98ce7e0d6eca08 100644 (file)
@@ -3659,16 +3659,17 @@ class HorizontalReduction {
   unsigned ReductionOpcode;
   /// The opcode of the values we perform a reduction on.
   unsigned ReducedValueOpcode;
-  /// The width of one full horizontal reduction operation.
-  unsigned ReduxWidth;
   /// Should we model this reduction as a pairwise reduction tree or a tree that
   /// splits the vector in halves and adds those halves.
   bool IsPairwiseReduction;
 
 public:
+  /// The width of one full horizontal reduction operation.
+  unsigned ReduxWidth;
+
   HorizontalReduction()
     : ReductionRoot(nullptr), ReductionPHI(nullptr), ReductionOpcode(0),
-    ReducedValueOpcode(0), ReduxWidth(0), IsPairwiseReduction(false) {}
+    ReducedValueOpcode(0), IsPairwiseReduction(false), ReduxWidth(0) {}
 
   /// \brief Try to find a reduction tree.
   bool matchAssociativeReduction(PHINode *Phi, BinaryOperator *B) {
@@ -3825,8 +3826,11 @@ public:
     return VectorizedTree != nullptr;
   }
 
-private:
+  unsigned numReductionValues() const {
+    return ReducedVals.size();
+  }
 
+private:
   /// \brief Calculate the cost of a reduction.
   int getReductionCost(TargetTransformInfo *TTI, Value *FirstReducedVal) {
     Type *ScalarTy = FirstReducedVal->getType();
@@ -3973,6 +3977,30 @@ static Value *getReductionValue(PHINode *P, BasicBlock *ParentBB,
   return Rdx;
 }
 
+/// \brief Attempt to reduce a horizontal reduction.
+/// If it is legal to match a horizontal reduction feeding
+/// the phi node P with reduction operators BI, then check if it
+/// can be done.
+/// \returns true if a horizontal reduction was matched and reduced.
+/// \returns false if a horizontal reduction was not matched.
+static bool canMatchHorizontalReduction(PHINode *P, BinaryOperator *BI,
+                                        BoUpSLP &R, TargetTransformInfo *TTI) {
+  if (!ShouldVectorizeHor)
+    return false;
+
+  HorizontalReduction HorRdx;
+  if (!HorRdx.matchAssociativeReduction(P, BI))
+    return false;
+
+  // If there is a sufficient number of reduction values, reduce
+  // to a nearby power-of-2. Can safely generate oversized
+  // vectors and rely on the backend to split them to legal sizes.
+  HorRdx.ReduxWidth =
+    std::max((uint64_t)4, PowerOf2Floor(HorRdx.numReductionValues()));
+
+  return HorRdx.tryToReduce(R, TTI);
+}
+
 bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
   bool Changed = false;
   SmallVector<Value *, 4> Incoming;
@@ -4049,9 +4077,7 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
         continue;
 
       // Try to match and vectorize a horizontal reduction.
-      HorizontalReduction HorRdx;
-      if (ShouldVectorizeHor && HorRdx.matchAssociativeReduction(P, BI) &&
-          HorRdx.tryToReduce(R, TTI)) {
+      if (canMatchHorizontalReduction(P, BI, R, TTI)) {
         Changed = true;
         it = BB->begin();
         e = BB->end();
@@ -4074,15 +4100,12 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
       continue;
     }
 
-    // Try to vectorize horizontal reductions feeding into a store.
     if (ShouldStartVectorizeHorAtStore)
       if (StoreInst *SI = dyn_cast<StoreInst>(it))
         if (BinaryOperator *BinOp =
                 dyn_cast<BinaryOperator>(SI->getValueOperand())) {
-          HorizontalReduction HorRdx;
-          if (((HorRdx.matchAssociativeReduction(nullptr, BinOp) &&
-                HorRdx.tryToReduce(R, TTI)) ||
-               tryToVectorize(BinOp, R))) {
+          if (canMatchHorizontalReduction(nullptr, BinOp, R, TTI) ||
+              tryToVectorize(BinOp, R)) {
             Changed = true;
             it = BB->begin();
             e = BB->end();
index 80ab421c17ac1e2889bb2cdbc786d213763371ac..b31f215f1ae837df14d7fd7e32a1a490a3b8ac8f 100644 (file)
@@ -145,3 +145,126 @@ for.end:                                          ; preds = %for.end.loopexit, %
   %s.1 = phi i32 [ 0, %entry ], [ %add13, %for.end.loopexit ]
   ret i32 %s.1
 }
+
+; CHECK: test_unrolled_select
+; CHECK: load <8 x i8>
+; CHECK: load <8 x i8>
+; CHECK: select <8 x i1>
+define i32 @test_unrolled_select(i8* noalias nocapture readonly %blk1, i8* noalias nocapture readonly %blk2, i32 %lx, i32 %h, i32 %lim) #0 {
+entry:
+  %cmp.43 = icmp sgt i32 %h, 0
+  br i1 %cmp.43, label %for.body.lr.ph, label %for.end
+
+for.body.lr.ph:                                   ; preds = %entry
+  %idx.ext = sext i32 %lx to i64
+  br label %for.body
+
+for.body:                                         ; preds = %for.body.lr.ph, %if.end.86
+  %s.047 = phi i32 [ 0, %for.body.lr.ph ], [ %add82, %if.end.86 ]
+  %j.046 = phi i32 [ 0, %for.body.lr.ph ], [ %inc, %if.end.86 ]
+  %p2.045 = phi i8* [ %blk2, %for.body.lr.ph ], [ %add.ptr88, %if.end.86 ]
+  %p1.044 = phi i8* [ %blk1, %for.body.lr.ph ], [ %add.ptr, %if.end.86 ]
+  %0 = load i8, i8* %p1.044, align 1
+  %conv = zext i8 %0 to i32
+  %1 = load i8, i8* %p2.045, align 1
+  %conv2 = zext i8 %1 to i32
+  %sub = sub nsw i32 %conv, %conv2
+  %cmp3 = icmp slt i32 %sub, 0
+  %sub5 = sub nsw i32 0, %sub
+  %sub5.sub = select i1 %cmp3, i32 %sub5, i32 %sub
+  %add = add nsw i32 %sub5.sub, %s.047
+  %arrayidx6 = getelementptr inbounds i8, i8* %p1.044, i64 1
+  %2 = load i8, i8* %arrayidx6, align 1
+  %conv7 = zext i8 %2 to i32
+  %arrayidx8 = getelementptr inbounds i8, i8* %p2.045, i64 1
+  %3 = load i8, i8* %arrayidx8, align 1
+  %conv9 = zext i8 %3 to i32
+  %sub10 = sub nsw i32 %conv7, %conv9
+  %cmp11 = icmp slt i32 %sub10, 0
+  %sub14 = sub nsw i32 0, %sub10
+  %v.1 = select i1 %cmp11, i32 %sub14, i32 %sub10
+  %add16 = add nsw i32 %add, %v.1
+  %arrayidx17 = getelementptr inbounds i8, i8* %p1.044, i64 2
+  %4 = load i8, i8* %arrayidx17, align 1
+  %conv18 = zext i8 %4 to i32
+  %arrayidx19 = getelementptr inbounds i8, i8* %p2.045, i64 2
+  %5 = load i8, i8* %arrayidx19, align 1
+  %conv20 = zext i8 %5 to i32
+  %sub21 = sub nsw i32 %conv18, %conv20
+  %cmp22 = icmp slt i32 %sub21, 0
+  %sub25 = sub nsw i32 0, %sub21
+  %sub25.sub21 = select i1 %cmp22, i32 %sub25, i32 %sub21
+  %add27 = add nsw i32 %add16, %sub25.sub21
+  %arrayidx28 = getelementptr inbounds i8, i8* %p1.044, i64 3
+  %6 = load i8, i8* %arrayidx28, align 1
+  %conv29 = zext i8 %6 to i32
+  %arrayidx30 = getelementptr inbounds i8, i8* %p2.045, i64 3
+  %7 = load i8, i8* %arrayidx30, align 1
+  %conv31 = zext i8 %7 to i32
+  %sub32 = sub nsw i32 %conv29, %conv31
+  %cmp33 = icmp slt i32 %sub32, 0
+  %sub36 = sub nsw i32 0, %sub32
+  %v.3 = select i1 %cmp33, i32 %sub36, i32 %sub32
+  %add38 = add nsw i32 %add27, %v.3
+  %arrayidx39 = getelementptr inbounds i8, i8* %p1.044, i64 4
+  %8 = load i8, i8* %arrayidx39, align 1
+  %conv40 = zext i8 %8 to i32
+  %arrayidx41 = getelementptr inbounds i8, i8* %p2.045, i64 4
+  %9 = load i8, i8* %arrayidx41, align 1
+  %conv42 = zext i8 %9 to i32
+  %sub43 = sub nsw i32 %conv40, %conv42
+  %cmp44 = icmp slt i32 %sub43, 0
+  %sub47 = sub nsw i32 0, %sub43
+  %sub47.sub43 = select i1 %cmp44, i32 %sub47, i32 %sub43
+  %add49 = add nsw i32 %add38, %sub47.sub43
+  %arrayidx50 = getelementptr inbounds i8, i8* %p1.044, i64 5
+  %10 = load i8, i8* %arrayidx50, align 1
+  %conv51 = zext i8 %10 to i32
+  %arrayidx52 = getelementptr inbounds i8, i8* %p2.045, i64 5
+  %11 = load i8, i8* %arrayidx52, align 1
+  %conv53 = zext i8 %11 to i32
+  %sub54 = sub nsw i32 %conv51, %conv53
+  %cmp55 = icmp slt i32 %sub54, 0
+  %sub58 = sub nsw i32 0, %sub54
+  %v.5 = select i1 %cmp55, i32 %sub58, i32 %sub54
+  %add60 = add nsw i32 %add49, %v.5
+  %arrayidx61 = getelementptr inbounds i8, i8* %p1.044, i64 6
+  %12 = load i8, i8* %arrayidx61, align 1
+  %conv62 = zext i8 %12 to i32
+  %arrayidx63 = getelementptr inbounds i8, i8* %p2.045, i64 6
+  %13 = load i8, i8* %arrayidx63, align 1
+  %conv64 = zext i8 %13 to i32
+  %sub65 = sub nsw i32 %conv62, %conv64
+  %cmp66 = icmp slt i32 %sub65, 0
+  %sub69 = sub nsw i32 0, %sub65
+  %sub69.sub65 = select i1 %cmp66, i32 %sub69, i32 %sub65
+  %add71 = add nsw i32 %add60, %sub69.sub65
+  %arrayidx72 = getelementptr inbounds i8, i8* %p1.044, i64 7
+  %14 = load i8, i8* %arrayidx72, align 1
+  %conv73 = zext i8 %14 to i32
+  %arrayidx74 = getelementptr inbounds i8, i8* %p2.045, i64 7
+  %15 = load i8, i8* %arrayidx74, align 1
+  %conv75 = zext i8 %15 to i32
+  %sub76 = sub nsw i32 %conv73, %conv75
+  %cmp77 = icmp slt i32 %sub76, 0
+  %sub80 = sub nsw i32 0, %sub76
+  %v.7 = select i1 %cmp77, i32 %sub80, i32 %sub76
+  %add82 = add nsw i32 %add71, %v.7
+  %cmp83 = icmp slt i32 %add82, %lim
+  br i1 %cmp83, label %if.end.86, label %for.end.loopexit
+
+if.end.86:                                        ; preds = %for.body
+  %add.ptr = getelementptr inbounds i8, i8* %p1.044, i64 %idx.ext
+  %add.ptr88 = getelementptr inbounds i8, i8* %p2.045, i64 %idx.ext
+  %inc = add nuw nsw i32 %j.046, 1
+  %cmp = icmp slt i32 %inc, %h
+  br i1 %cmp, label %for.body, label %for.end.loopexit
+
+for.end.loopexit:                                 ; preds = %for.body, %if.end.86
+  br label %for.end
+
+for.end:                                          ; preds = %for.end.loopexit, %entry
+  %s.1 = phi i32 [ 0, %entry ], [ %add82, %for.end.loopexit ]
+  ret i32 %s.1
+}
+