Merging r259375:
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCompares.cpp
index f53eeef1dae60cb0380e8376de948c060eeab56f..d9311a343eadb57fdcfd9d02a592e91985335a31 100644 (file)
@@ -216,8 +216,6 @@ static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero,
   Max = KnownOne|UnknownBits;
 }
 
-
-
 /// FoldCmpLoadFromIndexedGlobal - Called we see this pattern:
 ///   cmp pred (load (gep GV, ...)), cmpcst
 /// where GV is a global variable with a constant initializer.  Try to simplify
@@ -371,7 +369,6 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV,
       }
     }
 
-
     // If this element is in range, update our magic bitvector.
     if (i < 64 && IsTrueForElt)
       MagicBitvector |= 1ULL << i;
@@ -469,7 +466,6 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV,
     return new ICmpInst(ICmpInst::ICMP_UGT, Idx, End);
   }
 
-
   // If a magic bitvector captures the entire comparison state
   // of this load, replace it with computation that does:
   //   ((magic_cst >> i) & 1) != 0
@@ -496,7 +492,6 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV,
   return nullptr;
 }
 
-
 /// EvaluateGEPOffsetExpression - Return a value that can be used to compare
 /// the *offset* implied by a GEP to zero.  For example, if we have &A[i], we
 /// want to return 'i' for "icmp ne i, 0".  Note that, in general, indices can
@@ -562,8 +557,6 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC,
     }
   }
 
-
-
   // Okay, we know we have a single variable index, which must be a
   // pointer/array/vector index.  If there is no offset, life is simple, return
   // the index.
@@ -737,6 +730,83 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
   return nullptr;
 }
 
+Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca,
+                                         Value *Other) {
+  assert(ICI.isEquality() && "Cannot fold non-equality comparison.");
+
+  // It would be tempting to fold away comparisons between allocas and any
+  // pointer not based on that alloca (e.g. an argument). However, even
+  // though such pointers cannot alias, they can still compare equal.
+  //
+  // But LLVM doesn't specify where allocas get their memory, so if the alloca
+  // doesn't escape we can argue that it's impossible to guess its value, and we
+  // can therefore act as if any such guesses are wrong.
+  //
+  // The code below checks that the alloca doesn't escape, and that it's only
+  // used in a comparison once (the current instruction). The
+  // single-comparison-use condition ensures that we're trivially folding all
+  // comparisons against the alloca consistently, and avoids the risk of
+  // erroneously folding a comparison of the pointer with itself.
+
+  unsigned MaxIter = 32; // Break cycles and bound to constant-time.
+
+  SmallVector<Use *, 32> Worklist;
+  for (Use &U : Alloca->uses()) {
+    if (Worklist.size() >= MaxIter)
+      return nullptr;
+    Worklist.push_back(&U);
+  }
+
+  unsigned NumCmps = 0;
+  while (!Worklist.empty()) {
+    assert(Worklist.size() <= MaxIter);
+    Use *U = Worklist.pop_back_val();
+    Value *V = U->getUser();
+    --MaxIter;
+
+    if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V) || isa<PHINode>(V) ||
+        isa<SelectInst>(V)) {
+      // Track the uses.
+    } else if (isa<LoadInst>(V)) {
+      // Loading from the pointer doesn't escape it.
+      continue;
+    } else if (auto *SI = dyn_cast<StoreInst>(V)) {
+      // Storing *to* the pointer is fine, but storing the pointer escapes it.
+      if (SI->getValueOperand() == U->get())
+        return nullptr;
+      continue;
+    } else if (isa<ICmpInst>(V)) {
+      if (NumCmps++)
+        return nullptr; // Found more than one cmp.
+      continue;
+    } else if (auto *Intrin = dyn_cast<IntrinsicInst>(V)) {
+      switch (Intrin->getIntrinsicID()) {
+        // These intrinsics don't escape or compare the pointer. Memset is safe
+        // because we don't allow ptrtoint. Memcpy and memmove are safe because
+        // we don't allow stores, so src cannot point to V.
+        case Intrinsic::lifetime_start: case Intrinsic::lifetime_end:
+        case Intrinsic::dbg_declare: case Intrinsic::dbg_value:
+        case Intrinsic::memcpy: case Intrinsic::memmove: case Intrinsic::memset:
+          continue;
+        default:
+          return nullptr;
+      }
+    } else {
+      return nullptr;
+    }
+    for (Use &U : V->uses()) {
+      if (Worklist.size() >= MaxIter)
+        return nullptr;
+      Worklist.push_back(&U);
+    }
+  }
+
+  Type *CmpTy = CmpInst::makeCmpResultType(Other->getType());
+  return ReplaceInstUsesWith(
+      ICI,
+      ConstantInt::get(CmpTy, !CmpInst::isTrueWhenEqual(ICI.getPredicate())));
+}
+
 /// FoldICmpAddOpCst - Fold "icmp pred (X+CI), X".
 Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI,
                                             Value *X, ConstantInt *CI,
