[DAG] Teach DAG to also reassociate vector operations
authorJuergen Ributzka <juergen@apple.com>
Mon, 13 Jan 2014 20:51:35 +0000 (20:51 +0000)
committerJuergen Ributzka <juergen@apple.com>
Mon, 13 Jan 2014 20:51:35 +0000 (20:51 +0000)
This commit teaches DAG to reassociate vector ops, which in turn enables
constant folding of vector op chains that appear later on during custom lowering
and DAG combine.

Reviewed by Andrea Di Biagio

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@199135 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/CodeGen/SelectionDAGNodes.h
lib/CodeGen/SelectionDAG/DAGCombiner.cpp
lib/CodeGen/SelectionDAG/SelectionDAG.cpp
test/CodeGen/X86/vector-gep.ll

index 08eda723c6bc55b044080213b7967c986a104254..00773b3e6612b89e012f687add0bfcb2e4cc825d 100644 (file)
@@ -1492,6 +1492,8 @@ public:
                        unsigned &SplatBitSize, bool &HasAnyUndefs,
                        unsigned MinSplatBits = 0, bool isBigEndian = false);
 
+  bool isConstant() const;
+
   static inline bool classof(const SDNode *N) {
     return N->getOpcode() == ISD::BUILD_VECTOR;
   }
index 76f1bc857099287471a556d168349808118fe68a..8b697bcb35ada1f91bcd061ee81fdc650c28e0ea 100644 (file)
@@ -610,6 +610,51 @@ static bool isOneUseSetCC(SDValue N) {
 SDValue DAGCombiner::ReassociateOps(unsigned Opc, SDLoc DL,
                                     SDValue N0, SDValue N1) {
   EVT VT = N0.getValueType();
+  if (VT.isVector()) {
+    if (N0.getOpcode() == Opc) {
+      BuildVectorSDNode *L = dyn_cast<BuildVectorSDNode>(N0.getOperand(1));
+      if(L && L->isConstant()) {
+        BuildVectorSDNode *R = dyn_cast<BuildVectorSDNode>(N1);
+        if (R && R->isConstant()) {
+          // reassoc. (op (op x, c1), c2) -> (op x, (op c1, c2))
+          SDValue OpNode = DAG.FoldConstantArithmetic(Opc, VT, L, R);
+          return DAG.getNode(Opc, DL, VT, N0.getOperand(0), OpNode);
+        }
+
+        if (N0.hasOneUse()) {
+          // reassoc. (op (op x, c1), y) -> (op (op x, y), c1) iff x+c1 has one
+          // use
+          SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT,
+                                       N0.getOperand(0), N1);
+          AddToWorkList(OpNode.getNode());
+          return DAG.getNode(Opc, DL, VT, OpNode, N0.getOperand(1));
+        }
+      }
+    }
+
+    if (N1.getOpcode() == Opc) {
+      BuildVectorSDNode *R = dyn_cast<BuildVectorSDNode>(N1.getOperand(1));
+      if (R && R->isConstant()) {
+        BuildVectorSDNode *L = dyn_cast<BuildVectorSDNode>(N0);
+        if (L && L->isConstant()) {
+          // reassoc. (op c2, (op x, c1)) -> (op x, (op c1, c2))
+          SDValue OpNode = DAG.FoldConstantArithmetic(Opc, VT, R, L);
+          return DAG.getNode(Opc, DL, VT, N1.getOperand(0), OpNode);
+        }
+        if (N1.hasOneUse()) {
+          // reassoc. (op y, (op x, c1)) -> (op (op x, y), c1) iff x+c1 has one
+          // use
+          SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT,
+                                       N1.getOperand(0), N0);
+          AddToWorkList(OpNode.getNode());
+          return DAG.getNode(Opc, DL, VT, OpNode, N1.getOperand(1));
+        }
+      }
+    }
+
+    return SDValue();
+  }
+
   if (N0.getOpcode() == Opc && isa<ConstantSDNode>(N0.getOperand(1))) {
     if (isa<ConstantSDNode>(N1)) {
       // reassoc. (op (op x, c1), c2) -> (op x, (op c1, c2))
@@ -5868,14 +5913,7 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
   if (!LegalTypes &&
       N0.getOpcode() == ISD::BUILD_VECTOR && N0.getNode()->hasOneUse() &&
       VT.isVector()) {
-    bool isSimple = true;
-    for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i)
-      if (N0.getOperand(i).getOpcode() != ISD::UNDEF &&
-          N0.getOperand(i).getOpcode() != ISD::Constant &&
-          N0.getOperand(i).getOpcode() != ISD::ConstantFP) {
-        isSimple = false;
-        break;
-      }
+    bool isSimple = cast<BuildVectorSDNode>(N0)->isConstant();
 
     EVT DestEltVT = N->getValueType(0).getVectorElementType();
     assert(!DestEltVT.isVector() &&
@@ -10381,18 +10419,15 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) {
   // this operation.
   if (LHS.getOpcode() == ISD::BUILD_VECTOR &&
       RHS.getOpcode() == ISD::BUILD_VECTOR) {
+    // Check if both vectors are constants. If not bail out.
+    if (!cast<BuildVectorSDNode>(LHS)->isConstant() &&
+        !cast<BuildVectorSDNode>(RHS)->isConstant())
+      return SDValue();
+
     SmallVector<SDValue, 8> Ops;
     for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) {
       SDValue LHSOp = LHS.getOperand(i);
       SDValue RHSOp = RHS.getOperand(i);
-      // If these two elements can't be folded, bail out.
-      if ((LHSOp.getOpcode() != ISD::UNDEF &&
-           LHSOp.getOpcode() != ISD::Constant &&
-           LHSOp.getOpcode() != ISD::ConstantFP) ||
-          (RHSOp.getOpcode() != ISD::UNDEF &&
-           RHSOp.getOpcode() != ISD::Constant &&
-           RHSOp.getOpcode() != ISD::ConstantFP))
-        break;
 
       // Can't fold divide by zero.
       if (N->getOpcode() == ISD::SDIV || N->getOpcode() == ISD::UDIV ||
index 8a1dfdc39e2d2ca8b0c284a7fb6e6ceabd967023..e003caeddb177231ec9896d72aadc71a08d47155 100644 (file)
@@ -6533,6 +6533,15 @@ bool BuildVectorSDNode::isConstantSplat(APInt &SplatValue,
   return true;
 }
 
+bool BuildVectorSDNode::isConstant() const {
+  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
+    unsigned Opc = getOperand(i).getOpcode();
+    if (Opc != ISD::UNDEF && Opc != ISD::Constant && Opc != ISD::ConstantFP)
+      return false;
+  }
+  return true;
+}
+
 bool ShuffleVectorSDNode::isSplatMask(const int *Mask, EVT VT) {
   // Find the first non-undef value in the shuffle mask.
   unsigned i, e;
index b87d8447e543d110e6295d3dcbbbc41a7c3ab629..762c8a81286b9f2fd1fddc3b1f779238daf2c6e5 100644 (file)
@@ -4,22 +4,26 @@
 ;CHECK-LABEL: AGEP0:
 define <4 x i32*> @AGEP0(i32* %ptr) nounwind {
 entry:
+;CHECK-LABEL: AGEP0
+;CHECK: vbroadcast
+;CHECK-NEXT: vpaddd
+;CHECK-NEXT: ret
   %vecinit.i = insertelement <4 x i32*> undef, i32* %ptr, i32 0
   %vecinit2.i = insertelement <4 x i32*> %vecinit.i, i32* %ptr, i32 1
   %vecinit4.i = insertelement <4 x i32*> %vecinit2.i, i32* %ptr, i32 2
   %vecinit6.i = insertelement <4 x i32*> %vecinit4.i, i32* %ptr, i32 3
-;CHECK: padd
   %A2 = getelementptr <4 x i32*> %vecinit6.i, <4 x i32> <i32 1, i32 2, i32 3, i32 4>
-;CHECK: padd
   %A3 = getelementptr <4 x i32*> %A2, <4 x i32> <i32 10, i32 14, i32 19, i32 233>
   ret <4 x i32*> %A3
-;CHECK: ret
 }
 
 ;CHECK-LABEL: AGEP1:
 define i32 @AGEP1(<4 x i32*> %param) nounwind {
 entry:
-;CHECK: padd
+;CHECK-LABEL: AGEP1
+;CHECK: vpaddd
+;CHECK-NEXT: vpextrd
+;CHECK-NEXT: movl
   %A2 = getelementptr <4 x i32*> %param, <4 x i32> <i32 1, i32 2, i32 3, i32 4>
   %k = extractelement <4 x i32*> %A2, i32 3
   %v = load i32* %k
@@ -30,8 +34,9 @@ entry:
 ;CHECK-LABEL: AGEP2:
 define i32 @AGEP2(<4 x i32*> %param, <4 x i32> %off) nounwind {
 entry:
-;CHECK: pslld $2
-;CHECK: padd
+;CHECK_LABEL: AGEP2
+;CHECK: vpslld $2
+;CHECK-NEXT: vpadd
   %A2 = getelementptr <4 x i32*> %param, <4 x i32> %off
   %k = extractelement <4 x i32*> %A2, i32 3
   %v = load i32* %k
@@ -42,8 +47,9 @@ entry:
 ;CHECK-LABEL: AGEP3:
 define <4 x i32*> @AGEP3(<4 x i32*> %param, <4 x i32> %off) nounwind {
 entry:
-;CHECK: pslld $2
-;CHECK: padd
+;CHECK-LABEL: AGEP3
+;CHECK: vpslld $2
+;CHECK-NEXT: vpadd
   %A2 = getelementptr <4 x i32*> %param, <4 x i32> %off
   %v = alloca i32
   %k = insertelement <4 x i32*> %A2, i32* %v, i32 3
@@ -54,10 +60,11 @@ entry:
 ;CHECK-LABEL: AGEP4:
 define <4 x i16*> @AGEP4(<4 x i16*> %param, <4 x i32> %off) nounwind {
 entry:
+;CHECK-LABEL: AGEP4
 ; Multiply offset by two (add it to itself).
-;CHECK: padd
+;CHECK: vpadd
 ; add the base to the offset
-;CHECKpadd
+;CHECK-NEXT: vpadd
   %A = getelementptr <4 x i16*> %param, <4 x i32> %off
   ret <4 x i16*> %A
 ;CHECK: ret
@@ -66,7 +73,8 @@ entry:
 ;CHECK-LABEL: AGEP5:
 define <4 x i8*> @AGEP5(<4 x i8*> %param, <4 x i8> %off) nounwind {
 entry:
-;CHECK: paddd
+;CHECK-LABEL: AGEP5
+;CHECK: vpaddd
   %A = getelementptr <4 x i8*> %param, <4 x i8> %off
   ret <4 x i8*> %A
 ;CHECK: ret
@@ -77,6 +85,7 @@ entry:
 ;CHECK-LABEL: AGEP6:
 define <4 x i8*> @AGEP6(<4 x i8*> %param, <4 x i32> %off) nounwind {
 entry:
+;CHECK-LABEL: AGEP6
 ;CHECK-NOT: pslld
   %A = getelementptr <4 x i8*> %param, <4 x i32> %off
   ret <4 x i8*> %A