R600: Custom lower frem
[oota-llvm.git] / lib / Target / R600 / AMDGPUISelLowering.cpp
index e32344238656e580ce69d53be03aaaa323a52d16..8440c39c18d35e30a3d8f745cab17f89818b3cd5 100644 (file)
@@ -130,6 +130,9 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(TargetMachine &TM) :
   setOperationAction(ISD::FROUND, MVT::f32, Legal);
   setOperationAction(ISD::FTRUNC, MVT::f32, Legal);
 
+  setOperationAction(ISD::FREM, MVT::f32, Custom);
+  setOperationAction(ISD::FREM, MVT::f64, Custom);
+
   // Lower floating point store/load to integer store/load to reduce the number
   // of patterns in tablegen.
   setOperationAction(ISD::STORE, MVT::f32, Promote);
@@ -347,6 +350,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(TargetMachine &TM) :
     setOperationAction(ISD::FDIV, VT, Expand);
     setOperationAction(ISD::FEXP2, VT, Expand);
     setOperationAction(ISD::FLOG2, VT, Expand);
+    setOperationAction(ISD::FREM, VT, Expand);
     setOperationAction(ISD::FPOW, VT, Expand);
     setOperationAction(ISD::FFLOOR, VT, Expand);
     setOperationAction(ISD::FTRUNC, VT, Expand);
@@ -386,7 +390,7 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(TargetMachine &TM) :
   // There are no integer divide instructions, and these expand to a pretty
   // large sequence of instructions.
   setIntDivIsCheap(false);
-  setPow2DivIsCheap(false);
+  setPow2SDivIsCheap(false);
 
   // TODO: Investigate this when 64-bit divides are implemented.
   addBypassSlowDiv(64, 32);
