R600/SI: Use v_cvt_f32_ubyte* instructions
[oota-llvm.git] / lib / Target / R600 / SIISelLowering.cpp
index 608aad2d3ca0d1376184be0163b2c5f4ae6ec11d..846aeb63093a13587bf1c9ff6cbf870c26ef86f8 100644 (file)
@@ -24,6 +24,7 @@
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/IR/Function.h"
+#include "llvm/ADT/SmallString.h"
 
 using namespace llvm;
 
@@ -214,6 +215,8 @@ SITargetLowering::SITargetLowering(TargetMachine &TM) :
   setTargetDAGCombine(ISD::SELECT_CC);
   setTargetDAGCombine(ISD::SETCC);
 
+  setTargetDAGCombine(ISD::UINT_TO_FP);
+
   setSchedulingPreference(Sched::RegPressure);
 }
 
@@ -979,6 +982,96 @@ SDValue SITargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
 // Custom DAG optimizations
 //===----------------------------------------------------------------------===//
 
+SDValue SITargetLowering::performUCharToFloatCombine(SDNode *N,
+                                                     DAGCombinerInfo &DCI) {
+  EVT VT = N->getValueType(0);
+  EVT ScalarVT = VT.getScalarType();
+  if (ScalarVT != MVT::f32)
+    return SDValue();
+
+  SelectionDAG &DAG = DCI.DAG;
+  SDLoc DL(N);
+
+  SDValue Src = N->getOperand(0);
+  EVT SrcVT = Src.getValueType();
+
+  // TODO: We could try to match extracting the higher bytes, which would be
+  // easier if i8 vectors weren't promoted to i32 vectors, particularly after
+  // types are legalized. v4i8 -> v4f32 is probably the only case to worry
+  // about in practice.
+  if (DCI.isAfterLegalizeVectorOps() && SrcVT == MVT::i32) {
+    if (DAG.MaskedValueIsZero(Src, APInt::getHighBitsSet(32, 24))) {
+      SDValue Cvt = DAG.getNode(AMDGPUISD::CVT_F32_UBYTE0, DL, VT, Src);
+      DCI.AddToWorklist(Cvt.getNode());
+      return Cvt;
+    }
+  }
+
+  // We are primarily trying to catch operations on illegal vector types
+  // before they are expanded.
+  // For scalars, we can use the more flexible method of checking masked bits
+  // after legalization.
+  if (!DCI.isBeforeLegalize() ||
+      !SrcVT.isVector() ||
+      SrcVT.getVectorElementType() != MVT::i8) {
+    return SDValue();
+  }
+
+  assert(DCI.isBeforeLegalize() && "Unexpected legal type");
+
+  // Weird sized vectors are a pain to handle, but we know 3 is really the same
+  // size as 4.
+  unsigned NElts = SrcVT.getVectorNumElements();
+  if (!SrcVT.isSimple() && NElts != 3)
+    return SDValue();
+
+  // Handle v4i8 -> v4f32 extload. Replace the v4i8 with a legal i32 load to
+  // prevent a mess from expanding to v4i32 and repacking.
+  if (ISD::isNormalLoad(Src.getNode()) && Src.hasOneUse()) {
+    EVT LoadVT = getEquivalentMemType(*DAG.getContext(), SrcVT);
+    EVT RegVT = getEquivalentLoadRegType(*DAG.getContext(), SrcVT);
+    EVT FloatVT = EVT::getVectorVT(*DAG.getContext(), MVT::f32, NElts);
+
+    LoadSDNode *Load = cast<LoadSDNode>(Src);
+    SDValue NewLoad = DAG.getExtLoad(ISD::ZEXTLOAD, DL, RegVT,
+                                     Load->getChain(),
+                                     Load->getBasePtr(),
+                                     LoadVT,
+                                     Load->getMemOperand());
+
+    // Make sure successors of the original load stay after it by updating
+    // them to use the new Chain.
+    DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 1), NewLoad.getValue(1));
+
+    SmallVector<SDValue, 4> Elts;
+    if (RegVT.isVector())
+      DAG.ExtractVectorElements(NewLoad, Elts);
+    else
+      Elts.push_back(NewLoad);
+
+    SmallVector<SDValue, 4> Ops;
+
+    unsigned EltIdx = 0;
+    for (SDValue Elt : Elts) {
+      unsigned ComponentsInElt = std::min(4u, NElts - 4 * EltIdx);
+      for (unsigned I = 0; I < ComponentsInElt; ++I) {
+        unsigned Opc = AMDGPUISD::CVT_F32_UBYTE0 + I;
+        SDValue Cvt = DAG.getNode(Opc, DL, MVT::f32, Elt);
+        DCI.AddToWorklist(Cvt.getNode());
+        Ops.push_back(Cvt);
+      }
+
+      ++EltIdx;
+    }
+
+    assert(Ops.size() == NElts);
+
+    return DAG.getNode(ISD::BUILD_VECTOR, DL, FloatVT, Ops);
+  }
+
+  return SDValue();
+}
+
 SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
                                             DAGCombinerInfo &DCI) const {
   SelectionDAG &DAG = DCI.DAG;
@@ -1020,6 +1113,31 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
       }
       break;
     }
+
+  case AMDGPUISD::CVT_F32_UBYTE0:
+  case AMDGPUISD::CVT_F32_UBYTE1:
+  case AMDGPUISD::CVT_F32_UBYTE2:
+  case AMDGPUISD::CVT_F32_UBYTE3: {
+    unsigned Offset = N->getOpcode() - AMDGPUISD::CVT_F32_UBYTE0;
+
+    SDValue Src = N->getOperand(0);
+    APInt Demanded = APInt::getBitsSet(32, 8 * Offset, 8 * Offset + 8);
+
+    APInt KnownZero, KnownOne;
+    TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
+                                          !DCI.isBeforeLegalizeOps());
+    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+    if (TLO.ShrinkDemandedConstant(Src, Demanded) ||
+        TLI.SimplifyDemandedBits(Src, Demanded, KnownZero, KnownOne, TLO)) {
+      DCI.CommitTargetLoweringOpt(TLO);
+    }
+
+    break;
+  }
+
+  case ISD::UINT_TO_FP: {
+    return performUCharToFloatCombine(N, DCI);
+  }
   }
 
   return AMDGPUTargetLowering::PerformDAGCombine(N, DCI);