[InstCombine] SSE/AVX vector shifts demanded shift amount bits
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCalls.cpp
index 668789b516d0704606936c9a27c7517e6fed7321..1b9abfdacbf5cb4c2de9e847390bfad1d029f85c 100644 (file)
@@ -198,35 +198,113 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) {
 }
 
 static Value *SimplifyX86immshift(const IntrinsicInst &II,
-                                  InstCombiner::BuilderTy &Builder,
-                                  bool ShiftLeft) {
-  // Simplify if count is constant. To 0 if >= BitWidth,
-  // otherwise to shl/lshr.
-  auto CDV = dyn_cast<ConstantDataVector>(II.getArgOperand(1));
-  auto CInt = dyn_cast<ConstantInt>(II.getArgOperand(1));
-  if (!CDV && !CInt)
+                                  InstCombiner::BuilderTy &Builder) {
+  bool LogicalShift = false;
+  bool ShiftLeft = false;
+
+  switch (II.getIntrinsicID()) {
+  default:
+    return nullptr;
+  case Intrinsic::x86_sse2_psra_d:
+  case Intrinsic::x86_sse2_psra_w:
+  case Intrinsic::x86_sse2_psrai_d:
+  case Intrinsic::x86_sse2_psrai_w:
+  case Intrinsic::x86_avx2_psra_d:
+  case Intrinsic::x86_avx2_psra_w:
+  case Intrinsic::x86_avx2_psrai_d:
+  case Intrinsic::x86_avx2_psrai_w:
+    LogicalShift = false; ShiftLeft = false;
+    break;
+  case Intrinsic::x86_sse2_psrl_d:
+  case Intrinsic::x86_sse2_psrl_q:
+  case Intrinsic::x86_sse2_psrl_w:
+  case Intrinsic::x86_sse2_psrli_d:
+  case Intrinsic::x86_sse2_psrli_q:
+  case Intrinsic::x86_sse2_psrli_w:
+  case Intrinsic::x86_avx2_psrl_d:
+  case Intrinsic::x86_avx2_psrl_q:
+  case Intrinsic::x86_avx2_psrl_w:
+  case Intrinsic::x86_avx2_psrli_d:
+  case Intrinsic::x86_avx2_psrli_q:
+  case Intrinsic::x86_avx2_psrli_w:
+    LogicalShift = true; ShiftLeft = false;
+    break;
+  case Intrinsic::x86_sse2_psll_d:
+  case Intrinsic::x86_sse2_psll_q:
+  case Intrinsic::x86_sse2_psll_w:
+  case Intrinsic::x86_sse2_pslli_d:
+  case Intrinsic::x86_sse2_pslli_q:
+  case Intrinsic::x86_sse2_pslli_w:
+  case Intrinsic::x86_avx2_psll_d:
+  case Intrinsic::x86_avx2_psll_q:
+  case Intrinsic::x86_avx2_psll_w:
+  case Intrinsic::x86_avx2_pslli_d:
+  case Intrinsic::x86_avx2_pslli_q:
+  case Intrinsic::x86_avx2_pslli_w:
+    LogicalShift = true; ShiftLeft = true;
+    break;
+  }
+  assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left");
+
+  // Simplify if count is constant.
+  auto Arg1 = II.getArgOperand(1);
+  auto CAZ = dyn_cast<ConstantAggregateZero>(Arg1);
+  auto CDV = dyn_cast<ConstantDataVector>(Arg1);
+  auto CInt = dyn_cast<ConstantInt>(Arg1);
+  if (!CAZ && !CDV && !CInt)
     return nullptr;
-  ConstantInt *Count;
-  if (CDV)
-    Count = cast<ConstantInt>(CDV->getElementAsConstant(0));
-  else
-    Count = CInt;
+
+  APInt Count(64, 0);
+  if (CDV) {
+    // SSE2/AVX2 uses all the first 64-bits of the 128-bit vector
+    // operand to compute the shift amount.
+    auto VT = cast<VectorType>(CDV->getType());
+    unsigned BitWidth = VT->getElementType()->getPrimitiveSizeInBits();
+    assert((64 % BitWidth) == 0 && "Unexpected packed shift size");
+    unsigned NumSubElts = 64 / BitWidth;
+
+    // Concatenate the sub-elements to create the 64-bit value.
+    for (unsigned i = 0; i != NumSubElts; ++i) {
+      unsigned SubEltIdx = (NumSubElts - 1) - i;
+      auto SubElt = cast<ConstantInt>(CDV->getElementAsConstant(SubEltIdx));
+      Count = Count.shl(BitWidth);
+      Count |= SubElt->getValue().zextOrTrunc(64);
+    }
+  }
+  else if (CInt)
+    Count = CInt->getValue();
 
   auto Vec = II.getArgOperand(0);
   auto VT = cast<VectorType>(Vec->getType());
   auto SVT = VT->getElementType();
-  if (Count->getZExtValue() > (SVT->getPrimitiveSizeInBits() - 1))
-    return ConstantAggregateZero::get(VT);
-
   unsigned VWidth = VT->getNumElements();
+  unsigned BitWidth = SVT->getPrimitiveSizeInBits();
+
+  // If shift-by-zero then just return the original value.
+  if (Count == 0)
+    return Vec;
+
+  // Handle cases when Shift >= BitWidth.
+  if (Count.uge(BitWidth)) {
+    // If LogicalShift - just return zero.
+    if (LogicalShift)
+      return ConstantAggregateZero::get(VT);
+
+    // If ArithmeticShift - clamp Shift to (BitWidth - 1).
+    Count = APInt(64, BitWidth - 1);
+  }
 
   // Get a constant vector of the same type as the first operand.
-  auto VTCI = ConstantInt::get(VT->getElementType(), Count->getZExtValue());
+  auto ShiftAmt = ConstantInt::get(SVT, Count.zextOrTrunc(BitWidth));
+  auto ShiftVec = Builder.CreateVectorSplat(VWidth, ShiftAmt);
 
   if (ShiftLeft)
-    return Builder.CreateShl(Vec, Builder.CreateVectorSplat(VWidth, VTCI));
+    return Builder.CreateShl(Vec, ShiftVec);
+
+  if (LogicalShift)
+    return Builder.CreateLShr(Vec, ShiftVec);
 
-  return Builder.CreateLShr(Vec, Builder.CreateVectorSplat(VWidth, VTCI));
+  return Builder.CreateAShr(Vec, ShiftVec);
 }
 
 static Value *SimplifyX86extend(const IntrinsicInst &II,
@@ -753,39 +831,65 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
     break;
   }
 
+  // Constant fold ashr( <A x Bi>, Ci ).
   // Constant fold lshr( <A x Bi>, Ci ).
-  case Intrinsic::x86_sse2_psrl_d:
-  case Intrinsic::x86_sse2_psrl_q:
-  case Intrinsic::x86_sse2_psrl_w:
+  // Constant fold shl( <A x Bi>, Ci ).
+  case Intrinsic::x86_sse2_psrai_d:
+  case Intrinsic::x86_sse2_psrai_w:
+  case Intrinsic::x86_avx2_psrai_d:
+  case Intrinsic::x86_avx2_psrai_w:
   case Intrinsic::x86_sse2_psrli_d:
   case Intrinsic::x86_sse2_psrli_q:
   case Intrinsic::x86_sse2_psrli_w:
-  case Intrinsic::x86_avx2_psrl_d:
-  case Intrinsic::x86_avx2_psrl_q:
-  case Intrinsic::x86_avx2_psrl_w:
   case Intrinsic::x86_avx2_psrli_d:
   case Intrinsic::x86_avx2_psrli_q:
   case Intrinsic::x86_avx2_psrli_w:
-    if (Value *V = SimplifyX86immshift(*II, *Builder, false))
+  case Intrinsic::x86_sse2_pslli_d:
+  case Intrinsic::x86_sse2_pslli_q:
+  case Intrinsic::x86_sse2_pslli_w:
+  case Intrinsic::x86_avx2_pslli_d:
+  case Intrinsic::x86_avx2_pslli_q:
+  case Intrinsic::x86_avx2_pslli_w:
+    if (Value *V = SimplifyX86immshift(*II, *Builder))
       return ReplaceInstUsesWith(*II, V);
     break;
 
-  // Constant fold shl( <A x Bi>, Ci ).
+  case Intrinsic::x86_sse2_psra_d:
+  case Intrinsic::x86_sse2_psra_w:
+  case Intrinsic::x86_avx2_psra_d:
+  case Intrinsic::x86_avx2_psra_w:
+  case Intrinsic::x86_sse2_psrl_d:
+  case Intrinsic::x86_sse2_psrl_q:
+  case Intrinsic::x86_sse2_psrl_w:
+  case Intrinsic::x86_avx2_psrl_d:
+  case Intrinsic::x86_avx2_psrl_q:
+  case Intrinsic::x86_avx2_psrl_w:
   case Intrinsic::x86_sse2_psll_d:
   case Intrinsic::x86_sse2_psll_q:
   case Intrinsic::x86_sse2_psll_w:
-  case Intrinsic::x86_sse2_pslli_d:
-  case Intrinsic::x86_sse2_pslli_q:
-  case Intrinsic::x86_sse2_pslli_w:
   case Intrinsic::x86_avx2_psll_d:
   case Intrinsic::x86_avx2_psll_q:
-  case Intrinsic::x86_avx2_psll_w:
-  case Intrinsic::x86_avx2_pslli_d:
-  case Intrinsic::x86_avx2_pslli_q:
-  case Intrinsic::x86_avx2_pslli_w:
-    if (Value *V = SimplifyX86immshift(*II, *Builder, true))
+  case Intrinsic::x86_avx2_psll_w: {
+    if (Value *V = SimplifyX86immshift(*II, *Builder))
       return ReplaceInstUsesWith(*II, V);
+
+    // SSE2/AVX2 uses only the first 64-bits of the 128-bit vector
+    // operand to compute the shift amount.
+    auto ShiftAmt = II->getArgOperand(1);
+    auto ShiftType = cast<VectorType>(ShiftAmt->getType());
+    assert(ShiftType->getPrimitiveSizeInBits() == 128 &&
+           "Unexpected packed shift size");
+    unsigned VWidth = ShiftType->getNumElements();
+
+    APInt DemandedElts = APInt::getLowBitsSet(VWidth, VWidth / 2);
+    APInt UndefElts(VWidth, 0);
+    if (Value *V =
+            SimplifyDemandedVectorElts(ShiftAmt, DemandedElts, UndefElts)) {
+      II->setArgOperand(1, V);
+      return II;
+    }
     break;
+  }
 
   case Intrinsic::x86_sse41_pmovsxbd:
   case Intrinsic::x86_sse41_pmovsxbq:
@@ -913,7 +1017,20 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
     // This optimization is convoluted because the intrinsic is defined as
     // getting a vector of floats or doubles for the ps and pd versions.
     // FIXME: That should be changed.
+
+    Value *Op0 = II->getArgOperand(0);
+    Value *Op1 = II->getArgOperand(1);
     Value *Mask = II->getArgOperand(2);
+
+    // fold (blend A, A, Mask) -> A
+    if (Op0 == Op1)
+      return ReplaceInstUsesWith(CI, Op0);
+
+    // Zero Mask - select 1st argument.
+    if (isa<ConstantAggregateZero>(Mask))
+      return ReplaceInstUsesWith(CI, Op0);
+
+    // Constant Mask - select 1st/2nd argument lane based on top bit of mask.
     if (auto C = dyn_cast<ConstantDataVector>(Mask)) {
       auto Tyi1 = Builder->getInt1Ty();
       auto SelectorType = cast<VectorType>(Mask->getType());
@@ -936,11 +1053,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
         Selectors.push_back(ConstantInt::get(Tyi1, Selector >> (BitWidth - 1)));
       }
       auto NewSelector = ConstantVector::get(Selectors);
-      return SelectInst::Create(NewSelector, II->getArgOperand(1),
-                                II->getArgOperand(0), "blendv");
-    } else {
-      break;
+      return SelectInst::Create(NewSelector, Op1, Op0, "blendv");
     }
+    break;
   }
 
   case Intrinsic::x86_avx_vpermilvar_ps: