Reapply r219832 - InstCombine: Narrow switch instructions using known bits.
[oota-llvm.git] / lib / Transforms / InstCombine / InstructionCombining.cpp
index 08e24461a610fdfa1086eb2ebff666cf6d225fea..8d74976cb18b33f4407fad9f86b80933623edfaa 100644 (file)
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/StringSwitch.h"
+#include "llvm/Analysis/AssumptionTracker.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Dominators.h"
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/PatternMatch.h"
@@ -68,10 +70,11 @@ STATISTIC(NumExpand,    "Number of expansions");
 STATISTIC(NumFactor   , "Number of factorizations");
 STATISTIC(NumReassoc  , "Number of reassociations");
 
-static cl::opt<bool> UnsafeFPShrink("enable-double-float-shrink", cl::Hidden,
-                                   cl::init(false),
-                                   cl::desc("Enable unsafe double to float "
-                                            "shrinking for math lib calls"));
+static cl::opt<bool>
+    EnableUnsafeFPShrink("enable-double-float-shrink", cl::Hidden,
+                         cl::init(false),
+                         cl::desc("Enable unsafe double to float "
+                                  "shrinking for math lib calls"));
 
 // Initialization Routines
 void llvm::initializeInstCombine(PassRegistry &Registry) {
@@ -85,12 +88,14 @@ void LLVMInitializeInstCombine(LLVMPassRegistryRef R) {
 char InstCombiner::ID = 0;
 INITIALIZE_PASS_BEGIN(InstCombiner, "instcombine",
                 "Combine redundant instructions", false, false)
+INITIALIZE_PASS_DEPENDENCY(AssumptionTracker)
 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo)
 INITIALIZE_PASS_END(InstCombiner, "instcombine",
                 "Combine redundant instructions", false, false)
 
 void InstCombiner::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.setPreservesCFG();
+  AU.addRequired<AssumptionTracker>();
   AU.addRequired<TargetLibraryInfo>();
 }
 
@@ -390,6 +395,25 @@ static bool RightDistributesOverLeft(Instruction::BinaryOps LOp,
                                      Instruction::BinaryOps ROp) {
   if (Instruction::isCommutative(ROp))
     return LeftDistributesOverRight(ROp, LOp);
+
+  switch (LOp) {
+  default:
+    return false;
+  // (X >> Z) & (Y >> Z)  -> (X&Y) >> Z  for all shifts.
+  // (X >> Z) | (Y >> Z)  -> (X|Y) >> Z  for all shifts.
+  // (X >> Z) ^ (Y >> Z)  -> (X^Y) >> Z  for all shifts.
+  case Instruction::And:
+  case Instruction::Or:
+  case Instruction::Xor:
+    switch (ROp) {
+    default:
+      return false;
+    case Instruction::Shl:
+    case Instruction::LShr:
+    case Instruction::AShr:
+      return true;
+    }
+  }
   // TODO: It would be nice to handle division, aka "(X + Y)/Z = X/Z + Y/Z",
   // but this requires knowing that the addition does not overflow and other
   // such subtleties.
@@ -411,26 +435,37 @@ static Value *getIdentityValue(Instruction::BinaryOps OpCode, Value *V) {
 }
 
 /// This function factors binary ops which can be combined using distributive
-/// laws. This also factor SHL as MUL e.g. SHL(X, 2) ==> MUL(X, 4).
+/// laws. This function tries to transform 'Op' based TopLevelOpcode to enable
+/// factorization e.g for ADD(SHL(X , 2), MUL(X, 5)), When this function called
+/// with TopLevelOpcode == Instruction::Add and Op = SHL(X, 2), transforms
+/// SHL(X, 2) to MUL(X, 4) i.e. returns Instruction::Mul with LHS set to 'X' and
+/// RHS to 4.
 static Instruction::BinaryOps
-getBinOpsForFactorization(BinaryOperator *Op, Value *&LHS, Value *&RHS) {
+getBinOpsForFactorization(Instruction::BinaryOps TopLevelOpcode,
+                          BinaryOperator *Op, Value *&LHS, Value *&RHS) {
   if (!Op)
     return Instruction::BinaryOpsEnd;
 
-  if (Op->getOpcode() == Instruction::Shl) {
-    if (Constant *CST = dyn_cast<Constant>(Op->getOperand(1))) {
-      // The multiplier is really 1 << CST.
-      RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), CST);
-      LHS = Op->getOperand(0);
-      return Instruction::Mul;
+  LHS = Op->getOperand(0);
+  RHS = Op->getOperand(1);
+
+  switch (TopLevelOpcode) {
+  default:
+    return Op->getOpcode();
+
+  case Instruction::Add:
+  case Instruction::Sub:
+    if (Op->getOpcode() == Instruction::Shl) {
+      if (Constant *CST = dyn_cast<Constant>(Op->getOperand(1))) {
+        // The multiplier is really 1 << CST.
+        RHS = ConstantExpr::getShl(ConstantInt::get(Op->getType(), 1), CST);
+        return Instruction::Mul;
+      }
     }
+    return Op->getOpcode();
   }
 
   // TODO: We can add other conversions e.g. shr => div etc.
-
-  LHS = Op->getOperand(0);
-  RHS = Op->getOperand(1);
-  return Op->getOpcode();
 }
 
 /// This tries to simplify binary operations by factorizing out common terms
@@ -529,8 +564,9 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) {
 
   // Factorization.
   Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
-  Instruction::BinaryOps LHSOpcode = getBinOpsForFactorization(Op0, A, B);
-  Instruction::BinaryOps RHSOpcode = getBinOpsForFactorization(Op1, C, D);
+  auto TopLevelOpcode = I.getOpcode();
+  auto LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B);
+  auto RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D);
 
   // The instruction has the form "(A op' B) op (C op' D)".  Try to factorize
   // a common term.