@@ -851,7 +921,6 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
       // to the same result value.
       HiOverflow = AddWithOverflow(HiBound, LoBound, RangeSize, false);
     }
-
   } else if (DivRHS->getValue().isStrictlyPositive()) { // Divisor is > 0.
     if (CmpRHSV == 0) {       // (X / pos) op 0
       // Can't overflow.  e.g.  X/2 op 0 --> [-1, 2)
@@ -996,7 +1065,6 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr,
     return Res;
   }
 
-
   // If we are comparing against bits always shifted out, the
   // comparison cannot succeed.
   APInt Comp = CmpRHSV << ShAmtVal;
@@ -1074,18 +1142,22 @@ Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A,
   if (AP1 == AP2)
     return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType()));
 
-  // Get the distance between the highest bit that's set.
   int Shift;
-  // Both the constants are negative, take their positive to calculate log.
   if (IsAShr && AP1.isNegative())
-    // Get the ones' complement of AP2 and AP1 when computing the distance.
-    Shift = (~AP2).logBase2() - (~AP1).logBase2();
+    Shift = AP1.countLeadingOnes() - AP2.countLeadingOnes();
   else
-    Shift = AP2.logBase2() - AP1.logBase2();
+    Shift = AP1.countLeadingZeros() - AP2.countLeadingZeros();
 
   if (Shift > 0) {
-    if (IsAShr ? AP1 == AP2.ashr(Shift) : AP1 == AP2.lshr(Shift))
+    if (IsAShr && AP1 == AP2.ashr(Shift)) {
+      // There are multiple solutions if we are comparing against -1 and the LHS
+      // of the ashr is not a power of two.
+      if (AP1.isAllOnesValue() && !AP2.isPowerOf2())
+        return getICmp(I.ICMP_UGE, A, ConstantInt::get(A->getType(), Shift));
+      return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift));
+    } else if (AP1 == AP2.lshr(Shift)) {
       return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift));
+    }
   }
   // Shifting const2 will never be equal to const1.
   return getConstant(false);
@@ -1145,6 +1217,14 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
 
   switch (LHSI->getOpcode()) {
   case Instruction::Trunc:
+    if (RHS->isOne() && RHSV.getBitWidth() > 1) {
+      // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1
+      Value *V = nullptr;
+      if (ICI.getPredicate() == ICmpInst::ICMP_SLT &&
+          match(LHSI->getOperand(0), m_Signum(m_Value(V))))
+        return new ICmpInst(ICmpInst::ICMP_SLT, V,
+                            ConstantInt::get(V->getType(), 1));
+    }
     if (ICI.isEquality() && LHSI->hasOneUse()) {
       // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all
       // of the high bits truncated out of x are known.
@@ -1447,9 +1527,35 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
           ICI.getPredicate() == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_UGT
                                                   : ICmpInst::ICMP_ULE,
           LHSI->getOperand(0), SubOne(RHS));
+
+    // (icmp eq (and %A, C), 0) -> (icmp sgt (trunc %A), -1)
+    //   iff C is a power of 2
+    if (ICI.isEquality() && LHSI->hasOneUse() && match(RHS, m_Zero())) {
+      if (auto *CI = dyn_cast<ConstantInt>(LHSI->getOperand(1))) {
+        const APInt &AI = CI->getValue();
+        int32_t ExactLogBase2 = AI.exactLogBase2();
+        if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) {
+          Type *NTy = IntegerType::get(ICI.getContext(), ExactLogBase2 + 1);
+          Value *Trunc = Builder->CreateTrunc(LHSI->getOperand(0), NTy);
+          return new ICmpInst(ICI.getPredicate() == ICmpInst::ICMP_EQ
+                                  ? ICmpInst::ICMP_SGE
+                                  : ICmpInst::ICMP_SLT,
+                              Trunc, Constant::getNullValue(NTy));
+        }
+      }
+    }
     break;
 
   case Instruction::Or: {
+    if (RHS->isOne()) {
+      // icmp slt signum(V) 1 --> icmp slt V, 1
+      Value *V = nullptr;
+      if (ICI.getPredicate() == ICmpInst::ICMP_SLT &&
+          match(LHSI, m_Signum(m_Value(V))))
+        return new ICmpInst(ICmpInst::ICMP_SLT, V,
+                            ConstantInt::get(V->getType(), 1));
+    }
+
     if (!ICI.isEquality() || !RHS->isNullValue() || !LHSI->hasOneUse())
       break;
     Value *P, *Q;
@@ -2083,11 +2189,9 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
   // If the pattern matches, truncate the inputs to the narrower type and
   // use the sadd_with_overflow intrinsic to efficiently compute both the
   // result and the overflow bit.
-  Module *M = I.getParent()->getParent()->getParent();
-
   Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth);
-  Value *F = Intrinsic::getDeclaration(M, Intrinsic::sadd_with_overflow,
-                                       NewType);
+  Value *F = Intrinsic::getDeclaration(I.getModule(),
+                                       Intrinsic::sadd_with_overflow, NewType);
 
   InstCombiner::BuilderTy *Builder = IC.Builder;
 
@@ -2112,9 +2216,8 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
 bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS,
                                          Value *RHS, Instruction &OrigI,
                                          Value *&Result, Constant *&Overflow) {
-  assert((!OrigI.isCommutative() ||
-          !(isa<Constant>(LHS) && !isa<Constant>(RHS))) &&
-         "call with a constant RHS if possible!");
+  if (OrigI.isCommutative() && isa<Constant>(LHS) && !isa<Constant>(RHS))
+    std::swap(LHS, RHS);
 
   auto SetResult = [&](Value *OpResult, Constant *OverflowVal, bool ReuseName) {
     Result = OpResult;
@@ -2124,6 +2227,12 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS,
     return true;
   };
 
+  // If the overflow check was an add followed by a compare, the insertion point
+  // may be pointing to the compare.  We want to insert the new instructions
+  // before the add in case there are uses of the add between the add and the
+  // compare.
+  Builder->SetInsertPoint(&OrigI);
+
   switch (OCF) {
   case OCF_INVALID:
     llvm_unreachable("bad overflow check kind!");
@@ -2224,7 +2333,9 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal,
 
   assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal);
   assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal);
-  Instruction *MulInstr = cast<Instruction>(MulVal);
+  auto *MulInstr = dyn_cast<Instruction>(MulVal);
+  if (!MulInstr)
+    return nullptr;
   assert(MulInstr->getOpcode() == Instruction::Mul);
 
   auto *LHS = cast<ZExtOperator>(MulInstr->getOperand(0)),
@@ -2358,7 +2469,6 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal,
 
   InstCombiner::BuilderTy *Builder = IC.Builder;
   Builder->SetInsertPoint(MulInstr);
