Revert "Revert "Add Constant Hoisting Pass" (r200034)"
[oota-llvm.git] / lib / Target / X86 / X86TargetTransformInfo.cpp
index da2b021da932a54ca295f6ad9ba154802360947d..781be2fddd9556e16a423ef912465bc8c84b7a8d 100644 (file)
@@ -18,6 +18,7 @@
 #include "X86.h"
 #include "X86TargetMachine.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Target/CostTable.h"
 #include "llvm/Target/TargetLowering.h"
@@ -107,6 +108,14 @@ public:
   virtual unsigned getReductionCost(unsigned Opcode, Type *Ty,
                                     bool IsPairwiseForm) const LLVM_OVERRIDE;
 
+  virtual unsigned getIntImmCost(const APInt &Imm,
+                                 Type *Ty) const LLVM_OVERRIDE;
+
+  virtual unsigned getIntImmCost(unsigned Opcode, const APInt &Imm,
+                                 Type *Ty) const LLVM_OVERRIDE;
+  virtual unsigned getIntImmCost(Intrinsic::ID IID, const APInt &Imm,
+                                 Type *Ty) const LLVM_OVERRIDE;
+
   /// @}
 };
 
@@ -694,3 +703,89 @@ unsigned X86TTI::getReductionCost(unsigned Opcode, Type *ValTy,
   return TargetTransformInfo::getReductionCost(Opcode, ValTy, IsPairwise);
 }
 
+unsigned X86TTI::getIntImmCost(const APInt &Imm, Type *Ty) const {
+  assert(Ty->isIntegerTy());
+
+  unsigned BitSize = Ty->getPrimitiveSizeInBits();
+  if (BitSize == 0)
+    return ~0U;
+
+  if (Imm.getBitWidth() <= 64 &&
+      (isInt<32>(Imm.getSExtValue()) || isUInt<32>(Imm.getZExtValue())))
+    return TCC_Basic;
+  else
+    return 2 * TCC_Basic;
+}
+
+unsigned X86TTI::getIntImmCost(unsigned Opcode, const APInt &Imm,
+                               Type *Ty) const {
+  assert(Ty->isIntegerTy());
+
+  unsigned BitSize = Ty->getPrimitiveSizeInBits();
+  if (BitSize == 0)
+    return ~0U;
+
+  switch (Opcode) {
+  case Instruction::Add:
+  case Instruction::Sub:
+  case Instruction::Mul:
+  case Instruction::UDiv:
+  case Instruction::SDiv:
+  case Instruction::URem:
+  case Instruction::SRem:
+  case Instruction::Shl:
+  case Instruction::LShr:
+  case Instruction::AShr:
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor:
+  case Instruction::ICmp:
+    if (Imm.getBitWidth() <= 64 && isInt<32>(Imm.getSExtValue()))
+      return TCC_Free;
+    else
+      return X86TTI::getIntImmCost(Imm, Ty);
+  case Instruction::Trunc:
+  case Instruction::ZExt:
+  case Instruction::SExt:
+  case Instruction::IntToPtr:
+  case Instruction::PtrToInt:
+  case Instruction::BitCast:
+  case Instruction::Call:
+  case Instruction::Select:
+  case Instruction::Ret:
+  case Instruction::Load:
+  case Instruction::Store:
+    return X86TTI::getIntImmCost(Imm, Ty);
+  }
+  return TargetTransformInfo::getIntImmCost(Opcode, Imm, Ty);
+}
+
+unsigned X86TTI::getIntImmCost(Intrinsic::ID IID, const APInt &Imm,
+                               Type *Ty) const {
+  assert(Ty->isIntegerTy());
+
+  unsigned BitSize = Ty->getPrimitiveSizeInBits();
+  if (BitSize == 0)
+    return ~0U;
+
+  switch (IID) {
+  default: return TargetTransformInfo::getIntImmCost(IID, Imm, Ty);
+  case Intrinsic::sadd_with_overflow:
+  case Intrinsic::uadd_with_overflow:
+  case Intrinsic::ssub_with_overflow:
+  case Intrinsic::usub_with_overflow:
+  case Intrinsic::smul_with_overflow:
+  case Intrinsic::umul_with_overflow:
+    if (Imm.getBitWidth() <= 64 && isInt<32>(Imm.getSExtValue()))
+      return TCC_Free;
+    else
+      return X86TTI::getIntImmCost(Imm, Ty);
+  case Intrinsic::experimental_stackmap:
+  case Intrinsic::experimental_patchpoint_void:
+  case Intrinsic::experimental_patchpoint_i64:
+    if (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue()))
+      return TCC_Free;
+    else
+      return X86TTI::getIntImmCost(Imm, Ty);
+  }
+}