Revert r257064. It caused failures in some sanitizer tests.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineSimplifyDemanded.cpp
index c26766a95ac4fdd8672cd4242b126e7c34573ee5..743d51483ea16048fbeb0621a71b49b2b7e54ff8 100644 (file)
@@ -12,8 +12,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "InstCombine.h"
-#include "llvm/IR/DataLayout.h"
+#include "InstCombineInternal.h"
+#include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/PatternMatch.h"
 
@@ -44,13 +44,6 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
   Demanded &= OpC->getValue();
   I->setOperand(OpNo, ConstantInt::get(OpC->getType(), Demanded));
 
-  // If 'nsw' is set and the constant is negative, removing *any* bits from the
-  // constant could make overflow occur.  Remove 'nsw' from the instruction in
-  // this case.
-  if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(I))
-    if (OBO->hasNoSignedWrap() && OpC->getValue().isNegative())
-      cast<BinaryOperator>(OBO)->setHasNoSignedWrap(false);
-
   return true;
 }
 
@@ -64,8 +57,8 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) {
   APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
   APInt DemandedMask(APInt::getAllOnesValue(BitWidth));
 
-  Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask,
-                                     KnownZero, KnownOne, 0);
+  Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, KnownZero, KnownOne,
+                                     0, &Inst);
   if (!V) return false;
   if (V == &Inst) return true;
   ReplaceInstUsesWith(Inst, V);
@@ -78,8 +71,9 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) {
 bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask,
                                         APInt &KnownZero, APInt &KnownOne,
                                         unsigned Depth) {
-  Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask,
-                                          KnownZero, KnownOne, Depth);
+  auto *UserI = dyn_cast<Instruction>(U.getUser());
+  Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, KnownZero,
+                                          KnownOne, Depth, UserI);
   if (!NewVal) return false;
   U = NewVal;
   return true;
@@ -109,20 +103,18 @@ bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask,
 /// in the context where the specified bits are demanded, but not for all users.
 Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
                                              APInt &KnownZero, APInt &KnownOne,
-                                             unsigned Depth) {
+                                             unsigned Depth,
+                                             Instruction *CxtI) {
   assert(V != nullptr && "Null pointer of Value???");
   assert(Depth <= 6 && "Limit Search Depth");
   uint32_t BitWidth = DemandedMask.getBitWidth();
   Type *VTy = V->getType();
-  assert((DL || !VTy->isPointerTy()) &&
-         "SimplifyDemandedBits needs to know bit widths!");
-  assert((!DL || DL->getTypeSizeInBits(VTy->getScalarType()) == BitWidth) &&
-         (!VTy->isIntOrIntVectorTy() ||
-          VTy->getScalarSizeInBits() == BitWidth) &&
-         KnownZero.getBitWidth() == BitWidth &&
-         KnownOne.getBitWidth() == BitWidth &&
-         "Value *V, DemandedMask, KnownZero and KnownOne "
-         "must have same BitWidth");
+  assert(
+      (!VTy->isIntOrIntVectorTy() || VTy->getScalarSizeInBits() == BitWidth) &&
+      KnownZero.getBitWidth() == BitWidth &&
+      KnownOne.getBitWidth() == BitWidth &&
+      "Value *V, DemandedMask, KnownZero and KnownOne "
+      "must have same BitWidth");
   if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
     // We know all of the bits for a constant!
     KnownOne = CI->getValue() & DemandedMask;
@@ -152,7 +144,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
 
   Instruction *I = dyn_cast<Instruction>(V);
   if (!I) {
-    computeKnownBits(V, KnownZero, KnownOne, Depth);
+    computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI);
     return nullptr;        // Only analyze instructions.
   }
 
@@ -166,8 +158,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     // this instruction has a simpler value in that context.
     if (I->getOpcode() == Instruction::And) {
       // If either the LHS or the RHS are Zero, the result is zero.
-      computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1);
-      computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1);
+      computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
+                       CxtI);
+      computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
+                       CxtI);
 
       // If all of the demanded bits are known 1 on one side, return the other.
       // These bits cannot contribute to the result of the 'and' in this
@@ -188,8 +182,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       // only bits from X or Y are demanded.
 
       // If either the LHS or the RHS are One, the result is One.
-      computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1);
-      computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1);
+      computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
+                       CxtI);
+      computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
+                       CxtI);
 
       // If all of the demanded bits are known zero on one side, return the
       // other.  These bits cannot contribute to the result of the 'or' in this
@@ -213,8 +209,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       // We can simplify (X^Y) -> X or Y in the user's context if we know that
       // only bits from X or Y are demanded.
 
-      computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth+1);
-      computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1);
+      computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
+                       CxtI);
+      computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
+                       CxtI);
 
       // If all of the demanded bits are known zero on one side, return the
       // other.
@@ -225,7 +223,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     }
 
     // Compute the KnownZero/KnownOne bits to simplify things downstream.
-    computeKnownBits(I, KnownZero, KnownOne, Depth);
+    computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI);
     return nullptr;
   }
 
@@ -238,18 +236,24 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
 
   switch (I->getOpcode()) {
   default:
-    computeKnownBits(I, KnownZero, KnownOne, Depth);
+    computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI);
     break;
   case Instruction::And:
     // If either the LHS or the RHS are Zero, the result is zero.
-    if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask,
-                             RHSKnownZero, RHSKnownOne, Depth+1) ||
+    if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero,
+                             RHSKnownOne, Depth + 1) ||
         SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownZero,
-                             LHSKnownZero, LHSKnownOne, Depth+1))
+                             LHSKnownZero, LHSKnownOne, Depth + 1))
       return I;
     assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
     assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
 
+    // If the client is only demanding bits that we know, return the known
+    // constant.
+    if ((DemandedMask & ((RHSKnownZero | LHSKnownZero)|
+                         (RHSKnownOne & LHSKnownOne))) == DemandedMask)
+      return Constant::getIntegerValue(VTy, RHSKnownOne & LHSKnownOne);
+
     // If all of the demanded bits are known 1 on one side, return the other.
     // These bits cannot contribute to the result of the 'and'.
     if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) ==
@@ -274,14 +278,20 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     break;
   case Instruction::Or:
     // If either the LHS or the RHS are One, the result is One.
-    if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask,
-                             RHSKnownZero, RHSKnownOne, Depth+1) ||
+    if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero,
+                             RHSKnownOne, Depth + 1) ||
         SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownOne,
-                             LHSKnownZero, LHSKnownOne, Depth+1))
+                             LHSKnownZero, LHSKnownOne, Depth + 1))
       return I;
     assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
     assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
 
+    // If the client is only demanding bits that we know, return the known
+    // constant.
+    if ((DemandedMask & ((RHSKnownZero & LHSKnownZero)|
+                         (RHSKnownOne | LHSKnownOne))) == DemandedMask)
+      return Constant::getIntegerValue(VTy, RHSKnownOne | LHSKnownOne);
+
     // If all of the demanded bits are known zero on one side, return the other.
     // These bits cannot contribute to the result of the 'or'.
     if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) ==
@@ -310,14 +320,26 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     KnownOne = RHSKnownOne | LHSKnownOne;
     break;
   case Instruction::Xor: {
-    if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask,
-                             RHSKnownZero, RHSKnownOne, Depth+1) ||
-        SimplifyDemandedBits(I->getOperandUse(0), DemandedMask,
-                             LHSKnownZero, LHSKnownOne, Depth+1))
+    if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero,
+                             RHSKnownOne, Depth + 1) ||
+        SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, LHSKnownZero,
+                             LHSKnownOne, Depth + 1))
       return I;
     assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
     assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
 
+    // Output known-0 bits are known if clear or set in both the LHS & RHS.
+    APInt IKnownZero = (RHSKnownZero & LHSKnownZero) |
+                       (RHSKnownOne & LHSKnownOne);
+    // Output known-1 are known to be set if set in only one of the LHS, RHS.
+    APInt IKnownOne =  (RHSKnownZero & LHSKnownOne) |
+                       (RHSKnownOne & LHSKnownZero);
+
+    // If the client is only demanding bits that we know, return the known
+    // constant.
+    if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask)
+      return Constant::getIntegerValue(VTy, IKnownOne);
+
     // If all of the demanded bits are known zero on one side, return the other.
     // These bits cannot contribute to the result of the 'xor'.
     if ((DemandedMask & RHSKnownZero) == DemandedMask)
