Add a ARM-specific SD node for VBSL so that forms with a constant first operand
authorCameron Zwarich <zwarich@apple.com>
Wed, 30 Mar 2011 23:01:21 +0000 (23:01 +0000)
committerCameron Zwarich <zwarich@apple.com>
Wed, 30 Mar 2011 23:01:21 +0000 (23:01 +0000)
can be recognized. This fixes <rdar://problem/9183078>.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@128584 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/ARM/ARMFastISel.cpp
lib/Target/ARM/ARMISelLowering.cpp
lib/Target/ARM/ARMISelLowering.h
lib/Target/ARM/ARMInstrNEON.td
test/CodeGen/ARM/vbsl-constant.ll [new file with mode: 0644]

index 2af42c90af159b87fcfa2de092df50c882288c9e..0f3e6cf2c6427fe045011f3d97c860cec16a90cd 100644 (file)
@@ -115,6 +115,11 @@ class ARMFastISel : public FastISel {
                                      const TargetRegisterClass *RC,
                                      unsigned Op0, bool Op0IsKill,
                                      unsigned Op1, bool Op1IsKill);
+    virtual unsigned FastEmitInst_rrr(unsigned MachineInstOpcode,
+                                      const TargetRegisterClass *RC,
+                                      unsigned Op0, bool Op0IsKill,
+                                      unsigned Op1, bool Op1IsKill,
+                                      unsigned Op2, bool Op2IsKill);
     virtual unsigned FastEmitInst_ri(unsigned MachineInstOpcode,
                                      const TargetRegisterClass *RC,
                                      unsigned Op0, bool Op0IsKill,
@@ -315,6 +320,31 @@ unsigned ARMFastISel::FastEmitInst_rr(unsigned MachineInstOpcode,
   return ResultReg;
 }
 
+unsigned ARMFastISel::FastEmitInst_rrr(unsigned MachineInstOpcode,
+                                       const TargetRegisterClass *RC,
+                                       unsigned Op0, bool Op0IsKill,
+                                       unsigned Op1, bool Op1IsKill,
+                                       unsigned Op2, bool Op2IsKill) {
+  unsigned ResultReg = createResultReg(RC);
+  const TargetInstrDesc &II = TII.get(MachineInstOpcode);
+
+  if (II.getNumDefs() >= 1)
+    AddOptionalDefs(BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DL, II, ResultReg)
+                   .addReg(Op0, Op0IsKill * RegState::Kill)
+                   .addReg(Op1, Op1IsKill * RegState::Kill)
+                   .addReg(Op2, Op2IsKill * RegState::Kill));
+  else {
+    AddOptionalDefs(BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DL, II)
+                   .addReg(Op0, Op0IsKill * RegState::Kill)
+                   .addReg(Op1, Op1IsKill * RegState::Kill)
+                   .addReg(Op2, Op2IsKill * RegState::Kill));
+    AddOptionalDefs(BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DL,
+                           TII.get(TargetOpcode::COPY), ResultReg)
+                   .addReg(II.ImplicitDefs[0]));
+  }
+  return ResultReg;
+}
+
 unsigned ARMFastISel::FastEmitInst_ri(unsigned MachineInstOpcode,
                                       const TargetRegisterClass *RC,
                                       unsigned Op0, bool Op0IsKill,
index 9dc103fa765796c3007dc38c4240bdb052320884..11ff6bf9654c49098dfbb90ff311549dfdf7ab90 100644 (file)
@@ -866,6 +866,7 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {
   case ARMISD::BFI:           return "ARMISD::BFI";
   case ARMISD::VORRIMM:       return "ARMISD::VORRIMM";
   case ARMISD::VBICIMM:       return "ARMISD::VBICIMM";
+  case ARMISD::VBSL:          return "ARMISD::VBSL";
   case ARMISD::VLD2DUP:       return "ARMISD::VLD2DUP";
   case ARMISD::VLD3DUP:       return "ARMISD::VLD3DUP";
   case ARMISD::VLD4DUP:       return "ARMISD::VLD4DUP";
@@ -5336,6 +5337,37 @@ static SDValue PerformORCombine(SDNode *N,
     }
   }
 
+  SDValue N0 = N->getOperand(0);
+  if (N0.getOpcode() != ISD::AND)
+    return SDValue();
+  SDValue N1 = N->getOperand(1);
+
+  // (or (and B, A), (and C, ~A)) => (VBSL A, B, C) when A is a constant.
+  if (Subtarget->hasNEON() && N1.getOpcode() == ISD::AND && VT.isVector() &&
+      DAG.getTargetLoweringInfo().isTypeLegal(VT)) {
+    APInt SplatUndef;
+    unsigned SplatBitSize;
+    bool HasAnyUndefs;
+
+    BuildVectorSDNode *BVN0 = dyn_cast<BuildVectorSDNode>(N0->getOperand(1));
+    APInt SplatBits0;
+    if (BVN0 && BVN0->isConstantSplat(SplatBits0, SplatUndef, SplatBitSize,
+                                  HasAnyUndefs) && !HasAnyUndefs) {
+      BuildVectorSDNode *BVN1 = dyn_cast<BuildVectorSDNode>(N1->getOperand(1));
+      APInt SplatBits1;
+      if (BVN1 && BVN1->isConstantSplat(SplatBits1, SplatUndef, SplatBitSize,
+                                    HasAnyUndefs) && !HasAnyUndefs &&
+          SplatBits0 == ~SplatBits1) {
+        // Canonicalize the vector type to make instruction selection simpler.
+        EVT CanonicalVT = VT.is128BitVector() ? MVT::v4i32 : MVT::v2i32;
+        SDValue Result = DAG.getNode(ARMISD::VBSL, dl, CanonicalVT,
+                                     N0->getOperand(1), N0->getOperand(0),
+                                     N1->getOperand(1));
+        return DAG.getNode(ISD::BITCAST, dl, VT, Result);
+      }
+    }
+  }
+
   // Try to use the ARM/Thumb2 BFI (bitfield insert) instruction when
   // reasonable.
 
@@ -5343,7 +5375,6 @@ static SDValue PerformORCombine(SDNode *N,
   if (Subtarget->isThumb1Only() || !Subtarget->hasV6T2Ops())
     return SDValue();
 
-  SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
   DebugLoc DL = N->getDebugLoc();
   // 1) or (and A, mask), val => ARMbfi A, val, mask
   //      iff (val & mask) == val
@@ -5354,8 +5385,6 @@ static SDValue PerformORCombine(SDNode *N,
   //  2b) iff isBitFieldInvertedMask(~mask) && isBitFieldInvertedMask(mask2)
   //          && ~mask == mask2
   //  (i.e., copy a bitfield value into another bitfield of the same width)
-  if (N0.getOpcode() != ISD::AND)
-    return SDValue();
 
   if (VT != MVT::i32)
     return SDValue();
index e09c1dacfa8e8cd1864138b9e1a3fbe07b5f1ab0..e37855da3331969cfbbe9bf7420a221811e92ac8 100644 (file)
@@ -179,6 +179,9 @@ namespace llvm {
       // Vector AND with NOT of immediate
       VBICIMM,
 
+      // Vector bitwise select
+      VBSL,
+
       // Vector load N-element structure to all lanes:
       VLD2DUP = ISD::FIRST_TARGET_MEMORY_OPCODE,
       VLD3DUP,
index 374dac0bc20dfb7af1caf207f34eac1d61b2bdd4..c4fb3749b80e9304db1caa0725f16d842fd0283c 100644 (file)
@@ -80,6 +80,12 @@ def SDTARMVORRIMM : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0, 1>,
 def NEONvorrImm   : SDNode<"ARMISD::VORRIMM", SDTARMVORRIMM>;
 def NEONvbicImm   : SDNode<"ARMISD::VBICIMM", SDTARMVORRIMM>;
 
+def NEONvbsl      : SDNode<"ARMISD::VBSL",
+                           SDTypeProfile<1, 3, [SDTCisVec<0>,
+                                                SDTCisSameAs<0, 1>,
+                                                SDTCisSameAs<0, 2>,
+                                                SDTCisSameAs<0, 3>]>>;
+
 def NEONvdup      : SDNode<"ARMISD::VDUP", SDTypeProfile<1, 1, [SDTCisVec<0>]>>;
 
 // VDUPLANE can produce a quad-register result from a double-register source,
@@ -3767,16 +3773,21 @@ def  VBSLd    : N3VX<1, 0, 0b01, 0b0001, 0, 1, (outs DPR:$Vd),
                      (ins DPR:$src1, DPR:$Vn, DPR:$Vm),
                      N3RegFrm, IIC_VCNTiD,
                      "vbsl", "$Vd, $Vn, $Vm", "$src1 = $Vd",
-                     [(set DPR:$Vd,
-                       (v2i32 (or (and DPR:$Vn, DPR:$src1),
-                                  (and DPR:$Vm, (vnotd DPR:$src1)))))]>;
+                     [(set DPR:$Vd, (v2i32 (NEONvbsl DPR:$src1, DPR:$Vn, DPR:$Vm)))]>;
+
+def : Pat<(v2i32 (or (and DPR:$Vn, DPR:$Vd),
+                     (and DPR:$Vm, (vnotd DPR:$Vd)))),
+          (VBSLd DPR:$Vd, DPR:$Vn, DPR:$Vm)>;
+
 def  VBSLq    : N3VX<1, 0, 0b01, 0b0001, 1, 1, (outs QPR:$Vd),
                      (ins QPR:$src1, QPR:$Vn, QPR:$Vm),
                      N3RegFrm, IIC_VCNTiQ,
                      "vbsl", "$Vd, $Vn, $Vm", "$src1 = $Vd",
-                     [(set QPR:$Vd,
-                       (v4i32 (or (and QPR:$Vn, QPR:$src1),
-                                  (and QPR:$Vm, (vnotq QPR:$src1)))))]>;
+                     [(set QPR:$Vd, (v4i32 (NEONvbsl QPR:$src1, QPR:$Vn, QPR:$Vm)))]>;
+
+def : Pat<(v4i32 (or (and QPR:$Vn, QPR:$Vd),
+                     (and QPR:$Vm, (vnotq QPR:$Vd)))),
+          (VBSLq QPR:$Vd, QPR:$Vn, QPR:$Vm)>;
 
 //   VBIF     : Vector Bitwise Insert if False
 //              like VBSL but with: "vbif $dst, $src3, $src1", "$src2 = $dst",
diff --git a/test/CodeGen/ARM/vbsl-constant.ll b/test/CodeGen/ARM/vbsl-constant.ll
new file mode 100644 (file)
index 0000000..9a9cc5b
--- /dev/null
@@ -0,0 +1,97 @@
+; RUN: llc < %s -march=arm -mattr=+neon | FileCheck %s
+
+define <8 x i8> @v_bsli8(<8 x i8>* %A, <8 x i8>* %B, <8 x i8>* %C) nounwind {
+;CHECK: v_bsli8:
+;CHECK: vbsl
+       %tmp1 = load <8 x i8>* %A
+       %tmp2 = load <8 x i8>* %B
+       %tmp3 = load <8 x i8>* %C
+       %tmp4 = and <8 x i8> %tmp1, <i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3>
+       %tmp6 = and <8 x i8> %tmp3, <i8 -4, i8 -4, i8 -4, i8 -4, i8 -4, i8 -4, i8 -4, i8 -4>
+       %tmp7 = or <8 x i8> %tmp4, %tmp6
+       ret <8 x i8> %tmp7
+}
+
+define <4 x i16> @v_bsli16(<4 x i16>* %A, <4 x i16>* %B, <4 x i16>* %C) nounwind {
+;CHECK: v_bsli16:
+;CHECK: vbsl
+       %tmp1 = load <4 x i16>* %A
+       %tmp2 = load <4 x i16>* %B
+       %tmp3 = load <4 x i16>* %C
+       %tmp4 = and <4 x i16> %tmp1, <i16 3, i16 3, i16 3, i16 3>
+       %tmp6 = and <4 x i16> %tmp3, <i16 -4, i16 -4, i16 -4, i16 -4>
+       %tmp7 = or <4 x i16> %tmp4, %tmp6
+       ret <4 x i16> %tmp7
+}
+
+define <2 x i32> @v_bsli32(<2 x i32>* %A, <2 x i32>* %B, <2 x i32>* %C) nounwind {
+;CHECK: v_bsli32:
+;CHECK: vbsl
+       %tmp1 = load <2 x i32>* %A
+       %tmp2 = load <2 x i32>* %B
+       %tmp3 = load <2 x i32>* %C
+       %tmp4 = and <2 x i32> %tmp1, <i32 3, i32 3>
+       %tmp6 = and <2 x i32> %tmp3, <i32 -4, i32 -4>
+       %tmp7 = or <2 x i32> %tmp4, %tmp6
+       ret <2 x i32> %tmp7
+}
+
+define <1 x i64> @v_bsli64(<1 x i64>* %A, <1 x i64>* %B, <1 x i64>* %C) nounwind {
+;CHECK: v_bsli64:
+;CHECK: vbsl
+       %tmp1 = load <1 x i64>* %A
+       %tmp2 = load <1 x i64>* %B
+       %tmp3 = load <1 x i64>* %C
+       %tmp4 = and <1 x i64> %tmp1, <i64 3>
+       %tmp6 = and <1 x i64> %tmp3, <i64 -4>
+       %tmp7 = or <1 x i64> %tmp4, %tmp6
+       ret <1 x i64> %tmp7
+}
+
+define <16 x i8> @v_bslQi8(<16 x i8>* %A, <16 x i8>* %B, <16 x i8>* %C) nounwind {
+;CHECK: v_bslQi8:
+;CHECK: vbsl
+       %tmp1 = load <16 x i8>* %A
+       %tmp2 = load <16 x i8>* %B
+       %tmp3 = load <16 x i8>* %C
+       %tmp4 = and <16 x i8> %tmp1, <i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3>
+       %tmp6 = and <16 x i8> %tmp3, <i8 -4, i8 -4, i8 -4, i8 -4, i8 -4, i8 -4, i8 -4, i8 -4, i8 -4, i8 -4, i8 -4, i8 -4, i8 -4, i8 -4, i8 -4, i8 -4>
+       %tmp7 = or <16 x i8> %tmp4, %tmp6
+       ret <16 x i8> %tmp7
+}
+
+define <8 x i16> @v_bslQi16(<8 x i16>* %A, <8 x i16>* %B, <8 x i16>* %C) nounwind {
+;CHECK: v_bslQi16:
+;CHECK: vbsl
+       %tmp1 = load <8 x i16>* %A
+       %tmp2 = load <8 x i16>* %B
+       %tmp3 = load <8 x i16>* %C
+       %tmp4 = and <8 x i16> %tmp1, <i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3, i16 3>
+       %tmp6 = and <8 x i16> %tmp3, <i16 -4, i16 -4, i16 -4, i16 -4, i16 -4, i16 -4, i16 -4, i16 -4>
+       %tmp7 = or <8 x i16> %tmp4, %tmp6
+       ret <8 x i16> %tmp7
+}
+
+define <4 x i32> @v_bslQi32(<4 x i32>* %A, <4 x i32>* %B, <4 x i32>* %C) nounwind {
+;CHECK: v_bslQi32:
+;CHECK: vbsl
+       %tmp1 = load <4 x i32>* %A
+       %tmp2 = load <4 x i32>* %B
+       %tmp3 = load <4 x i32>* %C
+       %tmp4 = and <4 x i32> %tmp1, <i32 3, i32 3, i32 3, i32 3>
+       %tmp6 = and <4 x i32> %tmp3, <i32 -4, i32 -4, i32 -4, i32 -4>
+       %tmp7 = or <4 x i32> %tmp4, %tmp6
+       ret <4 x i32> %tmp7
+}
+
+define <2 x i64> @v_bslQi64(<2 x i64>* %A, <2 x i64>* %B, <2 x i64>* %C) nounwind {
+;CHECK: v_bslQi64:
+;CHECK: vbsl
+       %tmp1 = load <2 x i64>* %A
+       %tmp2 = load <2 x i64>* %B
+       %tmp3 = load <2 x i64>* %C
+       %tmp4 = and <2 x i64> %tmp1, <i64 3, i64 3>
+       %tmp6 = and <2 x i64> %tmp3, <i64 -4, i64 -4>
+       %tmp7 = or <2 x i64> %tmp4, %tmp6
+       ret <2 x i64> %tmp7
+}