Correct cost model for vector shift on AVX2
[oota-llvm.git] / lib / Target / X86 / X86TargetTransformInfo.cpp
index 777ef508ec3390675e4e38879e30f2b1818cd814..3e3b86edbb082019037e8a998b81861cf18a41b1 100644 (file)
@@ -169,6 +169,29 @@ unsigned X86TTI::getArithmeticInstrCost(unsigned Opcode, Type *Ty) const {
   int ISD = TLI->InstructionOpcodeToISD(Opcode);
   assert(ISD && "Invalid opcode");
 
+  static const CostTblEntry<MVT> AVX2CostTable[] = {
+    // Shifts on v4i64/v8i32 on AVX2 is legal even though we declare to
+    // customize them to detect the cases where shift amount is a scalar one.
+    { ISD::SHL,     MVT::v4i32,    1 },
+    { ISD::SRL,     MVT::v4i32,    1 },
+    { ISD::SRA,     MVT::v4i32,    1 },
+    { ISD::SHL,     MVT::v8i32,    1 },
+    { ISD::SRL,     MVT::v8i32,    1 },
+    { ISD::SRA,     MVT::v8i32,    1 },
+    { ISD::SHL,     MVT::v2i64,    1 },
+    { ISD::SRL,     MVT::v2i64,    1 },
+    { ISD::SHL,     MVT::v4i64,    1 },
+    { ISD::SRL,     MVT::v4i64,    1 },
+  };
+
+  // Look for AVX2 lowering tricks.
+  if (ST->hasAVX2()) {
+    int Idx = CostTableLookup<MVT>(AVX2CostTable, array_lengthof(AVX2CostTable),
+                                   ISD, LT.second);
+    if (Idx != -1)
+      return LT.first * AVX2CostTable[Idx].Cost;
+  }
+
   static const CostTblEntry<MVT> AVX1CostTable[] = {
     // We don't have to scalarize unsupported ops. We can issue two half-sized
     // operations and we only need to extract the upper YMM half.