@@ -385,10 +407,16 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     break;
   }
   case Instruction::Select:
-    if (SimplifyDemandedBits(I->getOperandUse(2), DemandedMask,
-                             RHSKnownZero, RHSKnownOne, Depth+1) ||
-        SimplifyDemandedBits(I->getOperandUse(1), DemandedMask,
-                             LHSKnownZero, LHSKnownOne, Depth+1))
+    // If this is a select as part of a min/max pattern, don't simplify any
+    // further in case we break the structure.
+    Value *LHS, *RHS;
+    if (matchSelectPattern(I, LHS, RHS).Flavor != SPF_UNKNOWN)
+      return nullptr;
+
+    if (SimplifyDemandedBits(I->getOperandUse(2), DemandedMask, RHSKnownZero,
+                             RHSKnownOne, Depth + 1) ||
+        SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, LHSKnownZero,
+                             LHSKnownOne, Depth + 1))
       return I;
     assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
     assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
@@ -407,8 +435,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     DemandedMask = DemandedMask.zext(truncBf);
     KnownZero = KnownZero.zext(truncBf);
     KnownOne = KnownOne.zext(truncBf);
-    if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask,
-                             KnownZero, KnownOne, Depth+1))
+    if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero,
+                             KnownOne, Depth + 1))
       return I;
     DemandedMask = DemandedMask.trunc(BitWidth);
     KnownZero = KnownZero.trunc(BitWidth);
@@ -433,8 +461,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       // Don't touch a vector-to-scalar bitcast.
       return nullptr;
 
-    if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask,
-                             KnownZero, KnownOne, Depth+1))
+    if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero,
+                             KnownOne, Depth + 1))
       return I;
     assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
     break;
@@ -445,8 +473,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     DemandedMask = DemandedMask.trunc(SrcBitWidth);
     KnownZero = KnownZero.trunc(SrcBitWidth);
     KnownOne = KnownOne.trunc(SrcBitWidth);
-    if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask,
-                             KnownZero, KnownOne, Depth+1))
+    if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero,
+                             KnownOne, Depth + 1))
       return I;
     DemandedMask = DemandedMask.zext(BitWidth);
     KnownZero = KnownZero.zext(BitWidth);
@@ -472,8 +500,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     InputDemandedBits = InputDemandedBits.trunc(SrcBitWidth);
     KnownZero = KnownZero.trunc(SrcBitWidth);
     KnownOne = KnownOne.trunc(SrcBitWidth);
-    if (SimplifyDemandedBits(I->getOperandUse(0), InputDemandedBits,
-                             KnownZero, KnownOne, Depth+1))
+    if (SimplifyDemandedBits(I->getOperandUse(0), InputDemandedBits, KnownZero,
+                             KnownOne, Depth + 1))
       return I;
     InputDemandedBits = InputDemandedBits.zext(BitWidth);
     KnownZero = KnownZero.zext(BitWidth);
@@ -494,113 +522,35 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     }
     break;
   }
-  case Instruction::Add: {
-    // Figure out what the input bits are.  If the top bits of the and result
-    // are not demanded, then the add doesn't demand them from its input
-    // either.
+  case Instruction::Add:
+  case Instruction::Sub: {
+    /// If the high-bits of an ADD/SUB are not demanded, then we do not care
+    /// about the high bits of the operands.
     unsigned NLZ = DemandedMask.countLeadingZeros();
-
-    // If there is a constant on the RHS, there are a variety of xformations
-    // we can do.
-    if (ConstantInt *RHS = dyn_cast<ConstantInt>(I->getOperand(1))) {
-      // If null, this should be simplified elsewhere.  Some of the xforms here
-      // won't work if the RHS is zero.
-      if (RHS->isZero())
-        break;
-
-      // If the top bit of the output is demanded, demand everything from the
-      // input.  Otherwise, we demand all the input bits except NLZ top bits.
-      APInt InDemandedBits(APInt::getLowBitsSet(BitWidth, BitWidth - NLZ));
-
-      // Find information about known zero/one bits in the input.
-      if (SimplifyDemandedBits(I->getOperandUse(0), InDemandedBits,
-                               LHSKnownZero, LHSKnownOne, Depth+1))
-        return I;
-
-      // If the RHS of the add has bits set that can't affect the input, reduce
-      // the constant.
-      if (ShrinkDemandedConstant(I, 1, InDemandedBits))
-        return I;
-
-      // Avoid excess work.
-      if (LHSKnownZero == 0 && LHSKnownOne == 0)
-        break;
-
-      // Turn it into OR if input bits are zero.
-      if ((LHSKnownZero & RHS->getValue()) == RHS->getValue()) {
-        Instruction *Or =
-          BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1),
-                                   I->getName());
-        return InsertNewInstWith(Or, *I);
-      }
-
-      // We can say something about the output known-zero and known-one bits,
-      // depending on potential carries from the input constant and the
-      // unknowns.  For example if the LHS is known to have at most the 0x0F0F0
-      // bits set and the RHS constant is 0x01001, then we know we have a known
-      // one mask of 0x00001 and a known zero mask of 0xE0F0E.
-
-      // To compute this, we first compute the potential carry bits.  These are
-      // the bits which may be modified.  I'm not aware of a better way to do
-      // this scan.
-      const APInt &RHSVal = RHS->getValue();
-      APInt CarryBits((~LHSKnownZero + RHSVal) ^ (~LHSKnownZero ^ RHSVal));
-
-      // Now that we know which bits have carries, compute the known-1/0 sets.
-
-      // Bits are known one if they are known zero in one operand and one in the
-      // other, and there is no input carry.
-      KnownOne = ((LHSKnownZero & RHSVal) |
-                  (LHSKnownOne & ~RHSVal)) & ~CarryBits;
-
-      // Bits are known zero if they are known zero in both operands and there
-      // is no input carry.
-      KnownZero = LHSKnownZero & ~RHSVal & ~CarryBits;
-    } else {
-      // If the high-bits of this ADD are not demanded, then it does not demand
-      // the high bits of its LHS or RHS.
-      if (DemandedMask[BitWidth-1] == 0) {
-        // Right fill the mask of bits for this ADD to demand the most
-        // significant bit and all those below it.
-        APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ));
-        if (SimplifyDemandedBits(I->getOperandUse(0), DemandedFromOps,
-                                 LHSKnownZero, LHSKnownOne, Depth+1) ||
-            SimplifyDemandedBits(I->getOperandUse(1), DemandedFromOps,
-                                 LHSKnownZero, LHSKnownOne, Depth+1))
-          return I;
-      }
-    }
-    break;
-  }
-  case Instruction::Sub:
-    // If the high-bits of this SUB are not demanded, then it does not demand
-    // the high bits of its LHS or RHS.
-    if (DemandedMask[BitWidth-1] == 0) {
-      // Right fill the mask of bits for this SUB to demand the most
+    if (NLZ > 0) {
+      // Right fill the mask of bits for this ADD/SUB to demand the most
       // significant bit and all those below it.
-      uint32_t NLZ = DemandedMask.countLeadingZeros();
       APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ));
       if (SimplifyDemandedBits(I->getOperandUse(0), DemandedFromOps,
-                               LHSKnownZero, LHSKnownOne, Depth+1) ||
+                               LHSKnownZero, LHSKnownOne, Depth + 1) ||
+          ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
           SimplifyDemandedBits(I->getOperandUse(1), DemandedFromOps,
-                               LHSKnownZero, LHSKnownOne, Depth+1))
+                               LHSKnownZero, LHSKnownOne, Depth + 1)) {
+        // Disable the nsw and nuw flags here: We can no longer guarantee that
+        // we won't wrap after simplification. Removing the nsw/nuw flags is
+        // legal here because the top bit is not demanded.
+        BinaryOperator &BinOP = *cast<BinaryOperator>(I);
+        BinOP.setHasNoSignedWrap(false);
+        BinOP.setHasNoUnsignedWrap(false);
         return I;
+      }
     }
 
