Add one more case we compute a max trip count.
[oota-llvm.git] / lib / Analysis / InstructionSimplify.cpp
index 9d78f8bf40441fafd023f08eebb9b126931459b7..131cc97d237965bd5efbef6fd62fbac8635a3ea1 100644 (file)
@@ -48,6 +48,26 @@ static Value *SimplifyOrInst(Value *, Value *, const TargetData *,
 static Value *SimplifyXorInst(Value *, Value *, const TargetData *,
                               const DominatorTree *, unsigned);
 
+/// getFalse - For a boolean type, or a vector of boolean type, return false, or
+/// a vector with every element false, as appropriate for the type.
+static Constant *getFalse(Type *Ty) {
+  assert((Ty->isIntegerTy(1) ||
+          (Ty->isVectorTy() &&
+           cast<VectorType>(Ty)->getElementType()->isIntegerTy(1))) &&
+         "Expected i1 type or a vector of i1!");
+  return Constant::getNullValue(Ty);
+}
+
+/// getTrue - For a boolean type, or a vector of boolean type, return true, or
+/// a vector with every element true, as appropriate for the type.
+static Constant *getTrue(Type *Ty) {
+  assert((Ty->isIntegerTy(1) ||
+          (Ty->isVectorTy() &&
+           cast<VectorType>(Ty)->getElementType()->isIntegerTy(1))) &&
+         "Expected i1 type or a vector of i1!");
+  return Constant::getAllOnesValue(Ty);
+}
+
 /// ValueDominatesPHI - Does the given value dominate the specified phi node?
 static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) {
   Instruction *I = dyn_cast<Instruction>(V);
@@ -526,7 +546,7 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
     if (Constant *CRHS = dyn_cast<Constant>(Op1)) {
       Constant *Ops[] = { CLHS, CRHS };
       return ConstantFoldInstOperands(Instruction::Add, CLHS->getType(),
-                                      Ops, 2, TD);
+                                      Ops, TD);
     }
 
     // Canonicalize the constant to the RHS.
@@ -595,7 +615,7 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
     if (Constant *CRHS = dyn_cast<Constant>(Op1)) {
       Constant *Ops[] = { CLHS, CRHS };
       return ConstantFoldInstOperands(Instruction::Sub, CLHS->getType(),
-                                      Ops, 2, TD);
+                                      Ops, TD);
     }
 
   // X - undef -> undef
@@ -715,7 +735,7 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const TargetData *TD,
     if (Constant *CRHS = dyn_cast<Constant>(Op1)) {
       Constant *Ops[] = { CLHS, CRHS };
       return ConstantFoldInstOperands(Instruction::Mul, CLHS->getType(),
-                                      Ops, 2, TD);
+                                      Ops, TD);
     }
 
     // Canonicalize the constant to the RHS.
@@ -788,7 +808,7 @@ static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
   if (Constant *C0 = dyn_cast<Constant>(Op0)) {
     if (Constant *C1 = dyn_cast<Constant>(Op1)) {
       Constant *Ops[] = { C0, C1 };
-      return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, 2, TD);
+      return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, TD);
     }
   }
 
@@ -909,7 +929,7 @@ static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
   if (Constant *C0 = dyn_cast<Constant>(Op0)) {
     if (Constant *C1 = dyn_cast<Constant>(Op1)) {
       Constant *Ops[] = { C0, C1 };
-      return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, 2, TD);
+      return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, TD);
     }
   }
 
@@ -1012,7 +1032,7 @@ static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1,
   if (Constant *C0 = dyn_cast<Constant>(Op0)) {
     if (Constant *C1 = dyn_cast<Constant>(Op1)) {
       Constant *Ops[] = { C0, C1 };
-      return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, 2, TD);
+      return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, TD);
     }
   }
 
@@ -1138,7 +1158,7 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const TargetData *TD,
     if (Constant *CRHS = dyn_cast<Constant>(Op1)) {
       Constant *Ops[] = { CLHS, CRHS };
       return ConstantFoldInstOperands(Instruction::And, CLHS->getType(),
-                                      Ops, 2, TD);
+                                      Ops, TD);
     }
 
     // Canonicalize the constant to the RHS.
@@ -1227,7 +1247,7 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const TargetData *TD,
     if (Constant *CRHS = dyn_cast<Constant>(Op1)) {
       Constant *Ops[] = { CLHS, CRHS };
       return ConstantFoldInstOperands(Instruction::Or, CLHS->getType(),
-                                      Ops, 2, TD);
+                                      Ops, TD);
     }
 
     // Canonicalize the constant to the RHS.
@@ -1321,7 +1341,7 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const TargetData *TD,
     if (Constant *CRHS = dyn_cast<Constant>(Op1)) {
       Constant *Ops[] = { CLHS, CRHS };
       return ConstantFoldInstOperands(Instruction::Xor, CLHS->getType(),
-                                      Ops, 2, TD);
+                                      Ops, TD);
     }
 
     // Canonicalize the constant to the RHS.
@@ -1372,7 +1392,7 @@ Value *llvm::SimplifyXorInst(Value *Op0, Value *Op1, const TargetData *TD,
   return ::SimplifyXorInst(Op0, Op1, TD, DT, RecursionLimit);
 }
 
-static const Type *GetCompareTy(Value *Op) {
+static Type *GetCompareTy(Value *Op) {
   return CmpInst::makeCmpResultType(Op->getType());
 }
 
@@ -1413,8 +1433,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
     Pred = CmpInst::getSwappedPredicate(Pred);
   }
 
-  const Type *ITy = GetCompareTy(LHS); // The return type.
-  const Type *OpTy = LHS->getType();   // The operand type.
+  Type *ITy = GetCompareTy(LHS); // The return type.
+  Type *OpTy = LHS->getType();   // The operand type.
 
   // icmp X, X -> true/false
   // X icmp undef -> true/false.  For example, icmp ugt %X, undef -> false
@@ -1478,48 +1498,46 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
     default:
       assert(false && "Unknown ICmp predicate!");
     case ICmpInst::ICMP_ULT:
-      // getNullValue also works for vectors, unlike getFalse.
-      return Constant::getNullValue(ITy);
+      return getFalse(ITy);
     case ICmpInst::ICMP_UGE:
-      // getAllOnesValue also works for vectors, unlike getTrue.
-      return ConstantInt::getAllOnesValue(ITy);
+      return getTrue(ITy);
     case ICmpInst::ICMP_EQ:
     case ICmpInst::ICMP_ULE:
       if (isKnownNonZero(LHS, TD))
-        return Constant::getNullValue(ITy);
+        return getFalse(ITy);
       break;
     case ICmpInst::ICMP_NE:
     case ICmpInst::ICMP_UGT:
       if (isKnownNonZero(LHS, TD))
-        return ConstantInt::getAllOnesValue(ITy);
+        return getTrue(ITy);
       break;
     case ICmpInst::ICMP_SLT:
       ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, TD);
       if (LHSKnownNegative)
-        return ConstantInt::getAllOnesValue(ITy);
+        return getTrue(ITy);
       if (LHSKnownNonNegative)
-        return Constant::getNullValue(ITy);
+        return getFalse(ITy);
       break;
     case ICmpInst::ICMP_SLE:
       ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, TD);
       if (LHSKnownNegative)
-        return ConstantInt::getAllOnesValue(ITy);
+        return getTrue(ITy);
       if (LHSKnownNonNegative && isKnownNonZero(LHS, TD))
-        return Constant::getNullValue(ITy);
+        return getFalse(ITy);
       break;
     case ICmpInst::ICMP_SGE:
       ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, TD);
       if (LHSKnownNegative)
-        return Constant::getNullValue(ITy);
+        return getFalse(ITy);
       if (LHSKnownNonNegative)
-        return ConstantInt::getAllOnesValue(ITy);
+        return getTrue(ITy);
       break;
     case ICmpInst::ICMP_SGT:
       ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, TD);
       if (LHSKnownNegative)
-        return Constant::getNullValue(ITy);
+        return getFalse(ITy);
       if (LHSKnownNonNegative && isKnownNonZero(LHS, TD))
-        return ConstantInt::getAllOnesValue(ITy);
+        return getTrue(ITy);
       break;
     }
   }
@@ -1593,8 +1611,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
   if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) {
     Instruction *LI = cast<CastInst>(LHS);
     Value *SrcOp = LI->getOperand(0);
-    const Type *SrcTy = SrcOp->getType();
-    const Type *DstTy = LI->getType();
+    Type *SrcTy = SrcOp->getType();
+    Type *DstTy = LI->getType();
 
     // Turn icmp (ptrtoint x), (ptrtoint/constant) into a compare of the input
     // if the integer type is the same size as the pointer type.
@@ -1811,8 +1829,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
     case ICmpInst::ICMP_EQ:
     case ICmpInst::ICMP_UGT:
     case ICmpInst::ICMP_UGE:
-      // getNullValue also works for vectors, unlike getFalse.
-      return Constant::getNullValue(ITy);
+      return getFalse(ITy);
     case ICmpInst::ICMP_SLT:
     case ICmpInst::ICMP_SLE:
       ComputeSignBit(LHS, KnownNonNegative, KnownNegative, TD);
@@ -1822,8 +1839,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
     case ICmpInst::ICMP_NE:
     case ICmpInst::ICMP_ULT:
     case ICmpInst::ICMP_ULE:
-      // getAllOnesValue also works for vectors, unlike getTrue.
-      return Constant::getAllOnesValue(ITy);
+      return getTrue(ITy);
     }
   }
   if (RBO && match(RBO, m_URem(m_Value(), m_Specific(LHS)))) {
@@ -1840,8 +1856,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
     case ICmpInst::ICMP_NE:
     case ICmpInst::ICMP_UGT:
     case ICmpInst::ICMP_UGE:
-      // getAllOnesValue also works for vectors, unlike getTrue.
-      return Constant::getAllOnesValue(ITy);
+      return getTrue(ITy);
     case ICmpInst::ICMP_SLT:
     case ICmpInst::ICMP_SLE:
       ComputeSignBit(RHS, KnownNonNegative, KnownNegative, TD);
@@ -1851,8 +1866,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
     case ICmpInst::ICMP_EQ:
     case ICmpInst::ICMP_ULT:
     case ICmpInst::ICMP_ULE:
-      // getNullValue also works for vectors, unlike getFalse.
-      return Constant::getNullValue(ITy);
+      return getFalse(ITy);
     }
   }
 
@@ -1874,7 +1888,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
         return V;
       break;
     case Instruction::Shl: {
-      bool NUW = LBO->hasNoUnsignedWrap() && LBO->hasNoUnsignedWrap();
+      bool NUW = LBO->hasNoUnsignedWrap() && RBO->hasNoUnsignedWrap();
       bool NSW = LBO->hasNoSignedWrap() && RBO->hasNoSignedWrap();
       if (!NUW && !NSW)
         break;
@@ -1955,10 +1969,10 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
     }
     case CmpInst::ICMP_SGE:
       // Always true.
-      return Constant::getAllOnesValue(ITy);
+      return getTrue(ITy);
     case CmpInst::ICMP_SLT:
       // Always false.
-      return Constant::getNullValue(ITy);
+      return getFalse(ITy);
     }
   }
 
@@ -2025,10 +2039,10 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
     }
     case CmpInst::ICMP_UGE:
       // Always true.
-      return Constant::getAllOnesValue(ITy);
+      return getTrue(ITy);
     case CmpInst::ICMP_ULT:
       // Always false.
-      return Constant::getNullValue(ITy);
+      return getFalse(ITy);
     }
   }
 
@@ -2040,40 +2054,40 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
     // max(x, ?) pred min(x, ?).
     if (Pred == CmpInst::ICMP_SGE)
       // Always true.