@@ -441,12 +445,12 @@ bool AMDGPUTargetLowering::isLoadBitCastBeneficial(EVT LoadTy,
 
 bool AMDGPUTargetLowering::isFAbsFree(EVT VT) const {
   assert(VT.isFloatingPoint());
-  return VT == MVT::f32;
+  return VT == MVT::f32 || VT == MVT::f64;
 }
 
 bool AMDGPUTargetLowering::isFNegFree(EVT VT) const {
   assert(VT.isFloatingPoint());
-  return VT == MVT::f32;
+  return VT == MVT::f32 || VT == MVT::f64;
 }
 
 bool AMDGPUTargetLowering::isTruncateFree(EVT Source, EVT Dest) const {
@@ -548,6 +552,7 @@ SDValue AMDGPUTargetLowering::LowerOperation(SDValue Op,
   case ISD::INTRINSIC_WO_CHAIN: return LowerINTRINSIC_WO_CHAIN(Op, DAG);
   case ISD::UDIVREM: return LowerUDIVREM(Op, DAG);
   case ISD::SDIVREM: return LowerSDIVREM(Op, DAG);
+  case ISD::FREM: return LowerFREM(Op, DAG);
   case ISD::FCEIL: return LowerFCEIL(Op, DAG);
   case ISD::FTRUNC: return LowerFTRUNC(Op, DAG);
   case ISD::FRINT: return LowerFRINT(Op, DAG);
@@ -853,6 +858,10 @@ SDValue AMDGPUTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
     case Intrinsic::AMDGPU_rsq_clamped:
       return DAG.getNode(AMDGPUISD::RSQ_CLAMPED, DL, VT, Op.getOperand(1));
 
+    case Intrinsic::AMDGPU_ldexp:
+      return DAG.getNode(AMDGPUISD::LDEXP, DL, VT, Op.getOperand(1),
+                                                   Op.getOperand(2));
+
     case AMDGPUIntrinsic::AMDGPU_imax:
       return DAG.getNode(AMDGPUISD::SMAX, DL, VT, Op.getOperand(1),
                                                   Op.getOperand(2));
@@ -1387,7 +1396,7 @@ SDValue AMDGPUTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
 // This is a shortcut for integer division because we have fast i32<->f32
 // conversions, and fast f32 reciprocal instructions. The fractional part of a
 // float is enough to accurately represent up to a 24-bit integer.
-SDValue AMDGPUTargetLowering::LowerSDIVREM24(SDValue Op, SelectionDAG &DAG) const {
+SDValue AMDGPUTargetLowering::LowerDIVREM24(SDValue Op, SelectionDAG &DAG, bool sign) const {
   SDLoc DL(Op);
   EVT VT = Op.getValueType();
   SDValue LHS = Op.getOperand(0);
@@ -1395,6 +1404,9 @@ SDValue AMDGPUTargetLowering::LowerSDIVREM24(SDValue Op, SelectionDAG &DAG) cons
   MVT IntVT = MVT::i32;
   MVT FltVT = MVT::f32;
 
+  ISD::NodeType ToFp  = sign ? ISD::SINT_TO_FP : ISD::UINT_TO_FP;
+  ISD::NodeType ToInt = sign ? ISD::FP_TO_SINT : ISD::FP_TO_UINT;
+
   if (VT.isVector()) {
     unsigned NElts = VT.getVectorNumElements();
     IntVT = MVT::getVectorVT(MVT::i32, NElts);
@@ -1403,29 +1415,35 @@ SDValue AMDGPUTargetLowering::LowerSDIVREM24(SDValue Op, SelectionDAG &DAG) cons
 
   unsigned BitSize = VT.getScalarType().getSizeInBits();
 
-  // char|short jq = ia ^ ib;
-  SDValue jq = DAG.getNode(ISD::XOR, DL, VT, LHS, RHS);
+  SDValue jq = DAG.getConstant(1, IntVT);
 
-  // jq = jq >> (bitsize - 2)
-  jq = DAG.getNode(ISD::SRA, DL, VT, jq, DAG.getConstant(BitSize - 2, VT));
+  if (sign) {
+    // char|short jq = ia ^ ib;
+    jq = DAG.getNode(ISD::XOR, DL, VT, LHS, RHS);
 
-  // jq = jq | 0x1
-  jq = DAG.getNode(ISD::OR, DL, VT, jq, DAG.getConstant(1, VT));
+    // jq = jq >> (bitsize - 2)
+    jq = DAG.getNode(ISD::SRA, DL, VT, jq, DAG.getConstant(BitSize - 2, VT));
 
-  // jq = (int)jq
-  jq = DAG.getSExtOrTrunc(jq, DL, IntVT);
+    // jq = jq | 0x1
+    jq = DAG.getNode(ISD::OR, DL, VT, jq, DAG.getConstant(1, VT));
+
+    // jq = (int)jq
+    jq = DAG.getSExtOrTrunc(jq, DL, IntVT);
+  }
 
   // int ia = (int)LHS;
-  SDValue ia = DAG.getSExtOrTrunc(LHS, DL, IntVT);
+  SDValue ia = sign ?
+    DAG.getSExtOrTrunc(LHS, DL, IntVT) : DAG.getZExtOrTrunc(LHS, DL, IntVT);
 
   // int ib, (int)RHS;
-  SDValue ib = DAG.getSExtOrTrunc(RHS, DL, IntVT);
+  SDValue ib = sign ?
+    DAG.getSExtOrTrunc(RHS, DL, IntVT) : DAG.getZExtOrTrunc(RHS, DL, IntVT);
 
   // float fa = (float)ia;
-  SDValue fa = DAG.getNode(ISD::SINT_TO_FP, DL, FltVT, ia);
+  SDValue fa = DAG.getNode(ToFp, DL, FltVT, ia);
 
   // float fb = (float)ib;
-  SDValue fb = DAG.getNode(ISD::SINT_TO_FP, DL, FltVT, ib);
+  SDValue fb = DAG.getNode(ToFp, DL, FltVT, ib);
 
   // float fq = native_divide(fa, fb);
   SDValue fq = DAG.getNode(ISD::FMUL, DL, FltVT,
@@ -1442,7 +1460,7 @@ SDValue AMDGPUTargetLowering::LowerSDIVREM24(SDValue Op, SelectionDAG &DAG) cons
                            DAG.getNode(ISD::FMUL, DL, FltVT, fqneg, fb), fa);
 
   // int iq = (int)fq;
-  SDValue iq = DAG.getNode(ISD::FP_TO_SINT, DL, IntVT, fq);
+  SDValue iq = DAG.getNode(ToInt, DL, IntVT, fq);
 
   // fr = fabs(fr);
   fr = DAG.getNode(ISD::FABS, DL, FltVT, fr);
@@ -1458,11 +1476,13 @@ SDValue AMDGPUTargetLowering::LowerSDIVREM24(SDValue Op, SelectionDAG &DAG) cons
   // jq = (cv ? jq : 0);
   jq = DAG.getNode(ISD::SELECT, DL, VT, cv, jq, DAG.getConstant(0, VT));
 
-  // dst = iq + jq;
-  iq = DAG.getSExtOrTrunc(iq, DL, VT);
+  // dst = trunc/extend to legal type
+  iq = sign ? DAG.getSExtOrTrunc(iq, DL, VT) : DAG.getZExtOrTrunc(iq, DL, VT);
 
+  // dst = iq + jq;
   SDValue Div = DAG.getNode(ISD::ADD, DL, VT, iq, jq);
 
+  // Rem needs compensation, it's easier to recompute it
   SDValue Rem = DAG.getNode(ISD::MUL, DL, VT, Div, RHS);
   Rem = DAG.getNode(ISD::SUB, DL, VT, LHS, Rem);
 
@@ -1481,6 +1501,16 @@ SDValue AMDGPUTargetLowering::LowerUDIVREM(SDValue Op,
   SDValue Num = Op.getOperand(0);
   SDValue Den = Op.getOperand(1);
 
+  if (VT == MVT::i32) {
+    if (DAG.MaskedValueIsZero(Op.getOperand(0), APInt(32, 0xff << 24)) &&
+        DAG.MaskedValueIsZero(Op.getOperand(1), APInt(32, 0xff << 24))) {
+      // TODO: We technically could do this for i64, but shouldn't that just be
+      // handled by something generally reducing 64-bit division on 32-bit
+      // values to 32-bit?
+      return LowerDIVREM24(Op, DAG, false);
+    }
+  }
+
   // RCP =  URECIP(Den) = 2^32 / Den + e
   // e is rounding error.
   SDValue RCP = DAG.getNode(AMDGPUISD::URECIP, DL, VT, Den);
@@ -1591,7 +1621,7 @@ SDValue AMDGPUTargetLowering::LowerSDIVREM(SDValue Op,
       // TODO: We technically could do this for i64, but shouldn't that just be
       // handled by something generally reducing 64-bit division on 32-bit
       // values to 32-bit?
-      return LowerSDIVREM24(Op, DAG);
+      return LowerDIVREM24(Op, DAG, true);
     }
   }
 
@@ -1625,6 +1655,20 @@ SDValue AMDGPUTargetLowering::LowerSDIVREM(SDValue Op,
   return DAG.getMergeValues(Res, DL);
 }
 
+// (frem x, y) -> (fsub x, (fmul (ftrunc (fdiv x, y)), y))
+SDValue AMDGPUTargetLowering::LowerFREM(SDValue Op, SelectionDAG &DAG) const {
+  SDLoc SL(Op);
+  EVT VT = Op.getValueType();
+  SDValue X = Op.getOperand(0);
+  SDValue Y = Op.getOperand(1);
+
+  SDValue Div = DAG.getNode(ISD::FDIV, SL, VT, X, Y);
+  SDValue Floor = DAG.getNode(ISD::FTRUNC, SL, VT, Div);
+  SDValue Mul = DAG.getNode(ISD::FMUL, SL, VT, Floor, Y);
+
+  return DAG.getNode(ISD::FSUB, SL, VT, X, Mul);
+}
+
 SDValue AMDGPUTargetLowering::LowerFCEIL(SDValue Op, SelectionDAG &DAG) const {
   SDLoc SL(Op);
   SDValue Src = Op.getOperand(0);
@@ -2132,6 +2176,7 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(DWORDADDR)
   NODE_NAME_CASE(FRACT)
   NODE_NAME_CASE(CLAMP)
+  NODE_NAME_CASE(MAD)
   NODE_NAME_CASE(FMAX)
   NODE_NAME_CASE(SMAX)
   NODE_NAME_CASE(UMAX)
@@ -2147,6 +2192,7 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(RSQ)
   NODE_NAME_CASE(RSQ_LEGACY)
   NODE_NAME_CASE(RSQ_CLAMPED)
+  NODE_NAME_CASE(LDEXP)
   NODE_NAME_CASE(DOT4)
   NODE_NAME_CASE(BFE_U32)
   NODE_NAME_CASE(BFE_I32)