AVX-512: Handled extractelement from mask vector;
[oota-llvm.git] / lib / Target / X86 / X86ISelLowering.cpp
index 6df0fd880f2edbd11d9344e34d334b7735918cad..a878ea82ea14bf5b0c90c4fbb1cdcf0c607f05e3 100644 (file)
@@ -16323,6 +16323,44 @@ static SDValue XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG,
                      EltNo);
 }
 
+/// Extract one bit from mask vector, like v16i1 or v8i1.
+/// AVX-512 feature.
+static SDValue ExtractBitFromMaskVector(SDNode *N, SelectionDAG &DAG) {
+  SDValue Vec = N->getOperand(0);
+  SDLoc dl(Vec);
+  MVT VecVT = Vec.getSimpleValueType();
+  SDValue Idx = N->getOperand(1);
+  MVT EltVT = N->getSimpleValueType(0);
+  
+  assert((VecVT.getVectorElementType() == MVT::i1 && EltVT == MVT::i8) ||
+         "Unexpected operands in ExtractBitFromMaskVector");
+
+  // variable index
+  if (!isa<ConstantSDNode>(Idx)) {
+    MVT ExtVT = (VecVT == MVT::v8i1 ?  MVT::v8i64 : MVT::v16i32);
+    SDValue Ext = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVT, Vec);
+    SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
+                              ExtVT.getVectorElementType(), Ext);
+    return DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
+  }
+
+  unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue();
+
+  MVT ScalarVT = MVT::getIntegerVT(VecVT.getSizeInBits());
+  unsigned MaxShift = VecVT.getSizeInBits() - 1;
+  Vec = DAG.getNode(ISD::BITCAST, dl, ScalarVT, Vec);
+  Vec = DAG.getNode(ISD::SHL, dl, ScalarVT, Vec, 
+              DAG.getConstant(MaxShift - IdxVal, ScalarVT));
+  Vec = DAG.getNode(ISD::SRL, dl, ScalarVT, Vec,
+    DAG.getConstant(MaxShift, ScalarVT));
+
+  if (VecVT == MVT::v16i1) {
+    Vec = DAG.getNode(ISD::BITCAST, dl, MVT::i16, Vec);
+    return DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Vec);
+  }
+  return DAG.getNode(ISD::BITCAST, dl, MVT::i8, Vec);
+}
+
 /// PerformEXTRACT_VECTOR_ELTCombine - Detect vector gather/scatter index
 /// generation and convert it from being a bunch of shuffles and extracts
 /// to a simple store and scalar loads to extract the elements.
@@ -16333,6 +16371,11 @@ static SDValue PerformEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
     return NewOp;
 
   SDValue InputVector = N->getOperand(0);
+
+  if (InputVector.getValueType().getVectorElementType() == MVT::i1 &&
+      !DCI.isBeforeLegalize())
+    return ExtractBitFromMaskVector(N, DAG);
+
   // Detect whether we are trying to convert from mmx to i32 and the bitcast
   // from mmx to v2i32 has a single usage.
   if (InputVector.getNode()->getOpcode() == llvm::ISD::BITCAST &&