-      return Constant::getAllOnesValue(ITy);
+      return getTrue(ITy);
     if (Pred == CmpInst::ICMP_SLT)
       // Always false.
-      return Constant::getNullValue(ITy);
+      return getFalse(ITy);
   } else if (match(LHS, m_SMin(m_Value(A), m_Value(B))) &&
              match(RHS, m_SMax(m_Value(C), m_Value(D))) &&
              (A == C || A == D || B == C || B == D)) {
     // min(x, ?) pred max(x, ?).
     if (Pred == CmpInst::ICMP_SLE)
       // Always true.
-      return Constant::getAllOnesValue(ITy);
+      return getTrue(ITy);
     if (Pred == CmpInst::ICMP_SGT)
       // Always false.
-      return Constant::getNullValue(ITy);
+      return getFalse(ITy);
   } else if (match(LHS, m_UMax(m_Value(A), m_Value(B))) &&
              match(RHS, m_UMin(m_Value(C), m_Value(D))) &&
              (A == C || A == D || B == C || B == D)) {
     // max(x, ?) pred min(x, ?).
     if (Pred == CmpInst::ICMP_UGE)
       // Always true.
-      return Constant::getAllOnesValue(ITy);
+      return getTrue(ITy);
     if (Pred == CmpInst::ICMP_ULT)
       // Always false.
-      return Constant::getNullValue(ITy);
+      return getFalse(ITy);
   } else if (match(LHS, m_UMin(m_Value(A), m_Value(B))) &&
              match(RHS, m_UMax(m_Value(C), m_Value(D))) &&
              (A == C || A == D || B == C || B == D)) {
     // min(x, ?) pred max(x, ?).
     if (Pred == CmpInst::ICMP_ULE)
       // Always true.
-      return Constant::getAllOnesValue(ITy);
+      return getTrue(ITy);
     if (Pred == CmpInst::ICMP_UGT)
       // Always false.
-      return Constant::getNullValue(ITy);
+      return getFalse(ITy);
   }
 
   // If the comparison is with the result of a select instruction, check whether
@@ -2204,58 +2218,86 @@ Value *llvm::SimplifySelectInst(Value *CondVal, Value *TrueVal, Value *FalseVal,
   if (TrueVal == FalseVal)
     return TrueVal;
 
-  if (isa<UndefValue>(TrueVal))   // select C, undef, X -> X
-    return FalseVal;
-  if (isa<UndefValue>(FalseVal))   // select C, X, undef -> X
-    return TrueVal;
   if (isa<UndefValue>(CondVal)) {  // select undef, X, Y -> X or Y
     if (isa<Constant>(TrueVal))
       return TrueVal;
     return FalseVal;
   }
+  if (isa<UndefValue>(TrueVal))   // select C, undef, X -> X
+    return FalseVal;
+  if (isa<UndefValue>(FalseVal))   // select C, X, undef -> X
+    return TrueVal;
 
   return 0;
 }
 
 /// SimplifyGEPInst - Given operands for an GetElementPtrInst, see if we can
 /// fold the result.  If not, this returns null.
-Value *llvm::SimplifyGEPInst(Value *const *Ops, unsigned NumOps,
+Value *llvm::SimplifyGEPInst(ArrayRef<Value *> Ops,
                              const TargetData *TD, const DominatorTree *) {
   // The type of the GEP pointer operand.
-  const PointerType *PtrTy = cast<PointerType>(Ops[0]->getType());
+  PointerType *PtrTy = cast<PointerType>(Ops[0]->getType());
 
   // getelementptr P -> P.
-  if (NumOps == 1)
+  if (Ops.size() == 1)
     return Ops[0];
 
   if (isa<UndefValue>(Ops[0])) {
     // Compute the (pointer) type returned by the GEP instruction.
-    const Type *LastType = GetElementPtrInst::getIndexedType(PtrTy, &Ops[1],
-                                                             NumOps-1);
-    const Type *GEPTy = PointerType::get(LastType, PtrTy->getAddressSpace());
+    Type *LastType = GetElementPtrInst::getIndexedType(PtrTy, Ops.slice(1));
+    Type *GEPTy = PointerType::get(LastType, PtrTy->getAddressSpace());
     return UndefValue::get(GEPTy);
   }
 
-  if (NumOps == 2) {
+  if (Ops.size() == 2) {
     // getelementptr P, 0 -> P.
     if (ConstantInt *C = dyn_cast<ConstantInt>(Ops[1]))
       if (C->isZero())
         return Ops[0];
     // getelementptr P, N -> P if P points to a type of zero size.
     if (TD) {
-      const Type *Ty = PtrTy->getElementType();
+      Type *Ty = PtrTy->getElementType();
       if (Ty->isSized() && TD->getTypeAllocSize(Ty) == 0)
         return Ops[0];
     }
   }
 
   // Check to see if this is constant foldable.
-  for (unsigned i = 0; i != NumOps; ++i)
+  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
     if (!isa<Constant>(Ops[i]))
       return 0;
 
-  return ConstantExpr::getGetElementPtr(cast<Constant>(Ops[0]),
-                                        (Constant *const*)Ops+1, NumOps-1);
+  return ConstantExpr::getGetElementPtr(cast<Constant>(Ops[0]), Ops.slice(1));
+}
+
+/// SimplifyInsertValueInst - Given operands for an InsertValueInst, see if we
+/// can fold the result.  If not, this returns null.
+Value *llvm::SimplifyInsertValueInst(Value *Agg, Value *Val,
+                                     ArrayRef<unsigned> Idxs,
+                                     const TargetData *,
+                                     const DominatorTree *) {
+  if (Constant *CAgg = dyn_cast<Constant>(Agg))
+    if (Constant *CVal = dyn_cast<Constant>(Val))
+      return ConstantFoldInsertValueInstruction(CAgg, CVal, Idxs);
+
+  // insertvalue x, undef, n -> x
+  if (match(Val, m_Undef()))
+    return Agg;
+
+  // insertvalue x, (extractvalue y, n), n
+  if (ExtractValueInst *EV = dyn_cast<ExtractValueInst>(Val))
+    if (EV->getAggregateOperand()->getType() == Agg->getType() &&
+        EV->getIndices() == Idxs) {
+      // insertvalue undef, (extractvalue y, n), n -> y
+      if (match(Agg, m_Undef()))
+        return EV->getAggregateOperand();
+
+      // insertvalue y, (extractvalue y, n), n -> y
+      if (Agg == EV->getAggregateOperand())
+        return Agg;
+    }
+
+  return 0;
 }
 
 /// SimplifyPHINode - See if we can fold the given phi.  If not, returns null.
@@ -2328,7 +2370,7 @@ static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS,
     if (Constant *CLHS = dyn_cast<Constant>(LHS))
       if (Constant *CRHS = dyn_cast<Constant>(RHS)) {
         Constant *COps[] = {CLHS, CRHS};
-        return ConstantFoldInstOperands(Opcode, LHS->getType(), COps, 2, TD);
+        return ConstantFoldInstOperands(Opcode, LHS->getType(), COps, TD);
       }
 
     // If the operation is associative, try some generic simplifications.
@@ -2456,7 +2498,14 @@ Value *llvm::SimplifyInstruction(Instruction *I, const TargetData *TD,
     break;
   case Instruction::GetElementPtr: {
     SmallVector<Value*, 8> Ops(I->op_begin(), I->op_end());
-    Result = SimplifyGEPInst(&Ops[0], Ops.size(), TD, DT);
+    Result = SimplifyGEPInst(Ops, TD, DT);
+    break;
+  }
+  case Instruction::InsertValue: {
+    InsertValueInst *IV = cast<InsertValueInst>(I);
+    Result = SimplifyInsertValueInst(IV->getAggregateOperand(),
+                                     IV->getInsertedValueOperand(),
+                                     IV->getIndices(), TD, DT);
     break;
   }
   case Instruction::PHI: