Remove unnecessary intermediate lambda. NFC
[oota-llvm.git] / lib / Analysis / InstructionSimplify.cpp
index be5ce2960ea8d49ddb4ddc8ffccc8034dbbfbb8e..0bd18c1a35cd447d4cb7c0e78b12955b89c4169f 100644 (file)
@@ -2090,8 +2090,7 @@ static Constant *computePointerICmp(const DataLayout &DL,
 
     // Is the set of underlying objects all noalias calls?
     auto IsNAC = [](SmallVectorImpl<Value *> &Objects) {
-      return std::all_of(Objects.begin(), Objects.end(),
-                         [](Value *V){ return isNoAliasCall(V); });
+      return std::all_of(Objects.begin(), Objects.end(), isNoAliasCall);
     };
 
     // Is the set of underlying objects all things which must be disjoint from
@@ -2128,77 +2127,6 @@ static Constant *computePointerICmp(const DataLayout &DL,
   return nullptr;
 }
 
-/// Return true if B is known to be implied by A.  A & B must be i1 (boolean)
-/// values or a vector of such values. Note that the truth table for
-/// implication is the same as <=u on i1 values (but not <=s!).  The truth
-/// table for both is: 
-///    | T | F (B)
-///  T | T | F
-///  F | T | T
-/// (A)
-static bool implies(Value *A, Value *B) {
-  assert(A->getType() == B->getType() && "mismatched type");
-  Type *OpTy = A->getType();
-  assert(OpTy->getScalarType()->isIntegerTy(1));
-  
-  // A ==> A by definition
-  if (A == B) return true;
-
-  if (OpTy->isVectorTy())
-    // TODO: extending the code below to handle vectors
-    return false;
-  assert(OpTy->isIntegerTy(1) && "implied by above");
-
-  ICmpInst::Predicate APred, BPred;
-  Value *I;
-  Value *L;
-  ConstantInt *CI;
-  // i +_{nsw} C_{>0} <s L ==> i <s L
-  if (match(A, m_ICmp(APred,
-                      m_NSWAdd(m_Value(I), m_ConstantInt(CI)),
-                      m_Value(L))) &&
-      APred == ICmpInst::ICMP_SLT &&
-      !CI->isNegative() &&
-      match(B, m_ICmp(BPred, m_Specific(I), m_Specific(L))) &&
-      BPred == ICmpInst::ICMP_SLT)
-    return true;
-
-  // i +_{nuw} C_{>0} <u L ==> i <u L
-  if (match(A, m_ICmp(APred,
-                      m_NUWAdd(m_Value(I), m_ConstantInt(CI)),
-                      m_Value(L))) &&
-      APred == ICmpInst::ICMP_ULT &&
-      !CI->isNegative() &&
-      match(B, m_ICmp(BPred, m_Specific(I), m_Specific(L))) &&
-      BPred == ICmpInst::ICMP_ULT)
-    return true;
-
-  return false;
-}
-
-static ConstantRange GetConstantRangeFromMetadata(MDNode *Ranges, uint32_t BitWidth) {
-  const unsigned NumRanges = Ranges->getNumOperands() / 2;
-  assert(NumRanges >= 1);
-
-  ConstantRange CR(BitWidth, false);
-  for (unsigned i = 0; i < NumRanges; ++i) {
-    auto *Low =
-        mdconst::extract<ConstantInt>(Ranges->getOperand(2 * i + 0));
-    auto *High =
-        mdconst::extract<ConstantInt>(Ranges->getOperand(2 * i + 1));
-
-    // Union will merge two ranges to one and potentially introduce a range
-    // not covered by the original two ranges. For example, [1, 5) and [8, 10)
-    // will become [1, 10). In this case, we can not fold comparison between
-    // constant 6 and a value of the above ranges. In practice, most values
-    // have only one range, so it might not be worth handling this by
-    // introducing additional complexity.
-    CR = CR.unionWith(ConstantRange(Low->getValue(), High->getValue()));
-  }
-
-  return CR;
-}
-
 /// SimplifyICmpInst - Given operands for an ICmpInst, see if we can
 /// fold the result.  If not, this returns null.
 static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
@@ -2247,7 +2175,18 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
       // X >=u 1 -> X
       if (match(RHS, m_One()))
         return LHS;
-      if (implies(RHS, LHS))
+      if (isImpliedCondition(RHS, LHS, Q.DL))
+        return getTrue(ITy);
+      break;
+    case ICmpInst::ICMP_SGE:
+      /// For signed comparison, the values for an i1 are 0 and -1 
+      /// respectively. This maps into a truth table of:
+      /// LHS | RHS | LHS >=s RHS   | LHS implies RHS
+      ///  0  |  0  |  1 (0 >= 0)   |  1
+      ///  0  |  1  |  1 (0 >= -1)  |  1
+      ///  1  |  0  |  0 (-1 >= 0)  |  0
+      ///  1  |  1  |  1 (-1 >= -1) |  1
+      if (isImpliedCondition(LHS, RHS, Q.DL))
         return getTrue(ITy);
       break;
     case ICmpInst::ICMP_SLT:
@@ -2261,7 +2200,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
         return LHS;
       break;
     case ICmpInst::ICMP_ULE:
-      if (implies(LHS, RHS))
+      if (isImpliedCondition(LHS, RHS, Q.DL))
         return getTrue(ITy);
       break;
     }
@@ -2447,7 +2386,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
 
     if (auto *I = dyn_cast<Instruction>(LHS))
       if (auto *Ranges = I->getMetadata(LLVMContext::MD_range))
-        LHS_CR = LHS_CR.intersectWith(GetConstantRangeFromMetadata(Ranges, Width));
+        LHS_CR = LHS_CR.intersectWith(getConstantRangeFromMetadata(*Ranges));
 
     if (!LHS_CR.isFullSet()) {
       if (RHS_CR.contains(LHS_CR))
@@ -2465,12 +2404,10 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
 
     if (RHS_Instr->getMetadata(LLVMContext::MD_range) &&
         LHS_Instr->getMetadata(LLVMContext::MD_range)) {
-      uint32_t BitWidth = Q.DL.getTypeSizeInBits(RHS->getType());
-
-      auto RHS_CR = GetConstantRangeFromMetadata(
-          RHS_Instr->getMetadata(LLVMContext::MD_range), BitWidth);
-      auto LHS_CR = GetConstantRangeFromMetadata(
-          LHS_Instr->getMetadata(LLVMContext::MD_range), BitWidth);
+      auto RHS_CR = getConstantRangeFromMetadata(
+          *RHS_Instr->getMetadata(LLVMContext::MD_range));
+      auto LHS_CR = getConstantRangeFromMetadata(
+          *LHS_Instr->getMetadata(LLVMContext::MD_range));
 
       auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR);
       if (Satisfied_CR.contains(LHS_CR))
@@ -4145,6 +4082,17 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL,
     break;
   }
 
+  // In general, it is possible for computeKnownBits to determine all bits in a
+  // value even when the operands are not all constants.
+  if (!Result && I->getType()->isIntegerTy()) {
+    unsigned BitWidth = I->getType()->getScalarSizeInBits();
+    APInt KnownZero(BitWidth, 0);
+    APInt KnownOne(BitWidth, 0);
+    computeKnownBits(I, KnownZero, KnownOne, DL, /*Depth*/0, AC, I, DT);
+    if ((KnownZero | KnownOne).isAllOnesValue())
+      Result = ConstantInt::get(I->getContext(), KnownOne);
+  }
+
   /// If called on unreachable code, the above logic may report that the
   /// instruction simplified to itself.  Make life easier for users by
   /// detecting that case here, returning a safe value instead.