-  Module *M = I.getParent()->getParent()->getParent();
 
   // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B)
   Value *MulA = A, *MulB = B;
@@ -2366,8 +2476,8 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal,
     MulA = Builder->CreateZExt(A, MulType);
   if (WidthB < MulWidth)
     MulB = Builder->CreateZExt(B, MulType);
-  Value *F =
-      Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, MulType);
+  Value *F = Intrinsic::getDeclaration(I.getModule(),
+                                       Intrinsic::umul_with_overflow, MulType);
   CallInst *Call = Builder->CreateCall(F, {MulA, MulB}, "umul");
   IC.Worklist.Add(MulInstr);
 
@@ -2469,7 +2579,6 @@ static APInt DemandedBitsLHSMask(ICmpInst &I,
   default:
     return APInt::getAllOnesValue(BitWidth);
   }
-
 }
 
 /// \brief Check if the order of \p Op0 and \p Op1 as operand in an ICmpInst
@@ -2646,7 +2755,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     Changed = true;
   }
 
-  if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AC))
+  if (Value *V =
+          SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AC, &I))
     return ReplaceInstUsesWith(I, V);
 
   // comparing -val or val with non-zero is the same as just comparing val
@@ -2905,7 +3015,6 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
                               ConstantInt::get(X->getType(),
                                                CI->countTrailingZeros()));
       }
-
       break;
     }
     case ICmpInst::ICMP_NE: {
@@ -2950,7 +3059,6 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
                               ConstantInt::get(X->getType(),
                                                CI->countTrailingZeros()));
       }
-
       break;
     }
     case ICmpInst::ICMP_ULT:
@@ -3103,7 +3211,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
         // comparison into the select arms, which will cause one to be
         // constant folded and the select turned into a bitwise or.
         Value *Op1 = nullptr, *Op2 = nullptr;
-        ConstantInt *CI = 0;
+        ConstantInt *CI = nullptr;
         if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) {
           Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC);
           CI = dyn_cast<ConstantInt>(Op1);
@@ -3177,6 +3285,17 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
                            ICmpInst::getSwappedPredicate(I.getPredicate()), I))
       return NI;
 
+  // Try to optimize equality comparisons against alloca-based pointers.
+  if (Op0->getType()->isPointerTy() && I.isEquality()) {
+    assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?");
+    if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op0, DL)))
+      if (Instruction *New = FoldAllocaCmp(I, Alloca, Op1))
+        return New;
+    if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op1, DL)))
+      if (Instruction *New = FoldAllocaCmp(I, Alloca, Op0))
+        return New;
+  }
+
   // Test to see if the operands of the icmp are casted versions of other
   // values.  If the ptr->ptr cast can be stripped off both arguments, we do so
   // now.
@@ -3304,6 +3423,26 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
         match(B, m_One()))
       return new ICmpInst(CmpInst::ICMP_SGE, A, Op1);
 
+    // icmp sgt X, (Y + -1) -> icmp sge X, Y
+    if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT &&
+        match(D, m_AllOnes()))
+      return new ICmpInst(CmpInst::ICMP_SGE, Op0, C);
+
+    // icmp sle X, (Y + -1) -> icmp slt X, Y
+    if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE &&
+        match(D, m_AllOnes()))
+      return new ICmpInst(CmpInst::ICMP_SLT, Op0, C);
+
+    // icmp sge X, (Y + 1) -> icmp sgt X, Y
+    if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE &&
+        match(D, m_One()))
+      return new ICmpInst(CmpInst::ICMP_SGT, Op0, C);
+
+    // icmp slt X, (Y + 1) -> icmp sle X, Y
+    if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT &&
+        match(D, m_One()))
+      return new ICmpInst(CmpInst::ICMP_SLE, Op0, C);
+
     // if C1 has greater magnitude than C2:
     //  icmp (X + C1), (Y + C2) -> icmp (X + C3), Y
     //  s.t. C3 = C1 - C2
