From: Justin Holewinski Date: Fri, 30 Sep 2011 12:54:43 +0000 (+0000) Subject: PTX: Attempt to cleanup/unify the handling of FP rounding modes. This requires X-Git-Url: http://plrg.eecs.uci.edu/git/?a=commitdiff_plain;h=c90e149ee43b5ffa751a7a03fab8b1a54cc3de92;p=oota-llvm.git PTX: Attempt to cleanup/unify the handling of FP rounding modes. This requires us to manually provide Pat<> definitions for all FP instruction patterns. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@140849 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Target/PTX/CMakeLists.txt b/lib/Target/PTX/CMakeLists.txt index ca24e0d5087..2d5bc452ce0 100644 --- a/lib/Target/PTX/CMakeLists.txt +++ b/lib/Target/PTX/CMakeLists.txt @@ -12,6 +12,7 @@ add_llvm_target(PTXCodeGen PTXISelDAGToDAG.cpp PTXISelLowering.cpp PTXInstrInfo.cpp + PTXFPRoundingModePass.cpp PTXFrameLowering.cpp PTXMCAsmStreamer.cpp PTXMCInstLower.cpp diff --git a/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp b/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp index ca943dee148..a4e03492a0a 100644 --- a/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp +++ b/lib/Target/PTX/InstPrinter/PTXInstPrinter.cpp @@ -19,6 +19,7 @@ #include "llvm/MC/MCInst.h" #include "llvm/MC/MCSymbol.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; @@ -146,4 +147,46 @@ void PTXInstPrinter::printMemOperand(const MCInst *MI, unsigned OpNo, printOperand(MI, OpNo+1, O); } +void PTXInstPrinter::printRoundingMode(const MCInst *MI, unsigned OpNo, + raw_ostream &O) { + const MCOperand &Op = MI->getOperand(OpNo); + assert (Op.isImm() && "Rounding modes must be immediate values"); + switch (Op.getImm()) { + default: + llvm_unreachable("Unknown rounding mode!"); + case PTXRoundingMode::RndDefault: + llvm_unreachable("FP rounding-mode pass did not handle instruction!"); + break; + case PTXRoundingMode::RndNone: + // Do not print anything. + break; + case PTXRoundingMode::RndNearestEven: + O << ".rn"; + break; + case PTXRoundingMode::RndTowardsZero: + O << ".rz"; + break; + case PTXRoundingMode::RndNegInf: + O << ".rm"; + break; + case PTXRoundingMode::RndPosInf: + O << ".rp"; + break; + case PTXRoundingMode::RndApprox: + O << ".approx"; + break; + case PTXRoundingMode::RndNearestEvenInt: + O << ".rni"; + break; + case PTXRoundingMode::RndTowardsZeroInt: + O << ".rzi"; + break; + case PTXRoundingMode::RndNegInfInt: + O << ".rmi"; + break; + case PTXRoundingMode::RndPosInfInt: + O << ".rpi"; + break; + } +} diff --git a/lib/Target/PTX/InstPrinter/PTXInstPrinter.h b/lib/Target/PTX/InstPrinter/PTXInstPrinter.h index 73a7977a658..86dfd482885 100644 --- a/lib/Target/PTX/InstPrinter/PTXInstPrinter.h +++ b/lib/Target/PTX/InstPrinter/PTXInstPrinter.h @@ -39,6 +39,7 @@ public: void printCall(const MCInst *MI, raw_ostream &O); void printOperand(const MCInst *MI, unsigned OpNo, raw_ostream &O); void printMemOperand(const MCInst *MI, unsigned OpNo, raw_ostream &O); + void printRoundingMode(const MCInst *MI, unsigned OpNo, raw_ostream &O); }; } diff --git a/lib/Target/PTX/MCTargetDesc/PTXBaseInfo.h b/lib/Target/PTX/MCTargetDesc/PTXBaseInfo.h index 58da5b38460..5339e47915b 100644 --- a/lib/Target/PTX/MCTargetDesc/PTXBaseInfo.h +++ b/lib/Target/PTX/MCTargetDesc/PTXBaseInfo.h @@ -35,6 +35,26 @@ namespace llvm { PRED_NONE = 2 }; } // namespace PTX + + /// Namespace to hold all target-specific flags. + namespace PTXRoundingMode { + // Instruction Flags + enum { + // Rounding Mode Flags + RndMask = 15, + RndDefault = 0, // --- + RndNone = 1, // + RndNearestEven = 2, // .rn + RndTowardsZero = 3, // .rz + RndNegInf = 4, // .rm + RndPosInf = 5, // .rp + RndApprox = 6, // .approx + RndNearestEvenInt = 7, // .rni + RndTowardsZeroInt = 8, // .rzi + RndNegInfInt = 9, // .rmi + RndPosInfInt = 10 // .rpi + }; + } // namespace PTXII } // namespace llvm #endif diff --git a/lib/Target/PTX/PTX.h b/lib/Target/PTX/PTX.h index fd74c1e4673..7d46cce4aec 100644 --- a/lib/Target/PTX/PTX.h +++ b/lib/Target/PTX/PTX.h @@ -31,6 +31,9 @@ namespace llvm { FunctionPass *createPTXMFInfoExtract(PTXTargetMachine &TM, CodeGenOpt::Level OptLevel); + FunctionPass *createPTXFPRoundingModePass(PTXTargetMachine &TM, + CodeGenOpt::Level OptLevel); + FunctionPass *createPTXRegisterAllocator(); void LowerPTXMachineInstrToMCInst(const MachineInstr *MI, MCInst &OutMI, diff --git a/lib/Target/PTX/PTXFPRoundingModePass.cpp b/lib/Target/PTX/PTXFPRoundingModePass.cpp new file mode 100644 index 00000000000..7fa435c6101 --- /dev/null +++ b/lib/Target/PTX/PTXFPRoundingModePass.cpp @@ -0,0 +1,155 @@ +//===-- PTXFPRoundingModePass.cpp - Assign rounding modes pass ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines a machine function pass that sets appropriate FP rounding +// modes for all relevant instructions. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "ptx-fp-rounding-mode" + +#include "PTX.h" +#include "PTXTargetMachine.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +// NOTE: PTXFPRoundingModePass should be executed just before emission. + +namespace llvm { + /// PTXFPRoundingModePass - Pass to assign appropriate FP rounding modes to + /// all FP instructions. Essentially, this pass just looks for all FP + /// instructions that have a rounding mode set to RndDefault, and sets an + /// appropriate rounding mode based on the target device. + /// + class PTXFPRoundingModePass : public MachineFunctionPass { + private: + static char ID; + PTXTargetMachine& TargetMachine; + + public: + PTXFPRoundingModePass(PTXTargetMachine &TM, CodeGenOpt::Level OptLevel) + : MachineFunctionPass(ID), + TargetMachine(TM) {} + + virtual bool runOnMachineFunction(MachineFunction &MF); + + virtual const char *getPassName() const { + return "PTX FP Rounding Mode Pass"; + } + + private: + + void processInstruction(MachineInstr &MI); + }; // class PTXFPRoundingModePass +} // namespace llvm + +using namespace llvm; + +char PTXFPRoundingModePass::ID = 0; + +bool PTXFPRoundingModePass::runOnMachineFunction(MachineFunction &MF) { + + // Look at each basic block + for (MachineFunction::iterator bbi = MF.begin(), bbe = MF.end(); bbi != bbe; + ++bbi) { + MachineBasicBlock &MBB = *bbi; + // Look at each instruction + for (MachineBasicBlock::iterator ii = MBB.begin(), ie = MBB.end(); + ii != ie; ++ii) { + MachineInstr &MI = *ii; + processInstruction(MI); + } + } + return false; +} + +void PTXFPRoundingModePass::processInstruction(MachineInstr &MI) { + // If the instruction has a rounding mode set to RndDefault, then assign an + // appropriate rounding mode based on the target device. + const PTXSubtarget& ST = TargetMachine.getSubtarget(); + switch (MI.getOpcode()) { + case PTX::FADDrr32: + case PTX::FADDri32: + case PTX::FADDrr64: + case PTX::FADDri64: + case PTX::FSUBrr32: + case PTX::FSUBri32: + case PTX::FSUBrr64: + case PTX::FSUBri64: + case PTX::FMULrr32: + case PTX::FMULri32: + case PTX::FMULrr64: + case PTX::FMULri64: + if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) { + MI.getOperand(1).setImm(PTXRoundingMode::RndNearestEven); + } + break; + case PTX::FNEGrr32: + case PTX::FNEGri32: + case PTX::FNEGrr64: + case PTX::FNEGri64: + if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) { + MI.getOperand(1).setImm(PTXRoundingMode::RndNone); + } + break; + case PTX::FDIVrr32: + case PTX::FDIVri32: + case PTX::FDIVrr64: + case PTX::FDIVri64: + if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) { + if (ST.fdivNeedsRoundingMode()) + MI.getOperand(1).setImm(PTXRoundingMode::RndNearestEven); + else + MI.getOperand(1).setImm(PTXRoundingMode::RndNone); + } + break; + case PTX::FMADrrr32: + case PTX::FMADrri32: + case PTX::FMADrii32: + case PTX::FMADrrr64: + case PTX::FMADrri64: + case PTX::FMADrii64: + if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) { + if (ST.fmadNeedsRoundingMode()) + MI.getOperand(1).setImm(PTXRoundingMode::RndNearestEven); + else + MI.getOperand(1).setImm(PTXRoundingMode::RndNone); + } + break; + case PTX::FSQRTrr32: + case PTX::FSQRTri32: + case PTX::FSQRTrr64: + case PTX::FSQRTri64: + if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) { + MI.getOperand(1).setImm(PTXRoundingMode::RndNearestEven); + } + break; + case PTX::FSINrr32: + case PTX::FSINri32: + case PTX::FSINrr64: + case PTX::FSINri64: + case PTX::FCOSrr32: + case PTX::FCOSri32: + case PTX::FCOSrr64: + case PTX::FCOSri64: + if (MI.getOperand(1).getImm() == PTXRoundingMode::RndDefault) { + MI.getOperand(1).setImm(PTXRoundingMode::RndApprox); + } + break; + } +} + +FunctionPass *llvm::createPTXFPRoundingModePass(PTXTargetMachine &TM, + CodeGenOpt::Level OptLevel) { + return new PTXFPRoundingModePass(TM, OptLevel); +} + diff --git a/lib/Target/PTX/PTXInstrFormats.td b/lib/Target/PTX/PTXInstrFormats.td index 6632bbfbc5d..397fdc319a8 100644 --- a/lib/Target/PTX/PTXInstrFormats.td +++ b/lib/Target/PTX/PTXInstrFormats.td @@ -7,12 +7,39 @@ // //===----------------------------------------------------------------------===// + +// Rounding Mode Specifier +/*class RoundingMode val> { + bits<3> Value = val; +} + +def RndDefault : RoundingMode<0>; +def RndNearestEven : RoundingMode<1>; +def RndNearestZero : RoundingMode<2>; +def RndNegInf : RoundingMode<3>; +def RndPosInf : RoundingMode<4>; +def RndApprox : RoundingMode<5>;*/ + + +// Rounding Mode Operand +def RndMode : Operand { + let PrintMethod = "printRoundingMode"; +} + +def RndDefault : PatLeaf<(i32 0)>; + // PTX Predicate operand, default to (0, 0) = (zero-reg, none). // Leave PrintMethod empty; predicate printing is defined elsewhere. def pred : PredicateOperand; +def RndModeOperand : Operand { + let MIOperandInfo = (ops i32imm); +} + +// Instruction Types let Namespace = "PTX" in { + class InstPTX pattern> : Instruction { dag OutOperandList = oops; diff --git a/lib/Target/PTX/PTXInstrInfo.td b/lib/Target/PTX/PTXInstrInfo.td index 08cb10ecf6a..0a7900f1525 100644 --- a/lib/Target/PTX/PTXInstrInfo.td +++ b/lib/Target/PTX/PTXInstrInfo.td @@ -80,75 +80,67 @@ def PTXcopyaddress // Instruction Class Templates //===----------------------------------------------------------------------===// +// For floating-point instructions, we cannot just embed the pattern into the +// instruction definition since we need to muck around with the rounding mode, +// and I do not know how to insert constants into instructions directly from +// pattern matches. + //===- Floating-Point Instructions - 2 Operand Form -----------------------===// -multiclass PTX_FLOAT_2OP { +multiclass PTX_FLOAT_2OP { def rr32 : InstPTX<(outs RegF32:$d), - (ins RegF32:$a), - !strconcat(opcstr, ".f32\t$d, $a"), - [(set RegF32:$d, (opnode RegF32:$a))]>; + (ins RndMode:$r, RegF32:$a), + !strconcat(opcstr, "$r.f32\t$d, $a"), []>; def ri32 : InstPTX<(outs RegF32:$d), - (ins f32imm:$a), - !strconcat(opcstr, ".f32\t$d, $a"), - [(set RegF32:$d, (opnode fpimm:$a))]>; + (ins RndMode:$r, f32imm:$a), + !strconcat(opcstr, "$r.f32\t$d, $a"), []>; def rr64 : InstPTX<(outs RegF64:$d), - (ins RegF64:$a), - !strconcat(opcstr, ".f64\t$d, $a"), - [(set RegF64:$d, (opnode RegF64:$a))]>; + (ins RndMode:$r, RegF64:$a), + !strconcat(opcstr, "$r.f64\t$d, $a"), []>; def ri64 : InstPTX<(outs RegF64:$d), - (ins f64imm:$a), - !strconcat(opcstr, ".f64\t$d, $a"), - [(set RegF64:$d, (opnode fpimm:$a))]>; + (ins RndMode:$r, f64imm:$a), + !strconcat(opcstr, "$r.f64\t$d, $a"), []>; } //===- Floating-Point Instructions - 3 Operand Form -----------------------===// -multiclass PTX_FLOAT_3OP { +multiclass PTX_FLOAT_3OP { def rr32 : InstPTX<(outs RegF32:$d), - (ins RegF32:$a, RegF32:$b), - !strconcat(opcstr, ".f32\t$d, $a, $b"), - [(set RegF32:$d, (opnode RegF32:$a, RegF32:$b))]>; + (ins RndMode:$r, RegF32:$a, RegF32:$b), + !strconcat(opcstr, "$r.f32\t$d, $a, $b"), []>; def ri32 : InstPTX<(outs RegF32:$d), - (ins RegF32:$a, f32imm:$b), - !strconcat(opcstr, ".f32\t$d, $a, $b"), - [(set RegF32:$d, (opnode RegF32:$a, fpimm:$b))]>; + (ins RndMode:$r, RegF32:$a, f32imm:$b), + !strconcat(opcstr, "$r.f32\t$d, $a, $b"), []>; def rr64 : InstPTX<(outs RegF64:$d), - (ins RegF64:$a, RegF64:$b), - !strconcat(opcstr, ".f64\t$d, $a, $b"), - [(set RegF64:$d, (opnode RegF64:$a, RegF64:$b))]>; + (ins RndMode:$r, RegF64:$a, RegF64:$b), + !strconcat(opcstr, "$r.f64\t$d, $a, $b"), []>; def ri64 : InstPTX<(outs RegF64:$d), - (ins RegF64:$a, f64imm:$b), - !strconcat(opcstr, ".f64\t$d, $a, $b"), - [(set RegF64:$d, (opnode RegF64:$a, fpimm:$b))]>; + (ins RndMode:$r, RegF64:$a, f64imm:$b), + !strconcat(opcstr, "$r.f64\t$d, $a, $b"), []>; } //===- Floating-Point Instructions - 4 Operand Form -----------------------===// -multiclass PTX_FLOAT_4OP { +multiclass PTX_FLOAT_4OP { def rrr32 : InstPTX<(outs RegF32:$d), - (ins RegF32:$a, RegF32:$b, RegF32:$c), - !strconcat(opcstr, ".f32\t$d, $a, $b, $c"), - [(set RegF32:$d, (opnode2 (opnode1 RegF32:$a, - RegF32:$b), - RegF32:$c))]>; + (ins RndMode:$r, RegF32:$a, RegF32:$b, RegF32:$c), + !strconcat(opcstr, "$r.f32\t$d, $a, $b, $c"), []>; def rri32 : InstPTX<(outs RegF32:$d), - (ins RegF32:$a, RegF32:$b, f32imm:$c), - !strconcat(opcstr, ".f32\t$d, $a, $b, $c"), - [(set RegF32:$d, (opnode2 (opnode1 RegF32:$a, - RegF32:$b), - fpimm:$c))]>; + (ins RndMode:$r, RegF32:$a, RegF32:$b, f32imm:$c), + !strconcat(opcstr, "$r.f32\t$d, $a, $b, $c"), []>; + def rii32 : InstPTX<(outs RegF32:$d), + (ins RndMode:$r, RegF32:$a, f32imm:$b, f32imm:$c), + !strconcat(opcstr, "$r.f32\t$d, $a, $b, $c"), []>; def rrr64 : InstPTX<(outs RegF64:$d), - (ins RegF64:$a, RegF64:$b, RegF64:$c), - !strconcat(opcstr, ".f64\t$d, $a, $b, $c"), - [(set RegF64:$d, (opnode2 (opnode1 RegF64:$a, - RegF64:$b), - RegF64:$c))]>; + (ins RndMode:$r, RegF64:$a, RegF64:$b, RegF64:$c), + !strconcat(opcstr, "$r.f64\t$d, $a, $b, $c"), []>; def rri64 : InstPTX<(outs RegF64:$d), - (ins RegF64:$a, RegF64:$b, f64imm:$c), - !strconcat(opcstr, ".f64\t$d, $a, $b, $c"), - [(set RegF64:$d, (opnode2 (opnode1 RegF64:$a, - RegF64:$b), - fpimm:$c))]>; + (ins RndMode:$r, RegF64:$a, RegF64:$b, f64imm:$c), + !strconcat(opcstr, "$r.f64\t$d, $a, $b, $c"), []>; + def rii64 : InstPTX<(outs RegF64:$d), + (ins RndMode:$r, RegF64:$a, f64imm:$b, f64imm:$c), + !strconcat(opcstr, "$r.f64\t$d, $a, $b, $c"), []>; } -multiclass INT3 { +//===- Integer Instructions - 3 Operand Form ------------------------------===// +multiclass PTX_INT3 { def rr16 : InstPTX<(outs RegI16:$d), (ins RegI16:$a, RegI16:$b), !strconcat(opcstr, ".u16\t$d, $a, $b"), @@ -175,6 +167,7 @@ multiclass INT3 { [(set RegI64:$d, (opnode RegI64:$a, imm:$b))]>; } +//===- Bitwise Logic Instructions - 3 Operand Form ------------------------===// multiclass PTX_LOGIC { def ripreds : InstPTX<(outs RegPred:$d), (ins RegPred:$a, i1imm:$b), @@ -210,7 +203,8 @@ multiclass PTX_LOGIC { [(set RegI64:$d, (opnode RegI64:$a, imm:$b))]>; } -multiclass INT3ntnc { +//===- Integer Shift Instructions - 3 Operand Form ------------------------===// +multiclass PTX_INT3ntnc { def rr16 : InstPTX<(outs RegI16:$d), (ins RegI16:$a, RegI16:$b), !strconcat(opcstr, "16\t$d, $a, $b"), @@ -249,6 +243,7 @@ multiclass INT3ntnc { [(set RegI64:$d, (opnode imm:$a, RegI64:$b))]>; } +//===- Set Predicate Instructions (Int) - 3/4 Operand Forms ---------------===// multiclass PTX_SETP_I { // TODO support 5-operand format: p|q, a, b, c @@ -333,6 +328,7 @@ multiclass PTX_SETP_I; } +//===- Set Predicate Instructions (FP) - 3/4 Operand Form -----------------===// multiclass PTX_SETP_FP { // TODO support 5-operand format: p|q, a, b, c @@ -432,6 +428,7 @@ multiclass PTX_SETP_FP; } +//===- Select Predicate Instructions - 4 Operand Form ---------------------===// multiclass PTX_SELP { def rr @@ -456,118 +453,60 @@ multiclass PTX_SELP; -defm SUB : INT3<"sub", sub>; -defm MUL : INT3<"mul.lo", mul>; // FIXME: Allow 32x32 -> 64 multiplies -defm DIV : INT3<"div", udiv>; -defm REM : INT3<"rem", urem>; +defm ADD : PTX_INT3<"add", add>; +defm SUB : PTX_INT3<"sub", sub>; +defm MUL : PTX_INT3<"mul.lo", mul>; // FIXME: Allow 32x32 -> 64 multiplies +defm DIV : PTX_INT3<"div", udiv>; +defm REM : PTX_INT3<"rem", urem>; ///===- Floating-Point Arithmetic Instructions ----------------------------===// -// Standard Unary Operations -defm FNEG : PTX_FLOAT_2OP<"neg", fneg>; +// FNEG +defm FNEG : PTX_FLOAT_2OP<"neg">; // Standard Binary Operations -defm FADD : PTX_FLOAT_3OP<"add.rn", fadd>; -defm FSUB : PTX_FLOAT_3OP<"sub.rn", fsub>; -defm FMUL : PTX_FLOAT_3OP<"mul.rn", fmul>; - -// For floating-point division: -// SM_13+ defaults to .rn for f32 and f64, -// SM10 must *not* provide a rounding - -// TODO: -// - Allow user selection of rounding modes for fdiv -// - Add support for -prec-div=false (.approx) - -def FDIVrr32SM13 : InstPTX<(outs RegF32:$d), - (ins RegF32:$a, RegF32:$b), - "div.rn.f32\t$d, $a, $b", - [(set RegF32:$d, (fdiv RegF32:$a, RegF32:$b))]>, - Requires<[FDivNeedsRoundingMode]>; -def FDIVri32SM13 : InstPTX<(outs RegF32:$d), - (ins RegF32:$a, f32imm:$b), - "div.rn.f32\t$d, $a, $b", - [(set RegF32:$d, (fdiv RegF32:$a, fpimm:$b))]>, - Requires<[FDivNeedsRoundingMode]>; -def FDIVrr32SM10 : InstPTX<(outs RegF32:$d), - (ins RegF32:$a, RegF32:$b), - "div.f32\t$d, $a, $b", - [(set RegF32:$d, (fdiv RegF32:$a, RegF32:$b))]>, - Requires<[FDivNoRoundingMode]>; -def FDIVri32SM10 : InstPTX<(outs RegF32:$d), - (ins RegF32:$a, f32imm:$b), - "div.f32\t$d, $a, $b", - [(set RegF32:$d, (fdiv RegF32:$a, fpimm:$b))]>, - Requires<[FDivNoRoundingMode]>; - -def FDIVrr64SM13 : InstPTX<(outs RegF64:$d), - (ins RegF64:$a, RegF64:$b), - "div.rn.f64\t$d, $a, $b", - [(set RegF64:$d, (fdiv RegF64:$a, RegF64:$b))]>, - Requires<[FDivNeedsRoundingMode]>; -def FDIVri64SM13 : InstPTX<(outs RegF64:$d), - (ins RegF64:$a, f64imm:$b), - "div.rn.f64\t$d, $a, $b", - [(set RegF64:$d, (fdiv RegF64:$a, fpimm:$b))]>, - Requires<[FDivNeedsRoundingMode]>; -def FDIVrr64SM10 : InstPTX<(outs RegF64:$d), - (ins RegF64:$a, RegF64:$b), - "div.f64\t$d, $a, $b", - [(set RegF64:$d, (fdiv RegF64:$a, RegF64:$b))]>, - Requires<[FDivNoRoundingMode]>; -def FDIVri64SM10 : InstPTX<(outs RegF64:$d), - (ins RegF64:$a, f64imm:$b), - "div.f64\t$d, $a, $b", - [(set RegF64:$d, (fdiv RegF64:$a, fpimm:$b))]>, - Requires<[FDivNoRoundingMode]>; - - +defm FADD : PTX_FLOAT_3OP<"add">; +defm FSUB : PTX_FLOAT_3OP<"sub">; +defm FMUL : PTX_FLOAT_3OP<"mul">; +defm FDIV : PTX_FLOAT_3OP<"div">; // Multi-operation hybrid instructions +defm FMAD : PTX_FLOAT_4OP<"mad">, Requires<[SupportsFMA]>; -// The selection of mad/fma is tricky. In some cases, they are the *same* -// instruction, but in other cases we may prefer one or the other. Also, -// different PTX versions differ on whether rounding mode flags are required. -// In the short term, mad is supported on all PTX versions and we use a -// default rounding mode no matter what shader model or PTX version. -// TODO: Allow the rounding mode to be selectable through llc. -defm FMADSM13 : PTX_FLOAT_4OP<"mad.rn", fmul, fadd>, - Requires<[FMadNeedsRoundingMode, SupportsFMA]>; -defm FMAD : PTX_FLOAT_4OP<"mad", fmul, fadd>, - Requires<[FMadNoRoundingMode, SupportsFMA]>; ///===- Floating-Point Intrinsic Instructions -----------------------------===// -def FSQRT32 : InstPTX<(outs RegF32:$d), - (ins RegF32:$a), - "sqrt.rn.f32\t$d, $a", - [(set RegF32:$d, (fsqrt RegF32:$a))]>; - -def FSQRT64 : InstPTX<(outs RegF64:$d), - (ins RegF64:$a), - "sqrt.rn.f64\t$d, $a", - [(set RegF64:$d, (fsqrt RegF64:$a))]>; +// SQRT +def FSQRTrr32 : InstPTX<(outs RegF32:$d), (ins RndMode:$r, RegF32:$a), + "sqrt$r.f32\t$d, $a", []>; +def FSQRTri32 : InstPTX<(outs RegF32:$d), (ins RndMode:$r, f32imm:$a), + "sqrt$r.f32\t$d, $a", []>; +def FSQRTrr64 : InstPTX<(outs RegF64:$d), (ins RndMode:$r, RegF64:$a), + "sqrt$r.f64\t$d, $a", []>; +def FSQRTri64 : InstPTX<(outs RegF64:$d), (ins RndMode:$r, f64imm:$a), + "sqrt$r.f64\t$d, $a", []>; + +// SIN +def FSINrr32 : InstPTX<(outs RegF32:$d), (ins RndMode:$r, RegF32:$a), + "sin$r.f32\t$d, $a", []>; +def FSINri32 : InstPTX<(outs RegF32:$d), (ins RndMode:$r, f32imm:$a), + "sin$r.f32\t$d, $a", []>; +def FSINrr64 : InstPTX<(outs RegF64:$d), (ins RndMode:$r, RegF64:$a), + "sin$r.f64\t$d, $a", []>; +def FSINri64 : InstPTX<(outs RegF64:$d), (ins RndMode:$r, f64imm:$a), + "sin$r.f64\t$d, $a", []>; + +// COS +def FCOSrr32 : InstPTX<(outs RegF32:$d), (ins RndMode:$r, RegF32:$a), + "cos$r.f32\t$d, $a", []>; +def FCOSri32 : InstPTX<(outs RegF32:$d), (ins RndMode:$r, f32imm:$a), + "cos$r.f32\t$d, $a", []>; +def FCOSrr64 : InstPTX<(outs RegF64:$d), (ins RndMode:$r, RegF64:$a), + "cos$r.f64\t$d, $a", []>; +def FCOSri64 : InstPTX<(outs RegF64:$d), (ins RndMode:$r, f64imm:$a), + "cos$r.f64\t$d, $a", []>; -def FSIN32 : InstPTX<(outs RegF32:$d), - (ins RegF32:$a), - "sin.approx.f32\t$d, $a", - [(set RegF32:$d, (fsin RegF32:$a))]>; -def FSIN64 : InstPTX<(outs RegF64:$d), - (ins RegF64:$a), - "sin.approx.f64\t$d, $a", - [(set RegF64:$d, (fsin RegF64:$a))]>; - -def FCOS32 : InstPTX<(outs RegF32:$d), - (ins RegF32:$a), - "cos.approx.f32\t$d, $a", - [(set RegF32:$d, (fcos RegF32:$a))]>; - -def FCOS64 : InstPTX<(outs RegF64:$d), - (ins RegF64:$a), - "cos.approx.f64\t$d, $a", - [(set RegF64:$d, (fcos RegF64:$a))]>; ///===- Comparison and Selection Instructions -----------------------------===// @@ -641,9 +580,9 @@ defm SELPf64 : PTX_SELP; ///===- Logic and Shift Instructions --------------------------------------===// -defm SHL : INT3ntnc<"shl.b", PTXshl>; -defm SRL : INT3ntnc<"shr.u", PTXsrl>; -defm SRA : INT3ntnc<"shr.s", PTXsra>; +defm SHL : PTX_INT3ntnc<"shl.b", PTXshl>; +defm SRL : PTX_INT3ntnc<"shr.u", PTXsrl>; +defm SRA : PTX_INT3ntnc<"shr.s", PTXsra>; defm AND : PTX_LOGIC<"and", and>; defm OR : PTX_LOGIC<"or", or>; @@ -798,6 +737,136 @@ def CVTf64s64 def CVTf64f32 : InstPTX<(outs RegF64:$d), (ins RegF32:$a), "cvt.f64.f32\t$d, $a", []>; + ///===- Control Flow Instructions -----------------------------------------===// + +let isBranch = 1, isTerminator = 1, isBarrier = 1 in { + def BRAd + : InstPTX<(outs), (ins brtarget:$d), "bra\t$d", [(br bb:$d)]>; +} + +let isBranch = 1, isTerminator = 1 in { + // FIXME: The pattern part is blank because I cannot (or do not yet know + // how to) use the first operand of PredicateOperand (a RegPred register) here + def BRAdp + : InstPTX<(outs), (ins brtarget:$d), "bra\t$d", + [/*(brcond pred:$_p, bb:$d)*/]>; +} + +let isReturn = 1, isTerminator = 1, isBarrier = 1 in { + def EXIT : InstPTX<(outs), (ins), "exit", [(PTXexit)]>; + def RET : InstPTX<(outs), (ins), "ret", [(PTXret)]>; +} + +let hasSideEffects = 1 in { + def CALL : InstPTX<(outs), (ins), "call", [(PTXcall)]>; +} + +///===- Parameter Passing Pseudo-Instructions -----------------------------===// + +def READPARAMPRED : InstPTX<(outs RegPred:$a), (ins i32imm:$b), + "mov.pred\t$a, %param$b", []>; +def READPARAMI16 : InstPTX<(outs RegI16:$a), (ins i32imm:$b), + "mov.b16\t$a, %param$b", []>; +def READPARAMI32 : InstPTX<(outs RegI32:$a), (ins i32imm:$b), + "mov.b32\t$a, %param$b", []>; +def READPARAMI64 : InstPTX<(outs RegI64:$a), (ins i32imm:$b), + "mov.b64\t$a, %param$b", []>; +def READPARAMF32 : InstPTX<(outs RegF32:$a), (ins i32imm:$b), + "mov.f32\t$a, %param$b", []>; +def READPARAMF64 : InstPTX<(outs RegF64:$a), (ins i32imm:$b), + "mov.f64\t$a, %param$b", []>; + +def WRITEPARAMPRED : InstPTX<(outs), (ins RegPred:$a), "//w", []>; +def WRITEPARAMI16 : InstPTX<(outs), (ins RegI16:$a), "//w", []>; +def WRITEPARAMI32 : InstPTX<(outs), (ins RegI32:$a), "//w", []>; +def WRITEPARAMI64 : InstPTX<(outs), (ins RegI64:$a), "//w", []>; +def WRITEPARAMF32 : InstPTX<(outs), (ins RegF32:$a), "//w", []>; +def WRITEPARAMF64 : InstPTX<(outs), (ins RegF64:$a), "//w", []>; + + +//===----------------------------------------------------------------------===// +// Instruction Selection Patterns +//===----------------------------------------------------------------------===// + +// FADD +def : Pat<(f32 (fadd RegF32:$a, RegF32:$b)), + (FADDrr32 RndDefault, RegF32:$a, RegF32:$b)>; +def : Pat<(f32 (fadd RegF32:$a, fpimm:$b)), + (FADDri32 RndDefault, RegF32:$a, fpimm:$b)>; +def : Pat<(f64 (fadd RegF64:$a, RegF64:$b)), + (FADDrr64 RndDefault, RegF64:$a, RegF64:$b)>; +def : Pat<(f64 (fadd RegF64:$a, fpimm:$b)), + (FADDri64 RndDefault, RegF64:$a, fpimm:$b)>; + +// FSUB +def : Pat<(f32 (fsub RegF32:$a, RegF32:$b)), + (FSUBrr32 RndDefault, RegF32:$a, RegF32:$b)>; +def : Pat<(f32 (fsub RegF32:$a, fpimm:$b)), + (FSUBri32 RndDefault, RegF32:$a, fpimm:$b)>; +def : Pat<(f64 (fsub RegF64:$a, RegF64:$b)), + (FSUBrr64 RndDefault, RegF64:$a, RegF64:$b)>; +def : Pat<(f64 (fsub RegF64:$a, fpimm:$b)), + (FSUBri64 RndDefault, RegF64:$a, fpimm:$b)>; + +// FMUL +def : Pat<(f32 (fmul RegF32:$a, RegF32:$b)), + (FMULrr32 RndDefault, RegF32:$a, RegF32:$b)>; +def : Pat<(f32 (fmul RegF32:$a, fpimm:$b)), + (FMULri32 RndDefault, RegF32:$a, fpimm:$b)>; +def : Pat<(f64 (fmul RegF64:$a, RegF64:$b)), + (FMULrr64 RndDefault, RegF64:$a, RegF64:$b)>; +def : Pat<(f64 (fmul RegF64:$a, fpimm:$b)), + (FMULri64 RndDefault, RegF64:$a, fpimm:$b)>; + +// FDIV +def : Pat<(f32 (fdiv RegF32:$a, RegF32:$b)), + (FDIVrr32 RndDefault, RegF32:$a, RegF32:$b)>; +def : Pat<(f32 (fdiv RegF32:$a, fpimm:$b)), + (FDIVri32 RndDefault, RegF32:$a, fpimm:$b)>; +def : Pat<(f64 (fdiv RegF64:$a, RegF64:$b)), + (FDIVrr64 RndDefault, RegF64:$a, RegF64:$b)>; +def : Pat<(f64 (fdiv RegF64:$a, fpimm:$b)), + (FDIVri64 RndDefault, RegF64:$a, fpimm:$b)>; + +// FMUL+FADD +def : Pat<(f32 (fadd (fmul RegF32:$a, RegF32:$b), RegF32:$c)), + (FMADrrr32 RndDefault, RegF32:$a, RegF32:$b, RegF32:$c)>; +def : Pat<(f32 (fadd (fmul RegF32:$a, RegF32:$b), fpimm:$c)), + (FMADrri32 RndDefault, RegF32:$a, RegF32:$b, fpimm:$c)>; +def : Pat<(f32 (fadd (fmul RegF32:$a, fpimm:$b), fpimm:$c)), + (FMADrrr32 RndDefault, RegF32:$a, fpimm:$b, fpimm:$c)>; +def : Pat<(f32 (fadd (fmul RegF32:$a, RegF32:$b), fpimm:$c)), + (FMADrri32 RndDefault, RegF32:$a, RegF32:$b, fpimm:$c)>; +def : Pat<(f64 (fadd (fmul RegF64:$a, RegF64:$b), RegF64:$c)), + (FMADrrr64 RndDefault, RegF64:$a, RegF64:$b, RegF64:$c)>; +def : Pat<(f64 (fadd (fmul RegF64:$a, RegF64:$b), fpimm:$c)), + (FMADrri64 RndDefault, RegF64:$a, RegF64:$b, fpimm:$c)>; +def : Pat<(f64 (fadd (fmul RegF64:$a, fpimm:$b), fpimm:$c)), + (FMADrri64 RndDefault, RegF64:$a, fpimm:$b, fpimm:$c)>; + +// FNEG +def : Pat<(f32 (fneg RegF32:$a)), (FNEGrr32 RndDefault, RegF32:$a)>; +def : Pat<(f32 (fneg fpimm:$a)), (FNEGri32 RndDefault, fpimm:$a)>; +def : Pat<(f64 (fneg RegF64:$a)), (FNEGrr64 RndDefault, RegF64:$a)>; +def : Pat<(f64 (fneg fpimm:$a)), (FNEGri64 RndDefault, fpimm:$a)>; + +// FSQRT +def : Pat<(f32 (fsqrt RegF32:$a)), (FSQRTrr32 RndDefault, RegF32:$a)>; +def : Pat<(f32 (fsqrt fpimm:$a)), (FSQRTri32 RndDefault, fpimm:$a)>; +def : Pat<(f64 (fsqrt RegF64:$a)), (FSQRTrr64 RndDefault, RegF64:$a)>; +def : Pat<(f64 (fsqrt fpimm:$a)), (FSQRTri64 RndDefault, fpimm:$a)>; + +// FSIN +def : Pat<(f32 (fsin RegF32:$a)), (FSINrr32 RndDefault, RegF32:$a)>; +def : Pat<(f32 (fsin fpimm:$a)), (FSINri32 RndDefault, fpimm:$a)>; +def : Pat<(f64 (fsin RegF64:$a)), (FSINrr64 RndDefault, RegF64:$a)>; +def : Pat<(f64 (fsin fpimm:$a)), (FSINri64 RndDefault, fpimm:$a)>; + +// FCOS +def : Pat<(f32 (fcos RegF32:$a)), (FCOSrr32 RndDefault, RegF32:$a)>; +def : Pat<(f32 (fcos fpimm:$a)), (FCOSri32 RndDefault, fpimm:$a)>; +def : Pat<(f64 (fcos RegF64:$a)), (FCOSrr64 RndDefault, RegF64:$a)>; +def : Pat<(f64 (fcos fpimm:$a)), (FCOSri64 RndDefault, fpimm:$a)>; // Type conversion notes: // - PTX does not directly support converting a predicate to a value, so we @@ -881,52 +950,6 @@ def : Pat<(f64 (fextend RegF32:$a)), (CVTf64f32 RegF32:$a)>; def : Pat<(f64 (bitconvert RegI64:$a)), (MOVf64i64 RegI64:$a)>; -///===- Control Flow Instructions -----------------------------------------===// - -let isBranch = 1, isTerminator = 1, isBarrier = 1 in { - def BRAd - : InstPTX<(outs), (ins brtarget:$d), "bra\t$d", [(br bb:$d)]>; -} - -let isBranch = 1, isTerminator = 1 in { - // FIXME: The pattern part is blank because I cannot (or do not yet know - // how to) use the first operand of PredicateOperand (a RegPred register) here - def BRAdp - : InstPTX<(outs), (ins brtarget:$d), "bra\t$d", - [/*(brcond pred:$_p, bb:$d)*/]>; -} - -let isReturn = 1, isTerminator = 1, isBarrier = 1 in { - def EXIT : InstPTX<(outs), (ins), "exit", [(PTXexit)]>; - def RET : InstPTX<(outs), (ins), "ret", [(PTXret)]>; -} - -let hasSideEffects = 1 in { - def CALL : InstPTX<(outs), (ins), "call", [(PTXcall)]>; -} - -///===- Parameter Passing Pseudo-Instructions -----------------------------===// - -def READPARAMPRED : InstPTX<(outs RegPred:$a), (ins i32imm:$b), - "mov.pred\t$a, %param$b", []>; -def READPARAMI16 : InstPTX<(outs RegI16:$a), (ins i32imm:$b), - "mov.b16\t$a, %param$b", []>; -def READPARAMI32 : InstPTX<(outs RegI32:$a), (ins i32imm:$b), - "mov.b32\t$a, %param$b", []>; -def READPARAMI64 : InstPTX<(outs RegI64:$a), (ins i32imm:$b), - "mov.b64\t$a, %param$b", []>; -def READPARAMF32 : InstPTX<(outs RegF32:$a), (ins i32imm:$b), - "mov.f32\t$a, %param$b", []>; -def READPARAMF64 : InstPTX<(outs RegF64:$a), (ins i32imm:$b), - "mov.f64\t$a, %param$b", []>; - -def WRITEPARAMPRED : InstPTX<(outs), (ins RegPred:$a), "//w", []>; -def WRITEPARAMI16 : InstPTX<(outs), (ins RegI16:$a), "//w", []>; -def WRITEPARAMI32 : InstPTX<(outs), (ins RegI32:$a), "//w", []>; -def WRITEPARAMI64 : InstPTX<(outs), (ins RegI64:$a), "//w", []>; -def WRITEPARAMF32 : InstPTX<(outs), (ins RegF32:$a), "//w", []>; -def WRITEPARAMF64 : InstPTX<(outs), (ins RegF64:$a), "//w", []>; - ///===- Intrinsic Instructions --------------------------------------------===// include "PTXIntrinsicInstrInfo.td" diff --git a/lib/Target/PTX/PTXTargetMachine.cpp b/lib/Target/PTX/PTXTargetMachine.cpp index 1f3f1721a65..449a3d9fc8d 100644 --- a/lib/Target/PTX/PTXTargetMachine.cpp +++ b/lib/Target/PTX/PTXTargetMachine.cpp @@ -367,6 +367,7 @@ bool PTXTargetMachine::addCommonCodeGenPasses(PassManagerBase &PM, printNoVerify(PM, "After PreEmit passes"); PM.add(createPTXMFInfoExtract(*this, OptLevel)); + PM.add(createPTXFPRoundingModePass(*this, OptLevel)); return false; }