@@ -552,7 +588,6 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) {
     return V;
 
   // Expansion.
-  Instruction::BinaryOps TopLevelOpcode = I.getOpcode();
   if (Op0 && RightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) {
     // The instruction has the form "(A op' B) op C".  See if expanding it out
     // to "(A op C) op' (B op C)" results in simplifications.
@@ -1284,7 +1319,7 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) {
 Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
   SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end());
 
-  if (Value *V = SimplifyGEPInst(Ops, DL))
+  if (Value *V = SimplifyGEPInst(Ops, DL, TLI, DT, AT))
     return ReplaceInstUsesWith(GEP, V);
 
   Value *PtrOp = GEP.getOperand(0);
@@ -1478,19 +1513,50 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
         GetElementPtrInst::Create(Src->getOperand(0), Indices, GEP.getName());
   }
 
-  // Canonicalize (gep i8* X, -(ptrtoint Y)) to (sub (ptrtoint X), (ptrtoint Y))
-  // The GEP pattern is emitted by the SCEV expander for certain kinds of
-  // pointer arithmetic.
-  if (DL && GEP.getNumIndices() == 1 &&
-      match(GEP.getOperand(1), m_Neg(m_PtrToInt(m_Value())))) {
+  if (DL && GEP.getNumIndices() == 1) {
     unsigned AS = GEP.getPointerAddressSpace();
-    if (GEP.getType() == Builder->getInt8PtrTy(AS) &&
-        GEP.getOperand(1)->getType()->getScalarSizeInBits() ==
+    if (GEP.getOperand(1)->getType()->getScalarSizeInBits() ==
         DL->getPointerSizeInBits(AS)) {
-      Operator *Index = cast<Operator>(GEP.getOperand(1));
-      Value *PtrToInt = Builder->CreatePtrToInt(PtrOp, Index->getType());
-      Value *NewSub = Builder->CreateSub(PtrToInt, Index->getOperand(1));
-      return CastInst::Create(Instruction::IntToPtr, NewSub, GEP.getType());
+      Type *PtrTy = GEP.getPointerOperandType();
+      Type *Ty = PtrTy->getPointerElementType();
+      uint64_t TyAllocSize = DL->getTypeAllocSize(Ty);
+
+      bool Matched = false;
+      uint64_t C;
+      Value *V = nullptr;
+      if (TyAllocSize == 1) {
+        V = GEP.getOperand(1);
+        Matched = true;
+      } else if (match(GEP.getOperand(1),
+                       m_AShr(m_Value(V), m_ConstantInt(C)))) {
+        if (TyAllocSize == 1ULL << C)
+          Matched = true;
+      } else if (match(GEP.getOperand(1),
+                       m_SDiv(m_Value(V), m_ConstantInt(C)))) {
+        if (TyAllocSize == C)
+          Matched = true;
+      }
+
+      if (Matched) {
+        // Canonicalize (gep i8* X, -(ptrtoint Y))
+        // to (inttoptr (sub (ptrtoint X), (ptrtoint Y)))
+        // The GEP pattern is emitted by the SCEV expander for certain kinds of
+        // pointer arithmetic.
+        if (match(V, m_Neg(m_PtrToInt(m_Value())))) {
+          Operator *Index = cast<Operator>(V);
+          Value *PtrToInt = Builder->CreatePtrToInt(PtrOp, Index->getType());
+          Value *NewSub = Builder->CreateSub(PtrToInt, Index->getOperand(1));
+          return CastInst::Create(Instruction::IntToPtr, NewSub, GEP.getType());
+        }
+        // Canonicalize (gep i8* X, (ptrtoint Y)-(ptrtoint X))
+        // to (bitcast Y)
+        Value *Y;
+        if (match(V, m_Sub(m_PtrToInt(m_Value(Y)),
+                           m_PtrToInt(m_Specific(GEP.getOperand(0)))))) {
+          return CastInst::CreatePointerBitCastOrAddrSpaceCast(Y,
+                                                               GEP.getType());
+        }
+      }
     }
   }
 
@@ -1582,9 +1648,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
           Builder->CreateGEP(StrippedPtr, Idx, GEP.getName());
 
         // V and GEP are both pointer types --> BitCast
-        if (StrippedPtrTy->getAddressSpace() == GEP.getPointerAddressSpace())
-          return new BitCastInst(NewGEP, GEP.getType());
-        return new AddrSpaceCastInst(NewGEP, GEP.getType());
+        return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP,
+                                                             GEP.getType());
       }
 
       // Transform things like:
@@ -1616,9 +1681,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
               Builder->CreateGEP(StrippedPtr, NewIdx, GEP.getName());
 
             // The NewGEP must be pointer typed, so must the old one -> BitCast
-            if (StrippedPtrTy->getAddressSpace() == GEP.getPointerAddressSpace())
-              return new BitCastInst(NewGEP, GEP.getType());
-            return new AddrSpaceCastInst(NewGEP, GEP.getType());
+            return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP,
+                                                                 GEP.getType());
           }
         }
       }
@@ -1658,9 +1722,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
               Builder->CreateInBoundsGEP(StrippedPtr, Off, GEP.getName()) :
               Builder->CreateGEP(StrippedPtr, Off, GEP.getName());
             // The NewGEP must be pointer typed, so must the old one -> BitCast
-            if (StrippedPtrTy->getAddressSpace() == GEP.getPointerAddressSpace())
-              return new BitCastInst(NewGEP, GEP.getType());
-            return new AddrSpaceCastInst(NewGEP, GEP.getType());
+            return CastInst::CreatePointerBitCastOrAddrSpaceCast(NewGEP,
+                                                                 GEP.getType());
           }
         }
       }
@@ -1670,6 +1733,18 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
   if (!DL)
     return nullptr;
 
+  // addrspacecast between types is canonicalized as a bitcast, then an
+  // addrspacecast. To take advantage of the below bitcast + struct GEP, look
+  // through the addrspacecast.
+  if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(PtrOp)) {
+    //   X = bitcast A addrspace(1)* to B addrspace(1)*
+    //   Y = addrspacecast A addrspace(1)* to B addrspace(2)*
+    //   Z = gep Y, <...constant indices...>
+    // Into an addrspacecasted GEP of the struct.
+    if (BitCastInst *BC = dyn_cast<BitCastInst>(ASC->getOperand(0)))
+      PtrOp = BC;
+  }
+
   /// See if we can simplify:
   ///   X = bitcast A* to B*
   ///   Y = gep X, <...constant indices...>
@@ -1678,11 +1753,10 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
   if (BitCastInst *BCI = dyn_cast<BitCastInst>(PtrOp)) {
     Value *Operand = BCI->getOperand(0);
     PointerType *OpType = cast<PointerType>(Operand->getType());
-    unsigned OffsetBits = DL->getPointerTypeSizeInBits(OpType);
+    unsigned OffsetBits = DL->getPointerTypeSizeInBits(GEP.getType());
     APInt Offset(OffsetBits, 0);
     if (!isa<BitCastInst>(Operand) &&
-        GEP.accumulateConstantOffset(*DL, Offset) &&
-        StrippedPtrTy->getAddressSpace() == GEP.getPointerAddressSpace()) {
+        GEP.accumulateConstantOffset(*DL, Offset)) {
 
       // If this GEP instruction doesn't move the pointer, just replace the GEP
       // with a bitcast of the real input to the dest type.
@@ -1700,6 +1774,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
             return &GEP;
           }
         }
+
+        if (Operand->getType()->getPointerAddressSpace() != GEP.getAddressSpace())
+          return new AddrSpaceCastInst(Operand, GEP.getType());
         return new BitCastInst(Operand, GEP.getType());
       }
 
@@ -1715,6 +1792,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
         if (NGEP->getType() == GEP.getType())
           return ReplaceInstUsesWith(GEP, NGEP);
         NGEP->takeName(&GEP);
+
+        if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace())
+          return new AddrSpaceCastInst(NGEP, GEP.getType());
         return new BitCastInst(NGEP, GEP.getType());
       }
     }
@@ -1925,7 +2005,25 @@ Instruction *InstCombiner::visitFree(CallInst &FI) {
   return nullptr;
 }
 
