Do not add cse-ed instructions into the visited map because we dont want to consider...
[oota-llvm.git] / lib / Transforms / Vectorize / SLPVectorizer.cpp
index 5bc3d852e79569eebaa3bf8ecac98935b5e67a5e..bb37994e9ff0ec2b6d5c8488f3e589615f7dbbb8 100644 (file)
@@ -53,7 +53,7 @@ namespace {
 
 static const unsigned MinVecRegSize = 128;
 
-static const unsigned RecursionMaxDepth = 6;
+static const unsigned RecursionMaxDepth = 12;
 
 /// RAII pattern to save the insertion point of the IR builder.
 class BuilderLocGuard {
@@ -239,6 +239,10 @@ public:
   /// NOTICE: The vectorization methods also use this set.
   ValueSet MustGather;
 
+  /// Contains PHINodes that are being processed. We use this data structure
+  /// to stop cycles in the graph.
+  ValueSet VisitedPHIs;
+
   /// Contains a list of values that are used outside the current tree. This
   /// set must be reset between runs.
   SetVector<Value *> MultiUserVals;
@@ -457,13 +461,31 @@ void FuncSLP::getTreeUses_rec(ArrayRef<Value *> VL, unsigned Depth) {
 
   // Mark instructions with multiple users.
   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
+    if (PHINode *PN = dyn_cast<PHINode>(VL[i])) {
+      unsigned NumUses = 0;
+      // Check that PHINodes have only one external (non-self) use.
+      for (Value::use_iterator U = VL[i]->use_begin(), UE = VL[i]->use_end();
+           U != UE; ++U) {
+        // Don't count self uses.
+        if (*U == PN)
+          continue;
+        NumUses++;
+      }
+      if (NumUses > 1) {
+        DEBUG(dbgs() << "SLP: Adding PHI to MultiUserVals "
+              "because it has " << NumUses << " users:" << *PN << " \n");
+        MultiUserVals.insert(PN);
+      }
+      continue;
+    }
+
     Instruction *I = dyn_cast<Instruction>(VL[i]);
     // Remember to check if all of the users of this instruction are vectorized
     // within our tree. At depth zero we have no local users, only external
     // users that we don't care about.
     if (Depth && I && I->getNumUses() > 1) {
       DEBUG(dbgs() << "SLP: Adding to MultiUserVals "
-                      "because it has multiple users:" << *I << " \n");
+            "because it has " << I->getNumUses() << " users:" << *I << " \n");
       MultiUserVals.insert(I);
     }
   }
@@ -483,6 +505,24 @@ void FuncSLP::getTreeUses_rec(ArrayRef<Value *> VL, unsigned Depth) {
     return MustGather.insert(VL.begin(), VL.end());
 
   switch (Opcode) {
+  case Instruction::PHI: {
+    PHINode *PH = dyn_cast<PHINode>(VL0);
+
+    // Stop self cycles.
+    if (VisitedPHIs.count(PH))
+        return;
+
+    VisitedPHIs.insert(PH);
+    for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {
+      ValueList Operands;
+      // Prepare the operand vector.
+      for (unsigned j = 0; j < VL.size(); ++j)
+        Operands.push_back(cast<PHINode>(VL[j])->getIncomingValue(i));
+
+      getTreeUses_rec(Operands, Depth + 1);
+    }
+    return;
+  }
   case Instruction::ExtractElement: {
     VectorType *VecTy = VectorType::get(VL[0]->getType(), VL.size());
     // No need to follow ExtractElements that are going to be optimized away.
@@ -605,16 +645,17 @@ int FuncSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
   if (ScalarTy->isVectorTy())
     return FuncSLP::MAX_COST;
 
-  VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
-
   if (allConstant(VL))
     return 0;
 
+  VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
+
   if (isSplat(VL))
     return TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, 0);
 
+  int GatherCost = getGatherCost(VecTy);
   if (Depth == RecursionMaxDepth || needToGatherAny(VL))
-    return getGatherCost(VecTy);
+    return GatherCost;
 
   BasicBlock *BB = getSameBlock(VL);
   unsigned Opcode = getSameOpcode(VL);
@@ -639,6 +680,35 @@ int FuncSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
 
   Instruction *VL0 = cast<Instruction>(VL[0]);
   switch (Opcode) {
+  case Instruction::PHI: {
+    PHINode *PH = dyn_cast<PHINode>(VL0);
+
+    // Stop self cycles.
+    if (VisitedPHIs.count(PH))
+        return 0;
+
+    VisitedPHIs.insert(PH);
+    int TotalCost = 0;
+    // Calculate the cost of all of the operands.
+    for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {      
+      ValueList Operands;
+      // Prepare the operand vector.
+      for (unsigned j = 0; j < VL.size(); ++j)
+        Operands.push_back(cast<PHINode>(VL[j])->getIncomingValue(i));
+
+      int Cost = getTreeCost_rec(Operands, Depth + 1);
+      if (Cost == MAX_COST)
+        return MAX_COST;
+      TotalCost += TotalCost;
+    }
+
+    if (TotalCost > GatherCost) {
+      MustGather.insert(VL.begin(), VL.end());
+      return GatherCost;
+    }
+
+    return TotalCost;
+  }
   case Instruction::ExtractElement: {
     if (CanReuseExtract(VL, VL.size(), VecTy))
       return 0;
@@ -677,6 +747,12 @@ int FuncSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
     VectorType *SrcVecTy = VectorType::get(SrcTy, VL.size());
     int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy);
     Cost += (VecCost - ScalarCost);
+
+    if (Cost > GatherCost) {
+      MustGather.insert(VL.begin(), VL.end());
+      return GatherCost;
+    }
+
     return Cost;
   }
   case Instruction::FCmp:
@@ -720,7 +796,7 @@ int FuncSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
       int Cost = getTreeCost_rec(Operands, Depth + 1);
       if (Cost == MAX_COST)
         return MAX_COST;
-      TotalCost += TotalCost;
+      TotalCost += Cost;
     }
 
     // Calculate the cost of this instruction.
@@ -739,6 +815,12 @@ int FuncSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
       VecCost = TTI->getArithmeticInstrCost(Opcode, VecTy);
     }
     TotalCost += (VecCost - ScalarCost);
+
+    if (TotalCost > GatherCost) {
+      MustGather.insert(VL.begin(), VL.end());
+      return GatherCost;
+    }
+
     return TotalCost;
   }
   case Instruction::Load: {
@@ -751,7 +833,14 @@ int FuncSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
     int ScalarLdCost = VecTy->getNumElements() *
                        TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0);
     int VecLdCost = TTI->getMemoryOpCost(Instruction::Load, ScalarTy, 1, 0);
