[PM/AA] Extract the ModRef enums from the AliasAnalysis class in
[oota-llvm.git] / lib / Analysis / InstructionSimplify.cpp
index ec56d888dc2fde8741f9f29f33812efe3e7659f7..fa42b48b6cdb9a27f5e76fa395219d8b1b542c46 100644 (file)
@@ -24,6 +24,7 @@
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Dominators.h"
@@ -854,8 +855,8 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,
       return X;
   }
 
-  // fsub nnan ninf x, x ==> 0.0
-  if (FMF.noNaNs() && FMF.noInfs() && Op0 == Op1)
+  // fsub nnan x, x ==> 0.0
+  if (FMF.noNaNs() && Op0 == Op1)
     return Constant::getNullValue(Op0->getType());
 
   return nullptr;
@@ -1126,6 +1127,21 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
   if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZero()))
     return Op0;
 
+  if (FMF.noNaNs()) {
+    // X / X -> 1.0 is legal when NaNs are ignored.
+    if (Op0 == Op1)
+      return ConstantFP::get(Op0->getType(), 1.0);
+
+    // -X /  X -> -1.0 and
+    //  X / -X -> -1.0 are legal when NaNs are ignored.
+    // We can ignore signed zeros because +-0.0/+-0.0 is NaN and ignored.
+    if ((BinaryOperator::isFNeg(Op0, /*IgnoreZeroSign=*/true) &&
+         BinaryOperator::getFNegArgument(Op0) == Op1) ||
+        (BinaryOperator::isFNeg(Op1, /*IgnoreZeroSign=*/true) &&
+         BinaryOperator::getFNegArgument(Op1) == Op0))
+      return ConstantFP::get(Op0->getType(), -1.0);
+  }
+
   return nullptr;
 }
 
@@ -3031,7 +3047,8 @@ Value *llvm::SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
 /// SimplifyFCmpInst - Given operands for an FCmpInst, see if we can
 /// fold the result.  If not, this returns null.
 static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
-                               const Query &Q, unsigned MaxRecurse) {
+                               FastMathFlags FMF, const Query &Q,
+                               unsigned MaxRecurse) {
   CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate;
   assert(CmpInst::isFPPredicate(Pred) && "Not an FP compare!");
 
@@ -3050,6 +3067,14 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
   if (Pred == FCmpInst::FCMP_TRUE)
     return ConstantInt::get(GetCompareTy(LHS), 1);
 
+  // UNO/ORD predicates can be trivially folded if NaNs are ignored.
+  if (FMF.noNaNs()) {
+    if (Pred == FCmpInst::FCMP_UNO)
+      return ConstantInt::get(GetCompareTy(LHS), 0);
+    if (Pred == FCmpInst::FCMP_ORD)
+      return ConstantInt::get(GetCompareTy(LHS), 1);
+  }
+
   // fcmp pred x, undef  and  fcmp pred undef, x
   // fold to true if unordered, false if ordered
   if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) {
@@ -3136,12 +3161,12 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
 }
 
 Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
-                              const DataLayout &DL,
+                              FastMathFlags FMF, const DataLayout &DL,
                               const TargetLibraryInfo *TLI,
                               const DominatorTree *DT, AssumptionCache *AC,
                               const Instruction *CxtI) {
-  return ::SimplifyFCmpInst(Predicate, LHS, RHS, Query(DL, TLI, DT, AC, CxtI),
-                            RecursionLimit);
+  return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF,
+                            Query(DL, TLI, DT, AC, CxtI), RecursionLimit);
 }
 
 /// SimplifyWithOpReplaced - See if V simplifies when its operand Op is
@@ -3496,6 +3521,82 @@ Value *llvm::SimplifyInsertValueInst(
                                    RecursionLimit);
 }
 
