Revert r257064. It caused failures in some sanitizer tests.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineSimplifyDemanded.cpp
index cd391d0385e927a8251c011fb555bda3dbd8ee2b..743d51483ea16048fbeb0621a71b49b2b7e54ff8 100644 (file)
@@ -13,6 +13,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "InstCombineInternal.h"
+#include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/PatternMatch.h"
 
@@ -43,19 +44,6 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
   Demanded &= OpC->getValue();
   I->setOperand(OpNo, ConstantInt::get(OpC->getType(), Demanded));
 
-  // If either 'nsw' or 'nuw' is set and the constant is negative,
-  // removing *any* bits from the constant could make overflow occur.
-  // Remove 'nsw' and 'nuw' from the instruction in this case.
-  if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(I)) {
-    assert(OBO->getOpcode() == Instruction::Add);
-    if (OBO->hasNoSignedWrap() || OBO->hasNoUnsignedWrap()) {
-      if (OpC->getValue().isNegative()) {
-        cast<BinaryOperator>(OBO)->setHasNoSignedWrap(false);
-        cast<BinaryOperator>(OBO)->setHasNoUnsignedWrap(false);
-      }
-    }
-  }
-
   return true;
 }
 
@@ -83,9 +71,9 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) {
 bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask,
                                         APInt &KnownZero, APInt &KnownOne,
                                         unsigned Depth) {
-  Value *NewVal =
-      SimplifyDemandedUseBits(U.get(), DemandedMask, KnownZero, KnownOne, Depth,
-                              dyn_cast<Instruction>(U.getUser()));
+  auto *UserI = dyn_cast<Instruction>(U.getUser());
+  Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, KnownZero,
+                                          KnownOne, Depth, UserI);
   if (!NewVal) return false;
   U = NewVal;
   return true;
@@ -419,6 +407,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     break;
   }
   case Instruction::Select:
+    // If this is a select as part of a min/max pattern, don't simplify any
+    // further in case we break the structure.
+    Value *LHS, *RHS;
+    if (matchSelectPattern(I, LHS, RHS).Flavor != SPF_UNKNOWN)
+      return nullptr;
+
     if (SimplifyDemandedBits(I->getOperandUse(2), DemandedMask, RHSKnownZero,
                              RHSKnownOne, Depth + 1) ||
         SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, LHSKnownZero,
@@ -528,113 +522,35 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     }
     break;
   }
