Reapply r219832 - InstCombine: Narrow switch instructions using known bits.
[oota-llvm.git] / lib / Transforms / InstCombine / InstructionCombining.cpp
index 3ae9f0ddce8cfec3f363be0c8fe155061b8ba9b9..8d74976cb18b33f4407fad9f86b80933623edfaa 100644 (file)
@@ -46,6 +46,7 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/PatternMatch.h"
@@ -69,10 +70,11 @@ STATISTIC(NumExpand,    "Number of expansions");
 STATISTIC(NumFactor   , "Number of factorizations");
 STATISTIC(NumReassoc  , "Number of reassociations");
 
-static cl::opt<bool> UnsafeFPShrink("enable-double-float-shrink", cl::Hidden,
-                                   cl::init(false),
-                                   cl::desc("Enable unsafe double to float "
-                                            "shrinking for math lib calls"));
+static cl::opt<bool>
+    EnableUnsafeFPShrink("enable-double-float-shrink", cl::Hidden,
+                         cl::init(false),
+                         cl::desc("Enable unsafe double to float "
+                                  "shrinking for math lib calls"));
 
 // Initialization Routines
 void llvm::initializeInstCombine(PassRegistry &Registry) {
@@ -1317,7 +1319,7 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) {
 Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
   SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end());
 
-  if (Value *V = SimplifyGEPInst(Ops, DL))
+  if (Value *V = SimplifyGEPInst(Ops, DL, TLI, DT, AT))
     return ReplaceInstUsesWith(GEP, V);
 
   Value *PtrOp = GEP.getOperand(0);
@@ -2003,7 +2005,25 @@ Instruction *InstCombiner::visitFree(CallInst &FI) {
   return nullptr;
 }
 
+Instruction *InstCombiner::visitReturnInst(ReturnInst &RI) {
+  if (RI.getNumOperands() == 0) // ret void
+    return nullptr;
+
+  Value *ResultOp = RI.getOperand(0);
+  Type *VTy = ResultOp->getType();
+  if (!VTy->isIntegerTy())
+    return nullptr;
 
+  // There might be assume intrinsics dominating this return that completely
+  // determine the value. If so, constant fold it.
+  unsigned BitWidth = VTy->getPrimitiveSizeInBits();
+  APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
+  computeKnownBits(ResultOp, KnownZero, KnownOne, 0, &RI);
+  if ((KnownZero|KnownOne).isAllOnesValue())
+    RI.setOperand(0, Constant::getIntegerValue(VTy, KnownOne));
+
+  return nullptr;
+}
 
 Instruction *InstCombiner::visitBranchInst(BranchInst &BI) {
   // Change br (not X), label True, label False to: br X, label False, True
@@ -2055,6 +2075,37 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) {
 
 Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) {
   Value *Cond = SI.getCondition();
+  unsigned BitWidth = cast<IntegerType>(Cond->getType())->getBitWidth();
+  APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
+  computeKnownBits(Cond, KnownZero, KnownOne);
+  unsigned LeadingKnownZeros = KnownZero.countLeadingOnes();
+  unsigned LeadingKnownOnes = KnownOne.countLeadingOnes();
+
+  // Compute the number of leading bits we can ignore.
+  for (auto &C : SI.cases()) {
+    LeadingKnownZeros = std::min(
+        LeadingKnownZeros, C.getCaseValue()->getValue().countLeadingZeros());
+    LeadingKnownOnes = std::min(
+        LeadingKnownOnes, C.getCaseValue()->getValue().countLeadingOnes());
+  }
+
+  unsigned NewWidth = BitWidth - std::max(LeadingKnownZeros, LeadingKnownOnes);
+
+  // Truncate the condition operand if the new type is equal to or larger than
+  // the largest legal integer type. We need to be conservative here since
+  // x86 generates redundant zero-extenstion instructions if the operand is
+  // truncated to i8 or i16.
+  if (BitWidth > NewWidth && NewWidth >= DL->getLargestLegalIntTypeSize()) {
+    IntegerType *Ty = IntegerType::get(SI.getContext(), NewWidth);
+    Builder->SetInsertPoint(&SI);
+    Value *NewCond = Builder->CreateTrunc(SI.getCondition(), Ty, "trunc");
+    SI.setCondition(NewCond);
+
+    for (auto &C : SI.cases())
+      static_cast<SwitchInst::CaseIt *>(&C)->setValue(ConstantInt::get(
+          SI.getContext(), C.getCaseValue()->getValue().trunc(NewWidth)));
+  }
+
   if (Instruction *I = dyn_cast<Instruction>(Cond)) {
     if (I->getOpcode() == Instruction::Add)
       if (ConstantInt *AddRHS = dyn_cast<ConstantInt>(I->getOperand(1))) {
@@ -2888,13 +2939,13 @@ bool InstCombiner::DoOneIteration(Function &F, unsigned Iteration) {
 }
 
 namespace {
-class InstCombinerLibCallSimplifier : public LibCallSimplifier {
+class InstCombinerLibCallSimplifier final : public LibCallSimplifier {
   InstCombiner *IC;
 public:
   InstCombinerLibCallSimplifier(const DataLayout *DL,
                                 const TargetLibraryInfo *TLI,
                                 InstCombiner *IC)
-    : LibCallSimplifier(DL, TLI, UnsafeFPShrink) {
+    : LibCallSimplifier(DL, TLI, EnableUnsafeFPShrink) {
     this->IC = IC;
   }
 
@@ -2914,6 +2965,11 @@ bool InstCombiner::runOnFunction(Function &F) {
   DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
   DL = DLP ? &DLP->getDataLayout() : nullptr;
   TLI = &getAnalysis<TargetLibraryInfo>();
+
+  DominatorTreeWrapperPass *DTWP =
+      getAnalysisIfAvailable<DominatorTreeWrapperPass>();
+  DT = DTWP ? &DTWP->getDomTree() : nullptr;
+
   // Minimizing size?
   MinimizeSize = F.getAttributes().hasAttribute(AttributeSet::FunctionIndex,
                                                 Attribute::MinSize);