Move some rem transforms out of instcombine and into instsimplify.
[oota-llvm.git] / lib / Analysis / InstructionSimplify.cpp
index c23abb7808929fcb5bc6bef45a905fb0d131fe53..9d6d3398feb88b2a2a4df43a3a473a712fb52eeb 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #define DEBUG_TYPE "instsimplify"
+#include "llvm/Operator.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/Dominators.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Support/ConstantRange.h"
 #include "llvm/Support/PatternMatch.h"
 #include "llvm/Support/ValueHandle.h"
 #include "llvm/Target/TargetData.h"
@@ -899,6 +901,111 @@ Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, const TargetData *TD,
   return ::SimplifyFDivInst(Op0, Op1, TD, DT, RecursionLimit);
 }
 
+/// SimplifyRem - Given operands for an SRem or URem, see if we can
+/// fold the result.  If not, this returns null.
+static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
+                          const TargetData *TD, const DominatorTree *DT,
+                          unsigned MaxRecurse) {
+  if (Constant *C0 = dyn_cast<Constant>(Op0)) {
+    if (Constant *C1 = dyn_cast<Constant>(Op1)) {
+      Constant *Ops[] = { C0, C1 };
+      return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, 2, TD);
+    }
+  }
+
+  bool isSigned = Opcode == Instruction::SRem;
+
+  // X % undef -> undef
+  if (match(Op1, m_Undef()))
+    return Op1;
+
+  // undef % X -> 0
+  if (match(Op0, m_Undef()))
+    return Constant::getNullValue(Op0->getType());
+
+  // 0 % X -> 0, we don't need to preserve faults!
+  if (match(Op0, m_Zero()))
+    return Op0;
+
+  // X % 0 -> undef, we don't need to preserve faults!
+  if (match(Op1, m_Zero()))
+    return UndefValue::get(Op0->getType());
+
+  // X % 1 -> 0
+  if (match(Op1, m_One()))
+    return Constant::getNullValue(Op0->getType());
+
+  if (Op0->getType()->isIntegerTy(1))
+    // It can't be remainder by zero, hence it must be remainder by one.
+    return Constant::getNullValue(Op0->getType());
+
+  // X % X -> 0
+  if (Op0 == Op1)
+    return Constant::getNullValue(Op0->getType());
+
+  // If the operation is with the result of a select instruction, check whether
+  // operating on either branch of the select always yields the same value.
+  if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1))
+    if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, TD, DT, MaxRecurse))
+      return V;
+
+  // If the operation is with the result of a phi instruction, check whether
+  // operating on all incoming values of the phi always yields the same value.
+  if (isa<PHINode>(Op0) || isa<PHINode>(Op1))
+    if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, TD, DT, MaxRecurse))
+      return V;
+
+  return 0;
+}
+
+/// SimplifySRemInst - Given operands for an SRem, see if we can
+/// fold the result.  If not, this returns null.
+static Value *SimplifySRemInst(Value *Op0, Value *Op1, const TargetData *TD,
+                               const DominatorTree *DT, unsigned MaxRecurse) {
+  if (Value *V = SimplifyRem(Instruction::SRem, Op0, Op1, TD, DT, MaxRecurse))
+    return V;
+
+  return 0;
+}
+
+Value *llvm::SimplifySRemInst(Value *Op0, Value *Op1, const TargetData *TD,
+                              const DominatorTree *DT) {
+  return ::SimplifySRemInst(Op0, Op1, TD, DT, RecursionLimit);
+}
+
+/// SimplifyURemInst - Given operands for a URem, see if we can
+/// fold the result.  If not, this returns null.
+static Value *SimplifyURemInst(Value *Op0, Value *Op1, const TargetData *TD,
+                               const DominatorTree *DT, unsigned MaxRecurse) {
+  if (Value *V = SimplifyRem(Instruction::URem, Op0, Op1, TD, DT, MaxRecurse))
+    return V;
+
+  return 0;
+}
+
+Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const TargetData *TD,
+                              const DominatorTree *DT) {
+  return ::SimplifyURemInst(Op0, Op1, TD, DT, RecursionLimit);
+}
+
+static Value *SimplifyFRemInst(Value *Op0, Value *Op1, const TargetData *,
+                               const DominatorTree *, unsigned) {
+  // undef % X -> undef    (the undef could be a snan).
+  if (match(Op0, m_Undef()))
+    return Op0;
+
+  // X % undef -> undef
+  if (match(Op1, m_Undef()))
+    return Op1;
+
+  return 0;
+}
+
+Value *llvm::SimplifyFRemInst(Value *Op0, Value *Op1, const TargetData *TD,
+                              const DominatorTree *DT) {
+  return ::SimplifyFRemInst(Op0, Op1, TD, DT, RecursionLimit);
+}
+
 /// SimplifyShift - Given operands for an Shl, LShr or AShr, see if we can
 /// fold the result.  If not, this returns null.
 static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1,
@@ -1343,7 +1450,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
   // the compare, and if only one of them is then we moved it to RHS already.
   if (isa<AllocaInst>(LHS) && (isa<GlobalValue>(RHS) || isa<AllocaInst>(RHS) ||
                                isa<ConstantPointerNull>(RHS)))
-    // We already know that LHS != LHS.
+    // We already know that LHS != RHS.
     return ConstantInt::get(ITy, CmpInst::isFalseWhenEqual(Pred));
 
   // If we are comparing with zero then try hard since this is a common case.
@@ -1399,44 +1506,67 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
 
   // See if we are doing a comparison with a constant integer.
   if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
-    switch (Pred) {
-    default: break;
-    case ICmpInst::ICMP_UGT:
-      if (CI->isMaxValue(false))                 // A >u MAX -> FALSE
-        return ConstantInt::getFalse(CI->getContext());
-      break;
-    case ICmpInst::ICMP_UGE:
-      if (CI->isMinValue(false))                 // A >=u MIN -> TRUE
-        return ConstantInt::getTrue(CI->getContext());
-      break;
-    case ICmpInst::ICMP_ULT:
-      if (CI->isMinValue(false))                 // A <u MIN -> FALSE
-        return ConstantInt::getFalse(CI->getContext());
-      break;
-    case ICmpInst::ICMP_ULE:
-      if (CI->isMaxValue(false))                 // A <=u MAX -> TRUE
-        return ConstantInt::getTrue(CI->getContext());
-      break;
-    case ICmpInst::ICMP_SGT:
-      if (CI->isMaxValue(true))                  // A >s MAX -> FALSE
-        return ConstantInt::getFalse(CI->getContext());
-      break;
-    case ICmpInst::ICMP_SGE:
-      if (CI->isMinValue(true))                  // A >=s MIN -> TRUE
-        return ConstantInt::getTrue(CI->getContext());
-      break;
-    case ICmpInst::ICMP_SLT:
-      if (CI->isMinValue(true))                  // A <s MIN -> FALSE
-        return ConstantInt::getFalse(CI->getContext());
-      break;
-    case ICmpInst::ICMP_SLE:
-      if (CI->isMaxValue(true))                  // A <=s MAX -> TRUE
-        return ConstantInt::getTrue(CI->getContext());
-      break;
+    // Rule out tautological comparisons (eg., ult 0 or uge 0).
+    ConstantRange RHS_CR = ICmpInst::makeConstantRange(Pred, CI->getValue());
+    if (RHS_CR.isEmptySet())
+      return ConstantInt::getFalse(CI->getContext());
+    if (RHS_CR.isFullSet())
+      return ConstantInt::getTrue(CI->getContext());
+
+    // Many binary operators with constant RHS have easy to compute constant
+    // range.  Use them to check whether the comparison is a tautology.
+    uint32_t Width = CI->getBitWidth();
+    APInt Lower = APInt(Width, 0);
+    APInt Upper = APInt(Width, 0);
+    ConstantInt *CI2;
+    if (match(LHS, m_URem(m_Value(), m_ConstantInt(CI2)))) {
+      // 'urem x, CI2' produces [0, CI2).
+      Upper = CI2->getValue();
+    } else if (match(LHS, m_SRem(m_Value(), m_ConstantInt(CI2)))) {
+      // 'srem x, CI2' produces (-|CI2|, |CI2|).
+      Upper = CI2->getValue().abs();
+      Lower = (-Upper) + 1;
+    } else if (match(LHS, m_UDiv(m_Value(), m_ConstantInt(CI2)))) {
+      // 'udiv x, CI2' produces [0, UINT_MAX / CI2].
+      APInt NegOne = APInt::getAllOnesValue(Width);
+      if (!CI2->isZero())
+        Upper = NegOne.udiv(CI2->getValue()) + 1;
+    } else if (match(LHS, m_SDiv(m_Value(), m_ConstantInt(CI2)))) {
+      // 'sdiv x, CI2' produces [INT_MIN / CI2, INT_MAX / CI2].
+      APInt IntMin = APInt::getSignedMinValue(Width);
+      APInt IntMax = APInt::getSignedMaxValue(Width);
+      APInt Val = CI2->getValue().abs();
+      if (!Val.isMinValue()) {
+        Lower = IntMin.sdiv(Val);
+        Upper = IntMax.sdiv(Val) + 1;
+      }
+    } else if (match(LHS, m_LShr(m_Value(), m_ConstantInt(CI2)))) {
+      // 'lshr x, CI2' produces [0, UINT_MAX >> CI2].
+      APInt NegOne = APInt::getAllOnesValue(Width);
+      if (CI2->getValue().ult(Width))
+        Upper = NegOne.lshr(CI2->getValue()) + 1;
+    } else if (match(LHS, m_AShr(m_Value(), m_ConstantInt(CI2)))) {
+      // 'ashr x, CI2' produces [INT_MIN >> CI2, INT_MAX >> CI2].
+      APInt IntMin = APInt::getSignedMinValue(Width);
+      APInt IntMax = APInt::getSignedMaxValue(Width);
+      if (CI2->getValue().ult(Width)) {
+        Lower = IntMin.ashr(CI2->getValue());
+        Upper = IntMax.ashr(CI2->getValue()) + 1;
+      }
+    } else if (match(LHS, m_Or(m_Value(), m_ConstantInt(CI2)))) {
+      // 'or x, CI2' produces [CI2, UINT_MAX].
+      Lower = CI2->getValue();
+    } else if (match(LHS, m_And(m_Value(), m_ConstantInt(CI2)))) {
+      // 'and x, CI2' produces [0, CI2].
+      Upper = CI2->getValue() + 1;
+    }
+    if (Lower != Upper) {
+      ConstantRange LHS_CR = ConstantRange(Lower, Upper);
+      if (RHS_CR.contains(LHS_CR))
+        return ConstantInt::getTrue(RHS->getContext());
+      if (RHS_CR.inverse().contains(LHS_CR))
+        return ConstantInt::getFalse(RHS->getContext());
     }
-
-    // TODO: integer div and rem with constant RHS all produce values in a
-    // constant range. We should check whether the comparison is a tautology.
   }
 
   // Compare of cast, for example (zext X) != 0 -> X != 0
@@ -1648,19 +1778,91 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
   }
 
   if (LBO && match(LBO, m_URem(m_Value(), m_Specific(RHS)))) {
+    bool KnownNonNegative, KnownNegative;
     switch (Pred) {
     default:
       break;
+    case ICmpInst::ICMP_SGT:
+    case ICmpInst::ICMP_SGE:
+      ComputeSignBit(LHS, KnownNonNegative, KnownNegative, TD);
+      if (!KnownNonNegative)
+        break;
+      // fall-through
     case ICmpInst::ICMP_EQ:
     case ICmpInst::ICMP_UGT:
     case ICmpInst::ICMP_UGE:
       return ConstantInt::getFalse(RHS->getContext());
+    case ICmpInst::ICMP_SLT:
+    case ICmpInst::ICMP_SLE:
+      ComputeSignBit(LHS, KnownNonNegative, KnownNegative, TD);
+      if (!KnownNonNegative)
+        break;
+      // fall-through
     case ICmpInst::ICMP_NE:
     case ICmpInst::ICMP_ULT:
     case ICmpInst::ICMP_ULE:
       return ConstantInt::getTrue(RHS->getContext());
     }
   }
+  if (RBO && match(RBO, m_URem(m_Value(), m_Specific(LHS)))) {
+    bool KnownNonNegative, KnownNegative;
+    switch (Pred) {
+    default:
+      break;
+    case ICmpInst::ICMP_SGT:
+    case ICmpInst::ICMP_SGE:
+      ComputeSignBit(RHS, KnownNonNegative, KnownNegative, TD);
+      if (!KnownNonNegative)
+        break;
+      // fall-through
+    case ICmpInst::ICMP_NE:
+    case ICmpInst::ICMP_UGT:
+    case ICmpInst::ICMP_UGE:
+      return ConstantInt::getTrue(RHS->getContext());
+    case ICmpInst::ICMP_SLT:
+    case ICmpInst::ICMP_SLE:
+      ComputeSignBit(RHS, KnownNonNegative, KnownNegative, TD);
+      if (!KnownNonNegative)
+        break;
+      // fall-through
+    case ICmpInst::ICMP_EQ:
+    case ICmpInst::ICMP_ULT:
+    case ICmpInst::ICMP_ULE:
+      return ConstantInt::getFalse(RHS->getContext());
+    }
+  }
+
+  if (MaxRecurse && LBO && RBO && LBO->getOpcode() == RBO->getOpcode() &&
+      LBO->getOperand(1) == RBO->getOperand(1)) {
+    switch (LBO->getOpcode()) {
+    default: break;
+    case Instruction::UDiv:
+    case Instruction::LShr:
+      if (ICmpInst::isSigned(Pred))
+        break;
+      // fall-through
+    case Instruction::SDiv:
+    case Instruction::AShr:
+      if (!LBO->isExact() && !RBO->isExact())
+        break;
+      if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0),
+                                      RBO->getOperand(0), TD, DT, MaxRecurse-1))
+        return V;
+      break;
+    case Instruction::Shl: {
+      bool NUW = LBO->hasNoUnsignedWrap() && LBO->hasNoUnsignedWrap();
+      bool NSW = LBO->hasNoSignedWrap() && RBO->hasNoSignedWrap();
+      if (!NUW && !NSW)
+        break;
+      if (!NSW && ICmpInst::isSigned(Pred))
+        break;
+      if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0),
+                                      RBO->getOperand(0), TD, DT, MaxRecurse-1))
+        return V;
+      break;
+    }
+    }
+  }
 
   // If the comparison is with the result of a select instruction, check whether
   // comparing with either branch of the select always yields the same value.
@@ -1897,6 +2099,9 @@ static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS,
   case Instruction::SDiv: return SimplifySDivInst(LHS, RHS, TD, DT, MaxRecurse);
   case Instruction::UDiv: return SimplifyUDivInst(LHS, RHS, TD, DT, MaxRecurse);
   case Instruction::FDiv: return SimplifyFDivInst(LHS, RHS, TD, DT, MaxRecurse);
+  case Instruction::SRem: return SimplifySRemInst(LHS, RHS, TD, DT, MaxRecurse);
+  case Instruction::URem: return SimplifyURemInst(LHS, RHS, TD, DT, MaxRecurse);
+  case Instruction::FRem: return SimplifyFRemInst(LHS, RHS, TD, DT, MaxRecurse);
   case Instruction::Shl:
     return SimplifyShlInst(LHS, RHS, /*isNSW*/false, /*isNUW*/false,
                            TD, DT, MaxRecurse);
@@ -1991,6 +2196,15 @@ Value *llvm::SimplifyInstruction(Instruction *I, const TargetData *TD,
   case Instruction::FDiv:
     Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), TD, DT);
     break;
+  case Instruction::SRem:
+    Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), TD, DT);
+    break;
+  case Instruction::URem:
+    Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), TD, DT);
+    break;
+  case Instruction::FRem:
+    Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), TD, DT);
+    break;
   case Instruction::Shl:
     Result = SimplifyShlInst(I->getOperand(0), I->getOperand(1),
                              cast<BinaryOperator>(I)->hasNoSignedWrap(),