SLPVectorize: Add support for vectorization of casts.
[oota-llvm.git] / lib / Transforms / Vectorize / VecUtils.cpp
index 7701d080ffd859ee882a3a303d77664b9599e0e0..9b9436683b12aea176a94b28b1749321e679b7b3 100644 (file)
@@ -328,6 +328,18 @@ void BoUpSLP::getTreeUses_rec(ArrayRef<Value *> VL, unsigned Depth) {
   }
 
   switch (Opcode) {
+    case Instruction::ZExt:
+    case Instruction::SExt:
+    case Instruction::FPToUI:
+    case Instruction::FPToSI:
+    case Instruction::FPExt:
+    case Instruction::PtrToInt:
+    case Instruction::IntToPtr:
+    case Instruction::SIToFP:
+    case Instruction::UIToFP:
+    case Instruction::Trunc:
+    case Instruction::FPTrunc:
+    case Instruction::BitCast:
     case Instruction::Add:
     case Instruction::FAdd:
     case Instruction::Sub:
@@ -445,6 +457,41 @@ int BoUpSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
   }
 
   switch (Opcode) {
+  case Instruction::ZExt:
+  case Instruction::SExt:
+  case Instruction::FPToUI:
+  case Instruction::FPToSI:
+  case Instruction::FPExt:
+  case Instruction::PtrToInt:
+  case Instruction::IntToPtr:
+  case Instruction::SIToFP:
+  case Instruction::UIToFP:
+  case Instruction::Trunc:
+  case Instruction::FPTrunc:
+  case Instruction::BitCast: {
+    int Cost = 0;
+    ValueList Operands;
+    Type *SrcTy = VL0->getOperand(0)->getType();
+    // Prepare the operand vector.
+    for (unsigned j = 0; j < VL.size(); ++j) {
+      Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
+      // Check that the casted type is the same for all users.
+      if (cast<Instruction>(VL[j])->getOperand(0)->getType() != SrcTy)
+        return getScalarizationCost(VecTy);
+    }
+
+    Cost += getTreeCost_rec(Operands, Depth+1);
+    if (Cost >= max_cost) return max_cost;
+
+    // Calculate the cost of this instruction.
+    int ScalarCost = VL.size() * TTI->getCastInstrCost(VL0->getOpcode(),
+                                                       VL0->getType(), SrcTy);
+
+    VectorType *SrcVecTy = VectorType::get(SrcTy, VL.size());
+    int VecCost = TTI->getCastInstrCost(VL0->getOpcode(), VecTy, SrcVecTy);
+    Cost += (VecCost - ScalarCost);
+    return Cost;
+  }
   case Instruction::Add:
   case Instruction::FAdd:
   case Instruction::Sub:
@@ -583,6 +630,28 @@ Value *BoUpSLP::vectorizeTree_rec(ArrayRef<Value *> VL, int VF) {
   }
 
   switch (Opcode) {
+  case Instruction::ZExt:
+  case Instruction::SExt:
+  case Instruction::FPToUI:
+  case Instruction::FPToSI:
+  case Instruction::FPExt:
+  case Instruction::PtrToInt:
+  case Instruction::IntToPtr:
+  case Instruction::SIToFP:
+  case Instruction::UIToFP:
+  case Instruction::Trunc:
+  case Instruction::FPTrunc:
+  case Instruction::BitCast: {
+    ValueList INVL;
+    for (int i = 0; i < VF; ++i)
+      INVL.push_back(cast<Instruction>(VL[i])->getOperand(0));
+    Value *InVec = vectorizeTree_rec(INVL, VF);
+    IRBuilder<> Builder(GetLastInstr(VL, VF));
+    CastInst *CI = dyn_cast<CastInst>(VL0);
+    Value *V = Builder.CreateCast(CI->getOpcode(), InVec, VecTy);
+    VectorizedValues[VL0] = V;
+    return V;
+  }
   case Instruction::Add:
   case Instruction::FAdd:
   case Instruction::Sub: