[NVPTX] Add isel patterns for bit-field extract (bfe)
authorJustin Holewinski <jholewinski@nvidia.com>
Fri, 27 Jun 2014 18:35:27 +0000 (18:35 +0000)
committerJustin Holewinski <jholewinski@nvidia.com>
Fri, 27 Jun 2014 18:35:27 +0000 (18:35 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@211932 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
lib/Target/NVPTX/NVPTXISelDAGToDAG.h
lib/Target/NVPTX/NVPTXInstrInfo.td
test/CodeGen/NVPTX/bfe.ll [new file with mode: 0644]

index cd308806c36a5e9a1ee493b0f97a9c8e2f481cb7..1ea47fc8d594d3b92aa2fe16bec149821f99d39c 100644 (file)
@@ -253,6 +253,12 @@ SDNode *NVPTXDAGToDAGISel::Select(SDNode *N) {
   case NVPTXISD::Suld3DV4I32Trap:
     ResNode = SelectSurfaceIntrinsic(N);
     break;
+  case ISD::AND:
+  case ISD::SRA:
+  case ISD::SRL:
+    // Try to select BFE
+    ResNode = SelectBFE(N);
+    break;
   case ISD::ADDRSPACECAST:
     ResNode = SelectAddrSpaceCast(N);
     break;
@@ -2959,6 +2965,214 @@ SDNode *NVPTXDAGToDAGISel::SelectSurfaceIntrinsic(SDNode *N) {
   return Ret;
 }
 
+/// SelectBFE - Look for instruction sequences that can be made more efficient
+/// by using the 'bfe' (bit-field extract) PTX instruction
+SDNode *NVPTXDAGToDAGISel::SelectBFE(SDNode *N) {
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+  SDValue Len;
+  SDValue Start;
+  SDValue Val;
+  bool IsSigned = false;
+
+  if (N->getOpcode() == ISD::AND) {
+    // Canonicalize the operands
+    // We want 'and %val, %mask'
+    if (isa<ConstantSDNode>(LHS) && !isa<ConstantSDNode>(RHS)) {
+      std::swap(LHS, RHS);
+    }
+
+    ConstantSDNode *Mask = dyn_cast<ConstantSDNode>(RHS);
+    if (!Mask) {
+      // We need a constant mask on the RHS of the AND
+      return NULL;
+    }
+
+    // Extract the mask bits
+    uint64_t MaskVal = Mask->getZExtValue();
+    if (!isMask_64(MaskVal)) {
+      // We *could* handle shifted masks here, but doing so would require an
+      // 'and' operation to fix up the low-order bits so we would trade
+      // shr+and for bfe+and, which has the same throughput
+      return NULL;
+    }
+
+    // How many bits are in our mask?
+    uint64_t NumBits = CountTrailingOnes_64(MaskVal);
+    Len = CurDAG->getTargetConstant(NumBits, MVT::i32);
+
+    if (LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SRA) {
+      // We have a 'srl/and' pair, extract the effective start bit and length
+      Val = LHS.getNode()->getOperand(0);
+      Start = LHS.getNode()->getOperand(1);
+      ConstantSDNode *StartConst = dyn_cast<ConstantSDNode>(Start);
+      if (StartConst) {
+        uint64_t StartVal = StartConst->getZExtValue();
+        // How many "good" bits do we have left?  "good" is defined here as bits
+        // that exist in the original value, not shifted in.
+        uint64_t GoodBits = Start.getValueType().getSizeInBits() - StartVal;
+        if (NumBits > GoodBits) {
+          // Do not handle the case where bits have been shifted in. In theory
+          // we could handle this, but the cost is likely higher than just
+          // emitting the srl/and pair.
+          return NULL;
+        }
+        Start = CurDAG->getTargetConstant(StartVal, MVT::i32);
+      } else {
+        // Do not handle the case where the shift amount (can be zero if no srl
+        // was found) is not constant. We could handle this case, but it would
+        // require run-time logic that would be more expensive than just
+        // emitting the srl/and pair.
+        return NULL;
+      }
+    } else {
+      // Do not handle the case where the LHS of the and is not a shift. While
+      // it would be trivial to handle this case, it would just transform
+      // 'and' -> 'bfe', but 'and' has higher-throughput.
+      return NULL;
+    }
+  } else if (N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) {
+    if (LHS->getOpcode() == ISD::AND) {
+      ConstantSDNode *ShiftCnst = dyn_cast<ConstantSDNode>(RHS);
+      if (!ShiftCnst) {
+        // Shift amount must be constant
+        return NULL;
+      }
+
+      uint64_t ShiftAmt = ShiftCnst->getZExtValue();
+
+      SDValue AndLHS = LHS->getOperand(0);
+      SDValue AndRHS = LHS->getOperand(1);
+
+      // Canonicalize the AND to have the mask on the RHS
+      if (isa<ConstantSDNode>(AndLHS)) {
+        std::swap(AndLHS, AndRHS);
+      }
+
+      ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(AndRHS);
+      if (!MaskCnst) {
+        // Mask must be constant
+        return NULL;
+      }
+
+      uint64_t MaskVal = MaskCnst->getZExtValue();
+      uint64_t NumZeros;
+      uint64_t NumBits;
+      if (isMask_64(MaskVal)) {
+        NumZeros = 0;
+        // The number of bits in the result bitfield will be the number of
+        // trailing ones (the AND) minus the number of bits we shift off
+        NumBits = CountTrailingOnes_64(MaskVal) - ShiftAmt;
+      } else if (isShiftedMask_64(MaskVal)) {
+        NumZeros = countTrailingZeros(MaskVal);
+        unsigned NumOnes = CountTrailingOnes_64(MaskVal >> NumZeros);
+        // The number of bits in the result bitfield will be the number of
+        // trailing zeros plus the number of set bits in the mask minus the
+        // number of bits we shift off
+        NumBits = NumZeros + NumOnes - ShiftAmt;
+      } else {
+        // This is not a mask we can handle
+        return NULL;
+      }
+
+      if (ShiftAmt < NumZeros) {
+        // Handling this case would require extra logic that would make this
+        // transformation non-profitable
+        return NULL;
+      }
+
+      Val = AndLHS;
+      Start = CurDAG->getTargetConstant(ShiftAmt, MVT::i32);
+      Len = CurDAG->getTargetConstant(NumBits, MVT::i32);
+    } else if (LHS->getOpcode() == ISD::SHL) {
+      // Here, we have a pattern like:
+      //
+      // (sra (shl val, NN), MM)
+      // or
+      // (srl (shl val, NN), MM)
+      //
+      // If MM >= NN, we can efficiently optimize this with bfe
+      Val = LHS->getOperand(0);
+
+      SDValue ShlRHS = LHS->getOperand(1);
+      ConstantSDNode *ShlCnst = dyn_cast<ConstantSDNode>(ShlRHS);
+      if (!ShlCnst) {
+        // Shift amount must be constant
+        return NULL;
+      }
+      uint64_t InnerShiftAmt = ShlCnst->getZExtValue();
+
+      SDValue ShrRHS = RHS;
+      ConstantSDNode *ShrCnst = dyn_cast<ConstantSDNode>(ShrRHS);
+      if (!ShrCnst) {
+        // Shift amount must be constant
+        return NULL;
+      }
+      uint64_t OuterShiftAmt = ShrCnst->getZExtValue();
+
+      // To avoid extra codegen and be profitable, we need Outer >= Inner
+      if (OuterShiftAmt < InnerShiftAmt) {
+        return NULL;
+      }
+
+      // If the outer shift is more than the type size, we have no bitfield to
+      // extract (since we also check that the inner shift is <= the outer shift
+      // then this also implies that the inner shift is < the type size)
+      if (OuterShiftAmt >= Val.getValueType().getSizeInBits()) {
+        return NULL;
+      }
+
+      Start =
+        CurDAG->getTargetConstant(OuterShiftAmt - InnerShiftAmt, MVT::i32);
+      Len =
+        CurDAG->getTargetConstant(Val.getValueType().getSizeInBits() -
+                                  OuterShiftAmt, MVT::i32);
+
+      if (N->getOpcode() == ISD::SRA) {
+        // If we have a arithmetic right shift, we need to use the signed bfe
+        // variant
+        IsSigned = true;
+      }
+    } else {
+      // No can do...
+      return NULL;
+    }
+  } else {
+    // No can do...
+    return NULL;
+  }
+
+
+  unsigned Opc;
+  // For the BFE operations we form here from "and" and "srl", always use the
+  // unsigned variants.
+  if (Val.getValueType() == MVT::i32) {
+    if (IsSigned) {
+      Opc = NVPTX::BFE_S32rii;
+    } else {
+      Opc = NVPTX::BFE_U32rii;
+    }
+  } else if (Val.getValueType() == MVT::i64) {
+    if (IsSigned) {
+      Opc = NVPTX::BFE_S64rii;
+    } else {
+      Opc = NVPTX::BFE_U64rii;
+    }
+  } else {
+    // We cannot handle this type
+    return NULL;
+  }
+
+  SDValue Ops[] = {
+    Val, Start, Len
+  };
+
+  SDNode *Ret =
+    CurDAG->getMachineNode(Opc, SDLoc(N), N->getVTList(), Ops);
+
+  return Ret;
+}
+
 // SelectDirectAddr - Match a direct address for DAG.
 // A direct address could be a globaladdress or externalsymbol.
 bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) {
index 11f92e79d99cb48082968a4bb2a285a0041968af..3aa1c8f72f561c151fb37752b72aed14ad31e01b 100644 (file)
@@ -71,6 +71,7 @@ private:
   SDNode *SelectAddrSpaceCast(SDNode *N);
   SDNode *SelectTextureIntrinsic(SDNode *N);
   SDNode *SelectSurfaceIntrinsic(SDNode *N);
+  SDNode *SelectBFE(SDNode *N);
         
   inline SDValue getI32Imm(unsigned Imm) {
     return CurDAG->getTargetConstant(Imm, MVT::i32);
index 14de4b76f0f907e3e6cadea35acb7f5ee6ffc751..e94250b38da7e33fbfd4e694bbc85c77bd57f875 100644 (file)
@@ -1179,6 +1179,29 @@ def ROTR64reg_sw : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src,
     !strconcat("}}", ""))))))))),
     [(set Int64Regs:$dst, (rotr Int64Regs:$src, Int32Regs:$amt))]>;
 
+// BFE - bit-field extract
+
+multiclass BFE<string TyStr, RegisterClass RC> {
+  // BFE supports both 32-bit and 64-bit values, but the start and length
+  // operands are always 32-bit
+  def rrr
+    : NVPTXInst<(outs RC:$d),
+                (ins RC:$a, Int32Regs:$b, Int32Regs:$c),
+                !strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
+  def rri
+    : NVPTXInst<(outs RC:$d),
+                (ins RC:$a, Int32Regs:$b, i32imm:$c),
+                !strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
+  def rii
+    : NVPTXInst<(outs RC:$d),
+                (ins RC:$a, i32imm:$b, i32imm:$c),
+                !strconcat("bfe.", TyStr, " \t$d, $a, $b, $c;"), []>;
+}
+
+defm BFE_S32 : BFE<"s32", Int32Regs>;
+defm BFE_U32 : BFE<"u32", Int32Regs>;
+defm BFE_S64 : BFE<"s64", Int64Regs>;
+defm BFE_U64 : BFE<"u64", Int64Regs>;
 
 //-----------------------------------
 // General Comparison
diff --git a/test/CodeGen/NVPTX/bfe.ll b/test/CodeGen/NVPTX/bfe.ll
new file mode 100644 (file)
index 0000000..2e816fe
--- /dev/null
@@ -0,0 +1,32 @@
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
+
+
+; CHECK: bfe0
+define i32 @bfe0(i32 %a) {
+; CHECK: bfe.u32 %r{{[0-9]+}}, %r{{[0-9]+}}, 4, 4
+; CHECK-NOT: shr
+; CHECK-NOT: and
+  %val0 = ashr i32 %a, 4
+  %val1 = and i32 %val0, 15
+  ret i32 %val1
+}
+
+; CHECK: bfe1
+define i32 @bfe1(i32 %a) {
+; CHECK: bfe.u32 %r{{[0-9]+}}, %r{{[0-9]+}}, 3, 3
+; CHECK-NOT: shr
+; CHECK-NOT: and
+  %val0 = ashr i32 %a, 3
+  %val1 = and i32 %val0, 7
+  ret i32 %val1
+}
+
+; CHECK: bfe2
+define i32 @bfe2(i32 %a) {
+; CHECK: bfe.u32 %r{{[0-9]+}}, %r{{[0-9]+}}, 5, 3
+; CHECK-NOT: shr
+; CHECK-NOT: and
+  %val0 = ashr i32 %a, 5
+  %val1 = and i32 %val0, 7
+  ret i32 %val1
+}