From d3c965d6251e6d939f7797f8704d4e3a82f7e274 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Wed, 16 Jan 2013 21:29:55 +0000 Subject: [PATCH] Change CostTable model to be global to all targets Moving the X86CostTable to a common place, so that other back-ends can share the code. Also simplifying it a bit and commoning up tables with one and two types on operations. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@172658 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/Analysis/TargetTransformInfo.h | 53 ++++++ lib/Analysis/TargetTransformInfo.cpp | 41 +++++ lib/Target/X86/X86TargetTransformInfo.cpp | 168 ++++++++------------ 3 files changed, 160 insertions(+), 102 deletions(-) diff --git a/include/llvm/Analysis/TargetTransformInfo.h b/include/llvm/Analysis/TargetTransformInfo.h index 4f6b9b2d26d..a9793a06319 100644 --- a/include/llvm/Analysis/TargetTransformInfo.h +++ b/include/llvm/Analysis/TargetTransformInfo.h @@ -27,6 +27,7 @@ #include "llvm/IR/Type.h" #include "llvm/Pass.h" #include "llvm/Support/DataTypes.h" +#include "llvm/CodeGen/ValueTypes.h" namespace llvm { @@ -209,6 +210,58 @@ public: /// satisfy the queries. ImmutablePass *createNoTargetTransformInfoPass(); +//======================================= COST TABLES == + +/// \brief An entry in a cost table +/// +/// Use it as a static array and call the CostTable below to +/// iterate through it and find the elements you're looking for. +/// +/// Leaving Types with fixed size to avoid complications during +/// static destruction. +struct CostTableEntry { + int ISD; // instruction ID + MVT Types[2]; // Types { dest, source } + unsigned Cost; // ideal cost +}; + +/// \brief Cost table, containing one or more costs for different instructions +/// +/// This class implement the cost table lookup, to simplify +/// how targets declare their own costs. +class CostTable { + const CostTableEntry *table; + const size_t size; + const unsigned numTypes; + +protected: + /// Searches for costs on the table + unsigned _findCost(int ISD, MVT *Types) const; + + // We don't want to expose a multi-type cost table, since types are not + // sequential by nature. If you need more cost table types, implement + // them below. + CostTable(const CostTableEntry *table, const size_t size, unsigned numTypes); + +public: + /// Cost Not found while searching + static const unsigned COST_NOT_FOUND = -1; +}; + +/// Specialisation for one-type cost table +class UnaryCostTable : public CostTable { +public: + UnaryCostTable(const CostTableEntry *table, const size_t size); + unsigned findCost(int ISD, MVT Type) const; +}; + +/// Specialisation for two-type cost table +class BinaryCostTable : public CostTable { +public: + BinaryCostTable(const CostTableEntry *table, const size_t size); + unsigned findCost(int ISD, MVT Type, MVT SrcType) const; +}; + } // End llvm namespace #endif diff --git a/lib/Analysis/TargetTransformInfo.cpp b/lib/Analysis/TargetTransformInfo.cpp index 3ef74eb2d64..344be719cb5 100644 --- a/lib/Analysis/TargetTransformInfo.cpp +++ b/lib/Analysis/TargetTransformInfo.cpp @@ -286,3 +286,44 @@ char NoTTI::ID = 0; ImmutablePass *llvm::createNoTargetTransformInfoPass() { return new NoTTI(); } + +//======================================= COST TABLES == + +CostTable::CostTable(const CostTableEntry *table, const size_t size, unsigned numTypes) + : table(table), size(size), numTypes(numTypes) { + assert(table && "missing cost table"); + assert(size > 0 && "empty cost table"); +} + +unsigned CostTable::_findCost(int ISD, MVT *Types) const { + for (unsigned i = 0; i < size; ++i) { + if (table[i].ISD == ISD) { + bool found = true; + for (unsigned t=0; tInstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); - static const X86CostTblEntry AVX1CostTable[] = { - // We don't have to scalarize unsupported ops. We can issue two half-sized - // operations and we only need to extract the upper YMM half. - // Two ops + 1 extract + 1 insert = 4. - { ISD::MUL, MVT::v8i32, 4 }, - { ISD::SUB, MVT::v8i32, 4 }, - { ISD::ADD, MVT::v8i32, 4 }, - { ISD::MUL, MVT::v4i64, 4 }, - { ISD::SUB, MVT::v4i64, 4 }, - { ISD::ADD, MVT::v4i64, 4 }, - }; + // We don't have to scalarize unsupported ops. We can issue two half-sized + // operations and we only need to extract the upper YMM half. + // Two ops + 1 extract + 1 insert = 4. + static const CostTableEntry AVX1CostTable[] = { + { ISD::MUL, { MVT::v8i32 }, 4 }, + { ISD::SUB, { MVT::v8i32 }, 4 }, + { ISD::ADD, { MVT::v8i32 }, 4 }, + { ISD::MUL, { MVT::v4i64 }, 4 }, + { ISD::SUB, { MVT::v4i64 }, 4 }, + { ISD::ADD, { MVT::v4i64 }, 4 }, + }; + UnaryCostTable costTable (AVX1CostTable, array_lengthof(AVX1CostTable)); // Look for AVX1 lowering tricks. if (ST->hasAVX()) { - int Idx = FindInTable(AVX1CostTable, array_lengthof(AVX1CostTable), ISD, - LT.second); - if (Idx != -1) - return LT.first * AVX1CostTable[Idx].Cost; + unsigned cost = costTable.findCost(ISD, LT.second); + if (cost != BinaryCostTable::COST_NOT_FOUND) + return LT.first * cost; } // Fallback to the default implementation. return TargetTransformInfo::getArithmeticInstrCost(Opcode, Ty); @@ -254,30 +216,29 @@ unsigned X86TTI::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src) const { if (!SrcTy.isSimple() || !DstTy.isSimple()) return TargetTransformInfo::getCastInstrCost(Opcode, Dst, Src); - static const X86TypeConversionCostTblEntry AVXConversionTbl[] = { - { ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i16, 1 }, - { ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i16, 1 }, - { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i32, 1 }, - { ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i32, 1 }, - { ISD::TRUNCATE, MVT::v4i32, MVT::v4i64, 1 }, - { ISD::TRUNCATE, MVT::v8i16, MVT::v8i32, 1 }, - { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8, 1 }, - { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8, 1 }, - { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8, 1 }, - { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8, 1 }, - { ISD::FP_TO_SINT, MVT::v8i8, MVT::v8f32, 1 }, - { ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f32, 1 }, - { ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i1, 6 }, - { ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i1, 9 }, - { ISD::TRUNCATE, MVT::v8i32, MVT::v8i64, 3 }, + static const CostTableEntry AVXConversionTbl[] = { + { ISD::SIGN_EXTEND, { MVT::v8i32, MVT::v8i16 }, 1 }, + { ISD::ZERO_EXTEND, { MVT::v8i32, MVT::v8i16 }, 1 }, + { ISD::SIGN_EXTEND, { MVT::v4i64, MVT::v4i32 }, 1 }, + { ISD::ZERO_EXTEND, { MVT::v4i64, MVT::v4i32 }, 1 }, + { ISD::TRUNCATE, { MVT::v4i32, MVT::v4i64 }, 1 }, + { ISD::TRUNCATE, { MVT::v8i16, MVT::v8i32 }, 1 }, + { ISD::SINT_TO_FP, { MVT::v8f32, MVT::v8i8 }, 1 }, + { ISD::SINT_TO_FP, { MVT::v4f32, MVT::v4i8 }, 1 }, + { ISD::UINT_TO_FP, { MVT::v8f32, MVT::v8i8 }, 1 }, + { ISD::UINT_TO_FP, { MVT::v4f32, MVT::v4i8 }, 1 }, + { ISD::FP_TO_SINT, { MVT::v8i8, MVT::v8f32 }, 1 }, + { ISD::FP_TO_SINT, { MVT::v4i8, MVT::v4f32 }, 1 }, + { ISD::ZERO_EXTEND, { MVT::v8i32, MVT::v8i1 }, 6 }, + { ISD::SIGN_EXTEND, { MVT::v8i32, MVT::v8i1 }, 9 }, + { ISD::TRUNCATE, { MVT::v8i32, MVT::v8i64 }, 3 } }; + BinaryCostTable costTable (AVXConversionTbl, array_lengthof(AVXConversionTbl)); if (ST->hasAVX()) { - int Idx = FindInConvertTable(AVXConversionTbl, - array_lengthof(AVXConversionTbl), - ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()); - if (Idx != -1) - return AVXConversionTbl[Idx].Cost; + unsigned cost = costTable.findCost(ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()); + if (cost != BinaryCostTable::COST_NOT_FOUND) + return cost; } return TargetTransformInfo::getCastInstrCost(Opcode, Dst, Src); @@ -293,48 +254,51 @@ unsigned X86TTI::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); - static const X86CostTblEntry SSE42CostTbl[] = { - { ISD::SETCC, MVT::v2f64, 1 }, - { ISD::SETCC, MVT::v4f32, 1 }, - { ISD::SETCC, MVT::v2i64, 1 }, - { ISD::SETCC, MVT::v4i32, 1 }, - { ISD::SETCC, MVT::v8i16, 1 }, - { ISD::SETCC, MVT::v16i8, 1 }, + static const CostTableEntry SSE42CostTbl[] = { + { ISD::SETCC, { MVT::v2f64 }, 1 }, + { ISD::SETCC, { MVT::v4f32 }, 1 }, + { ISD::SETCC, { MVT::v2i64 }, 1 }, + { ISD::SETCC, { MVT::v4i32 }, 1 }, + { ISD::SETCC, { MVT::v8i16 }, 1 }, + { ISD::SETCC, { MVT::v16i8 }, 1 }, }; + UnaryCostTable costTableSSE4 (SSE42CostTbl, array_lengthof(SSE42CostTbl)); - static const X86CostTblEntry AVX1CostTbl[] = { - { ISD::SETCC, MVT::v4f64, 1 }, - { ISD::SETCC, MVT::v8f32, 1 }, + static const CostTableEntry AVX1CostTbl[] = { + { ISD::SETCC, { MVT::v4f64 }, 1 }, + { ISD::SETCC, { MVT::v8f32 }, 1 }, // AVX1 does not support 8-wide integer compare. - { ISD::SETCC, MVT::v4i64, 4 }, - { ISD::SETCC, MVT::v8i32, 4 }, - { ISD::SETCC, MVT::v16i16, 4 }, - { ISD::SETCC, MVT::v32i8, 4 }, + { ISD::SETCC, { MVT::v4i64 }, 4 }, + { ISD::SETCC, { MVT::v8i32 }, 4 }, + { ISD::SETCC, { MVT::v16i16 }, 4 }, + { ISD::SETCC, { MVT::v32i8 }, 4 }, }; + UnaryCostTable costTableAVX1 (AVX1CostTbl, array_lengthof(AVX1CostTbl)); - static const X86CostTblEntry AVX2CostTbl[] = { - { ISD::SETCC, MVT::v4i64, 1 }, - { ISD::SETCC, MVT::v8i32, 1 }, - { ISD::SETCC, MVT::v16i16, 1 }, - { ISD::SETCC, MVT::v32i8, 1 }, + static const CostTableEntry AVX2CostTbl[] = { + { ISD::SETCC, { MVT::v4i64 }, 1 }, + { ISD::SETCC, { MVT::v8i32 }, 1 }, + { ISD::SETCC, { MVT::v16i16 }, 1 }, + { ISD::SETCC, { MVT::v32i8 }, 1 }, }; + UnaryCostTable costTableAVX2 (AVX2CostTbl, array_lengthof(AVX2CostTbl)); if (ST->hasAVX2()) { - int Idx = FindInTable(AVX2CostTbl, array_lengthof(AVX2CostTbl), ISD, MTy); - if (Idx != -1) - return LT.first * AVX2CostTbl[Idx].Cost; + unsigned cost = costTableAVX2.findCost(ISD, MTy); + if (cost != BinaryCostTable::COST_NOT_FOUND) + return LT.first * cost; } if (ST->hasAVX()) { - int Idx = FindInTable(AVX1CostTbl, array_lengthof(AVX1CostTbl), ISD, MTy); - if (Idx != -1) - return LT.first * AVX1CostTbl[Idx].Cost; + unsigned cost = costTableAVX1.findCost(ISD, MTy); + if (cost != BinaryCostTable::COST_NOT_FOUND) + return LT.first * cost; } if (ST->hasSSE42()) { - int Idx = FindInTable(SSE42CostTbl, array_lengthof(SSE42CostTbl), ISD, MTy); - if (Idx != -1) - return LT.first * SSE42CostTbl[Idx].Cost; + unsigned cost = costTableSSE4.findCost(ISD, MTy); + if (cost != BinaryCostTable::COST_NOT_FOUND) + return LT.first * cost; } return TargetTransformInfo::getCmpSelInstrCost(Opcode, ValTy, CondTy); -- 2.34.1