Recognize code for doing vector gather/scatter index calculations with
authorDan Gohman <gohman@apple.com>
Mon, 15 Mar 2010 23:23:03 +0000 (23:23 +0000)
committerDan Gohman <gohman@apple.com>
Mon, 15 Mar 2010 23:23:03 +0000 (23:23 +0000)
32-bit indices. Instead of shuffling each element out of the index vector,
when all indices are needed, just store the input vector to the stack and
load the elements out.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@98588 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp
test/CodeGen/X86/gather-addresses.ll [new file with mode: 0644]

index 491c118fe3b06dae66f1c6125ddceed0eca9e366..2cf6dea0a38f7cfb6034ecefc5a6bed2854cbcbf 100644 (file)
@@ -990,6 +990,7 @@ X86TargetLowering::X86TargetLowering(X86TargetMachine &TM)
 
   // We have target-specific dag combine patterns for the following nodes:
   setTargetDAGCombine(ISD::VECTOR_SHUFFLE);
+  setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT);
   setTargetDAGCombine(ISD::BUILD_VECTOR);
   setTargetDAGCombine(ISD::SELECT);
   setTargetDAGCombine(ISD::SHL);
@@ -8853,6 +8854,87 @@ static SDValue PerformShuffleCombine(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+/// PerformShuffleCombine - Detect vector gather/scatter index generation
+/// and convert it from being a bunch of shuffles and extracts to a simple
+/// store and scalar loads to extract the elements.
+static SDValue PerformEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
+                                                const TargetLowering &TLI) {
+  SDValue InputVector = N->getOperand(0);
+
+  // Only operate on vectors of 4 elements, where the alternative shuffling
+  // gets to be more expensive.
+  if (InputVector.getValueType() != MVT::v4i32)
+    return SDValue();
+
+  // Check whether every use of InputVector is an EXTRACT_VECTOR_ELT with a
+  // single use which is a sign-extend or zero-extend, and all elements are
+  // used.
+  SmallVector<SDNode *, 4> Uses;
+  unsigned ExtractedElements = 0;
+  for (SDNode::use_iterator UI = InputVector.getNode()->use_begin(),
+       UE = InputVector.getNode()->use_end(); UI != UE; ++UI) {
+    if (UI.getUse().getResNo() != InputVector.getResNo())
+      return SDValue();
+
+    SDNode *Extract = *UI;
+    if (Extract->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+      return SDValue();
+
+    if (Extract->getValueType(0) != MVT::i32)
+      return SDValue();
+    if (!Extract->hasOneUse())
+      return SDValue();
+    if (Extract->use_begin()->getOpcode() != ISD::SIGN_EXTEND &&
+        Extract->use_begin()->getOpcode() != ISD::ZERO_EXTEND)
+      return SDValue();
+    if (!isa<ConstantSDNode>(Extract->getOperand(1)))
+      return SDValue();
+
+    // Record which element was extracted.
+    ExtractedElements |=
+      1 << cast<ConstantSDNode>(Extract->getOperand(1))->getZExtValue();
+
+    Uses.push_back(Extract);
+  }
+
+  // If not all the elements were used, this may not be worthwhile.
+  if (ExtractedElements != 15)
+    return SDValue();
+
+  // Ok, we've now decided to do the transformation.
+  DebugLoc dl = InputVector.getDebugLoc();
+
+  // Store the value to a temporary stack slot.
+  SDValue StackPtr = DAG.CreateStackTemporary(InputVector.getValueType());
+  SDValue Ch = DAG.getStore(DAG.getEntryNode(), dl, InputVector, StackPtr, NULL, 0,
+                            false, false, 0);
+
+  // Replace each use (extract) with a load of the appropriate element.
+  for (SmallVectorImpl<SDNode *>::iterator UI = Uses.begin(),
+       UE = Uses.end(); UI != UE; ++UI) {
+    SDNode *Extract = *UI;
+
+    // Compute the element's address.
+    SDValue Idx = Extract->getOperand(1);
+    unsigned EltSize =
+        InputVector.getValueType().getVectorElementType().getSizeInBits()/8;
+    uint64_t Offset = EltSize * cast<ConstantSDNode>(Idx)->getZExtValue();
+    SDValue OffsetVal = DAG.getConstant(Offset, TLI.getPointerTy());
+
+    SDValue ScalarAddr = DAG.getNode(ISD::ADD, dl, Idx.getValueType(), OffsetVal, StackPtr);
+
+    // Load the scalar.
+    SDValue LoadScalar = DAG.getLoad(Extract->getValueType(0), dl, Ch, ScalarAddr,
+                          NULL, 0, false, false, 0);
+
+    // Replace the exact with the load.
+    DAG.ReplaceAllUsesOfValueWith(SDValue(Extract, 0), LoadScalar);
+  }
+
+  // The replacement was made in place; don't return anything.
+  return SDValue();
+}
+
 /// PerformSELECTCombine - Do target-specific dag combines on SELECT nodes.
 static SDValue PerformSELECTCombine(SDNode *N, SelectionDAG &DAG,
                                     const X86Subtarget *Subtarget) {
@@ -9742,6 +9824,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   switch (N->getOpcode()) {
   default: break;
   case ISD::VECTOR_SHUFFLE: return PerformShuffleCombine(N, DAG, *this);
+  case ISD::EXTRACT_VECTOR_ELT:
+                        return PerformEXTRACT_VECTOR_ELTCombine(N, DAG, *this);
   case ISD::SELECT:         return PerformSELECTCombine(N, DAG, Subtarget);
   case X86ISD::CMOV:        return PerformCMOVCombine(N, DAG, DCI);
   case ISD::MUL:            return PerformMulCombine(N, DAG, DCI);
diff --git a/test/CodeGen/X86/gather-addresses.ll b/test/CodeGen/X86/gather-addresses.ll
new file mode 100644 (file)
index 0000000..0719838
--- /dev/null
@@ -0,0 +1,39 @@
+; RUN: llc -march=x86-64 < %s | FileCheck %s
+
+; When doing vector gather-scatter index calculation with 32-bit indices,
+; bounce the vector off of cache rather than shuffling each individual
+; element out of the index vector.
+
+; CHECK: pand     (%rdx), %xmm0
+; CHECK: movaps   %xmm0, -24(%rsp)
+; CHECK: movslq   -24(%rsp), %rax
+; CHECK: movsd    (%rdi,%rax,8), %xmm0
+; CHECK: movslq   -20(%rsp), %rax
+; CHECK: movhpd   (%rdi,%rax,8), %xmm0
+; CHECK: movslq   -16(%rsp), %rax
+; CHECK: movsd    (%rdi,%rax,8), %xmm1
+; CHECK: movslq   -12(%rsp), %rax
+; CHECK: movhpd   (%rdi,%rax,8), %xmm1
+
+define <4 x double> @foo(double* %p, <4 x i32>* %i, <4 x i32>* %h) nounwind {
+  %a = load <4 x i32>* %i
+  %b = load <4 x i32>* %h
+  %j = and <4 x i32> %a, %b
+  %d0 = extractelement <4 x i32> %j, i32 0
+  %d1 = extractelement <4 x i32> %j, i32 1
+  %d2 = extractelement <4 x i32> %j, i32 2
+  %d3 = extractelement <4 x i32> %j, i32 3
+  %q0 = getelementptr double* %p, i32 %d0
+  %q1 = getelementptr double* %p, i32 %d1
+  %q2 = getelementptr double* %p, i32 %d2
+  %q3 = getelementptr double* %p, i32 %d3
+  %r0 = load double* %q0
+  %r1 = load double* %q1
+  %r2 = load double* %q2
+  %r3 = load double* %q3
+  %v0 = insertelement <4 x double> undef, double %r0, i32 0
+  %v1 = insertelement <4 x double> %v0, double %r1, i32 1
+  %v2 = insertelement <4 x double> %v1, double %r2, i32 2
+  %v3 = insertelement <4 x double> %v2, double %r3, i32 3
+  ret <4 x double> %v3
+}