Simplified BLEND pattern matching for shuffles.
[oota-llvm.git] / lib / Target / X86 / X86ISelLowering.cpp
index 81e8a7bd88d1004c74d88fe7dfeae9f57e2ef8d6..b3ff4ee98e3dbcfcfddc7c536b73c7c7638defa4 100644 (file)
@@ -5641,64 +5641,53 @@ LowerVECTOR_SHUFFLEtoBlend(ShuffleVectorSDNode *SVOp,
   SDValue V1 = SVOp->getOperand(0);
   SDValue V2 = SVOp->getOperand(1);
   DebugLoc dl = SVOp->getDebugLoc();
-  MVT VT = SVOp->getValueType(0).getSimpleVT();
+  EVT VT = SVOp->getValueType(0);
+  EVT EltVT = VT.getVectorElementType();
   unsigned NumElems = VT.getVectorNumElements();
 
-  if (!Subtarget->hasSSE41())
+  if (!Subtarget->hasSSE41() || EltVT == MVT::i8)
+    return SDValue();
+  if (!Subtarget->hasInt256() && VT == MVT::v16i16)
     return SDValue();
 
-  unsigned ISDNo = 0;
-  MVT OpTy;
-
-  switch (VT.SimpleTy) {
-  default: return SDValue();
-  case MVT::v8i16:
-    ISDNo = X86ISD::BLENDPW;
-    OpTy = MVT::v8i16;
-    break;
-  case MVT::v4i32:
-  case MVT::v4f32:
-    ISDNo = X86ISD::BLENDPS;
-    OpTy = MVT::v4f32;
-    break;
-  case MVT::v2i64:
-  case MVT::v2f64:
-    ISDNo = X86ISD::BLENDPD;
-    OpTy = MVT::v2f64;
-    break;
-  case MVT::v8i32:
-  case MVT::v8f32:
-    if (!Subtarget->hasFp256())
-      return SDValue();
-    ISDNo = X86ISD::BLENDPS;
-    OpTy = MVT::v8f32;
-    break;
-  case MVT::v4i64:
-  case MVT::v4f64:
-    if (!Subtarget->hasFp256())
-      return SDValue();
-    ISDNo = X86ISD::BLENDPD;
-    OpTy = MVT::v4f64;
-    break;
-  }
-  assert(ISDNo && "Invalid Op Number");
+  // Check the mask for BLEND and build the value.
+  unsigned MaskValue = 0;
+  // There are 2 lanes if (NumElems > 8), and 1 lane otherwise.
+  unsigned NumLanes = (NumElems-1)/8 + 1; 
+  unsigned NumElemsInLane = NumElems / NumLanes;
 
-  unsigned MaskVals = 0;
+  // Blend for v16i16 should be symetric for the both lanes.
+  for (unsigned i = 0; i < NumElemsInLane; ++i) {
 
-  for (unsigned i = 0; i != NumElems; ++i) {
+    int SndLaneEltIdx = (NumLanes == 2) ? 
+      SVOp->getMaskElt(i + NumElemsInLane) : -1;
     int EltIdx = SVOp->getMaskElt(i);
-    if (EltIdx == (int)i || EltIdx < 0)
-      MaskVals |= (1<<i);
-    else if (EltIdx == (int)(i + NumElems))
-      continue; // Bit is set to zero;
-    else
+
+    if ((EltIdx == -1 || EltIdx == (int)i) && 
+        (SndLaneEltIdx == -1 || SndLaneEltIdx == (int)(i + NumElemsInLane)))
+      continue;
+
+    if (((unsigned)EltIdx == (i + NumElems)) && 
+        (SndLaneEltIdx == -1 || 
+         (unsigned)SndLaneEltIdx == i + NumElems + NumElemsInLane))
+      MaskValue |= (1<<i);
+    else 
       return SDValue();
   }
 
-  V1 = DAG.getNode(ISD::BITCAST, dl, OpTy, V1);
-  V2 = DAG.getNode(ISD::BITCAST, dl, OpTy, V2);
-  SDValue Ret =  DAG.getNode(ISDNo, dl, OpTy, V1, V2,
-                             DAG.getConstant(MaskVals, MVT::i32));
+  // Convert i32 vectors to floating point if it is not AVX2.
+  // AVX2 introduced VPBLENDD instruction for 128 and 256-bit vectors.
+  EVT BlendVT = VT;
+  if (EltVT == MVT::i64 || (EltVT == MVT::i32 && !Subtarget->hasInt256())) {
+    BlendVT = EVT::getVectorVT(*DAG.getContext(), 
+                              EVT::getFloatingPointVT(EltVT.getSizeInBits()), 
+                              NumElems);
+    V1 = DAG.getNode(ISD::BITCAST, dl, VT, V1);
+    V2 = DAG.getNode(ISD::BITCAST, dl, VT, V2);
+  }
+  
+  SDValue Ret =  DAG.getNode(X86ISD::BLENDI, dl, BlendVT, V1, V2,
+                             DAG.getConstant(MaskValue, MVT::i32));
   return DAG.getNode(ISD::BITCAST, dl, VT, Ret);
 }
 
@@ -11972,9 +11961,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
   case X86ISD::ANDNP:              return "X86ISD::ANDNP";
   case X86ISD::PSIGN:              return "X86ISD::PSIGN";
   case X86ISD::BLENDV:             return "X86ISD::BLENDV";
-  case X86ISD::BLENDPW:            return "X86ISD::BLENDPW";
-  case X86ISD::BLENDPS:            return "X86ISD::BLENDPS";
-  case X86ISD::BLENDPD:            return "X86ISD::BLENDPD";
+  case X86ISD::BLENDI:             return "X86ISD::BLENDI";
   case X86ISD::HADD:               return "X86ISD::HADD";
   case X86ISD::HSUB:               return "X86ISD::HSUB";
   case X86ISD::FHADD:              return "X86ISD::FHADD";