Add AVX2 vpbroadcast support
[oota-llvm.git] / lib / Target / X86 / X86ISelLowering.cpp
index 4986aac04f23e1cd45c3f927bab25a69ccb4a899..6a14f220a472c8f2b222742029ba3acbb7e09ad9 100644 (file)
@@ -5115,9 +5115,9 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, SmallVectorImpl<SDValue> &Elts,
 /// 1. A splat BUILD_VECTOR which uses a single scalar load.
 /// 2. A splat shuffle which uses a scalar_to_vector node which comes from
 /// a scalar load.
-/// The scalar load node is returned when a pattern is found, 
-/// or SDValue() otherwise. 
-static SDValue isVectorBroadcast(SDValue &Op) {
+/// The scalar load node is returned when a pattern is found,
+/// or SDValue() otherwise.
+static SDValue isVectorBroadcast(SDValue &Op, bool hasAVX2) {
   EVT VT = Op.getValueType();
   SDValue V = Op;
 
@@ -5134,16 +5134,16 @@ static SDValue isVectorBroadcast(SDValue &Op) {
 
     case ISD::BUILD_VECTOR: {
       // The BUILD_VECTOR node must be a splat.
-      if (!isSplatVector(V.getNode())) 
+      if (!isSplatVector(V.getNode()))
         return SDValue();
 
       Ld = V.getOperand(0);
-    
-      // The suspected load node has several users. Make sure that all 
+
+      // The suspected load node has several users. Make sure that all
       // of its users are from the BUILD_VECTOR node.
-      if (!Ld->hasNUsesOfValue(VT.getVectorNumElements(), 0)) 
+      if (!Ld->hasNUsesOfValue(VT.getVectorNumElements(), 0))
         return SDValue();
-      break; 
+      break;
     }
 
     case ISD::VECTOR_SHUFFLE: {
@@ -5151,11 +5151,11 @@ static SDValue isVectorBroadcast(SDValue &Op) {
 
       // Shuffles must have a splat mask where the first element is
       // broadcasted.
-      if ((!SVOp->isSplat()) || SVOp->getMaskElt(0) != 0) 
+      if ((!SVOp->isSplat()) || SVOp->getMaskElt(0) != 0)
         return SDValue();
 
       SDValue Sc = Op.getOperand(0);
-      if (Sc.getOpcode() != ISD::SCALAR_TO_VECTOR) 
+      if (Sc.getOpcode() != ISD::SCALAR_TO_VECTOR)
         return SDValue();
 
       Ld = Sc.getOperand(0);
@@ -5167,15 +5167,27 @@ static SDValue isVectorBroadcast(SDValue &Op) {
       break;
     }
   }
-  
+
   // The scalar source must be a normal load.
-  if (!ISD::isNormalLoad(Ld.getNode())) 
+  if (!ISD::isNormalLoad(Ld.getNode()))
     return SDValue();
-  
+
   bool Is256 = VT.getSizeInBits() == 256;
   bool Is128 = VT.getSizeInBits() == 128;
   unsigned ScalarSize = Ld.getValueType().getSizeInBits();
 
+  if (hasAVX2) {
+    // VBroadcast to YMM
+    if (Is256 && (ScalarSize == 8  || ScalarSize == 16 ||
+                  ScalarSize == 32 || ScalarSize == 64 ))
+      return Ld;
+
+    // VBroadcast to XMM
+    if (Is128 && (ScalarSize ==  8 || ScalarSize == 32 ||
+                  ScalarSize == 16 || ScalarSize == 64 ))
+      return Ld;
+  }
+
   // VBroadcast to YMM
   if (Is256 && (ScalarSize == 32 || ScalarSize == 64))
     return Ld;
@@ -5184,6 +5196,7 @@ static SDValue isVectorBroadcast(SDValue &Op) {
   if (Is128 && (ScalarSize == 32))
     return Ld;
 
+
   // Unsupported broadcast.
   return SDValue();
 }
@@ -5216,7 +5229,7 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
     return getOnesVector(Op.getValueType(), DAG, dl);
   }
 
-  SDValue LD = isVectorBroadcast(Op);
+  SDValue LD = isVectorBroadcast(Op, Subtarget->hasAVX2());
   if (Subtarget->hasAVX() && LD.getNode())
       return DAG.getNode(X86ISD::VBROADCAST, dl, VT, LD);
 
@@ -6613,7 +6626,7 @@ SDValue NormalizeVectorShuffle(SDValue Op, SelectionDAG &DAG,
       return Op;
 
     // Use vbroadcast whenever the splat comes from a foldable load
-    SDValue LD = isVectorBroadcast(Op);
+    SDValue LD = isVectorBroadcast(Op, Subtarget->hasAVX2());
     if (Subtarget->hasAVX() && LD.getNode())
       return DAG.getNode(X86ISD::VBROADCAST, dl, VT, LD);