+/// SimplifyExtractValueInst - Given operands for an ExtractValueInst, see if we
+/// can fold the result.  If not, this returns null.
+static Value *SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs,
+                                       const Query &, unsigned) {
+  if (auto *CAgg = dyn_cast<Constant>(Agg))
+    return ConstantFoldExtractValueInstruction(CAgg, Idxs);
+
+  // extractvalue x, (insertvalue y, elt, n), n -> elt
+  unsigned NumIdxs = Idxs.size();
+  for (auto *IVI = dyn_cast<InsertValueInst>(Agg); IVI != nullptr;
+       IVI = dyn_cast<InsertValueInst>(IVI->getAggregateOperand())) {
+    ArrayRef<unsigned> InsertValueIdxs = IVI->getIndices();
+    unsigned NumInsertValueIdxs = InsertValueIdxs.size();
+    unsigned NumCommonIdxs = std::min(NumInsertValueIdxs, NumIdxs);
+    if (InsertValueIdxs.slice(0, NumCommonIdxs) ==
+        Idxs.slice(0, NumCommonIdxs)) {
+      if (NumIdxs == NumInsertValueIdxs)
+        return IVI->getInsertedValueOperand();
+      break;
+    }
+  }
+
+  return nullptr;
+}
+
+Value *llvm::SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs,
+                                      const DataLayout &DL,
+                                      const TargetLibraryInfo *TLI,
+                                      const DominatorTree *DT,
+                                      AssumptionCache *AC,
+                                      const Instruction *CxtI) {
+  return ::SimplifyExtractValueInst(Agg, Idxs, Query(DL, TLI, DT, AC, CxtI),
+                                    RecursionLimit);
+}
+
+/// SimplifyExtractElementInst - Given operands for an ExtractElementInst, see if we
+/// can fold the result.  If not, this returns null.
+static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const Query &,
+                                         unsigned) {
+  if (auto *CVec = dyn_cast<Constant>(Vec)) {
+    if (auto *CIdx = dyn_cast<Constant>(Idx))
+      return ConstantFoldExtractElementInstruction(CVec, CIdx);
+
+    // The index is not relevant if our vector is a splat.
+    if (auto *Splat = CVec->getSplatValue())
+      return Splat;
+
+    if (isa<UndefValue>(Vec))
+      return UndefValue::get(Vec->getType()->getVectorElementType());
+  }
+
+  // If extracting a specified index from the vector, see if we can recursively
+  // find a previously computed scalar that was inserted into the vector.
+  if (auto *IdxC = dyn_cast<ConstantInt>(Idx)) {
+    unsigned IndexVal = IdxC->getZExtValue();
+    unsigned VectorWidth = Vec->getType()->getVectorNumElements();
+
+    // If this is extracting an invalid index, turn this into undef, to avoid
+    // crashing the code below.
+    if (IndexVal >= VectorWidth)
+      return UndefValue::get(Vec->getType()->getVectorElementType());
+
+    if (Value *Elt = findScalarElement(Vec, IndexVal))
+      return Elt;
+  }
+
+  return nullptr;
+}
+
+Value *llvm::SimplifyExtractElementInst(
+    Value *Vec, Value *Idx, const DataLayout &DL, const TargetLibraryInfo *TLI,
+    const DominatorTree *DT, AssumptionCache *AC, const Instruction *CxtI) {
+  return ::SimplifyExtractElementInst(Vec, Idx, Query(DL, TLI, DT, AC, CxtI),
+                                      RecursionLimit);
+}
+
 /// SimplifyPHINode - See if we can fold the given phi.  If not, returns null.
 static Value *SimplifyPHINode(PHINode *PN, const Query &Q) {
   // If all of the PHI's incoming values are the same then replace the PHI node
@@ -3655,7 +3756,7 @@ static Value *SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
                               const Query &Q, unsigned MaxRecurse) {
   if (CmpInst::isIntPredicate((CmpInst::Predicate)Predicate))
     return SimplifyICmpInst(Predicate, LHS, RHS, Q, MaxRecurse);
-  return SimplifyFCmpInst(Predicate, LHS, RHS, Q, MaxRecurse);
+  return SimplifyFCmpInst(Predicate, LHS, RHS, FastMathFlags(), Q, MaxRecurse);
 }
 
 Value *llvm::SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
@@ -3885,9 +3986,9 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL,
                          I->getOperand(1), DL, TLI, DT, AC, I);
     break;
   case Instruction::FCmp:
-    Result =
-        SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), I->getOperand(0),
-                         I->getOperand(1), DL, TLI, DT, AC, I);
+    Result = SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(),
+                              I->getOperand(0), I->getOperand(1),
+                              I->getFastMathFlags(), DL, TLI, DT, AC, I);
     break;
   case Instruction::Select:
     Result = SimplifySelectInst(I->getOperand(0), I->getOperand(1),
@@ -3905,6 +4006,18 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL,
                                      IV->getIndices(), DL, TLI, DT, AC, I);
     break;
   }
+  case Instruction::ExtractValue: {
+    auto *EVI = cast<ExtractValueInst>(I);
+    Result = SimplifyExtractValueInst(EVI->getAggregateOperand(),
+                                      EVI->getIndices(), DL, TLI, DT, AC, I);
+    break;
+  }
+  case Instruction::ExtractElement: {
+    auto *EEI = cast<ExtractElementInst>(I);
+    Result = SimplifyExtractElementInst(
+        EEI->getVectorOperand(), EEI->getIndexOperand(), DL, TLI, DT, AC, I);
+    break;
+  }
   case Instruction::PHI:
     Result = SimplifyPHINode(cast<PHINode>(I), Query(DL, TLI, DT, AC, I));
     break;