[NVPTX] Generate a more optimal sequence for select of i1
[oota-llvm.git] / lib / Target / NVPTX / NVPTXISelLowering.cpp
index b7ca3f2bd68d30a13bf6ea4cdc3169bcea8dade6..3a13dc05b6751101c4adc45f1484bfa7fcd59798 100644 (file)
@@ -107,7 +107,7 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, Type *Ty,
 
 // NVPTXTargetLowering Constructor.
 NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM)
-    : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
+    : TargetLowering(TM), nvTM(&TM),
       nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
 
   // always lower memset, memcpy, and memmove intrinsics to load/store
@@ -203,8 +203,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM)
   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
 
   // Turn FP extload into load/fextend
-  setLoadExtAction(ISD::EXTLOAD, MVT::f16, Expand);
-  setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
   // Turn FP truncstore into trunc + store.
   setTruncStoreAction(MVT::f32, MVT::f16, Expand);
   setTruncStoreAction(MVT::f64, MVT::f16, Expand);
@@ -214,12 +215,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM)
   setOperationAction(ISD::LOAD, MVT::i1, Custom);
   setOperationAction(ISD::STORE, MVT::i1, Custom);
 
-  setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
-  setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
-  setTruncStoreAction(MVT::i64, MVT::i1, Expand);
-  setTruncStoreAction(MVT::i32, MVT::i1, Expand);
-  setTruncStoreAction(MVT::i16, MVT::i1, Expand);
-  setTruncStoreAction(MVT::i8, MVT::i1, Expand);
+  for (MVT VT : MVT::integer_valuetypes()) {
+    setLoadExtAction(ISD::SEXTLOAD, VT, MVT::i1, Promote);
+    setLoadExtAction(ISD::ZEXTLOAD, VT, MVT::i1, Promote);
+    setTruncStoreAction(VT, MVT::i1, Expand);
+  }
 
   // This is legal in NVPTX
   setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
@@ -232,9 +232,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM)
   setOperationAction(ISD::ADDE, MVT::i64, Expand);
 
   // Register custom handling for vector loads/stores
-  for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
-       ++i) {
-    MVT VT = (MVT::SimpleValueType) i;
+  for (MVT VT : MVT::vector_valuetypes()) {
     if (IsPTXVectorType(VT)) {
       setOperationAction(ISD::LOAD, VT, Custom);
       setOperationAction(ISD::STORE, VT, Custom);
@@ -261,6 +259,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM)
   setOperationAction(ISD::CTPOP, MVT::i32, Legal);
   setOperationAction(ISD::CTPOP, MVT::i64, Legal);
 
+  // PTX does not directly support SELP of i1, so promote to i32 first
+  setOperationAction(ISD::SELECT, MVT::i1, Custom);
+
   // We have some custom DAG combine patterns for these nodes
   setTargetDAGCombine(ISD::ADD);
   setTargetDAGCombine(ISD::AND);
@@ -905,16 +906,14 @@ NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
       O << ".param .b" << size << " _";
     } else if (isa<PointerType>(retTy)) {
       O << ".param .b" << getPointerTy().getSizeInBits() << " _";
+    } else if ((retTy->getTypeID() == Type::StructTyID) ||
+               isa<VectorType>(retTy)) {
+      O << ".param .align "
+        << retAlignment
+        << " .b8 _["
+        << getDataLayout()->getTypeAllocSize(retTy) << "]";
     } else {
-      if((retTy->getTypeID() == Type::StructTyID) ||
-         isa<VectorType>(retTy)) {
-        O << ".param .align "
-          << retAlignment
-          << " .b8 _["
-          << getDataLayout()->getTypeAllocSize(retTy) << "]";
-      } else {
-        assert(false && "Unknown return type");
-      }
+      llvm_unreachable("Unknown return type");
     }
     O << ") ";
   }
@@ -1355,7 +1354,12 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     //  .param .align 16 .b8 retval0[<size-in-bytes>], or
     //  .param .b<size-in-bits> retval0
     unsigned resultsz = TD->getTypeAllocSizeInBits(retTy);
-    if (retTy->isSingleValueType()) {
+    // Emit ".param .b<size-in-bits> retval0" instead of byte arrays only for
+    // these three types to match the logic in
+    // NVPTXAsmPrinter::printReturnValStr and NVPTXTargetLowering::getPrototype.
+    // Plus, this behavior is consistent with nvcc's.
+    if (retTy->isFloatingPointTy() || retTy->isIntegerTy() ||
+        retTy->isPointerTy()) {
       // Scalar needs to be at least 32bit wide
       if (resultsz < 32)
         resultsz = 32;
@@ -1451,8 +1455,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
       EVT ObjectVT = getValueType(retTy);
       unsigned NumElts = ObjectVT.getVectorNumElements();
       EVT EltVT = ObjectVT.getVectorElementType();
-      assert(nvTM->getTargetLowering()->getNumRegisters(F->getContext(),
-                                                        ObjectVT) == NumElts &&
+      assert(nvTM->getSubtargetImpl()->getTargetLowering()->getNumRegisters(
+                 F->getContext(), ObjectVT) == NumElts &&
              "Vector was not scalarized");
       unsigned sz = EltVT.getSizeInBits();
       bool needTruncate = sz < 8 ? true : false;
@@ -1802,11 +1806,29 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::SRA_PARTS:
   case ISD::SRL_PARTS:
     return LowerShiftRightParts(Op, DAG);
+  case ISD::SELECT:
+    return LowerSelect(Op, DAG);
   default:
     llvm_unreachable("Custom lowering not defined for operation");
   }
 }
 
+SDValue NVPTXTargetLowering::LowerSelect(SDValue Op, SelectionDAG &DAG) const {
+  SDValue Op0 = Op->getOperand(0);
+  SDValue Op1 = Op->getOperand(1);
+  SDValue Op2 = Op->getOperand(2);
+  SDLoc DL(Op.getNode());
+
+  assert(Op.getValueType() == MVT::i1 && "Custom lowering enabled only for i1");
+
+  Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1);
+  Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2);
+  SDValue Select = DAG.getNode(ISD::SELECT, DL, MVT::i32, Op0, Op1, Op2);
+  SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Select);
+
+  return Trunc;
+}
+
 SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
   if (Op.getValueType() == MVT::i1)
     return LowerLOADi1(Op, DAG);
@@ -2028,7 +2050,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
 
   const Function *F = MF.getFunction();
   const AttributeSet &PAL = F->getAttributes();
-  const TargetLowering *TLI = DAG.getTarget().getTargetLowering();
+  const TargetLowering *TLI = DAG.getSubtarget().getTargetLowering();
 
   SDValue Root = DAG.getRoot();
   std::vector<SDValue> OutChains;
@@ -2163,7 +2185,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
         unsigned NumElts = ObjectVT.getVectorNumElements();
         assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
                "Vector was not scalarized");
-        unsigned Ofst = 0;
         EVT EltVT = ObjectVT.getVectorElementType();
 
         // V1 load
@@ -2172,10 +2193,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
           // We only have one element, so just directly load it
           Value *SrcValue = Constant::getNullValue(PointerType::get(
               EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
-          SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
-                                        DAG.getConstant(Ofst, getPointerTy()));
           SDValue P = DAG.getLoad(
-              EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
+              EltVT, dl, Root, Arg, MachinePointerInfo(SrcValue), false,
               false, true,
               TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
           if (P.getNode())
@@ -2184,7 +2203,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
           if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
             P = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, P);
           InVals.push_back(P);
-          Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext()));
           ++InsIdx;
         } else if (NumElts == 2) {
           // V2 load
@@ -2192,10 +2210,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
           Value *SrcValue = Constant::getNullValue(PointerType::get(
               VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
-          SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
-                                        DAG.getConstant(Ofst, getPointerTy()));
           SDValue P = DAG.getLoad(
-              VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
+              VecVT, dl, Root, Arg, MachinePointerInfo(SrcValue), false,
               false, true,
               TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
           if (P.getNode())
@@ -2213,7 +2229,6 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
 
           InVals.push_back(Elt0);
           InVals.push_back(Elt1);
-          Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
           InsIdx += 2;
         } else {
           // V4 loads
@@ -2231,6 +2246,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
             VecSize = 2;
           }
           EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
+          unsigned Ofst = 0;
           for (unsigned i = 0; i < NumElts; i += VecSize) {
             Value *SrcValue = Constant::getNullValue(
                 PointerType::get(VecVT.getTypeForEVT(F->getContext()),
@@ -3270,16 +3286,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     Info.vol = 0;
     Info.readMem = true;
     Info.writeMem = false;
-
-    // alignment is available as metadata.
-    // Grab it and set the alignment.
-    assert(I.hasMetadataOtherThanDebugLoc() && "Must have alignment metadata");
-    MDNode *AlignMD = I.getMetadata("align");
-    assert(AlignMD && "Must have a non-null MDNode");
-    assert(AlignMD->getNumOperands() == 1 && "Must have a single operand");
-    Value *Align = AlignMD->getOperand(0);
-    int64_t Alignment = cast<ConstantInt>(Align)->getZExtValue();
-    Info.align = Alignment;
+    Info.align = cast<ConstantInt>(I.getArgOperand(1))->getZExtValue();
 
     return true;
   }
@@ -3299,16 +3306,7 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     Info.vol = 0;
     Info.readMem = true;
     Info.writeMem = false;
-
-    // alignment is available as metadata.
-    // Grab it and set the alignment.
-    assert(I.hasMetadataOtherThanDebugLoc() && "Must have alignment metadata");
-    MDNode *AlignMD = I.getMetadata("align");
-    assert(AlignMD && "Must have a non-null MDNode");
-    assert(AlignMD->getNumOperands() == 1 && "Must have a single operand");
-    Value *Align = AlignMD->getOperand(0);
-    int64_t Alignment = cast<ConstantInt>(Align)->getZExtValue();
-    Info.align = Alignment;
+    Info.align = cast<ConstantInt>(I.getArgOperand(1))->getZExtValue();
 
     return true;
   }
@@ -4515,3 +4513,10 @@ NVPTXTargetObjectFile::~NVPTXTargetObjectFile() {
   delete DwarfRangesSection;
   delete DwarfMacroInfoSection;
 }
+
+const MCSection *
+NVPTXTargetObjectFile::SelectSectionForGlobal(const GlobalValue *GV,
+                                              SectionKind Kind, Mangler &Mang,
+                                              const TargetMachine &TM) const {
+  return getDataSection();
+}