Improve long vector sext/zext lowering on ARM
[oota-llvm.git] / lib / Target / ARM / ARMISelLowering.cpp
index 514971f01ee82fac1b2552aeb907182e15aa5493..40d2e8d2657787f92599aff9bdcd934ab5dcb8ad 100644 (file)
@@ -564,6 +564,16 @@ ARMTargetLowering::ARMTargetLowering(TargetMachine &TM)
     setOperationAction(ISD::FP_ROUND,   MVT::v2f32, Expand);
     setOperationAction(ISD::FP_EXTEND,  MVT::v2f64, Expand);
 
+    // Custom expand long extensions to vectors.
+    setOperationAction(ISD::SIGN_EXTEND, MVT::v8i32,  Custom);
+    setOperationAction(ISD::ZERO_EXTEND, MVT::v8i32,  Custom);
+    setOperationAction(ISD::SIGN_EXTEND, MVT::v4i64,  Custom);
+    setOperationAction(ISD::ZERO_EXTEND, MVT::v4i64,  Custom);
+    setOperationAction(ISD::SIGN_EXTEND, MVT::v16i32, Custom);
+    setOperationAction(ISD::ZERO_EXTEND, MVT::v16i32, Custom);
+    setOperationAction(ISD::SIGN_EXTEND, MVT::v8i64,  Custom);
+    setOperationAction(ISD::ZERO_EXTEND, MVT::v8i64,  Custom);
+
     // NEON does not have single instruction CTPOP for vectors with element
     // types wider than 8-bits.  However, custom lowering can leverage the
     // v8i8/v16i8 vcnt instruction.
@@ -3433,6 +3443,47 @@ SDValue ARMTargetLowering::LowerFRAMEADDR(SDValue Op, SelectionDAG &DAG) const {
   return FrameAddr;
 }
 
+/// Custom Expand long vector extensions, where size(DestVec) > 2*size(SrcVec),
+/// and size(DestVec) > 128-bits.
+/// This is achieved by doing the one extension from the SrcVec, splitting the
+/// result, extending these parts, and then concatenating these into the
+/// destination.
+static SDValue ExpandVectorExtension(SDNode *N, SelectionDAG &DAG) {
+  SDValue Op = N->getOperand(0);
+  EVT SrcVT = Op.getValueType();
+  EVT DestVT = N->getValueType(0);
+
+  assert(DestVT.getSizeInBits() > 128 &&
+         "Custom sext/zext expansion needs >128-bit vector.");
+  // If this is a normal length extension, use the default expansion.
+  if (SrcVT.getSizeInBits()*4 != DestVT.getSizeInBits() &&
+      SrcVT.getSizeInBits()*8 != DestVT.getSizeInBits())
+    return SDValue();
+
+  DebugLoc dl = N->getDebugLoc();
+  unsigned SrcEltSize = SrcVT.getVectorElementType().getSizeInBits();
+  unsigned DestEltSize = DestVT.getVectorElementType().getSizeInBits();
+  unsigned NumElts = SrcVT.getVectorNumElements();
+  LLVMContext &Ctx = *DAG.getContext();
+  SDValue Mid, SplitLo, SplitHi, ExtLo, ExtHi;
+
+  EVT MidVT = EVT::getVectorVT(Ctx, EVT::getIntegerVT(Ctx, SrcEltSize*2),
+                               NumElts);
+  EVT SplitVT = EVT::getVectorVT(Ctx, EVT::getIntegerVT(Ctx, SrcEltSize*2),
+                                 NumElts/2);
+  EVT ExtVT = EVT::getVectorVT(Ctx, EVT::getIntegerVT(Ctx, DestEltSize),
+                               NumElts/2);
+
+  Mid = DAG.getNode(N->getOpcode(), dl, MidVT, Op);
+  SplitLo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SplitVT, Mid,
+                        DAG.getIntPtrConstant(0));
+  SplitHi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SplitVT, Mid,
+                        DAG.getIntPtrConstant(NumElts/2));
+  ExtLo = DAG.getNode(N->getOpcode(), dl, ExtVT, SplitLo);
+  ExtHi = DAG.getNode(N->getOpcode(), dl, ExtVT, SplitHi);
+  return DAG.getNode(ISD::CONCAT_VECTORS, dl, DestVT, ExtLo, ExtHi);
+}
+
 /// ExpandBITCAST - If the target supports VFP, this function is called to
 /// expand a bit convert where either the source or destination type is i64 to
 /// use a VMOVDRR or VMOVRRD node.  This should not be done when the non-i64
@@ -5621,6 +5672,10 @@ void ARMTargetLowering::ReplaceNodeResults(SDNode *N,
   case ISD::BITCAST:
     Res = ExpandBITCAST(N, DAG);
     break;
+  case ISD::SIGN_EXTEND:
+  case ISD::ZERO_EXTEND:
+    Res = ExpandVectorExtension(N, DAG);
+    break;
   case ISD::SRL:
   case ISD::SRA:
     Res = Expand64BitShift(N, DAG, Subtarget);