@@ -3421,7 +3560,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
                                 BO1->getOperand(0));
           }
 
-          if (CI->isMaxValue(true)) {
+          if (BO0->getOpcode() == Instruction::Xor && CI->isMaxValue(true)) {
             ICmpInst::Predicate Pred = I.isSigned()
                                            ? I.getUnsignedPredicate()
                                            : I.getSignedPredicate();
@@ -3473,6 +3612,18 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
       }
       }
     }
+
+    if (BO0) {
+      // Transform  A & (L - 1) `ult` L --> L != 0
+      auto LSubOne = m_Add(m_Specific(Op1), m_AllOnes());
+      auto BitwiseAnd =
+          m_CombineOr(m_And(m_Value(), LSubOne), m_And(LSubOne, m_Value()));
+
+      if (match(BO0, BitwiseAnd) && I.getPredicate() == ICmpInst::ICMP_ULT) {
+        auto *Zero = Constant::getNullValue(BO0->getType());
+        return new ICmpInst(ICmpInst::ICMP_NE, Op1, Zero);
+      }
+    }
   }
 
   { Value *A, *B;
@@ -3697,15 +3848,7 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I,
 
   IntegerType *IntTy = cast<IntegerType>(LHSI->getOperand(0)->getType());
 
-  // Check to see that the input is converted from an integer type that is small
-  // enough that preserves all bits.  TODO: check here for "known" sign bits.
-  // This would allow us to handle (fptosi (x >>s 62) to float) if x is i64 f.e.
-  unsigned InputSize = IntTy->getScalarSizeInBits();
-
-  // If this is a uitofp instruction, we need an extra bit to hold the sign.
   bool LHSUnsigned = isa<UIToFPInst>(LHSI);
-  if (LHSUnsigned)
-    ++InputSize;
 
   if (I.isEquality()) {
     FCmpInst::Predicate P = I.getPredicate();
@@ -3732,13 +3875,30 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I,
     // equality compares as integer?
   }
 
-  // Comparisons with zero are a special case where we know we won't lose
-  // information.
-  bool IsCmpZero = RHS.isPosZero();
+  // Check to see that the input is converted from an integer type that is small
+  // enough that preserves all bits.  TODO: check here for "known" sign bits.
+  // This would allow us to handle (fptosi (x >>s 62) to float) if x is i64 f.e.
+  unsigned InputSize = IntTy->getScalarSizeInBits();
 
-  // If the conversion would lose info, don't hack on this.
-  if ((int)InputSize > MantissaWidth && !IsCmpZero)
-    return nullptr;
+  // Following test does NOT adjust InputSize downwards for signed inputs, 
+  // because the most negative value still requires all the mantissa bits 
+  // to distinguish it from one less than that value.
+  if ((int)InputSize > MantissaWidth) {
+    // Conversion would lose accuracy. Check if loss can impact comparison.
+    int Exp = ilogb(RHS);
+    if (Exp == APFloat::IEK_Inf) {
+      int MaxExponent = ilogb(APFloat::getLargest(RHS.getSemantics()));
+      if (MaxExponent < (int)InputSize - !LHSUnsigned) 
+        // Conversion could create infinity.
+        return nullptr;
+    } else {
+      // Note that if RHS is zero or NaN, then Exp is negative 
+      // and first condition is trivially false.
+      if (MantissaWidth <= Exp && Exp <= (int)InputSize - !LHSUnsigned) 
+        // Conversion could affect comparison.
+        return nullptr;
+    }
+  }
 
   // Otherwise, we can potentially simplify the comparison.  We know that it
   // will always come through as an integer value and we know the constant is
@@ -3927,7 +4087,8 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
 
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
-  if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AC))
+  if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1,
+                                  I.getFastMathFlags(), DL, TLI, DT, AC, &I))
     return ReplaceInstUsesWith(I, V);
 
   // Simplify 'fcmp pred X, X'