-    return VecLdCost - ScalarLdCost;
+    int TotalCost = VecLdCost - ScalarLdCost;
+
+    if (TotalCost > GatherCost) {
+      MustGather.insert(VL.begin(), VL.end());
+      return GatherCost;
+    }
+
+    return TotalCost;
   }
   case Instruction::Store: {
     // We know that we can merge the stores. Calculate the cost.
@@ -786,6 +875,7 @@ int FuncSLP::getTreeCost(ArrayRef<Value *> VL) {
   LaneMap.clear();
   MultiUserVals.clear();
   MustGather.clear();
+  VisitedPHIs.clear();
 
   if (!getSameBlock(VL))
     return MAX_COST;
@@ -852,8 +942,8 @@ bool FuncSLP::vectorizeStoreChain(ArrayRef<Value *> Chain, int CostThreshold) {
       vectorizeTree(Operands);
 
       // Remove the scalar stores.
-      for (int i = 0, e = VF; i < e; ++i)
-        cast<Instruction>(Operands[i])->eraseFromParent();
+      for (int j = 0, e = VF; j < e; ++j)
+        cast<Instruction>(Operands[j])->eraseFromParent();
 
       // Move to the next bundle.
       i += VF - 1;
@@ -970,6 +1060,30 @@ Value *FuncSLP::vectorizeTree_rec(ArrayRef<Value *> VL) {
   assert(Opcode == getSameOpcode(VL) && "Invalid opcode");
 
   switch (Opcode) {
+  case Instruction::PHI: {
+    PHINode *PH = dyn_cast<PHINode>(VL0);
+    Builder.SetInsertPoint(PH->getParent()->getFirstInsertionPt());
+    PHINode *NewPhi = Builder.CreatePHI(VecTy, PH->getNumIncomingValues());
+    VectorizedValues[VL0] = NewPhi;
+
+    for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {
+      ValueList Operands;
+      BasicBlock *IBB = PH->getIncomingBlock(i);
+
+      // Prepare the operand vector.
+      for (unsigned j = 0; j < VL.size(); ++j)
+        Operands.push_back(cast<PHINode>(VL[j])->getIncomingValueForBlock(IBB));
+
+      Builder.SetInsertPoint(IBB->getTerminator());
+      Value *Vec = vectorizeTree_rec(Operands);
+      NewPhi->addIncoming(Vec, IBB);
+    }
+
+    assert(NewPhi->getNumIncomingValues() == PH->getNumIncomingValues() &&
+           "Invalid number of incoming values");
+    return NewPhi;
+  }
+
   case Instruction::ExtractElement: {
     if (CanReuseExtract(VL, VL.size(), VecTy))
       return VL0->getOperand(0);
@@ -1130,6 +1244,7 @@ Value *FuncSLP::vectorizeTree(ArrayRef<Value *> VL) {
     BlocksNumbers[it].forget();
   // Clear the state.
   MustGather.clear();
+  VisitedPHIs.clear();
   VectorizedValues.clear();
   MemBarrierIgnoreList.clear();
   return V;
@@ -1143,6 +1258,8 @@ Value *FuncSLP::vectorizeArith(ArrayRef<Value *> Operands) {
   for (unsigned i = 0, e = Operands.size(); i != e; ++i) {
     Value *S = Builder.CreateExtractElement(Vec, Builder.getInt32(i));
     Operands[i]->replaceAllUsesWith(S);
+    Instruction *I = cast<Instruction>(Operands[i]);
+    I->eraseFromParent();
   }
 
   return Vec;
@@ -1165,7 +1282,7 @@ void FuncSLP::optimizeGatherSequence() {
     // Check if it has a preheader.
     BasicBlock *PreHeader = L->getLoopPreheader();
     if (!PreHeader)
-      return;
+      continue;
 
     // If the vector or the element that we insert into it are
     // instructions that are defined in this basic block then we can't
@@ -1195,17 +1312,19 @@ void FuncSLP::optimizeGatherSequence() {
       if (!Insert || !GatherSeq.count(Insert))
         continue;
 
-     // Check if we can replace this instruction with any of the
-     // visited instructions.
+      // Check if we can replace this instruction with any of the
+      // visited instructions.
       for (SmallPtrSet<Instruction*, 16>::iterator v = Visited.begin(),
            ve = Visited.end(); v != ve; ++v) {
         if (Insert->isIdenticalTo(*v) &&
-          DT->dominates((*v)->getParent(), Insert->getParent())) {
+            DT->dominates((*v)->getParent(), Insert->getParent())) {
           Insert->replaceAllUsesWith(*v);
+          Insert = 0;
           break;
         }
       }
-      Visited.insert(Insert);
+      if (Insert)
+        Visited.insert(Insert);
     }
   }
 }