Move some rem transforms out of instcombine and into instsimplify.
[oota-llvm.git] / lib / Analysis / ValueTracking.cpp
index f8145cd0e3d9bca93a01d84a281112ea8b2b22eb..8f18dd278aa08d17979a148fa0cea94ada1b9b70 100644 (file)
 #include "llvm/Target/TargetData.h"
 #include "llvm/Support/GetElementPtrTypeIterator.h"
 #include "llvm/Support/MathExtras.h"
+#include "llvm/Support/PatternMatch.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include <cstring>
 using namespace llvm;
+using namespace llvm::PatternMatch;
+
+const unsigned MaxDepth = 6;
+
+/// getBitWidth - Returns the bitwidth of the given scalar or pointer type (if
+/// unknown returns 0).  For vector types, returns the element type's bitwidth.
+static unsigned getBitWidth(const Type *Ty, const TargetData *TD) {
+  if (unsigned BitWidth = Ty->getScalarSizeInBits())
+    return BitWidth;
+  assert(isa<PointerType>(Ty) && "Expected a pointer type!");
+  return TD ? TD->getPointerSizeInBits() : 0;
+}
 
 /// ComputeMaskedBits - Determine which of the bits specified in Mask are
 /// known to be either zero or one and return them in the KnownZero/KnownOne
@@ -47,7 +60,6 @@ using namespace llvm;
 void llvm::ComputeMaskedBits(Value *V, const APInt &Mask,
                              APInt &KnownZero, APInt &KnownOne,
                              const TargetData *TD, unsigned Depth) {
-  const unsigned MaxDepth = 6;
   assert(V && "No Value?");
   assert(Depth <= MaxDepth && "Limit Search Depth");
   unsigned BitWidth = Mask.getBitWidth();
@@ -337,7 +349,7 @@ void llvm::ComputeMaskedBits(Value *V, const APInt &Mask,
     // (ashr X, C1) & C2 == 0   iff  (-1 >> C1) & C2 == 0
     if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
       // Compute the new bits that are at the top now.
-      uint64_t ShiftAmt = SA->getLimitedValue(BitWidth);
+      uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
       
       // Signed shift right.
       APInt Mask2(Mask.shl(ShiftAmt));
@@ -417,6 +429,29 @@ void llvm::ComputeMaskedBits(Value *V, const APInt &Mask,
       KnownZero |= LHSKnownZero & Mask;
       KnownOne  |= LHSKnownOne & Mask;
     }
+
+    // Are we still trying to solve for the sign bit?
+    if (Mask.isNegative() && !KnownZero.isNegative() && !KnownOne.isNegative()){
+      OverflowingBinaryOperator *OBO = cast<OverflowingBinaryOperator>(I);
+      if (OBO->hasNoSignedWrap()) {
+        if (I->getOpcode() == Instruction::Add) {
+          // Adding two positive numbers can't wrap into negative
+          if (LHSKnownZero.isNegative() && KnownZero2.isNegative())
+            KnownZero |= APInt::getSignBit(BitWidth);
+          // and adding two negative numbers can't wrap into positive.
+          else if (LHSKnownOne.isNegative() && KnownOne2.isNegative())
+            KnownOne |= APInt::getSignBit(BitWidth);
+        } else {
+          // Subtracting a negative number from a positive one can't wrap
+          if (LHSKnownZero.isNegative() && KnownOne2.isNegative())
+            KnownZero |= APInt::getSignBit(BitWidth);
+          // neither can subtracting a positive number from a negative one.
+          else if (LHSKnownOne.isNegative() && KnownZero2.isNegative())
+            KnownOne |= APInt::getSignBit(BitWidth);
+        }
+      }
+    }
+
     return;
   }
   case Instruction::SRem:
@@ -448,6 +483,19 @@ void llvm::ComputeMaskedBits(Value *V, const APInt &Mask,
         assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); 
       }
     }
+
+    // The sign bit is the LHS's sign bit, except when the result of the
+    // remainder is zero.
+    if (Mask.isNegative() && KnownZero.isNonNegative()) {
+      APInt Mask2 = APInt::getSignBit(BitWidth);
+      APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
+      ComputeMaskedBits(I->getOperand(0), Mask2, LHSKnownZero, LHSKnownOne, TD,
+                        Depth+1);
+      // If it's known zero, our sign bit is also zero.
+      if (LHSKnownZero.isNegative())
+        KnownZero |= LHSKnownZero;
+    }
+
     break;
   case Instruction::URem: {
     if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
@@ -578,9 +626,17 @@ void llvm::ComputeMaskedBits(Value *V, const APInt &Mask,
       }
     }
 
