Teach DAG combine to fold (extract_subvec (concat v1, ..) i) to v_i
[oota-llvm.git] / lib / CodeGen / SelectionDAG / DAGCombiner.cpp
index 6e4a772a89d19f0fade508e58708261e1f76344f..4ac6d1b5163f1cd01af2ecc4e940cc006059b614 100644 (file)
@@ -8610,8 +8610,8 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) {
       return SDValue();
 
     // Only handle cases where both indexes are constants with the same type.
-    ConstantSDNode *InsIdx = dyn_cast<ConstantSDNode>(N->getOperand(1));
-    ConstantSDNode *ExtIdx = dyn_cast<ConstantSDNode>(V->getOperand(2));
+    ConstantSDNode *ExtIdx = dyn_cast<ConstantSDNode>(N->getOperand(1));
+    ConstantSDNode *InsIdx = dyn_cast<ConstantSDNode>(V->getOperand(2));
 
     if (InsIdx && ExtIdx &&
         InsIdx->getValueType(0).getSizeInBits() <= 64 &&
@@ -8628,6 +8628,21 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) {
     }
   }
 
+  if (V->getOpcode() == ISD::CONCAT_VECTORS) {
+    // Combine:
+    //    (extract_subvec (concat V1, V2, ...), i)
+    // Into:
+    //    Vi if possible
+    for (unsigned i = 0, e = V->getNumOperands(); i != e; ++i)
+      if (V->getOperand(i).getValueType() != NVT)
+        return SDValue();
+    unsigned Idx = dyn_cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
+    unsigned NumElems = NVT.getVectorNumElements();
+    assert((Idx % NumElems) == 0 &&
+           "IDX in concat is not a multiple of the result vector length.");
+    return V->getOperand(Idx / NumElems);
+  }
+
   return SDValue();
 }