Remove templates from CostTableLookup functions. All instantiations had the same...
[oota-llvm.git] / lib / Target / AArch64 / AArch64TargetTransformInfo.cpp
index a51a0674c8f9cfd339fb2b8d9253a11c4cec4921..ab17bb810d40b774d0d218d6fc70e0a35bb57b68 100644 (file)
@@ -23,7 +23,7 @@ using namespace llvm;
 /// \brief Calculate the cost of materializing a 64-bit value. This helper
 /// method might only calculate a fraction of a larger immediate. Therefore it
 /// is valid to return a cost of ZERO.
-unsigned AArch64TTIImpl::getIntImmCost(int64_t Val) {
+int AArch64TTIImpl::getIntImmCost(int64_t Val) {
   // Check if the immediate can be encoded within an instruction.
   if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
     return 0;
@@ -37,7 +37,7 @@ unsigned AArch64TTIImpl::getIntImmCost(int64_t Val) {
 }
 
 /// \brief Calculate the cost of materializing the given constant.
-unsigned AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty) {
+int AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty) {
   assert(Ty->isIntegerTy());
 
   unsigned BitSize = Ty->getPrimitiveSizeInBits();
@@ -51,18 +51,18 @@ unsigned AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty) {
 
   // Split the constant into 64-bit chunks and calculate the cost for each
   // chunk.
-  unsigned Cost = 0;
+  int Cost = 0;
   for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
     APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
     int64_t Val = Tmp.getSExtValue();
     Cost += getIntImmCost(Val);
   }
   // We need at least one instruction to materialze the constant.
-  return std::max(1U, Cost);
+  return std::max(1, Cost);
 }
 
-unsigned AArch64TTIImpl::getIntImmCost(unsigned Opcode, unsigned Idx,
-                                       const APInt &Imm, Type *Ty) {
+int AArch64TTIImpl::getIntImmCost(unsigned Opcode, unsigned Idx,
+                                  const APInt &Imm, Type *Ty) {
   assert(Ty->isIntegerTy());
 
   unsigned BitSize = Ty->getPrimitiveSizeInBits();
@@ -118,17 +118,17 @@ unsigned AArch64TTIImpl::getIntImmCost(unsigned Opcode, unsigned Idx,
   }
 
   if (Idx == ImmIdx) {
-    unsigned NumConstants = (BitSize + 63) / 64;
-    unsigned Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty);
+    int NumConstants = (BitSize + 63) / 64;
+    int Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty);
     return (Cost <= NumConstants * TTI::TCC_Basic)
-               ? static_cast<unsigned>(TTI::TCC_Free)
+               ? static_cast<int>(TTI::TCC_Free)
                : Cost;
   }
   return AArch64TTIImpl::getIntImmCost(Imm, Ty);
 }
 
-unsigned AArch64TTIImpl::getIntImmCost(Intrinsic::ID IID, unsigned Idx,
-                                       const APInt &Imm, Type *Ty) {
+int AArch64TTIImpl::getIntImmCost(Intrinsic::ID IID, unsigned Idx,
+                                  const APInt &Imm, Type *Ty) {
   assert(Ty->isIntegerTy());
 
   unsigned BitSize = Ty->getPrimitiveSizeInBits();
@@ -147,10 +147,10 @@ unsigned AArch64TTIImpl::getIntImmCost(Intrinsic::ID IID, unsigned Idx,
   case Intrinsic::smul_with_overflow:
   case Intrinsic::umul_with_overflow:
     if (Idx == 1) {
-      unsigned NumConstants = (BitSize + 63) / 64;
-      unsigned Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty);
+      int NumConstants = (BitSize + 63) / 64;
+      int Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty);
       return (Cost <= NumConstants * TTI::TCC_Basic)
-                 ? static_cast<unsigned>(TTI::TCC_Free)
+                 ? static_cast<int>(TTI::TCC_Free)
                  : Cost;
     }
     break;
@@ -176,18 +176,40 @@ AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) {
   return TTI::PSK_Software;
 }
 
-unsigned AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
-                                          Type *Src) {
+int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) {
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
   assert(ISD && "Invalid opcode");
 
-  EVT SrcTy = TLI->getValueType(Src);
-  EVT DstTy = TLI->getValueType(Dst);
+  EVT SrcTy = TLI->getValueType(DL, Src);
+  EVT DstTy = TLI->getValueType(DL, Dst);
 
   if (!SrcTy.isSimple() || !DstTy.isSimple())
     return BaseT::getCastInstrCost(Opcode, Dst, Src);
 
-  static const TypeConversionCostTblEntry<MVT> ConversionTbl[] = {
+  static const TypeConversionCostTblEntry
+  ConversionTbl[] = {
+    { ISD::SIGN_EXTEND, MVT::v4i32, MVT::v4i16, 0 },
+    { ISD::ZERO_EXTEND, MVT::v4i32, MVT::v4i16, 0 },
+    { ISD::SIGN_EXTEND, MVT::v2i64, MVT::v2i32, 1 },
+    { ISD::ZERO_EXTEND, MVT::v2i64, MVT::v2i32, 1 },
+    { ISD::TRUNCATE,    MVT::v4i32, MVT::v4i64, 0 },
+    { ISD::TRUNCATE,    MVT::v4i16, MVT::v4i32, 1 },
+
+    // The number of shll instructions for the extension.
+    { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i16, 3 },
+    { ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i16, 3 },
+    { ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i8, 3 },
+    { ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i8, 3 },
+    { ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i8, 7 },
+    { ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i8, 7 },
+    { ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i16, 6 },
+    { ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i16, 6 },
+    { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
+    { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
+
+    { ISD::TRUNCATE,    MVT::v16i8, MVT::v16i32, 6 },
+    { ISD::TRUNCATE,    MVT::v8i8, MVT::v8i32, 3 },
+
     // LowerVectorINT_TO_FP:
     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
@@ -210,6 +232,16 @@ unsigned AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8,  3 },
     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
 
+    // Complex: to v8f32
+    { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
+    { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
+    { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
+    { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
+
+    // Complex: to v16f32
+    { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
+    { ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
+
     // Complex: to v2f64
     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
@@ -250,22 +282,21 @@ unsigned AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f64, 2 },
   };
 
-  int Idx = ConvertCostTableLookup<MVT>(
-      ConversionTbl, array_lengthof(ConversionTbl), ISD, DstTy.getSimpleVT(),
-      SrcTy.getSimpleVT());
-  if (Idx != -1)
-    return ConversionTbl[Idx].Cost;
+  if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
+                                                 DstTy.getSimpleVT(),
+                                                 SrcTy.getSimpleVT()))
+    return Entry->Cost;
 
   return BaseT::getCastInstrCost(Opcode, Dst, Src);
 }
 
-unsigned AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
-                                            unsigned Index) {
+int AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
+                                       unsigned Index) {
   assert(Val->isVectorTy() && "This must be a vector type");
 
   if (Index != -1U) {
     // Legalize the type.
-    std::pair<unsigned, MVT> LT = TLI->getTypeLegalizationCost(Val);
+    std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Val);
 
     // This type is legalized to a scalar type.
     if (!LT.second.isVector())
@@ -281,15 +312,15 @@ unsigned AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
   }
 
   // All other insert/extracts cost this much.
-  return 2;
+  return 3;
 }
 
-unsigned AArch64TTIImpl::getArithmeticInstrCost(
+int AArch64TTIImpl::getArithmeticInstrCost(
     unsigned Opcode, Type *Ty, TTI::OperandValueKind Opd1Info,
     TTI::OperandValueKind Opd2Info, TTI::OperandValueProperties Opd1PropInfo,
     TTI::OperandValueProperties Opd2PropInfo) {
   // Legalize the type.
-  std::pair<unsigned, MVT> LT = TLI->getTypeLegalizationCost(Ty);
+  std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Ty);
 
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
 
@@ -300,10 +331,9 @@ unsigned AArch64TTIImpl::getArithmeticInstrCost(
     // normally expanded to the sequence ADD + CMP + SELECT + SRA.
     // The OperandValue properties many not be same as that of previous
     // operation; conservatively assume OP_None.
-    unsigned Cost =
-      getArithmeticInstrCost(Instruction::Add, Ty, Opd1Info, Opd2Info,
-                             TargetTransformInfo::OP_None,
-                             TargetTransformInfo::OP_None);
+    int Cost = getArithmeticInstrCost(Instruction::Add, Ty, Opd1Info, Opd2Info,
+                                      TargetTransformInfo::OP_None,
+                                      TargetTransformInfo::OP_None);
     Cost += getArithmeticInstrCost(Instruction::Sub, Ty, Opd1Info, Opd2Info,
                                    TargetTransformInfo::OP_None,
                                    TargetTransformInfo::OP_None);
@@ -331,7 +361,7 @@ unsigned AArch64TTIImpl::getArithmeticInstrCost(
   }
 }
 
-unsigned AArch64TTIImpl::getAddressComputationCost(Type *Ty, bool IsComplex) {
+int AArch64TTIImpl::getAddressComputationCost(Type *Ty, bool IsComplex) {
   // Address computations in vectorized code with non-consecutive addresses will
   // likely result in more instructions compared to scalar code where the
   // computation can more often be merged into the index mode. The resulting
@@ -346,41 +376,40 @@ unsigned AArch64TTIImpl::getAddressComputationCost(Type *Ty, bool IsComplex) {
   return 1;
 }
 
-unsigned AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
-                                            Type *CondTy) {
+int AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
+                                       Type *CondTy) {
 
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
-  // We don't lower vector selects well that are wider than the register width.
+  // We don't lower some vector selects well that are wider than the register
+  // width.
   if (ValTy->isVectorTy() && ISD == ISD::SELECT) {
     // We would need this many instructions to hide the scalarization happening.
-    const unsigned AmortizationCost = 20;
-    static const TypeConversionCostTblEntry<MVT::SimpleValueType>
+    const int AmortizationCost = 20;
+    static const TypeConversionCostTblEntry
     VectorSelectTbl[] = {
-      { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 * AmortizationCost },
-      { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 * AmortizationCost },
-      { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 * AmortizationCost },
+      { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
+      { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
+      { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
       { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
       { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
       { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
     };
 
-    EVT SelCondTy = TLI->getValueType(CondTy);
-    EVT SelValTy = TLI->getValueType(ValTy);
+    EVT SelCondTy = TLI->getValueType(DL, CondTy);
+    EVT SelValTy = TLI->getValueType(DL, ValTy);
     if (SelCondTy.isSimple() && SelValTy.isSimple()) {
-      int Idx =
-          ConvertCostTableLookup(VectorSelectTbl, ISD, SelCondTy.getSimpleVT(),
-                                 SelValTy.getSimpleVT());
-      if (Idx != -1)
-        return VectorSelectTbl[Idx].Cost;
+      if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
+                                                     SelCondTy.getSimpleVT(),
+                                                     SelValTy.getSimpleVT()))
+        return Entry->Cost;
     }
   }
   return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy);
 }
 
-unsigned AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
-                                         unsigned Alignment,
-                                         unsigned AddressSpace) {
-  std::pair<unsigned, MVT> LT = TLI->getTypeLegalizationCost(Src);
+int AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
+                                    unsigned Alignment, unsigned AddressSpace) {
+  std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, Src);
 
   if (Opcode == Instruction::Store && Src->isVectorTy() && Alignment != 16 &&
       Src->getVectorElementType()->isIntegerTy(64)) {
@@ -389,7 +418,7 @@ unsigned AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
     // practice on inlined memcpy code.
     // We make v2i64 stores expensive so that we will only vectorize if there
     // are 6 other instructions getting vectorized.
-    unsigned AmortizationCost = 6;
+    int AmortizationCost = 6;
 
     return LT.first * 2 * AmortizationCost;
   }
@@ -407,20 +436,30 @@ unsigned AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src,
   return LT.first;
 }
 
-unsigned AArch64TTIImpl::getInterleavedMemoryOpCost(
-    unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
-    unsigned Alignment, unsigned AddressSpace) {
-  assert(isa<VectorType>(VecTy) && "Expect vector types");
-
-  if (Factor > 1 && Factor < 5 && isTypeLegal(VecTy))
-    return Factor;
+int AArch64TTIImpl::getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy,
+                                               unsigned Factor,
+                                               ArrayRef<unsigned> Indices,
+                                               unsigned Alignment,
+                                               unsigned AddressSpace) {
+  assert(Factor >= 2 && "Invalid interleave factor");
+  assert(isa<VectorType>(VecTy) && "Expect a vector type");
+
+  if (Factor <= TLI->getMaxSupportedInterleaveFactor()) {
+    unsigned NumElts = VecTy->getVectorNumElements();
+    Type *SubVecTy = VectorType::get(VecTy->getScalarType(), NumElts / Factor);
+    unsigned SubVecSize = DL.getTypeAllocSizeInBits(SubVecTy);
+
+    // ldN/stN only support legal vector types of size 64 or 128 in bits.
+    if (NumElts % Factor == 0 && (SubVecSize == 64 || SubVecSize == 128))
+      return Factor;
+  }
 
   return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
                                            Alignment, AddressSpace);
 }
 
-unsigned AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
-  unsigned Cost = 0;
+int AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
+  int Cost = 0;
   for (auto *I : Tys) {
     if (!I->isVectorTy())
       continue;