+Instruction *InstCombiner::visitReturnInst(ReturnInst &RI) {
+  if (RI.getNumOperands() == 0) // ret void
+    return nullptr;
+
+  Value *ResultOp = RI.getOperand(0);
+  Type *VTy = ResultOp->getType();
+  if (!VTy->isIntegerTy())
+    return nullptr;
 
+  // There might be assume intrinsics dominating this return that completely
+  // determine the value. If so, constant fold it.
+  unsigned BitWidth = VTy->getPrimitiveSizeInBits();
+  APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
+  computeKnownBits(ResultOp, KnownZero, KnownOne, 0, &RI);
+  if ((KnownZero|KnownOne).isAllOnesValue())
+    RI.setOperand(0, Constant::getIntegerValue(VTy, KnownOne));
+
+  return nullptr;
+}
 
 Instruction *InstCombiner::visitBranchInst(BranchInst &BI) {
   // Change br (not X), label True, label False to: br X, label False, True
@@ -1977,6 +2075,37 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) {
 
 Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) {
   Value *Cond = SI.getCondition();
+  unsigned BitWidth = cast<IntegerType>(Cond->getType())->getBitWidth();
+  APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
+  computeKnownBits(Cond, KnownZero, KnownOne);
+  unsigned LeadingKnownZeros = KnownZero.countLeadingOnes();
+  unsigned LeadingKnownOnes = KnownOne.countLeadingOnes();
+
+  // Compute the number of leading bits we can ignore.
+  for (auto &C : SI.cases()) {
+    LeadingKnownZeros = std::min(
+        LeadingKnownZeros, C.getCaseValue()->getValue().countLeadingZeros());
+    LeadingKnownOnes = std::min(
+        LeadingKnownOnes, C.getCaseValue()->getValue().countLeadingOnes());
+  }
+
+  unsigned NewWidth = BitWidth - std::max(LeadingKnownZeros, LeadingKnownOnes);
+
+  // Truncate the condition operand if the new type is equal to or larger than
+  // the largest legal integer type. We need to be conservative here since
+  // x86 generates redundant zero-extenstion instructions if the operand is
+  // truncated to i8 or i16.
+  if (BitWidth > NewWidth && NewWidth >= DL->getLargestLegalIntTypeSize()) {
+    IntegerType *Ty = IntegerType::get(SI.getContext(), NewWidth);
+    Builder->SetInsertPoint(&SI);
+    Value *NewCond = Builder->CreateTrunc(SI.getCondition(), Ty, "trunc");
+    SI.setCondition(NewCond);
+
+    for (auto &C : SI.cases())
+      static_cast<SwitchInst::CaseIt *>(&C)->setValue(ConstantInt::get(
+          SI.getContext(), C.getCaseValue()->getValue().trunc(NewWidth)));
+  }
+
   if (Instruction *I = dyn_cast<Instruction>(Cond)) {
     if (I->getOpcode() == Instruction::Add)
       if (ConstantInt *AddRHS = dyn_cast<ConstantInt>(I->getOperand(1))) {
@@ -2534,7 +2663,7 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) {
 /// whose condition is a known constant, we only visit the reachable successors.
 ///
 static bool AddReachableCodeToWorklist(BasicBlock *BB,
-                                       SmallPtrSet<BasicBlock*, 64> &Visited,
+                                       SmallPtrSetImpl<BasicBlock*> &Visited,
                                        InstCombiner &IC,
                                        const DataLayout *DL,
                                        const TargetLibraryInfo *TLI) {
@@ -2730,9 +2859,18 @@ bool InstCombiner::DoOneIteration(Function &F, unsigned Iteration) {
         // If the user is one of our immediate successors, and if that successor
         // only has us as a predecessors (we'd have to split the critical edge
         // otherwise), we can keep going.
-        if (UserIsSuccessor && UserParent->getSinglePredecessor())
+        if (UserIsSuccessor && UserParent->getSinglePredecessor()) {
           // Okay, the CFG is simple enough, try to sink this instruction.
-          MadeIRChange |= TryToSinkInstruction(I, UserParent);
+          if (TryToSinkInstruction(I, UserParent)) {
+            MadeIRChange = true;
+            // We'll add uses of the sunk instruction below, but since sinking
+            // can expose opportunities for it's *operands* add them to the
+            // worklist
+            for (Use &U : I->operands())
+              if (Instruction *OpI = dyn_cast<Instruction>(U.get()))
+                Worklist.Add(OpI);
+          }
+        }
       }
     }
 
@@ -2801,13 +2939,13 @@ bool InstCombiner::DoOneIteration(Function &F, unsigned Iteration) {
 }
 
 namespace {
-class InstCombinerLibCallSimplifier : public LibCallSimplifier {
+class InstCombinerLibCallSimplifier final : public LibCallSimplifier {
   InstCombiner *IC;
 public:
   InstCombinerLibCallSimplifier(const DataLayout *DL,
                                 const TargetLibraryInfo *TLI,
                                 InstCombiner *IC)
-    : LibCallSimplifier(DL, TLI, UnsafeFPShrink) {
+    : LibCallSimplifier(DL, TLI, EnableUnsafeFPShrink) {
     this->IC = IC;
   }
 
@@ -2823,9 +2961,15 @@ bool InstCombiner::runOnFunction(Function &F) {
   if (skipOptnoneFunction(F))
     return false;
 
+  AT = &getAnalysis<AssumptionTracker>();
   DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
   DL = DLP ? &DLP->getDataLayout() : nullptr;
   TLI = &getAnalysis<TargetLibraryInfo>();
+
+  DominatorTreeWrapperPass *DTWP =
+      getAnalysisIfAvailable<DominatorTreeWrapperPass>();
+  DT = DTWP ? &DTWP->getDomTree() : nullptr;
+
   // Minimizing size?
   MinimizeSize = F.getAttributes().hasAttribute(AttributeSet::FunctionIndex,
                                                 Attribute::MinSize);
@@ -2834,7 +2978,7 @@ bool InstCombiner::runOnFunction(Function &F) {
   /// instructions into the worklist when they are created.
   IRBuilder<true, TargetFolder, InstCombineIRInserter>
     TheBuilder(F.getContext(), TargetFolder(DL),
-               InstCombineIRInserter(Worklist));
+               InstCombineIRInserter(Worklist, AT));
   Builder = &TheBuilder;
 
   InstCombinerLibCallSimplifier TheSimplifier(DL, TLI, this);