-  case Instruction::Add: {
-    // Figure out what the input bits are.  If the top bits of the and result
-    // are not demanded, then the add doesn't demand them from its input
-    // either.
+  case Instruction::Add:
+  case Instruction::Sub: {
+    /// If the high-bits of an ADD/SUB are not demanded, then we do not care
+    /// about the high bits of the operands.
     unsigned NLZ = DemandedMask.countLeadingZeros();
-
-    // If there is a constant on the RHS, there are a variety of xformations
-    // we can do.
-    if (ConstantInt *RHS = dyn_cast<ConstantInt>(I->getOperand(1))) {
-      // If null, this should be simplified elsewhere.  Some of the xforms here
-      // won't work if the RHS is zero.
-      if (RHS->isZero())
-        break;
-
-      // If the top bit of the output is demanded, demand everything from the
-      // input.  Otherwise, we demand all the input bits except NLZ top bits.
-      APInt InDemandedBits(APInt::getLowBitsSet(BitWidth, BitWidth - NLZ));
-
-      // Find information about known zero/one bits in the input.
-      if (SimplifyDemandedBits(I->getOperandUse(0), InDemandedBits,
-                               LHSKnownZero, LHSKnownOne, Depth + 1))
-        return I;
-
-      // If the RHS of the add has bits set that can't affect the input, reduce
-      // the constant.
-      if (ShrinkDemandedConstant(I, 1, InDemandedBits))
-        return I;
-
-      // Avoid excess work.
-      if (LHSKnownZero == 0 && LHSKnownOne == 0)
-        break;
-
-      // Turn it into OR if input bits are zero.
-      if ((LHSKnownZero & RHS->getValue()) == RHS->getValue()) {
-        Instruction *Or =
-          BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1),
-                                   I->getName());
-        return InsertNewInstWith(Or, *I);
-      }
-
-      // We can say something about the output known-zero and known-one bits,
-      // depending on potential carries from the input constant and the
-      // unknowns.  For example if the LHS is known to have at most the 0x0F0F0
-      // bits set and the RHS constant is 0x01001, then we know we have a known
-      // one mask of 0x00001 and a known zero mask of 0xE0F0E.
-
-      // To compute this, we first compute the potential carry bits.  These are
-      // the bits which may be modified.  I'm not aware of a better way to do
-      // this scan.
-      const APInt &RHSVal = RHS->getValue();
-      APInt CarryBits((~LHSKnownZero + RHSVal) ^ (~LHSKnownZero ^ RHSVal));
-
-      // Now that we know which bits have carries, compute the known-1/0 sets.
-
-      // Bits are known one if they are known zero in one operand and one in the
-      // other, and there is no input carry.
-      KnownOne = ((LHSKnownZero & RHSVal) |
-                  (LHSKnownOne & ~RHSVal)) & ~CarryBits;
-
-      // Bits are known zero if they are known zero in both operands and there
-      // is no input carry.
-      KnownZero = LHSKnownZero & ~RHSVal & ~CarryBits;
-    } else {
-      // If the high-bits of this ADD are not demanded, then it does not demand
-      // the high bits of its LHS or RHS.
-      if (DemandedMask[BitWidth-1] == 0) {
-        // Right fill the mask of bits for this ADD to demand the most
-        // significant bit and all those below it.
-        APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ));
-        if (SimplifyDemandedBits(I->getOperandUse(0), DemandedFromOps,
-                                 LHSKnownZero, LHSKnownOne, Depth + 1) ||
-            SimplifyDemandedBits(I->getOperandUse(1), DemandedFromOps,
-                                 LHSKnownZero, LHSKnownOne, Depth + 1))
-          return I;
-      }
-    }
-    break;
-  }
-  case Instruction::Sub:
-    // If the high-bits of this SUB are not demanded, then it does not demand
-    // the high bits of its LHS or RHS.
-    if (DemandedMask[BitWidth-1] == 0) {
-      // Right fill the mask of bits for this SUB to demand the most
+    if (NLZ > 0) {
+      // Right fill the mask of bits for this ADD/SUB to demand the most
       // significant bit and all those below it.
-      uint32_t NLZ = DemandedMask.countLeadingZeros();
       APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ));
       if (SimplifyDemandedBits(I->getOperandUse(0), DemandedFromOps,
                                LHSKnownZero, LHSKnownOne, Depth + 1) ||
+          ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
           SimplifyDemandedBits(I->getOperandUse(1), DemandedFromOps,
-                               LHSKnownZero, LHSKnownOne, Depth + 1))
+                               LHSKnownZero, LHSKnownOne, Depth + 1)) {
+        // Disable the nsw and nuw flags here: We can no longer guarantee that
+        // we won't wrap after simplification. Removing the nsw/nuw flags is
+        // legal here because the top bit is not demanded.
+        BinaryOperator &BinOP = *cast<BinaryOperator>(I);
+        BinOP.setHasNoSignedWrap(false);
+        BinOP.setHasNoUnsignedWrap(false);
         return I;
+      }
     }
 
-    // Otherwise just hand the sub off to computeKnownBits to fill in
+    // Otherwise just hand the add/sub off to computeKnownBits to fill in
     // the known zeros and ones.
     computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI);
-
-    // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known
-    // zero.
-    if (ConstantInt *C0 = dyn_cast<ConstantInt>(I->getOperand(0))) {
-      APInt I0 = C0->getValue();
-      if ((I0 + 1).isPowerOf2() && (I0 | KnownZero).isAllOnesValue()) {
-        Instruction *Xor = BinaryOperator::CreateXor(I->getOperand(1), C0);
-        return InsertNewInstWith(Xor, *I);
-      }
-    }
     break;
