[ARM64]Implement NEON post-increment LD1(lane) and post-increment LD1R.
[oota-llvm.git] / lib / Target / ARM64 / ARM64ISelDAGToDAG.cpp
index 2e84a268de9a090b89f76d87e6118e5115fa49ac..4a1f9717bf73a9293299fed7c901726cd1c15288 100644 (file)
@@ -988,9 +988,12 @@ SDNode *ARM64DAGToDAGISel::SelectPostLoad(SDNode *N, unsigned NumVecs,
 
   // Update uses of vector list
   SDValue SuperReg = SDValue(Ld, 1);
-  for (unsigned i = 0; i < NumVecs; ++i)
-    ReplaceUses(SDValue(N, i),
-        CurDAG->getTargetExtractSubreg(SubRegIdx + i, dl, VT, SuperReg));
+  if (NumVecs == 1)
+    ReplaceUses(SDValue(N, 0), SuperReg);
+  else
+    for (unsigned i = 0; i < NumVecs; ++i)
+      ReplaceUses(SDValue(N, i),
+          CurDAG->getTargetExtractSubreg(SubRegIdx + i, dl, VT, SuperReg));
 
   // Update the chain
   ReplaceUses(SDValue(N, NumVecs + 1), SDValue(Ld, 2));
@@ -1153,14 +1156,20 @@ SDNode *ARM64DAGToDAGISel::SelectPostLoadLane(SDNode *N, unsigned NumVecs,
 
   // Update uses of the vector list
   SDValue SuperReg = SDValue(Ld, 1);
-  EVT WideVT = RegSeq.getOperand(1)->getValueType(0);
-  static unsigned QSubs[] = { ARM64::qsub0, ARM64::qsub1, ARM64::qsub2,
-                              ARM64::qsub3 };
-  for (unsigned i = 0; i < NumVecs; ++i) {
-    SDValue NV = CurDAG->getTargetExtractSubreg(QSubs[i], dl, WideVT, SuperReg);
-    if (Narrow)
-      NV = NarrowVector(NV, *CurDAG);
-    ReplaceUses(SDValue(N, i), NV);
+  if (NumVecs == 1) {
+    ReplaceUses(SDValue(N, 0),
+                Narrow ? NarrowVector(SuperReg, *CurDAG) : SuperReg);
+  } else {
+    EVT WideVT = RegSeq.getOperand(1)->getValueType(0);
+    static unsigned QSubs[] = { ARM64::qsub0, ARM64::qsub1, ARM64::qsub2,
+                                ARM64::qsub3 };
+    for (unsigned i = 0; i < NumVecs; ++i) {
+      SDValue NV = CurDAG->getTargetExtractSubreg(QSubs[i], dl, WideVT,
+                                                  SuperReg);
+      if (Narrow)
+        NV = NarrowVector(NV, *CurDAG);
+      ReplaceUses(SDValue(N, i), NV);
+    }
   }
 
   // Update the Chain
@@ -2657,6 +2666,25 @@ SDNode *ARM64DAGToDAGISel::Select(SDNode *Node) {
       return SelectPostLoad(Node, 4, ARM64::LD1Fourv2d_POST, ARM64::qsub0);
     break;
   }
+  case ARM64ISD::LD1DUPpost: {
+    if (VT == MVT::v8i8)
+      return SelectPostLoad(Node, 1, ARM64::LD1Rv8b_POST, ARM64::dsub0);
+    else if (VT == MVT::v16i8)
+      return SelectPostLoad(Node, 1, ARM64::LD1Rv16b_POST, ARM64::qsub0);
+    else if (VT == MVT::v4i16)
+      return SelectPostLoad(Node, 1, ARM64::LD1Rv4h_POST, ARM64::dsub0);
+    else if (VT == MVT::v8i16)
+      return SelectPostLoad(Node, 1, ARM64::LD1Rv8h_POST, ARM64::qsub0);
+    else if (VT == MVT::v2i32 || VT == MVT::v2f32)
+      return SelectPostLoad(Node, 1, ARM64::LD1Rv2s_POST, ARM64::dsub0);
+    else if (VT == MVT::v4i32 || VT == MVT::v4f32)
+      return SelectPostLoad(Node, 1, ARM64::LD1Rv4s_POST, ARM64::qsub0);
+    else if (VT == MVT::v1i64 || VT == MVT::v1f64)
+      return SelectPostLoad(Node, 1, ARM64::LD1Rv1d_POST, ARM64::dsub0);
+    else if (VT == MVT::v2i64 || VT == MVT::v2f64)
+      return SelectPostLoad(Node, 1, ARM64::LD1Rv2d_POST, ARM64::qsub0);
+    break;
+  }
   case ARM64ISD::LD2DUPpost: {
     if (VT == MVT::v8i8)
       return SelectPostLoad(Node, 2, ARM64::LD2Rv8b_POST, ARM64::dsub0);
@@ -2714,6 +2742,19 @@ SDNode *ARM64DAGToDAGISel::Select(SDNode *Node) {
       return SelectPostLoad(Node, 4, ARM64::LD4Rv2d_POST, ARM64::qsub0);
     break;
   }
+  case ARM64ISD::LD1LANEpost: {
+    if (VT == MVT::v16i8 || VT == MVT::v8i8)
+      return SelectPostLoadLane(Node, 1, ARM64::LD1i8_POST);
+    else if (VT == MVT::v8i16 || VT == MVT::v4i16)
+      return SelectPostLoadLane(Node, 1, ARM64::LD1i16_POST);
+    else if (VT == MVT::v4i32 || VT == MVT::v2i32 || VT == MVT::v4f32 ||
+             VT == MVT::v2f32)
+      return SelectPostLoadLane(Node, 1, ARM64::LD1i32_POST);
+    else if (VT == MVT::v2i64 || VT == MVT::v1i64 || VT == MVT::v2f64 ||
+             VT == MVT::v1f64)
+      return SelectPostLoadLane(Node, 1, ARM64::LD1i64_POST);
+    break;
+  }
   case ARM64ISD::LD2LANEpost: {
     if (VT == MVT::v16i8 || VT == MVT::v8i8)
       return SelectPostLoadLane(Node, 2, ARM64::LD2i8_POST);