Fix logic error optimizing "icmp pred (urem X, Y), Y" where pred is signed.
[oota-llvm.git] / lib / Analysis / InstructionSimplify.cpp
index e296215442c6584a4dcc93757ca7bd1951b760ab..52c045658623b180982ca4fd9317ecfa14d3b301 100644 (file)
@@ -667,9 +667,9 @@ Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
 /// This is very similar to GetPointerBaseWithConstantOffset except it doesn't
 /// follow non-inbounds geps. This allows it to remain usable for icmp ult/etc.
 /// folding.
-static ConstantInt *stripAndComputeConstantOffsets(const DataLayout *TD,
-                                                   Value *&V) {
-  assert(V->getType()->isPointerTy());
+static Constant *stripAndComputeConstantOffsets(const DataLayout *TD,
+                                                Value *&V) {
+  assert(V->getType()->getScalarType()->isPointerTy());
 
   // Without DataLayout, just be conservative for now. Theoretically, more could
   // be done in this case.
@@ -697,11 +697,16 @@ static ConstantInt *stripAndComputeConstantOffsets(const DataLayout *TD,
     } else {
       break;
     }
-    assert(V->getType()->isPointerTy() && "Unexpected operand type!");
+    assert(V->getType()->getScalarType()->isPointerTy() &&
+           "Unexpected operand type!");
   } while (Visited.insert(V));
 
   Type *IntPtrTy = TD->getIntPtrType(V->getContext());
-  return cast<ConstantInt>(ConstantInt::get(IntPtrTy, Offset));
+  Constant *OffsetIntPtr = ConstantInt::get(IntPtrTy, Offset);
+  if (V->getType()->isVectorTy())
+    return ConstantVector::getSplat(V->getType()->getVectorNumElements(),
+                                    OffsetIntPtr);
+  return OffsetIntPtr;
 }
 
 /// \brief Compute the constant difference between two pointer values.
@@ -1358,6 +1363,10 @@ static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact,
   if (Value *V = SimplifyShift(Instruction::LShr, Op0, Op1, Q, MaxRecurse))
     return V;
 
+  // X >> X -> 0
+  if (Op0 == Op1)
+    return Constant::getNullValue(Op0->getType());
+
   // undef >>l X -> 0
   if (match(Op0, m_Undef()))
     return Constant::getNullValue(Op0->getType());
@@ -1386,6 +1395,10 @@ static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact,
   if (Value *V = SimplifyShift(Instruction::AShr, Op0, Op1, Q, MaxRecurse))
     return V;
 
+  // X >> X -> 0
+  if (Op0 == Op1)
+    return Constant::getNullValue(Op0->getType());
+
   // all ones >>a X -> all ones
   if (match(Op0, m_AllOnes()))
     return Op0;
@@ -1706,7 +1719,7 @@ static Value *ExtractEquivalentCondition(Value *V, CmpInst::Predicate Pred,
 //    subobject at its beginning) or function, both are pointers to one past the
 //    last element of the same array object, or one is a pointer to one past the
 //    end of one array object and the other is a pointer to the start of a
-//    different array object that happens to immediately follow the rst array
+//    different array object that happens to immediately follow the first array
 //    object in the address space.)
 //
 // C11's version is more restrictive, however there's no reason why an argument
@@ -1758,8 +1771,8 @@ static Constant *computePointerICmp(const DataLayout *TD,
   // numerous hazards. AliasAnalysis and its utilities rely on special rules
   // governing loads and stores which don't apply to icmps. Also, AliasAnalysis
   // doesn't need to guarantee pointer inequality when it says NoAlias.
-  ConstantInt *LHSOffset = stripAndComputeConstantOffsets(TD, LHS);
-  ConstantInt *RHSOffset = stripAndComputeConstantOffsets(TD, RHS);
+  Constant *LHSOffset = stripAndComputeConstantOffsets(TD, LHS);
+  Constant *RHSOffset = stripAndComputeConstantOffsets(TD, RHS);
 
   // If LHS and RHS are related via constant offsets to the same base
   // value, we can replace it with an icmp which just compares the offsets.
@@ -1799,11 +1812,14 @@ static Constant *computePointerICmp(const DataLayout *TD,
     // address, due to canonicalization and constant folding.
     if (isa<AllocaInst>(LHS) &&
         (isa<AllocaInst>(RHS) || isa<GlobalVariable>(RHS))) {
+      ConstantInt *LHSOffsetCI = dyn_cast<ConstantInt>(LHSOffset);
+      ConstantInt *RHSOffsetCI = dyn_cast<ConstantInt>(RHSOffset);
       uint64_t LHSSize, RHSSize;
-      if (getObjectSize(LHS, LHSSize, TD, TLI) &&
+      if (LHSOffsetCI && RHSOffsetCI &&
+          getObjectSize(LHS, LHSSize, TD, TLI) &&
           getObjectSize(RHS, RHSSize, TD, TLI)) {
-        const APInt &LHSOffsetValue = LHSOffset->getValue();
-        const APInt &RHSOffsetValue = RHSOffset->getValue();
+        const APInt &LHSOffsetValue = LHSOffsetCI->getValue();
+        const APInt &RHSOffsetValue = RHSOffsetCI->getValue();
         if (!LHSOffsetValue.isNegative() &&
             !RHSOffsetValue.isNegative() &&
             LHSOffsetValue.ult(LHSSize) &&
@@ -2230,6 +2246,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
     }
   }
 
+  // icmp pred (urem X, Y), Y
   if (LBO && match(LBO, m_URem(m_Value(), m_Specific(RHS)))) {
     bool KnownNonNegative, KnownNegative;
     switch (Pred) {
@@ -2237,7 +2254,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
       break;
     case ICmpInst::ICMP_SGT:
     case ICmpInst::ICMP_SGE:
-      ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.TD);
+      ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.TD);
       if (!KnownNonNegative)
         break;
       // fall-through
@@ -2247,7 +2264,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
       return getFalse(ITy);
     case ICmpInst::ICMP_SLT:
     case ICmpInst::ICMP_SLE:
-      ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.TD);
+      ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.TD);
       if (!KnownNonNegative)
         break;
       // fall-through
@@ -2257,6 +2274,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
       return getTrue(ITy);
     }
   }
+
+  // icmp pred X, (urem Y, X)
   if (RBO && match(RBO, m_URem(m_Value(), m_Specific(LHS)))) {
     bool KnownNonNegative, KnownNegative;
     switch (Pred) {
@@ -2264,7 +2283,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
       break;
     case ICmpInst::ICMP_SGT:
     case ICmpInst::ICMP_SGE:
-      ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.TD);
+      ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.TD);
       if (!KnownNonNegative)
         break;
       // fall-through
@@ -2274,7 +2293,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
       return getTrue(ITy);
     case ICmpInst::ICMP_SLT:
     case ICmpInst::ICMP_SLE:
-      ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.TD);
+      ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.TD);
       if (!KnownNonNegative)
         break;
       // fall-through
@@ -2917,6 +2936,37 @@ Value *llvm::SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
                            RecursionLimit);
 }
 
+static bool IsIdempotent(Intrinsic::ID ID) {
+  switch (ID) {
+  default: return false;
+
+  // Unary idempotent: f(f(x)) = f(x)
+  case Intrinsic::fabs:
+  case Intrinsic::floor:
+  case Intrinsic::ceil:
+  case Intrinsic::trunc:
+  case Intrinsic::rint:
+  case Intrinsic::nearbyint:
+    return true;
+  }
+}
+
+template <typename IterTy>
+static Value *SimplifyIntrinsic(Intrinsic::ID IID, IterTy ArgBegin, IterTy ArgEnd,
+                                const Query &Q, unsigned MaxRecurse) {
+  // Perform idempotent optimizations
+  if (!IsIdempotent(IID))
+    return 0;
+
+  // Unary Ops
+  if (std::distance(ArgBegin, ArgEnd) == 1)
+    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(*ArgBegin))
+      if (II->getIntrinsicID() == IID)
+        return II;
+
+  return 0;
+}
+
 template <typename IterTy>
 static Value *SimplifyCall(Value *V, IterTy ArgBegin, IterTy ArgEnd,
                            const Query &Q, unsigned MaxRecurse) {
@@ -2933,6 +2983,11 @@ static Value *SimplifyCall(Value *V, IterTy ArgBegin, IterTy ArgEnd,
   if (!F)
     return 0;
 
+  if (unsigned IID = F->getIntrinsicID())
+    if (Value *Ret =
+        SimplifyIntrinsic((Intrinsic::ID) IID, ArgBegin, ArgEnd, Q, MaxRecurse))
+      return Ret;
+
   if (!canConstantFoldCallTo(F))
     return 0;