+  }
   case Instruction::Shl:
     if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
       {
@@ -1141,7 +1057,13 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
     APInt LeftDemanded(DemandedElts), RightDemanded(DemandedElts);
     if (ConstantVector* CV = dyn_cast<ConstantVector>(I->getOperand(0))) {
       for (unsigned i = 0; i < VWidth; i++) {
-        if (CV->getAggregateElement(i)->isNullValue())
+        Constant *CElt = CV->getAggregateElement(i);
+        // Method isNullValue always returns false when called on a
+        // ConstantExpr. If CElt is a ConstantExpr then skip it in order to
+        // to avoid propagating incorrect information.
+        if (isa<ConstantExpr>(CElt))
+          continue;
+        if (CElt->isNullValue())
           LeftDemanded.clearBit(i);
         else
           RightDemanded.clearBit(i);
@@ -1166,6 +1088,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
     if (!VTy) break;
     unsigned InVWidth = VTy->getNumElements();
     APInt InputDemandedElts(InVWidth, 0);
+    UndefElts2 = APInt(InVWidth, 0);
     unsigned Ratio;
 
     if (VWidth == InVWidth) {
@@ -1173,29 +1096,25 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
       // elements as are demanded of us.
       Ratio = 1;
       InputDemandedElts = DemandedElts;
-    } else if (VWidth > InVWidth) {
-      // Untested so far.
-      break;
-
-      // If there are more elements in the result than there are in the source,
-      // then an input element is live if any of the corresponding output
-      // elements are live.
-      Ratio = VWidth/InVWidth;
-      for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) {
+    } else if ((VWidth % InVWidth) == 0) {
+      // If the number of elements in the output is a multiple of the number of
+      // elements in the input then an input element is live if any of the
+      // corresponding output elements are live.
+      Ratio = VWidth / InVWidth;
+      for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
         if (DemandedElts[OutIdx])
-          InputDemandedElts.setBit(OutIdx/Ratio);
-      }
-    } else {
-      // Untested so far.
-      break;
-
-      // If there are more elements in the source than there are in the result,
-      // then an input element is live if the corresponding output element is
-      // live.
-      Ratio = InVWidth/VWidth;
+          InputDemandedElts.setBit(OutIdx / Ratio);
+    } else if ((InVWidth % VWidth) == 0) {
+      // If the number of elements in the input is a multiple of the number of
+      // elements in the output then an input element is live if the
+      // corresponding output element is live.
+      Ratio = InVWidth / VWidth;
       for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx)
-        if (DemandedElts[InIdx/Ratio])
+        if (DemandedElts[InIdx / Ratio])
           InputDemandedElts.setBit(InIdx);
+    } else {
+      // Unsupported so far.
+      break;
     }
 
     // div/rem demand all inputs, because they don't want divide by zero.
@@ -1206,24 +1125,26 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
       MadeChange = true;
     }
 
-    UndefElts = UndefElts2;
-    if (VWidth > InVWidth) {
-      llvm_unreachable("Unimp");
-      // If there are more elements in the result than there are in the source,
-      // then an output element is undef if the corresponding input element is
-      // undef.
+    if (VWidth == InVWidth) {
+      UndefElts = UndefElts2;
+    } else if ((VWidth % InVWidth) == 0) {
+      // If the number of elements in the output is a multiple of the number of
+      // elements in the input then an output element is undef if the
+      // corresponding input element is undef.
       for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
-        if (UndefElts2[OutIdx/Ratio])
+        if (UndefElts2[OutIdx / Ratio])
           UndefElts.setBit(OutIdx);
-    } else if (VWidth < InVWidth) {
+    } else if ((InVWidth % VWidth) == 0) {
+      // If the number of elements in the input is a multiple of the number of
+      // elements in the output then an output element is undef if all of the
+      // corresponding input elements are undef.
+      for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) {
+        APInt SubUndef = UndefElts2.lshr(OutIdx * Ratio).zextOrTrunc(Ratio);
+        if (SubUndef.countPopulation() == Ratio)
+          UndefElts.setBit(OutIdx);
+      }
+    } else {
       llvm_unreachable("Unimp");
-      // If there are more elements in the source than there are in the result,
-      // then a result element is undef if all of the corresponding input
-      // elements are undef.
-      UndefElts = ~0ULL >> (64-VWidth);  // Start out all undef.
-      for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx)
-        if (!UndefElts2[InIdx])            // Not undef?
-          UndefElts.clearBit(InIdx/Ratio);    // Clear undef bit.
     }
     break;
   }
@@ -1321,6 +1242,15 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
       // like undef&0.  The result is known zero, not undef.
       UndefElts &= UndefElts2;
       break;
+
+    // SSE4A instructions leave the upper 64-bits of the 128-bit result
+    // in an undefined state.
+    case Intrinsic::x86_sse4a_extrq:
+    case Intrinsic::x86_sse4a_extrqi:
+    case Intrinsic::x86_sse4a_insertq:
+    case Intrinsic::x86_sse4a_insertqi:
+      UndefElts |= APInt::getHighBitsSet(VWidth, VWidth / 2);
+      break;
     }
     break;
   }