Remove unnecessary intermediate lambda. NFC
[oota-llvm.git] / lib / Analysis / InstructionSimplify.cpp
index ab5cd01ab634744a976aa8f0b7d69050703e04f5..0bd18c1a35cd447d4cb7c0e78b12955b89c4169f 100644 (file)
@@ -24,6 +24,7 @@
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Dominators.h"
@@ -122,9 +123,9 @@ static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) {
   }
 
   // Otherwise, if the instruction is in the entry block, and is not an invoke,
-  // then it obviously dominates all phi nodes.
+  // and is not a catchpad, then it obviously dominates all phi nodes.
   if (I->getParent() == &I->getParent()->getParent()->getEntryBlock() &&
-      !isa<InvokeInst>(I))
+      !isa<InvokeInst>(I) && !isa<CatchPadInst>(I))
     return true;
 
   return false;
@@ -2089,8 +2090,7 @@ static Constant *computePointerICmp(const DataLayout &DL,
 
     // Is the set of underlying objects all noalias calls?
     auto IsNAC = [](SmallVectorImpl<Value *> &Objects) {
-      return std::all_of(Objects.begin(), Objects.end(),
-                         [](Value *V){ return isNoAliasCall(V); });
+      return std::all_of(Objects.begin(), Objects.end(), isNoAliasCall);
     };
 
     // Is the set of underlying objects all things which must be disjoint from
@@ -2175,6 +2175,19 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
       // X >=u 1 -> X
       if (match(RHS, m_One()))
         return LHS;
+      if (isImpliedCondition(RHS, LHS, Q.DL))
+        return getTrue(ITy);
+      break;
+    case ICmpInst::ICMP_SGE:
+      /// For signed comparison, the values for an i1 are 0 and -1 
+      /// respectively. This maps into a truth table of:
+      /// LHS | RHS | LHS >=s RHS   | LHS implies RHS
+      ///  0  |  0  |  1 (0 >= 0)   |  1
+      ///  0  |  1  |  1 (0 >= -1)  |  1
+      ///  1  |  0  |  0 (-1 >= 0)  |  0
+      ///  1  |  1  |  1 (-1 >= -1) |  1
+      if (isImpliedCondition(LHS, RHS, Q.DL))
+        return getTrue(ITy);
       break;
     case ICmpInst::ICMP_SLT:
       // X <s 0 -> X
@@ -2186,6 +2199,10 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
       if (match(RHS, m_One()))
         return LHS;
       break;
+    case ICmpInst::ICMP_ULE:
+      if (isImpliedCondition(LHS, RHS, Q.DL))
+        return getTrue(ITy);
+      break;
     }
   }
 
@@ -2359,9 +2376,19 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
     } else if (match(LHS, m_And(m_Value(), m_ConstantInt(CI2)))) {
       // 'and x, CI2' produces [0, CI2].
       Upper = CI2->getValue() + 1;
+    } else if (match(LHS, m_NUWAdd(m_Value(), m_ConstantInt(CI2)))) {
+      // 'add nuw x, CI2' produces [CI2, UINT_MAX].
+      Lower = CI2->getValue();
     }
-    if (Lower != Upper) {
-      ConstantRange LHS_CR = ConstantRange(Lower, Upper);
+
+    ConstantRange LHS_CR = Lower != Upper ? ConstantRange(Lower, Upper)
+                                          : ConstantRange(Width, true);
+
+    if (auto *I = dyn_cast<Instruction>(LHS))
+      if (auto *Ranges = I->getMetadata(LLVMContext::MD_range))
+        LHS_CR = LHS_CR.intersectWith(getConstantRangeFromMetadata(*Ranges));
+
+    if (!LHS_CR.isFullSet()) {
       if (RHS_CR.contains(LHS_CR))
         return ConstantInt::getTrue(RHS->getContext());
       if (RHS_CR.inverse().contains(LHS_CR))
@@ -2369,6 +2396,30 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
     }
   }
 
+  // If both operands have range metadata, use the metadata
+  // to simplify the comparison.
+  if (isa<Instruction>(RHS) && isa<Instruction>(LHS)) {
+    auto RHS_Instr = dyn_cast<Instruction>(RHS);
+    auto LHS_Instr = dyn_cast<Instruction>(LHS);
+
+    if (RHS_Instr->getMetadata(LLVMContext::MD_range) &&
+        LHS_Instr->getMetadata(LLVMContext::MD_range)) {
+      auto RHS_CR = getConstantRangeFromMetadata(
+          *RHS_Instr->getMetadata(LLVMContext::MD_range));
+      auto LHS_CR = getConstantRangeFromMetadata(
+          *LHS_Instr->getMetadata(LLVMContext::MD_range));
+
+      auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR);
+      if (Satisfied_CR.contains(LHS_CR))
+        return ConstantInt::getTrue(RHS->getContext());
+
+      auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion(
+                CmpInst::getInversePredicate(Pred), RHS_CR);
+      if (InversedSatisfied_CR.contains(LHS_CR))
+        return ConstantInt::getFalse(RHS->getContext());
+    }
+  }
+
   // Compare of cast, for example (zext X) != 0 -> X != 0
   if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) {
     Instruction *LI = cast<CastInst>(LHS);
@@ -2528,6 +2579,14 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
     }
   }
 
+  // icmp eq|ne X, Y -> false|true if X != Y
+  if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) &&
+      isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT)) {
+    LLVMContext &Ctx = LHS->getType()->getContext();
+    return Pred == ICmpInst::ICMP_NE ?
+      ConstantInt::getTrue(Ctx) : ConstantInt::getFalse(Ctx);
+  }
+  
   // Special logic for binary operators.
   BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS);
   BinaryOperator *RBO = dyn_cast<BinaryOperator>(RHS);
@@ -3520,6 +3579,73 @@ Value *llvm::SimplifyInsertValueInst(
                                    RecursionLimit);
 }
 
+/// SimplifyExtractValueInst - Given operands for an ExtractValueInst, see if we
+/// can fold the result.  If not, this returns null.
+static Value *SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs,
+                                       const Query &, unsigned) {
+  if (auto *CAgg = dyn_cast<Constant>(Agg))
+    return ConstantFoldExtractValueInstruction(CAgg, Idxs);
+
+  // extractvalue x, (insertvalue y, elt, n), n -> elt
+  unsigned NumIdxs = Idxs.size();
+  for (auto *IVI = dyn_cast<InsertValueInst>(Agg); IVI != nullptr;
+       IVI = dyn_cast<InsertValueInst>(IVI->getAggregateOperand())) {
+    ArrayRef<unsigned> InsertValueIdxs = IVI->getIndices();
+    unsigned NumInsertValueIdxs = InsertValueIdxs.size();
+    unsigned NumCommonIdxs = std::min(NumInsertValueIdxs, NumIdxs);
+    if (InsertValueIdxs.slice(0, NumCommonIdxs) ==
+        Idxs.slice(0, NumCommonIdxs)) {
+      if (NumIdxs == NumInsertValueIdxs)
+        return IVI->getInsertedValueOperand();
+      break;
+    }
+  }
+
+  return nullptr;
+}
+
+Value *llvm::SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs,
+                                      const DataLayout &DL,
+                                      const TargetLibraryInfo *TLI,
+                                      const DominatorTree *DT,
+                                      AssumptionCache *AC,
+                                      const Instruction *CxtI) {
+  return ::SimplifyExtractValueInst(Agg, Idxs, Query(DL, TLI, DT, AC, CxtI),
+                                    RecursionLimit);
+}
+
+/// SimplifyExtractElementInst - Given operands for an ExtractElementInst, see if we
+/// can fold the result.  If not, this returns null.
+static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const Query &,
+                                         unsigned) {
+  if (auto *CVec = dyn_cast<Constant>(Vec)) {
+    if (auto *CIdx = dyn_cast<Constant>(Idx))
+      return ConstantFoldExtractElementInstruction(CVec, CIdx);
+
+    // The index is not relevant if our vector is a splat.
+    if (auto *Splat = CVec->getSplatValue())
+      return Splat;
+
+    if (isa<UndefValue>(Vec))
+      return UndefValue::get(Vec->getType()->getVectorElementType());
+  }
+
+  // If extracting a specified index from the vector, see if we can recursively
+  // find a previously computed scalar that was inserted into the vector.
+  if (auto *IdxC = dyn_cast<ConstantInt>(Idx))
+    if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue()))
+      return Elt;
+
+  return nullptr;
+}
+
+Value *llvm::SimplifyExtractElementInst(
+    Value *Vec, Value *Idx, const DataLayout &DL, const TargetLibraryInfo *TLI,
+    const DominatorTree *DT, AssumptionCache *AC, const Instruction *CxtI) {
+  return ::SimplifyExtractElementInst(Vec, Idx, Query(DL, TLI, DT, AC, CxtI),
+                                      RecursionLimit);
+}
+
 /// SimplifyPHINode - See if we can fold the given phi.  If not, returns null.
 static Value *SimplifyPHINode(PHINode *PN, const Query &Q) {
   // If all of the PHI's incoming values are the same then replace the PHI node
@@ -3929,6 +4055,18 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL,
                                      IV->getIndices(), DL, TLI, DT, AC, I);
     break;
   }
+  case Instruction::ExtractValue: {
+    auto *EVI = cast<ExtractValueInst>(I);
+    Result = SimplifyExtractValueInst(EVI->getAggregateOperand(),
+                                      EVI->getIndices(), DL, TLI, DT, AC, I);
+    break;
+  }
+  case Instruction::ExtractElement: {
+    auto *EEI = cast<ExtractElementInst>(I);
+    Result = SimplifyExtractElementInst(
+        EEI->getVectorOperand(), EEI->getIndexOperand(), DL, TLI, DT, AC, I);
+    break;
+  }
   case Instruction::PHI:
     Result = SimplifyPHINode(cast<PHINode>(I), Query(DL, TLI, DT, AC, I));
     break;
@@ -3944,6 +4082,17 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL,
     break;
   }
 
+  // In general, it is possible for computeKnownBits to determine all bits in a
+  // value even when the operands are not all constants.
+  if (!Result && I->getType()->isIntegerTy()) {
+    unsigned BitWidth = I->getType()->getScalarSizeInBits();
+    APInt KnownZero(BitWidth, 0);
+    APInt KnownOne(BitWidth, 0);
+    computeKnownBits(I, KnownZero, KnownOne, DL, /*Depth*/0, AC, I, DT);
+    if ((KnownZero | KnownOne).isAllOnesValue())
+      Result = ConstantInt::get(I->getContext(), KnownOne);
+  }
+
   /// If called on unreachable code, the above logic may report that the
   /// instruction simplified to itself.  Make life easier for users by
   /// detecting that case here, returning a safe value instead.