SLPVectorizer: Add support for trees with external users.
authorNadav Rotem <nrotem@apple.com>
Fri, 10 May 2013 22:59:33 +0000 (22:59 +0000)
committerNadav Rotem <nrotem@apple.com>
Fri, 10 May 2013 22:59:33 +0000 (22:59 +0000)
For example:
bar() {
  int a = A[i];
  int b = A[i+1];
  B[i] = a;
  B[i+1] = b;
  foo(a);  <--- a is used outside the vectorized expression.
}

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@181648 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Vectorize/VecUtils.cpp
lib/Transforms/Vectorize/VecUtils.h
test/Transforms/SLPVectorizer/X86/diamond.ll

index 9b9436683b12aea176a94b28b1749321e679b7b3..55adf8a8161ce58b8907694af631efdf7ba4bf4d 100644 (file)
@@ -243,6 +243,10 @@ int BoUpSLP::getTreeCost(ArrayRef<Value *> VL) {
   LaneMap.clear();
   MultiUserVals.clear();
   MustScalarize.clear();
+  MustExtract.clear();
+
+  // Find the location of the last root.
+  unsigned LastRootIndex = InstrIdx[GetLastInstr(VL, VL.size())];
 
   // Scan the tree and find which value is used by which lane, and which values
   // must be scalarized.
@@ -258,15 +262,31 @@ int BoUpSLP::getTreeCost(ArrayRef<Value *> VL) {
     for (Value::use_iterator I = (*it)->use_begin(), E = (*it)->use_end();
          I != E; ++I) {
       if (LaneMap.find(*I) == LaneMap.end()) {
-        MustScalarize.insert(*it);
-        DEBUG(dbgs()<<"SLP: Adding " << **it <<
-              " to MustScalarize because of an out of tree usage.\n");
-        break;
+        DEBUG(dbgs()<<"SLP: Instr " << **it << " has multiple users.\n");
+
+        // We don't have an ordering problem if the user is not in this basic
+        // block.
+        Instruction *Inst = cast<Instruction>(*I);
+        if (Inst->getParent() == BB) {
+          // We don't have an ordering problem if the user is after the last
+          // root.
+          unsigned Idx = InstrIdx[Inst];
+          if (Idx < LastRootIndex) {
+            MustScalarize.insert(*it);
+            DEBUG(dbgs()<<"SLP: Adding to MustScalarize "
+                  "because of an unsafe out of tree usage.\n");
+            break;
+          }
+        }
+
+        DEBUG(dbgs()<<"SLP: Adding to MustExtract "
+              "because of a safe out of tree usage.\n");
+        MustExtract.insert(*it);
       }
       if (Lane == -1) Lane = LaneMap[*I];
       if (Lane != LaneMap[*I]) {
         MustScalarize.insert(*it);
-        DEBUG(dbgs()<<"Adding " << **it <<
+        DEBUG(dbgs()<<"SLP: Adding " << **it <<
               " to MustScalarize because multiple lane use it: "
               << Lane << " and " << LaneMap[*I] << ".\n");
         break;
@@ -456,6 +476,13 @@ int BoUpSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
     }
   }
 
+  // Calculate the extract cost.
+  unsigned ExternalUserExtractCost = 0;
+  for (unsigned i = 0, e = VL.size(); i < e; ++i)
+    if (MustExtract.count(VL[i]))
+      ExternalUserExtractCost +=
+        TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, i);
+
   switch (Opcode) {
   case Instruction::ZExt:
   case Instruction::SExt:
@@ -469,7 +496,7 @@ int BoUpSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
   case Instruction::Trunc:
   case Instruction::FPTrunc:
   case Instruction::BitCast: {
-    int Cost = 0;
+    int Cost = ExternalUserExtractCost;
     ValueList Operands;
     Type *SrcTy = VL0->getOperand(0)->getType();
     // Prepare the operand vector.
@@ -510,7 +537,7 @@ int BoUpSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
   case Instruction::And:
   case Instruction::Or:
   case Instruction::Xor: {
-    int Cost = 0;
+    int Cost = ExternalUserExtractCost;
     // Calculate the cost of all of the operands.
     for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
       ValueList Operands;
@@ -540,7 +567,7 @@ int BoUpSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
     int ScalarLdCost = VecTy->getNumElements() *
       TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0);
     int VecLdCost = TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0);
-    return VecLdCost - ScalarLdCost;
+    return VecLdCost - ScalarLdCost + ExternalUserExtractCost;
   }
   case Instruction::Store: {
     // We know that we can merge the stores. Calculate the cost.
@@ -556,7 +583,7 @@ int BoUpSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
     }
 
     int TotalCost = StoreCost + getTreeCost_rec(Operands, Depth + 1);
-    return TotalCost;
+    return TotalCost + ExternalUserExtractCost;
   }
   default:
     // Unable to vectorize unknown instructions.
@@ -588,10 +615,24 @@ Value *BoUpSLP::Scalarize(ArrayRef<Value *> VL, VectorType *Ty) {
 
 Value *BoUpSLP::vectorizeTree(ArrayRef<Value *> VL, int VF) {
   Value *V = vectorizeTree_rec(VL, VF);
+
+  Instruction *LastInstr = GetLastInstr(VL, VL.size());
+  IRBuilder<> Builder(LastInstr);
+  for (ValueSet::iterator it = MustExtract.begin(), e = MustExtract.end();
+       it != e; ++it) {
+    Instruction *I = cast<Instruction>(*it);
+    Value *Vec = VectorizedValues[I];
+    assert(LaneMap.count(I) && "Unable to find the lane for the external use");
+    Value *Idx = Builder.getInt32(LaneMap[I]);
+    Value *Extract = Builder.CreateExtractElement(Vec, Idx);
+    I->replaceAllUsesWith(Extract);
+  }
+
   // We moved some instructions around. We have to number them again
   // before we can do any analysis.
   numberInstructions();
   MustScalarize.clear();
+  MustExtract.clear();
   return V;
 }
 
index 5456c6c7795992701ebc1b19dd72ee0d534bb8b4..abb35840e93a36b938b699fb27d9295f2112f386 100644 (file)
@@ -127,6 +127,11 @@ private:
   /// NOTICE: The vectorization methods also use this set.
   ValueSet MustScalarize;
 
+  /// Contains values that have users outside of the vectorized graph.
+  /// We need to generate extract instructions for these values.
+  /// NOTICE: The vectorization methods also use this set.
+  ValueSet MustExtract;
+
   /// Contains a list of values that are used outside the current tree. This
   /// set must be reset between runs.
   ValueSet MultiUserVals;
index 8e85cb6c9b8f84da9db5741218befc344bd67c2a..49c8712d2027dc03138d850e481e430929b1f679 100644 (file)
@@ -41,7 +41,7 @@ entry:
 }
 
 
-; int foo_fail(int * restrict B,  int * restrict A, int n, int m) {
+; int extr_user(int * restrict B,  int * restrict A, int n, int m) {
 ;   B[0] = n * A[0] + m * A[0];
 ;   B[1] = n * A[1] + m * A[1];
 ;   B[2] = n * A[2] + m * A[2];
@@ -49,10 +49,11 @@ entry:
 ;   return A[0];
 ; }
 
-; CHECK: @foo_fail
-; CHECK-NOT: load <4 x i32>
+; CHECK: @extr_user
+; CHECK: store <4 x i32>
+; CHECK-NEXT: extractelement <4 x i32>
 ; CHECK: ret
-define i32 @foo_fail(i32* noalias nocapture %B, i32* noalias nocapture %A, i32 %n, i32 %m) {
+define i32 @extr_user(i32* noalias nocapture %B, i32* noalias nocapture %A, i32 %n, i32 %m) {
 entry:
   %0 = load i32* %A, align 4
   %mul238 = add i32 %m, %n