[NVPTX] Generate a more optimal sequence for select of i1
[oota-llvm.git] / lib / Target / NVPTX / NVPTXISelLowering.cpp
index e605547b8e36e501b0021ec24ab460446d99fb97..3a13dc05b6751101c4adc45f1484bfa7fcd59798 100644 (file)
@@ -203,8 +203,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM)
   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
 
   // Turn FP extload into load/fextend
-  setLoadExtAction(ISD::EXTLOAD, MVT::f16, Expand);
-  setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
   // Turn FP truncstore into trunc + store.
   setTruncStoreAction(MVT::f32, MVT::f16, Expand);
   setTruncStoreAction(MVT::f64, MVT::f16, Expand);
@@ -214,12 +215,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM)
   setOperationAction(ISD::LOAD, MVT::i1, Custom);
   setOperationAction(ISD::STORE, MVT::i1, Custom);
 
-  setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
-  setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
-  setTruncStoreAction(MVT::i64, MVT::i1, Expand);
-  setTruncStoreAction(MVT::i32, MVT::i1, Expand);
-  setTruncStoreAction(MVT::i16, MVT::i1, Expand);
-  setTruncStoreAction(MVT::i8, MVT::i1, Expand);
+  for (MVT VT : MVT::integer_valuetypes()) {
+    setLoadExtAction(ISD::SEXTLOAD, VT, MVT::i1, Promote);
+    setLoadExtAction(ISD::ZEXTLOAD, VT, MVT::i1, Promote);
+    setTruncStoreAction(VT, MVT::i1, Expand);
+  }
 
   // This is legal in NVPTX
   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
@@ -232,9 +232,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM)
   setOperationAction(ISD::ADDE, MVT::i64, Expand);
 
   // Register custom handling for vector loads/stores
-  for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
-       ++i) {
-    MVT VT = (MVT::SimpleValueType) i;
+  for (MVT VT : MVT::vector_valuetypes()) {
     if (IsPTXVectorType(VT)) {
       setOperationAction(ISD::LOAD, VT, Custom);
       setOperationAction(ISD::STORE, VT, Custom);
@@ -261,6 +259,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM)
   setOperationAction(ISD::CTPOP, MVT::i32, Legal);
   setOperationAction(ISD::CTPOP, MVT::i64, Legal);
 
+  // PTX does not directly support SELP of i1, so promote to i32 first
+  setOperationAction(ISD::SELECT, MVT::i1, Custom);
+
   // We have some custom DAG combine patterns for these nodes
   setTargetDAGCombine(ISD::ADD);
   setTargetDAGCombine(ISD::AND);
@@ -1805,11 +1806,29 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::SRA_PARTS:
   case ISD::SRL_PARTS:
     return LowerShiftRightParts(Op, DAG);
+  case ISD::SELECT:
+    return LowerSelect(Op, DAG);
   default:
     llvm_unreachable("Custom lowering not defined for operation");
   }
 }
 
+SDValue NVPTXTargetLowering::LowerSelect(SDValue Op, SelectionDAG &DAG) const {
+  SDValue Op0 = Op->getOperand(0);
+  SDValue Op1 = Op->getOperand(1);
+  SDValue Op2 = Op->getOperand(2);
+  SDLoc DL(Op.getNode());
+
+  assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
+
+  Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1);
+  Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2);
+  SDValue Select = DAG.getNode(ISD::SELECT, DL, MVT::i32, Op0, Op1, Op2);
+  SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Select);
+
+  return Trunc;
+}
+
 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
   if (Op.getValueType() == MVT::i1)
     return LowerLOADi1(Op, DAG);