+    // Unreachable blocks may have zero-operand PHI nodes.
+    if (P->getNumIncomingValues() == 0)
+      return;
+
     // Otherwise take the unions of the known bit sets of the operands,
     // taking conservative care to avoid excessive recursion.
     if (Depth < MaxDepth - 1 && !KnownZero && !KnownOne) {
+      // Skip if every incoming value references to ourself.
+      if (P->hasConstantValue() == P)
+        break;
+
       KnownZero = APInt::getAllOnesValue(BitWidth);
       KnownOne = APInt::getAllOnesValue(BitWidth);
       for (unsigned i = 0, e = P->getNumIncomingValues(); i != e; ++i) {
@@ -620,6 +676,182 @@ void llvm::ComputeMaskedBits(Value *V, const APInt &Mask,
   }
 }
 
+/// ComputeSignBit - Determine whether the sign bit is known to be zero or
+/// one.  Convenience wrapper around ComputeMaskedBits.
+void llvm::ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne,
+                          const TargetData *TD, unsigned Depth) {
+  unsigned BitWidth = getBitWidth(V->getType(), TD);
+  if (!BitWidth) {
+    KnownZero = false;
+    KnownOne = false;
+    return;
+  }
+  APInt ZeroBits(BitWidth, 0);
+  APInt OneBits(BitWidth, 0);
+  ComputeMaskedBits(V, APInt::getSignBit(BitWidth), ZeroBits, OneBits, TD,
+                    Depth);
+  KnownOne = OneBits[BitWidth - 1];
+  KnownZero = ZeroBits[BitWidth - 1];
+}
+
+/// isPowerOfTwo - Return true if the given value is known to have exactly one
+/// bit set when defined. For vectors return true if every element is known to
+/// be a power of two when defined.  Supports values with integer or pointer
+/// types and vectors of integers.
+bool llvm::isPowerOfTwo(Value *V, const TargetData *TD, unsigned Depth) {
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
+    return CI->getValue().isPowerOf2();
+  // TODO: Handle vector constants.
+
+  // 1 << X is clearly a power of two if the one is not shifted off the end.  If
+  // it is shifted off the end then the result is undefined.
+  if (match(V, m_Shl(m_One(), m_Value())))
+    return true;
+
+  // (signbit) >>l X is clearly a power of two if the one is not shifted off the
+  // bottom.  If it is shifted off the bottom then the result is undefined.
+  if (match(V, m_LShr(m_SignBit(), m_Value())))
+    return true;
+
+  // The remaining tests are all recursive, so bail out if we hit the limit.
+  if (Depth++ == MaxDepth)
+    return false;
+
+  if (ZExtInst *ZI = dyn_cast<ZExtInst>(V))
+    return isPowerOfTwo(ZI->getOperand(0), TD, Depth);
+
+  if (SelectInst *SI = dyn_cast<SelectInst>(V))
+    return isPowerOfTwo(SI->getTrueValue(), TD, Depth) &&
+      isPowerOfTwo(SI->getFalseValue(), TD, Depth);
+
+  // An exact divide or right shift can only shift off zero bits, so the result
+  // is a power of two only if the first operand is a power of two and not
+  // copying a sign bit (sdiv int_min, 2).
+  if (match(V, m_LShr(m_Value(), m_Value())) ||
+      match(V, m_UDiv(m_Value(), m_Value()))) {
+    PossiblyExactOperator *PEO = cast<PossiblyExactOperator>(V);
+    if (PEO->isExact())
+      return isPowerOfTwo(PEO->getOperand(0), TD, Depth);
+  }
+
+  return false;
+}
+
+/// isKnownNonZero - Return true if the given value is known to be non-zero
+/// when defined.  For vectors return true if every element is known to be
+/// non-zero when defined.  Supports values with integer or pointer type and
+/// vectors of integers.
+bool llvm::isKnownNonZero(Value *V, const TargetData *TD, unsigned Depth) {
+  if (Constant *C = dyn_cast<Constant>(V)) {
+    if (C->isNullValue())
+      return false;
+    if (isa<ConstantInt>(C))
+      // Must be non-zero due to null test above.
+      return true;
+    // TODO: Handle vectors
+    return false;
+  }
+
+  // The remaining tests are all recursive, so bail out if we hit the limit.
+  if (Depth++ == MaxDepth)
+    return false;
+
+  unsigned BitWidth = getBitWidth(V->getType(), TD);
+
+  // X | Y != 0 if X != 0 or Y != 0.
+  Value *X = 0, *Y = 0;
+  if (match(V, m_Or(m_Value(X), m_Value(Y))))
+    return isKnownNonZero(X, TD, Depth) || isKnownNonZero(Y, TD, Depth);
+
+  // ext X != 0 if X != 0.
+  if (isa<SExtInst>(V) || isa<ZExtInst>(V))
+    return isKnownNonZero(cast<Instruction>(V)->getOperand(0), TD, Depth);
+
+  // shl X, Y != 0 if X is odd.  Note that the value of the shift is undefined
+  // if the lowest bit is shifted off the end.
+  if (BitWidth && match(V, m_Shl(m_Value(X), m_Value(Y)))) {
+    // shl nuw can't remove any non-zero bits.
+    BinaryOperator *BO = cast<BinaryOperator>(V);
+    if (BO->hasNoUnsignedWrap())
+      return isKnownNonZero(X, TD, Depth);
+
+    APInt KnownZero(BitWidth, 0);
+    APInt KnownOne(BitWidth, 0);
+    ComputeMaskedBits(X, APInt(BitWidth, 1), KnownZero, KnownOne, TD, Depth);
+    if (KnownOne[0])
+      return true;
+  }
+  // shr X, Y != 0 if X is negative.  Note that the value of the shift is not
+  // defined if the sign bit is shifted off the end.
+  else if (match(V, m_Shr(m_Value(X), m_Value(Y)))) {
+    // shr exact can only shift out zero bits.
+    BinaryOperator *BO = cast<BinaryOperator>(V);
+    if (BO->isExact())
+      return isKnownNonZero(X, TD, Depth);
+
+    bool XKnownNonNegative, XKnownNegative;
+    ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth);
+    if (XKnownNegative)
+      return true;
+  }
+  // div exact can only produce a zero if the dividend is zero.
+  else if (match(V, m_IDiv(m_Value(X), m_Value()))) {
+    BinaryOperator *BO = cast<BinaryOperator>(V);
+    if (BO->isExact())
+      return isKnownNonZero(X, TD, Depth);
+  }
+  // X + Y.
+  else if (match(V, m_Add(m_Value(X), m_Value(Y)))) {
+    bool XKnownNonNegative, XKnownNegative;
+    bool YKnownNonNegative, YKnownNegative;
+    ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth);
+    ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, TD, Depth);
+
+    // If X and Y are both non-negative (as signed values) then their sum is not
+    // zero unless both X and Y are zero.
+    if (XKnownNonNegative && YKnownNonNegative)
+      if (isKnownNonZero(X, TD, Depth) || isKnownNonZero(Y, TD, Depth))
+        return true;
+
+    // If X and Y are both negative (as signed values) then their sum is not
+    // zero unless both X and Y equal INT_MIN.
+    if (BitWidth && XKnownNegative && YKnownNegative) {
+      APInt KnownZero(BitWidth, 0);
+      APInt KnownOne(BitWidth, 0);
+      APInt Mask = APInt::getSignedMaxValue(BitWidth);
+      // The sign bit of X is set.  If some other bit is set then X is not equal
+      // to INT_MIN.
+      ComputeMaskedBits(X, Mask, KnownZero, KnownOne, TD, Depth);
+      if ((KnownOne & Mask) != 0)
+        return true;
+      // The sign bit of Y is set.  If some other bit is set then Y is not equal
+      // to INT_MIN.
+      ComputeMaskedBits(Y, Mask, KnownZero, KnownOne, TD, Depth);
+      if ((KnownOne & Mask) != 0)
+        return true;
+    }
+
+    // The sum of a non-negative number and a power of two is not zero.
+    if (XKnownNonNegative && isPowerOfTwo(Y, TD, Depth))
+      return true;
+    if (YKnownNonNegative && isPowerOfTwo(X, TD, Depth))
+      return true;
+  }
+  // (C ? X : Y) != 0 if X != 0 and Y != 0.
+  else if (SelectInst *SI = dyn_cast<SelectInst>(V)) {
+    if (isKnownNonZero(SI->getTrueValue(), TD, Depth) &&
+        isKnownNonZero(SI->getFalseValue(), TD, Depth))
+      return true;
+  }
+
+  if (!BitWidth) return false;
+  APInt KnownZero(BitWidth, 0);
+  APInt KnownOne(BitWidth, 0);
+  ComputeMaskedBits(V, APInt::getAllOnesValue(BitWidth), KnownZero, KnownOne,
+                    TD, Depth);
+  return KnownOne != 0;
+}
+
 /// MaskedValueIsZero - Return true if 'V & Mask' is known to be zero.  We use
 /// this predicate to simplify operations downstream.  Mask is known to be zero
 /// for bits that V cannot have.
@@ -989,6 +1221,80 @@ bool llvm::CannotBeNegativeZero(const Value *V, unsigned Depth) {
   return false;
 }
 
+/// isBytewiseValue - If the specified value can be set by repeating the same
+/// byte in memory, return the i8 value that it is represented with.  This is
+/// true for all i8 values obviously, but is also true for i32 0, i32 -1,
+/// i16 0xF0F0, double 0.0 etc.  If the value can't be handled with a repeated
+/// byte store (e.g. i16 0x1234), return null.
+Value *llvm::isBytewiseValue(Value *V) {
+  // All byte-wide stores are splatable, even of arbitrary variables.
+  if (V->getType()->isIntegerTy(8)) return V;
+
+  // Handle 'null' ConstantArrayZero etc.
+  if (Constant *C = dyn_cast<Constant>(V))
+    if (C->isNullValue())
+      return Constant::getNullValue(Type::getInt8Ty(V->getContext()));
+  
+  // Constant float and double values can be handled as integer values if the
+  // corresponding integer value is "byteable".  An important case is 0.0. 
+  if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) {
+    if (CFP->getType()->isFloatTy())
+      V = ConstantExpr::getBitCast(CFP, Type::getInt32Ty(V->getContext()));
+    if (CFP->getType()->isDoubleTy())
+      V = ConstantExpr::getBitCast(CFP, Type::getInt64Ty(V->getContext()));
+    // Don't handle long double formats, which have strange constraints.
+  }
+  
+  // We can handle constant integers that are power of two in size and a 
+  // multiple of 8 bits.
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
+    unsigned Width = CI->getBitWidth();
+    if (isPowerOf2_32(Width) && Width > 8) {
+      // We can handle this value if the recursive binary decomposition is the
+      // same at all levels.
+      APInt Val = CI->getValue();
+      APInt Val2;
+      while (Val.getBitWidth() != 8) {
+        unsigned NextWidth = Val.getBitWidth()/2;
+        Val2  = Val.lshr(NextWidth);
+        Val2 = Val2.trunc(Val.getBitWidth()/2);
+        Val = Val.trunc(Val.getBitWidth()/2);
+        
+        // If the top/bottom halves aren't the same, reject it.
+        if (Val != Val2)
+          return 0;
+      }
+      return ConstantInt::get(V->getContext(), Val);
+    }
+  }
+  
+  // A ConstantArray is splatable if all its members are equal and also
+  // splatable.
+  if (ConstantArray *CA = dyn_cast<ConstantArray>(V)) {
+    if (CA->getNumOperands() == 0)
+      return 0;
+    
+    Value *Val = isBytewiseValue(CA->getOperand(0));
+    if (!Val)
+      return 0;
+    
+    for (unsigned I = 1, E = CA->getNumOperands(); I != E; ++I)
+      if (CA->getOperand(I-1) != CA->getOperand(I))
+        return 0;
+    
+    return Val;
+  }
+  
+  // Conceptually, we could handle things like:
+  //   %a = zext i8 %X to i16
+  //   %b = shl i16 %a, 8
+  //   %c = or i16 %a, %b
+  // but until there is an example that actually needs this, it doesn't seem
+  // worth worrying about.
+  return 0;
+}
+
+
 // This is the recursive version of BuildSubAggregate. It takes a few different
 // arguments. Idxs is the index within the nested struct From that we are
 // looking at now (which is of type IndexedType). IdxSkip is the number of
@@ -1022,7 +1328,7 @@ static Value *BuildSubAggregate(Value *From, Value* To, const Type *IndexedType,
         break;
       }
     }
-    // If we succesfully found a value for each of our subaggregates 
+    // If we successfully found a value for each of our subaggregates
     if (To)
       return To;
   }
@@ -1435,7 +1741,8 @@ uint64_t llvm::GetStringLength(Value *V) {
   return Len == ~0ULL ? 1 : Len;
 }
 
-Value *llvm::GetUnderlyingObject(Value *V, unsigned MaxLookup) {
+Value *
+llvm::GetUnderlyingObject(Value *V, const TargetData *TD, unsigned MaxLookup) {
   if (!V->getType()->isPointerTy())
     return V;
   for (unsigned Count = 0; MaxLookup == 0 || Count < MaxLookup; ++Count) {
@@ -1450,8 +1757,8 @@ Value *llvm::GetUnderlyingObject(Value *V, unsigned MaxLookup) {
     } else {
       // See if InstructionSimplify knows any relevant tricks.
       if (Instruction *I = dyn_cast<Instruction>(V))
-        // TODO: Aquire TargetData and DominatorTree and use them.
-        if (Value *Simplified = SimplifyInstruction(I, 0, 0)) {
+        // TODO: Acquire a DominatorTree and use it.
+        if (Value *Simplified = SimplifyInstruction(I, TD, 0)) {
           V = Simplified;
           continue;
         }