-    // Otherwise just hand the sub off to computeKnownBits to fill in
+    // Otherwise just hand the add/sub off to computeKnownBits to fill in
     // the known zeros and ones.
-    computeKnownBits(V, KnownZero, KnownOne, Depth);
-
-    // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known
-    // zero.
-    if (ConstantInt *C0 = dyn_cast<ConstantInt>(I->getOperand(0))) {
-      APInt I0 = C0->getValue();
-      if ((I0 + 1).isPowerOf2() && (I0 | KnownZero).isAllOnesValue()) {
-        Instruction *Xor = BinaryOperator::CreateXor(I->getOperand(1), C0);
-        return InsertNewInstWith(Xor, *I);
-      }
-    }
+    computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI);
     break;
+  }
   case Instruction::Shl:
     if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
       {
@@ -624,8 +574,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       else if (IOp->hasNoUnsignedWrap())
         DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
 
-      if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn,
-                               KnownZero, KnownOne, Depth+1))
+      if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero,
+                               KnownOne, Depth + 1))
         return I;
       assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
       KnownZero <<= ShiftAmt;
@@ -648,8 +598,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       if (cast<LShrOperator>(I)->isExact())
         DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
 
-      if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn,
-                               KnownZero, KnownOne, Depth+1))
+      if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero,
+                               KnownOne, Depth + 1))
         return I;
       assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
       KnownZero = APIntOps::lshr(KnownZero, ShiftAmt);
@@ -693,8 +643,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       if (cast<AShrOperator>(I)->isExact())
         DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
 
-      if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn,
-                               KnownZero, KnownOne, Depth+1))
+      if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero,
+                               KnownOne, Depth + 1))
         return I;
       assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
       // Compute the new bits that are at the top now.
@@ -734,8 +684,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
 
         APInt LowBits = RA - 1;
         APInt Mask2 = LowBits | APInt::getSignBit(BitWidth);
-        if (SimplifyDemandedBits(I->getOperandUse(0), Mask2,
-                                 LHSKnownZero, LHSKnownOne, Depth+1))
+        if (SimplifyDemandedBits(I->getOperandUse(0), Mask2, LHSKnownZero,
+                                 LHSKnownOne, Depth + 1))
           return I;
 
         // The low bits of LHS are unchanged by the srem.
@@ -760,7 +710,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     // remainder is zero.
     if (DemandedMask.isNegative() && KnownZero.isNonNegative()) {
       APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
-      computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth+1);
+      computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
+                       CxtI);
       // If it's known zero, our sign bit is also zero.
       if (LHSKnownZero.isNegative())
         KnownZero.setBit(KnownZero.getBitWidth() - 1);
@@ -769,10 +720,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
   case Instruction::URem: {
     APInt KnownZero2(BitWidth, 0), KnownOne2(BitWidth, 0);
     APInt AllOnes = APInt::getAllOnesValue(BitWidth);
-    if (SimplifyDemandedBits(I->getOperandUse(0), AllOnes,
-                             KnownZero2, KnownOne2, Depth+1) ||
-        SimplifyDemandedBits(I->getOperandUse(1), AllOnes,
-                             KnownZero2, KnownOne2, Depth+1))
+    if (SimplifyDemandedBits(I->getOperandUse(0), AllOnes, KnownZero2,
+                             KnownOne2, Depth + 1) ||
+        SimplifyDemandedBits(I->getOperandUse(1), AllOnes, KnownZero2,
+                             KnownOne2, Depth + 1))
       return I;
 
     unsigned Leaders = KnownZero2.countLeadingOnes();
@@ -822,7 +773,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
         return nullptr;
       }
     }
-    computeKnownBits(V, KnownZero, KnownOne, Depth);
+    computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI);
     break;
   }
 
@@ -1012,7 +963,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
       // Note that we can't propagate undef elt info, because we don't know
       // which elt is getting updated.
       TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts,
-                                        UndefElts2, Depth+1);
+                                        UndefElts2, Depth + 1);
       if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; }
       break;
     }
@@ -1030,7 +981,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
     APInt DemandedElts2 = DemandedElts;
     DemandedElts2.clearBit(IdxNo);
     TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts2,
-                                      UndefElts, Depth+1);
+                                      UndefElts, Depth + 1);
     if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; }
 
     // The inserted element is defined.
@@ -1058,12 +1009,12 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
 
     APInt UndefElts4(LHSVWidth, 0);
     TmpV = SimplifyDemandedVectorElts(I->getOperand(0), LeftDemanded,
-                                      UndefElts4, Depth+1);
+                                      UndefElts4, Depth + 1);
     if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; }
 
     APInt UndefElts3(LHSVWidth, 0);
     TmpV = SimplifyDemandedVectorElts(I->getOperand(1), RightDemanded,
-                                      UndefElts3, Depth+1);
+                                      UndefElts3, Depth + 1);
     if (TmpV) { I->setOperand(1, TmpV); MadeChange = true; }
 
     bool NewUndefElts = false;
@@ -1106,19 +1057,25 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
     APInt LeftDemanded(DemandedElts), RightDemanded(DemandedElts);
     if (ConstantVector* CV = dyn_cast<ConstantVector>(I->getOperand(0))) {
       for (unsigned i = 0; i < VWidth; i++) {
-        if (CV->getAggregateElement(i)->isNullValue())
+        Constant *CElt = CV->getAggregateElement(i);
+        // Method isNullValue always returns false when called on a
+        // ConstantExpr. If CElt is a ConstantExpr then skip it in order to
+        // to avoid propagating incorrect information.
+        if (isa<ConstantExpr>(CElt))
+          continue;
+        if (CElt->isNullValue())
           LeftDemanded.clearBit(i);
         else
           RightDemanded.clearBit(i);
       }
     }
 
-    TmpV = SimplifyDemandedVectorElts(I->getOperand(1), LeftDemanded,
-                                      UndefElts, Depth+1);
+    TmpV = SimplifyDemandedVectorElts(I->getOperand(1), LeftDemanded, UndefElts,
+                                      Depth + 1);
     if (TmpV) { I->setOperand(1, TmpV); MadeChange = true; }
 
     TmpV = SimplifyDemandedVectorElts(I->getOperand(2), RightDemanded,
-                                      UndefElts2, Depth+1);
+                                      UndefElts2, Depth + 1);
     if (TmpV) { I->setOperand(2, TmpV); MadeChange = true; }
 
     // Output elements are undefined if both are undefined.
@@ -1131,6 +1088,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
     if (!VTy) break;
     unsigned InVWidth = VTy->getNumElements();
     APInt InputDemandedElts(InVWidth, 0);
+    UndefElts2 = APInt(InVWidth, 0);
     unsigned Ratio;
 
     if (VWidth == InVWidth) {
@@ -1138,57 +1096,55 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
       // elements as are demanded of us.
       Ratio = 1;
       InputDemandedElts = DemandedElts;
-    } else if (VWidth > InVWidth) {
-      // Untested so far.
-      break;
-
-      // If there are more elements in the result than there are in the source,
-      // then an input element is live if any of the corresponding output
-      // elements are live.
-      Ratio = VWidth/InVWidth;
-      for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) {
+    } else if ((VWidth % InVWidth) == 0) {
+      // If the number of elements in the output is a multiple of the number of
+      // elements in the input then an input element is live if any of the
+      // corresponding output elements are live.
+      Ratio = VWidth / InVWidth;
+      for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
         if (DemandedElts[OutIdx])
-          InputDemandedElts.setBit(OutIdx/Ratio);
-      }
-    } else {
-      // Untested so far.
-      break;
-
-      // If there are more elements in the source than there are in the result,
-      // then an input element is live if the corresponding output element is
-      // live.
-      Ratio = InVWidth/VWidth;
+          InputDemandedElts.setBit(OutIdx / Ratio);
+    } else if ((InVWidth % VWidth) == 0) {
+      // If the number of elements in the input is a multiple of the number of
+      // elements in the output then an input element is live if the
+      // corresponding output element is live.
+      Ratio = InVWidth / VWidth;
       for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx)
-        if (DemandedElts[InIdx/Ratio])
+        if (DemandedElts[InIdx / Ratio])
           InputDemandedElts.setBit(InIdx);
+    } else {
+      // Unsupported so far.
+      break;
     }
 
     // div/rem demand all inputs, because they don't want divide by zero.
     TmpV = SimplifyDemandedVectorElts(I->getOperand(0), InputDemandedElts,
-                                      UndefElts2, Depth+1);
+                                      UndefElts2, Depth + 1);
     if (TmpV) {
       I->setOperand(0, TmpV);
       MadeChange = true;
     }
 
-    UndefElts = UndefElts2;
-    if (VWidth > InVWidth) {
-      llvm_unreachable("Unimp");
-      // If there are more elements in the result than there are in the source,
-      // then an output element is undef if the corresponding input element is
-      // undef.
+    if (VWidth == InVWidth) {
+      UndefElts = UndefElts2;
+    } else if ((VWidth % InVWidth) == 0) {
+      // If the number of elements in the output is a multiple of the number of
+      // elements in the input then an output element is undef if the
+      // corresponding input element is undef.
       for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
-        if (UndefElts2[OutIdx/Ratio])
+        if (UndefElts2[OutIdx / Ratio])
           UndefElts.setBit(OutIdx);
-    } else if (VWidth < InVWidth) {
+    } else if ((InVWidth % VWidth) == 0) {
+      // If the number of elements in the input is a multiple of the number of
+      // elements in the output then an output element is undef if all of the
+      // corresponding input elements are undef.
+      for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) {
+        APInt SubUndef = UndefElts2.lshr(OutIdx * Ratio).zextOrTrunc(Ratio);
+        if (SubUndef.countPopulation() == Ratio)
+          UndefElts.setBit(OutIdx);
+      }
+    } else {
       llvm_unreachable("Unimp");
-      // If there are more elements in the source than there are in the result,
-      // then a result element is undef if all of the corresponding input
-      // elements are undef.
-      UndefElts = ~0ULL >> (64-VWidth);  // Start out all undef.
-      for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx)
-        if (!UndefElts2[InIdx])            // Not undef?
-          UndefElts.clearBit(InIdx/Ratio);    // Clear undef bit.
     }
     break;
   }
@@ -1199,11 +1155,11 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
   case Instruction::Sub:
   case Instruction::Mul:
     // div/rem demand all inputs, because they don't want divide by zero.
-    TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts,
-                                      UndefElts, Depth+1);
+    TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts, UndefElts,
+                                      Depth + 1);
     if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; }
     TmpV = SimplifyDemandedVectorElts(I->getOperand(1), DemandedElts,
-                                      UndefElts2, Depth+1);
+                                      UndefElts2, Depth + 1);
     if (TmpV) { I->setOperand(1, TmpV); MadeChange = true; }
 
     // Output elements are undefined if both are undefined.  Consider things
@@ -1212,8 +1168,8 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
     break;
   case Instruction::FPTrunc:
   case Instruction::FPExt:
-    TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts,
-                                      UndefElts, Depth+1);
+    TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts, UndefElts,
+                                      Depth + 1);
     if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; }
     break;
 
@@ -1234,10 +1190,10 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
     case Intrinsic::x86_sse2_min_sd:
     case Intrinsic::x86_sse2_max_sd:
       TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts,
-                                        UndefElts, Depth+1);
+                                        UndefElts, Depth + 1);
       if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; }
       TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts,
-                                        UndefElts2, Depth+1);
+                                        UndefElts2, Depth + 1);
       if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; }
 
       // If only the low elt is demanded and this is a scalarizable intrinsic,
@@ -1286,6 +1242,15 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
       // like undef&0.  The result is known zero, not undef.
       UndefElts &= UndefElts2;
       break;
+
+    // SSE4A instructions leave the upper 64-bits of the 128-bit result
+    // in an undefined state.
+    case Intrinsic::x86_sse4a_extrq:
+    case Intrinsic::x86_sse4a_extrqi:
+    case Intrinsic::x86_sse4a_insertq:
+    case Intrinsic::x86_sse4a_insertqi:
+      UndefElts |= APInt::getHighBitsSet(VWidth, VWidth / 2);
+      break;
     }
     break;
   }