[InstCombine] SSE/AVX vector shifts demanded shift amount bits
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCalls.cpp
index 9b491641e35add2fe4b5bcba44899d0fe37e5e6a..1b9abfdacbf5cb4c2de9e847390bfad1d029f85c 100644 (file)
@@ -67,8 +67,7 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {
   unsigned CopyAlign = MI->getAlignment();
 
   if (CopyAlign < MinAlign) {
-    MI->setAlignment(ConstantInt::get(MI->getAlignmentType(),
-                                             MinAlign, false));
+    MI->setAlignment(ConstantInt::get(MI->getAlignmentType(), MinAlign, false));
     return MI;
   }
 
@@ -198,6 +197,116 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) {
   return nullptr;
 }
 
+static Value *SimplifyX86immshift(const IntrinsicInst &II,
+                                  InstCombiner::BuilderTy &Builder) {
+  bool LogicalShift = false;
+  bool ShiftLeft = false;
+
+  switch (II.getIntrinsicID()) {
+  default:
+    return nullptr;
+  case Intrinsic::x86_sse2_psra_d:
+  case Intrinsic::x86_sse2_psra_w:
+  case Intrinsic::x86_sse2_psrai_d:
+  case Intrinsic::x86_sse2_psrai_w:
+  case Intrinsic::x86_avx2_psra_d:
+  case Intrinsic::x86_avx2_psra_w:
+  case Intrinsic::x86_avx2_psrai_d:
+  case Intrinsic::x86_avx2_psrai_w:
+    LogicalShift = false; ShiftLeft = false;
+    break;
+  case Intrinsic::x86_sse2_psrl_d:
+  case Intrinsic::x86_sse2_psrl_q:
+  case Intrinsic::x86_sse2_psrl_w:
+  case Intrinsic::x86_sse2_psrli_d:
+  case Intrinsic::x86_sse2_psrli_q:
+  case Intrinsic::x86_sse2_psrli_w:
+  case Intrinsic::x86_avx2_psrl_d:
+  case Intrinsic::x86_avx2_psrl_q:
+  case Intrinsic::x86_avx2_psrl_w:
+  case Intrinsic::x86_avx2_psrli_d:
+  case Intrinsic::x86_avx2_psrli_q:
+  case Intrinsic::x86_avx2_psrli_w:
+    LogicalShift = true; ShiftLeft = false;
+    break;
+  case Intrinsic::x86_sse2_psll_d:
+  case Intrinsic::x86_sse2_psll_q:
+  case Intrinsic::x86_sse2_psll_w:
+  case Intrinsic::x86_sse2_pslli_d:
+  case Intrinsic::x86_sse2_pslli_q:
+  case Intrinsic::x86_sse2_pslli_w:
+  case Intrinsic::x86_avx2_psll_d:
+  case Intrinsic::x86_avx2_psll_q:
+  case Intrinsic::x86_avx2_psll_w:
+  case Intrinsic::x86_avx2_pslli_d:
+  case Intrinsic::x86_avx2_pslli_q:
+  case Intrinsic::x86_avx2_pslli_w:
+    LogicalShift = true; ShiftLeft = true;
+    break;
+  }
+  assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left");
+
+  // Simplify if count is constant.
+  auto Arg1 = II.getArgOperand(1);
+  auto CAZ = dyn_cast<ConstantAggregateZero>(Arg1);
+  auto CDV = dyn_cast<ConstantDataVector>(Arg1);
+  auto CInt = dyn_cast<ConstantInt>(Arg1);
+  if (!CAZ && !CDV && !CInt)
+    return nullptr;
+
+  APInt Count(64, 0);
+  if (CDV) {
+    // SSE2/AVX2 uses all the first 64-bits of the 128-bit vector
+    // operand to compute the shift amount.
+    auto VT = cast<VectorType>(CDV->getType());
+    unsigned BitWidth = VT->getElementType()->getPrimitiveSizeInBits();
+    assert((64 % BitWidth) == 0 && "Unexpected packed shift size");
+    unsigned NumSubElts = 64 / BitWidth;
+
+    // Concatenate the sub-elements to create the 64-bit value.
+    for (unsigned i = 0; i != NumSubElts; ++i) {
+      unsigned SubEltIdx = (NumSubElts - 1) - i;
+      auto SubElt = cast<ConstantInt>(CDV->getElementAsConstant(SubEltIdx));
+      Count = Count.shl(BitWidth);
+      Count |= SubElt->getValue().zextOrTrunc(64);
+    }
+  }
+  else if (CInt)
+    Count = CInt->getValue();
+
+  auto Vec = II.getArgOperand(0);
+  auto VT = cast<VectorType>(Vec->getType());
+  auto SVT = VT->getElementType();
+  unsigned VWidth = VT->getNumElements();
+  unsigned BitWidth = SVT->getPrimitiveSizeInBits();
+
+  // If shift-by-zero then just return the original value.
+  if (Count == 0)
+    return Vec;
+
+  // Handle cases when Shift >= BitWidth.
+  if (Count.uge(BitWidth)) {
+    // If LogicalShift - just return zero.
+    if (LogicalShift)
+      return ConstantAggregateZero::get(VT);
+
+    // If ArithmeticShift - clamp Shift to (BitWidth - 1).
+    Count = APInt(64, BitWidth - 1);
+  }
+
+  // Get a constant vector of the same type as the first operand.
+  auto ShiftAmt = ConstantInt::get(SVT, Count.zextOrTrunc(BitWidth));
+  auto ShiftVec = Builder.CreateVectorSplat(VWidth, ShiftAmt);
+
+  if (ShiftLeft)
+    return Builder.CreateShl(Vec, ShiftVec);
+
+  if (LogicalShift)
+    return Builder.CreateLShr(Vec, ShiftVec);
+
+  return Builder.CreateAShr(Vec, ShiftVec);
+}
+
 static Value *SimplifyX86extend(const IntrinsicInst &II,
                                 InstCombiner::BuilderTy &Builder,
                                 bool SignExtend) {
@@ -207,7 +316,7 @@ static Value *SimplifyX86extend(const IntrinsicInst &II,
 
   // Extract a subvector of the first NumDstElts lanes and sign/zero extend.
   SmallVector<int, 8> ShuffleMask;
-  for (int i = 0; i != NumDstElts; ++i)
+  for (int i = 0; i != (int)NumDstElts; ++i)
     ShuffleMask.push_back(i);
 
   Value *SV = Builder.CreateShuffleVector(II.getArgOperand(0),
@@ -722,78 +831,64 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
     break;
   }
 
-  // Constant fold <A x Bi> << Ci.
-  // FIXME: We don't handle _dq because it's a shift of an i128, but is
-  // represented in the IR as <2 x i64>. A per element shift is wrong.
-  case Intrinsic::x86_sse2_psll_d:
-  case Intrinsic::x86_sse2_psll_q:
-  case Intrinsic::x86_sse2_psll_w:
+  // Constant fold ashr( <A x Bi>, Ci ).
+  // Constant fold lshr( <A x Bi>, Ci ).
+  // Constant fold shl( <A x Bi>, Ci ).
+  case Intrinsic::x86_sse2_psrai_d:
+  case Intrinsic::x86_sse2_psrai_w:
+  case Intrinsic::x86_avx2_psrai_d:
+  case Intrinsic::x86_avx2_psrai_w:
+  case Intrinsic::x86_sse2_psrli_d:
+  case Intrinsic::x86_sse2_psrli_q:
+  case Intrinsic::x86_sse2_psrli_w:
+  case Intrinsic::x86_avx2_psrli_d:
+  case Intrinsic::x86_avx2_psrli_q:
+  case Intrinsic::x86_avx2_psrli_w:
   case Intrinsic::x86_sse2_pslli_d:
   case Intrinsic::x86_sse2_pslli_q:
   case Intrinsic::x86_sse2_pslli_w:
-  case Intrinsic::x86_avx2_psll_d:
-  case Intrinsic::x86_avx2_psll_q:
-  case Intrinsic::x86_avx2_psll_w:
   case Intrinsic::x86_avx2_pslli_d:
   case Intrinsic::x86_avx2_pslli_q:
   case Intrinsic::x86_avx2_pslli_w:
+    if (Value *V = SimplifyX86immshift(*II, *Builder))
+      return ReplaceInstUsesWith(*II, V);
+    break;
+
+  case Intrinsic::x86_sse2_psra_d:
+  case Intrinsic::x86_sse2_psra_w:
+  case Intrinsic::x86_avx2_psra_d:
+  case Intrinsic::x86_avx2_psra_w:
   case Intrinsic::x86_sse2_psrl_d:
   case Intrinsic::x86_sse2_psrl_q:
   case Intrinsic::x86_sse2_psrl_w:
-  case Intrinsic::x86_sse2_psrli_d:
-  case Intrinsic::x86_sse2_psrli_q:
-  case Intrinsic::x86_sse2_psrli_w:
   case Intrinsic::x86_avx2_psrl_d:
   case Intrinsic::x86_avx2_psrl_q:
   case Intrinsic::x86_avx2_psrl_w:
-  case Intrinsic::x86_avx2_psrli_d:
-  case Intrinsic::x86_avx2_psrli_q:
-  case Intrinsic::x86_avx2_psrli_w: {
-    // Simplify if count is constant. To 0 if >= BitWidth,
-    // otherwise to shl/lshr.
-    auto CDV = dyn_cast<ConstantDataVector>(II->getArgOperand(1));
-    auto CInt = dyn_cast<ConstantInt>(II->getArgOperand(1));
-    if (!CDV && !CInt)
-      break;
-    ConstantInt *Count;
-    if (CDV)
-      Count = cast<ConstantInt>(CDV->getElementAsConstant(0));
-    else
-      Count = CInt;
-
-    auto Vec = II->getArgOperand(0);
-    auto VT = cast<VectorType>(Vec->getType());
-    if (Count->getZExtValue() >
-        VT->getElementType()->getPrimitiveSizeInBits() - 1)
-      return ReplaceInstUsesWith(
-          CI, ConstantAggregateZero::get(Vec->getType()));
-
-    bool isPackedShiftLeft = true;
-    switch (II->getIntrinsicID()) {
-    default : break;
-    case Intrinsic::x86_sse2_psrl_d:
-    case Intrinsic::x86_sse2_psrl_q:
-    case Intrinsic::x86_sse2_psrl_w:
-    case Intrinsic::x86_sse2_psrli_d:
-    case Intrinsic::x86_sse2_psrli_q:
-    case Intrinsic::x86_sse2_psrli_w:
-    case Intrinsic::x86_avx2_psrl_d:
-    case Intrinsic::x86_avx2_psrl_q:
-    case Intrinsic::x86_avx2_psrl_w:
-    case Intrinsic::x86_avx2_psrli_d:
-    case Intrinsic::x86_avx2_psrli_q:
-    case Intrinsic::x86_avx2_psrli_w: isPackedShiftLeft = false; break;
-    }
+  case Intrinsic::x86_sse2_psll_d:
+  case Intrinsic::x86_sse2_psll_q:
+  case Intrinsic::x86_sse2_psll_w:
+  case Intrinsic::x86_avx2_psll_d:
+  case Intrinsic::x86_avx2_psll_q:
+  case Intrinsic::x86_avx2_psll_w: {
+    if (Value *V = SimplifyX86immshift(*II, *Builder))
+      return ReplaceInstUsesWith(*II, V);
 
-    unsigned VWidth = VT->getNumElements();
-    // Get a constant vector of the same type as the first operand.
-    auto VTCI = ConstantInt::get(VT->getElementType(), Count->getZExtValue());
-    if (isPackedShiftLeft)
-      return BinaryOperator::CreateShl(Vec,
-          Builder->CreateVectorSplat(VWidth, VTCI));
+    // SSE2/AVX2 uses only the first 64-bits of the 128-bit vector
+    // operand to compute the shift amount.
+    auto ShiftAmt = II->getArgOperand(1);
+    auto ShiftType = cast<VectorType>(ShiftAmt->getType());
+    assert(ShiftType->getPrimitiveSizeInBits() == 128 &&
+           "Unexpected packed shift size");
+    unsigned VWidth = ShiftType->getNumElements();
 
-    return BinaryOperator::CreateLShr(Vec,
-        Builder->CreateVectorSplat(VWidth, VTCI));
+    APInt DemandedElts = APInt::getLowBitsSet(VWidth, VWidth / 2);
+    APInt UndefElts(VWidth, 0);
+    if (Value *V =
+            SimplifyDemandedVectorElts(ShiftAmt, DemandedElts, UndefElts)) {
+      II->setArgOperand(1, V);
+      return II;
+    }
+    break;
   }
 
   case Intrinsic::x86_sse41_pmovsxbd:
@@ -922,7 +1017,20 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
     // This optimization is convoluted because the intrinsic is defined as
     // getting a vector of floats or doubles for the ps and pd versions.
     // FIXME: That should be changed.
+
+    Value *Op0 = II->getArgOperand(0);
+    Value *Op1 = II->getArgOperand(1);
     Value *Mask = II->getArgOperand(2);
+
+    // fold (blend A, A, Mask) -> A
+    if (Op0 == Op1)
+      return ReplaceInstUsesWith(CI, Op0);
+
+    // Zero Mask - select 1st argument.
+    if (isa<ConstantAggregateZero>(Mask))
+      return ReplaceInstUsesWith(CI, Op0);
+
+    // Constant Mask - select 1st/2nd argument lane based on top bit of mask.
     if (auto C = dyn_cast<ConstantDataVector>(Mask)) {
       auto Tyi1 = Builder->getInt1Ty();
       auto SelectorType = cast<VectorType>(Mask->getType());
@@ -945,11 +1053,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
         Selectors.push_back(ConstantInt::get(Tyi1, Selector >> (BitWidth - 1)));
       }
       auto NewSelector = ConstantVector::get(Selectors);
-      return SelectInst::Create(NewSelector, II->getArgOperand(1),
-                                II->getArgOperand(0), "blendv");
-    } else {
-      break;
+      return SelectInst::Create(NewSelector, Op1, Op0, "blendv");
     }
+    break;
   }
 
   case Intrinsic::x86_avx_vpermilvar_ps: