[x86] Re-apply a variant of the x86 side of r212324 now that the rest
authorChandler Carruth <chandlerc@gmail.com>
Wed, 9 Jul 2014 10:06:58 +0000 (10:06 +0000)
committerChandler Carruth <chandlerc@gmail.com>
Wed, 9 Jul 2014 10:06:58 +0000 (10:06 +0000)
has settled without incident, removing the x86-specific and overly
strict 'isVectorSplat' routine in favor of generic and more powerful
splat detection.

The primary motivation and result of this is that the x86 backend can
now see through splats which contain undef elements. This is essential
if we are using a widening form of legalization and I've updated a test
case to also run in that mode as before this change the generated code
for the test case was completely scalarized.

This version of the patch much more carefully handles the undef lanes.
- We aren't overly conservative about them in the shift lowering
  (where we will never use the splat itself).
- One place where the splat would have been re-used by the existing code
  now explicitly constructs a new constant splat that will be safe.
- The broadcast lowering is much more reasonable with undefs by doing
  a correct check of whether the splat is the only user of a loaded
  value, checking that the splat actually crosses multiple lanes before
  using a broadcast, and handling broadcasts of non-constant splats.

As a consequence of the last bullet, the weird usage of vpshufd instead
of vbroadcast is gone, and we actually can lower an AVX splat with
vbroadcastss where before we emitted a really strange pattern of
a vector load and a manual splat across the vector.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@212602 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp
test/CodeGen/X86/avx-splat.ll
test/CodeGen/X86/widen_cast-4.ll

index 306a659c983136bbd7eed3d4d66a9362484d7f8d..daf141bfcec181bdab2f4ba954acd5593b870ed1 100644 (file)
@@ -4858,19 +4858,6 @@ static bool ShouldXformToMOVLP(SDNode *V1, SDNode *V2,
   return true;
 }
 
   return true;
 }
 
-/// isSplatVector - Returns true if N is a BUILD_VECTOR node whose elements are
-/// all the same.
-static bool isSplatVector(SDNode *N) {
-  if (N->getOpcode() != ISD::BUILD_VECTOR)
-    return false;
-
-  SDValue SplatValue = N->getOperand(0);
-  for (unsigned i = 1, e = N->getNumOperands(); i != e; ++i)
-    if (N->getOperand(i) != SplatValue)
-      return false;
-  return true;
-}
-
 /// isZeroShuffle - Returns true if N is a VECTOR_SHUFFLE that can be resolved
 /// to an zero vector.
 /// FIXME: move to dag combiner / method on ShuffleVectorSDNode
 /// isZeroShuffle - Returns true if N is a VECTOR_SHUFFLE that can be resolved
 /// to an zero vector.
 /// FIXME: move to dag combiner / method on ShuffleVectorSDNode
@@ -5779,18 +5766,22 @@ static SDValue LowerVectorBroadcast(SDValue Op, const X86Subtarget* Subtarget,
       return SDValue();
 
     case ISD::BUILD_VECTOR: {
       return SDValue();
 
     case ISD::BUILD_VECTOR: {
-      // The BUILD_VECTOR node must be a splat.
-      if (!isSplatVector(Op.getNode()))
+      auto *BVOp = cast<BuildVectorSDNode>(Op.getNode());
+      BitVector UndefElements;
+      SDValue Splat = BVOp->getSplatValue(&UndefElements);
+
+      // We need a splat of a single value to use broadcast, and it doesn't
+      // make any sense if the value is only in one element of the vector.
+      if (!Splat || (VT.getVectorNumElements() - UndefElements.count()) <= 1)
         return SDValue();
 
         return SDValue();
 
-      Ld = Op.getOperand(0);
+      Ld = Splat;
       ConstSplatVal = (Ld.getOpcode() == ISD::Constant ||
       ConstSplatVal = (Ld.getOpcode() == ISD::Constant ||
-                     Ld.getOpcode() == ISD::ConstantFP);
+                       Ld.getOpcode() == ISD::ConstantFP);
 
 
-      // The suspected load node has several users. Make sure that all
-      // of its users are from the BUILD_VECTOR node.
-      // Constants may have multiple users.
-      if (!ConstSplatVal && !Ld->hasNUsesOfValue(VT.getVectorNumElements(), 0))
+      // Make sure that all of the users of a non-constant load are from the
+      // BUILD_VECTOR node.
+      if (!ConstSplatVal && !BVOp->isOnlyUserOf(Ld.getNode()))
         return SDValue();
       break;
     }
         return SDValue();
       break;
     }
@@ -9375,8 +9366,13 @@ X86TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const {
   bool Commuted = false;
   // FIXME: This should also accept a bitcast of a splat?  Be careful, not
   // 1,1,1,1 -> v8i16 though.
   bool Commuted = false;
   // FIXME: This should also accept a bitcast of a splat?  Be careful, not
   // 1,1,1,1 -> v8i16 though.
-  V1IsSplat = isSplatVector(V1.getNode());
-  V2IsSplat = isSplatVector(V2.getNode());
+  BitVector UndefElements;
+  if (auto *BVOp = dyn_cast<BuildVectorSDNode>(V1.getNode()))
+    if (BVOp->getConstantSplatNode(&UndefElements) && UndefElements.none())
+      V1IsSplat = true;
+  if (auto *BVOp = dyn_cast<BuildVectorSDNode>(V2.getNode()))
+    if (BVOp->getConstantSplatNode(&UndefElements) && UndefElements.none())
+      V2IsSplat = true;
 
   // Canonicalize the splat or undef, if present, to be on the RHS.
   if (!V2IsUndef && V1IsSplat && !V2IsSplat) {
 
   // Canonicalize the splat or undef, if present, to be on the RHS.
   if (!V2IsUndef && V1IsSplat && !V2IsSplat) {
@@ -15171,10 +15167,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
   SDValue Amt = Op.getOperand(1);
 
   // Optimize shl/srl/sra with constant shift amount.
   SDValue Amt = Op.getOperand(1);
 
   // Optimize shl/srl/sra with constant shift amount.
-  if (isSplatVector(Amt.getNode())) {
-    SDValue SclrAmt = Amt->getOperand(0);
-    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(SclrAmt)) {
-      uint64_t ShiftAmt = C->getZExtValue();
+  if (auto *BVAmt = dyn_cast<BuildVectorSDNode>(Amt)) {
+    if (auto *ShiftConst = BVAmt->getConstantSplatNode()) {
+      uint64_t ShiftAmt = ShiftConst->getZExtValue();
 
       if (VT == MVT::v2i64 || VT == MVT::v4i32 || VT == MVT::v8i16 ||
           (Subtarget->hasInt256() &&
 
       if (VT == MVT::v2i64 || VT == MVT::v4i32 || VT == MVT::v8i16 ||
           (Subtarget->hasInt256() &&
@@ -19423,28 +19418,34 @@ static SDValue PerformSELECTCombine(SDNode *N, SelectionDAG &DAG,
           Other->getOpcode() == ISD::SUB && DAG.isEqualTo(OpRHS, CondRHS))
         return DAG.getNode(X86ISD::SUBUS, DL, VT, OpLHS, OpRHS);
 
           Other->getOpcode() == ISD::SUB && DAG.isEqualTo(OpRHS, CondRHS))
         return DAG.getNode(X86ISD::SUBUS, DL, VT, OpLHS, OpRHS);
 
-      // If the RHS is a constant we have to reverse the const canonicalization.
-      // x > C-1 ? x+-C : 0 --> subus x, C
-      if (CC == ISD::SETUGT && Other->getOpcode() == ISD::ADD &&
-          isSplatVector(CondRHS.getNode()) && isSplatVector(OpRHS.getNode())) {
-        APInt A = cast<ConstantSDNode>(OpRHS.getOperand(0))->getAPIntValue();
-        if (CondRHS.getConstantOperandVal(0) == -A-1)
-          return DAG.getNode(X86ISD::SUBUS, DL, VT, OpLHS,
-                             DAG.getConstant(-A, VT));
-      }
-
-      // Another special case: If C was a sign bit, the sub has been
-      // canonicalized into a xor.
-      // FIXME: Would it be better to use computeKnownBits to determine whether
-      //        it's safe to decanonicalize the xor?
-      // x s< 0 ? x^C : 0 --> subus x, C
-      if (CC == ISD::SETLT && Other->getOpcode() == ISD::XOR &&
-          ISD::isBuildVectorAllZeros(CondRHS.getNode()) &&
-          isSplatVector(OpRHS.getNode())) {
-        APInt A = cast<ConstantSDNode>(OpRHS.getOperand(0))->getAPIntValue();
-        if (A.isSignBit())
-          return DAG.getNode(X86ISD::SUBUS, DL, VT, OpLHS, OpRHS);
-      }
+      if (auto *OpRHSBV = dyn_cast<BuildVectorSDNode>(OpRHS))
+        if (auto *OpRHSConst = OpRHSBV->getConstantSplatNode()) {
+          if (auto *CondRHSBV = dyn_cast<BuildVectorSDNode>(CondRHS))
+            if (auto *CondRHSConst = CondRHSBV->getConstantSplatNode())
+              // If the RHS is a constant we have to reverse the const
+              // canonicalization.
+              // x > C-1 ? x+-C : 0 --> subus x, C
+              if (CC == ISD::SETUGT && Other->getOpcode() == ISD::ADD &&
+                  CondRHSConst->getAPIntValue() ==
+                      (-OpRHSConst->getAPIntValue() - 1))
+                return DAG.getNode(
+                    X86ISD::SUBUS, DL, VT, OpLHS,
+                    DAG.getConstant(-OpRHSConst->getAPIntValue(), VT));
+
+          // Another special case: If C was a sign bit, the sub has been
+          // canonicalized into a xor.
+          // FIXME: Would it be better to use computeKnownBits to determine
+          //        whether it's safe to decanonicalize the xor?
+          // x s< 0 ? x^C : 0 --> subus x, C
+          if (CC == ISD::SETLT && Other->getOpcode() == ISD::XOR &&
+              ISD::isBuildVectorAllZeros(CondRHS.getNode()) &&
+              OpRHSConst->getAPIntValue().isSignBit())
+            // Note that we have to rebuild the RHS constant here to ensure we
+            // don't rely on particular values of undef lanes.
+            return DAG.getNode(
+                X86ISD::SUBUS, DL, VT, OpLHS,
+                DAG.getConstant(OpRHSConst->getAPIntValue(), VT));
+        }
     }
   }
 
     }
   }
 
@@ -20152,16 +20153,15 @@ static SDValue PerformSHLCombine(SDNode *N, SelectionDAG &DAG) {
   // vector operations in many cases. Also, on sandybridge ADD is faster than
   // shl.
   // (shl V, 1) -> add V,V
   // vector operations in many cases. Also, on sandybridge ADD is faster than
   // shl.
   // (shl V, 1) -> add V,V
-  if (isSplatVector(N1.getNode())) {
-    assert(N0.getValueType().isVector() && "Invalid vector shift type");
-    ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1->getOperand(0));
-    // We shift all of the values by one. In many cases we do not have
-    // hardware support for this operation. This is better expressed as an ADD
-    // of two values.
-    if (N1C && (1 == N1C->getZExtValue())) {
-      return DAG.getNode(ISD::ADD, SDLoc(N), VT, N0, N0);
+  if (auto *N1BV = dyn_cast<BuildVectorSDNode>(N1))
+    if (auto *N1SplatC = N1BV->getConstantSplatNode()) {
+      assert(N0.getValueType().isVector() && "Invalid vector shift type");
+      // We shift all of the values by one. In many cases we do not have
+      // hardware support for this operation. This is better expressed as an ADD
+      // of two values.
+      if (N1SplatC->getZExtValue() == 1)
+        return DAG.getNode(ISD::ADD, SDLoc(N), VT, N0, N0);
     }
     }
-  }
 
   return SDValue();
 }
 
   return SDValue();
 }
@@ -20180,10 +20180,9 @@ static SDValue performShiftToAllZeros(SDNode *N, SelectionDAG &DAG,
 
   SDValue Amt = N->getOperand(1);
   SDLoc DL(N);
 
   SDValue Amt = N->getOperand(1);
   SDLoc DL(N);
-  if (isSplatVector(Amt.getNode())) {
-    SDValue SclrAmt = Amt->getOperand(0);
-    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(SclrAmt)) {
-      APInt ShiftAmt = C->getAPIntValue();
+  if (auto *AmtBV = dyn_cast<BuildVectorSDNode>(Amt))
+    if (auto *AmtSplat = AmtBV->getConstantSplatNode()) {
+      APInt ShiftAmt = AmtSplat->getAPIntValue();
       unsigned MaxAmount = VT.getVectorElementType().getSizeInBits();
 
       // SSE2/AVX2 logical shifts always return a vector of 0s
       unsigned MaxAmount = VT.getVectorElementType().getSizeInBits();
 
       // SSE2/AVX2 logical shifts always return a vector of 0s
@@ -20193,7 +20192,6 @@ static SDValue performShiftToAllZeros(SDNode *N, SelectionDAG &DAG,
       if (ShiftAmt.trunc(8).uge(MaxAmount))
         return getZeroVector(VT, Subtarget, DAG, DL);
     }
       if (ShiftAmt.trunc(8).uge(MaxAmount))
         return getZeroVector(VT, Subtarget, DAG, DL);
     }
-  }
 
   return SDValue();
 }
 
   return SDValue();
 }
@@ -20387,9 +20385,10 @@ static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG,
 
   // The right side has to be a 'trunc' or a constant vector.
   bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE;
 
   // The right side has to be a 'trunc' or a constant vector.
   bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE;
-  bool RHSConst = (isSplatVector(N1.getNode()) &&
-                   isa<ConstantSDNode>(N1->getOperand(0)));
-  if (!RHSTrunc && !RHSConst)
+  ConstantSDNode *RHSConstSplat;
+  if (auto *RHSBV = dyn_cast<BuildVectorSDNode>(N1))
+    RHSConstSplat = RHSBV->getConstantSplatNode();
+  if (!RHSTrunc && !RHSConstSplat)
     return SDValue();
 
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
     return SDValue();
 
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
@@ -20399,9 +20398,9 @@ static SDValue WidenMaskArithmetic(SDNode *N, SelectionDAG &DAG,
 
   // Set N0 and N1 to hold the inputs to the new wide operation.
   N0 = N0->getOperand(0);
 
   // Set N0 and N1 to hold the inputs to the new wide operation.
   N0 = N0->getOperand(0);
-  if (RHSConst) {
+  if (RHSConstSplat) {
     N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT.getScalarType(),
     N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, WideVT.getScalarType(),
-                     N1->getOperand(0));
+                     SDValue(RHSConstSplat, 0));
     SmallVector<SDValue, 8> C(WideVT.getVectorNumElements(), N1);
     N1 = DAG.getNode(ISD::BUILD_VECTOR, DL, WideVT, C);
   } else if (RHSTrunc) {
     SmallVector<SDValue, 8> C(WideVT.getVectorNumElements(), N1);
     N1 = DAG.getNode(ISD::BUILD_VECTOR, DL, WideVT, C);
   } else if (RHSTrunc) {
@@ -20547,12 +20546,9 @@ static SDValue PerformOrCombine(SDNode *N, SelectionDAG &DAG,
       unsigned EltBits = MaskVT.getVectorElementType().getSizeInBits();
       unsigned SraAmt = ~0;
       if (Mask.getOpcode() == ISD::SRA) {
       unsigned EltBits = MaskVT.getVectorElementType().getSizeInBits();
       unsigned SraAmt = ~0;
       if (Mask.getOpcode() == ISD::SRA) {
-        SDValue Amt = Mask.getOperand(1);
-        if (isSplatVector(Amt.getNode())) {
-          SDValue SclrAmt = Amt->getOperand(0);
-          if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(SclrAmt))
-            SraAmt = C->getZExtValue();
-        }
+        if (auto *AmtBV = dyn_cast<BuildVectorSDNode>(Mask.getOperand(1)))
+          if (auto *AmtConst = AmtBV->getConstantSplatNode())
+            SraAmt = AmtConst->getZExtValue();
       } else if (Mask.getOpcode() == X86ISD::VSRAI) {
         SDValue SraC = Mask.getOperand(1);
         SraAmt  = cast<ConstantSDNode>(SraC)->getZExtValue();
       } else if (Mask.getOpcode() == X86ISD::VSRAI) {
         SDValue SraC = Mask.getOperand(1);
         SraAmt  = cast<ConstantSDNode>(SraC)->getZExtValue();
index 5d0781531f4d944bd87d916595e5fffe6221a8a1..b1b2f8b97a7331ee801c7f373fccaf3941055be7 100644 (file)
@@ -43,13 +43,10 @@ entry:
   ret <4 x double> %vecinit6.i
 }
 
   ret <4 x double> %vecinit6.i
 }
 
-; Test this simple opt:
+; Test this turns into a broadcast:
 ;   shuffle (scalar_to_vector (load (ptr + 4))), undef, <0, 0, 0, 0>
 ;   shuffle (scalar_to_vector (load (ptr + 4))), undef, <0, 0, 0, 0>
-; To:
-;   shuffle (vload ptr)), undef, <1, 1, 1, 1>
-; CHECK: vmovdqa
-; CHECK-NEXT: vpshufd $-1
-; CHECK-NEXT: vinsertf128  $1
+;   
+; CHECK: vbroadcastss
 define <8 x float> @funcE() nounwind {
 allocas:
   %udx495 = alloca [18 x [18 x float]], align 32
 define <8 x float> @funcE() nounwind {
 allocas:
   %udx495 = alloca [18 x [18 x float]], align 32
index 1bc06a77cbf7b2f1f0ccff9dbb7873d4920bb665..19b84f19a4ff96fa061fc3fa589bd60d6a43f309 100644 (file)
@@ -1,8 +1,9 @@
 ; RUN: llc < %s -march=x86 -mattr=+sse4.2 | FileCheck %s
 ; RUN: llc < %s -march=x86 -mattr=+sse4.2 | FileCheck %s
-; CHECK: psraw
-; CHECK: psraw
+; RUN: llc < %s -march=x86 -mattr=+sse4.2 -x86-experimental-vector-widening-legalization | FileCheck %s --check-prefix=CHECK-WIDE
 
 define void @update(i64* %dst_i, i64* %src_i, i32 %n) nounwind {
 
 define void @update(i64* %dst_i, i64* %src_i, i32 %n) nounwind {
+; CHECK-LABEL: update:
+; CHECK-WIDE-LABEL: update:
 entry:
        %dst_i.addr = alloca i64*               ; <i64**> [#uses=2]
        %src_i.addr = alloca i64*               ; <i64**> [#uses=2]
 entry:
        %dst_i.addr = alloca i64*               ; <i64**> [#uses=2]
        %src_i.addr = alloca i64*               ; <i64**> [#uses=2]
@@ -44,6 +45,26 @@ forbody:             ; preds = %forcond
        %shr = ashr <8 x i8> %add, < i8 2, i8 2, i8 2, i8 2, i8 2, i8 2, i8 2, i8 2 >           ; <<8 x i8>> [#uses=1]
        store <8 x i8> %shr, <8 x i8>* %arrayidx10
        br label %forinc
        %shr = ashr <8 x i8> %add, < i8 2, i8 2, i8 2, i8 2, i8 2, i8 2, i8 2, i8 2 >           ; <<8 x i8>> [#uses=1]
        store <8 x i8> %shr, <8 x i8>* %arrayidx10
        br label %forinc
+; CHECK: %forbody
+; CHECK:      pmovzxbw
+; CHECK-NEXT: paddw
+; CHECK-NEXT: psllw $8
+; CHECK-NEXT: psraw $8
+; CHECK-NEXT: psraw $2
+; CHECK-NEXT: pshufb
+; CHECK-NEXT: movlpd
+;
+; FIXME: We shouldn't require both a movd and an insert.
+; CHECK-WIDE: %forbody
+; CHECK-WIDE:      movd
+; CHECK-WIDE-NEXT: pinsrd
+; CHECK-WIDE-NEXT: paddb
+; CHECK-WIDE-NEXT: psrlw $2
+; CHECK-WIDE-NEXT: pand
+; CHECK-WIDE-NEXT: pxor
+; CHECK-WIDE-NEXT: psubb
+; CHECK-WIDE-NEXT: pextrd
+; CHECK-WIDE-NEXT: movd
 
 forinc:                ; preds = %forbody
        %tmp15 = load i32* %i           ; <i32> [#uses=1]
 
 forinc:                ; preds = %forbody
        %tmp15 = load i32* %i           ; <i32> [#uses=1]