Support pattern matching various x86 sse shifts.
authorNate Begeman <natebegeman@mac.com>
Mon, 26 Jan 2009 00:52:55 +0000 (00:52 +0000)
committerNate Begeman <natebegeman@mac.com>
Mon, 26 Jan 2009 00:52:55 +0000 (00:52 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@62979 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp

index a2ae6095f504a26fdb7a59ee5a61215bc98f1055..8edd5a4df6c66ddd5f3fb016173b1dee7d5f0467 100644 (file)
@@ -806,6 +806,9 @@ X86TargetLowering::X86TargetLowering(X86TargetMachine &TM)
   setTargetDAGCombine(ISD::VECTOR_SHUFFLE);
   setTargetDAGCombine(ISD::BUILD_VECTOR);
   setTargetDAGCombine(ISD::SELECT);
+  setTargetDAGCombine(ISD::SHL);
+  setTargetDAGCombine(ISD::SRA);
+  setTargetDAGCombine(ISD::SRL);
   setTargetDAGCombine(ISD::STORE);
 
   computeRegisterProperties();
@@ -7654,6 +7657,93 @@ static SDValue PerformSELECTCombine(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+/// PerformShiftCombine - Transforms vector shift nodes to use vector shifts
+///                       when possible.
+static SDValue PerformShiftCombine(SDNode* N, SelectionDAG &DAG,
+                                   const X86Subtarget *Subtarget) {
+  // On X86 with SSE2 support, we can transform this to a vector shift if
+  // all elements are shifted by the same amount.  We can't do this in legalize
+  // because the a constant vector is typically transformed to a constant pool
+  // so we have no knowledge of the shift amount.
+  MVT VT = N->getValueType(0);
+  if (Subtarget->hasSSE2() &&
+      (VT == MVT::v2i64 || VT == MVT::v4i32 || VT == MVT::v8i16)) {
+    SDValue  ValOp = N->getOperand(0);
+    SDValue  ShAmtOp = N->getOperand(1);
+    unsigned NumElts = VT.getVectorNumElements();
+
+    if (ShAmtOp.getOpcode() == ISD::BUILD_VECTOR) {
+      unsigned i = 0;
+      SDValue BaseShAmt;
+      for (; i != NumElts; ++i) {
+        SDValue Arg = ShAmtOp.getOperand(i);
+        if (Arg.getOpcode() == ISD::UNDEF) continue;
+        BaseShAmt = Arg;
+        break;
+      }
+      for (; i != NumElts; ++i) {
+        SDValue Arg = ShAmtOp.getOperand(i);
+        if (Arg.getOpcode() == ISD::UNDEF) continue;
+        if (Arg != BaseShAmt) {
+          return SDValue();
+        }
+      }
+
+      MVT EltVT = VT.getVectorElementType();
+      if (EltVT.bitsGT(MVT::i32))
+        BaseShAmt = DAG.getNode(ISD::TRUNCATE, MVT::i32, BaseShAmt);
+      else if (EltVT.bitsLT(MVT::i32))
+        BaseShAmt = DAG.getNode(ISD::ANY_EXTEND, MVT::i32, BaseShAmt);
+
+      // The shift amount is identical so we can do a vector shift.
+      switch (N->getOpcode()) {
+      default:
+        assert(0 && "Unknown shift opcode!");
+        break;
+      case ISD::SHL:
+        if (VT == MVT::v2i64)
+          return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT,
+                         DAG.getConstant(Intrinsic::x86_sse2_pslli_q, MVT::i32),
+                         ValOp, BaseShAmt);
+        else if (VT == MVT::v4i32)
+            return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT,
+                         DAG.getConstant(Intrinsic::x86_sse2_pslli_d, MVT::i32),
+                         ValOp, BaseShAmt);
+        else if (VT == MVT::v8i16)
+          return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT,
+                         DAG.getConstant(Intrinsic::x86_sse2_pslli_w, MVT::i32),
+                         ValOp, BaseShAmt);
+        break;
+      case ISD::SRA:
+        if (VT == MVT::v4i32)
+          return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT,
+                         DAG.getConstant(Intrinsic::x86_sse2_psrai_d, MVT::i32),
+                         ValOp, BaseShAmt);
+        else if (VT == MVT::v8i16)
+          return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT,
+                         DAG.getConstant(Intrinsic::x86_sse2_psrai_w, MVT::i32),
+                         ValOp, BaseShAmt);
+        break;
+      case ISD::SRL:
+        if (VT == MVT::v2i64)
+          return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT,
+                         DAG.getConstant(Intrinsic::x86_sse2_psrli_q, MVT::i32),
+                         ValOp, BaseShAmt);
+        else if (VT == MVT::v4i32)
+          return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT,
+                         DAG.getConstant(Intrinsic::x86_sse2_psrli_d, MVT::i32),
+                         ValOp, BaseShAmt);
+        else if (VT ==  MVT::v8i16)
+          return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, VT,
+                         DAG.getConstant(Intrinsic::x86_sse2_psrli_w, MVT::i32),
+                         ValOp, BaseShAmt);
+        break;
+      }
+    }
+  }
+  return SDValue();
+}
+
 /// PerformSTORECombine - Do target-specific dag combines on STORE nodes.
 static SDValue PerformSTORECombine(SDNode *N, SelectionDAG &DAG,
                                      const X86Subtarget *Subtarget) {
@@ -7782,6 +7872,9 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::BUILD_VECTOR:
     return PerformBuildVectorCombine(N, DAG, Subtarget, *this);
   case ISD::SELECT:         return PerformSELECTCombine(N, DAG, Subtarget);
+  case ISD::SHL:
+  case ISD::SRA:
+  case ISD::SRL:            return PerformShiftCombine(N, DAG, Subtarget);
   case ISD::STORE:          return PerformSTORECombine(N, DAG, Subtarget);
   case X86ISD::FXOR:
   case X86ISD::FOR:         return PerformFORCombine(N, DAG);