R600: Fold clamp, neg, abs
authorTom Stellard <thomas.stellard@amd.com>
Thu, 31 Jan 2013 22:11:54 +0000 (22:11 +0000)
committerTom Stellard <thomas.stellard@amd.com>
Thu, 31 Jan 2013 22:11:54 +0000 (22:11 +0000)
Patch by: Vincent Lejeune

Reviewed-by: Tom Stellard <thomas.stellard@amd.com>
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@174099 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/R600/AMDILISelDAGToDAG.cpp
test/CodeGen/R600/fsub.ll

index ece26efceedf105b2ac2daa1386c935f0e666a84..84223f62e17c52fa39a7a5575df868a02a53da80 100644 (file)
@@ -272,7 +272,11 @@ SDNode *AMDGPUDAGToDAGISel::Select(SDNode *N) {
   if (ST.device()->getGeneration() <= AMDGPUDeviceInfo::HD6XXX) {
     const R600InstrInfo *TII =
         static_cast<const R600InstrInfo*>(TM.getInstrInfo());
-    if (Result && TII->isALUInstr(Result->getMachineOpcode())) {
+    if (Result && Result->isMachineOpcode()
+        && TII->isALUInstr(Result->getMachineOpcode())) {
+      // Fold FNEG/FABS/CONST_ADDRESS
+      // TODO: Isel can generate multiple MachineInst, we need to recursively
+      // parse Result
       bool IsModified = false;
       do {
         std::vector<SDValue> Ops;
@@ -281,10 +285,28 @@ SDNode *AMDGPUDAGToDAGISel::Select(SDNode *N) {
           Ops.push_back(*I);
         IsModified = FoldOperands(Result->getMachineOpcode(), TII, Ops);
         if (IsModified) {
-          Result = CurDAG->MorphNodeTo(Result, Result->getOpcode(),
-              Result->getVTList(), Ops.data(), Ops.size());
+          Result = CurDAG->UpdateNodeOperands(Result, Ops.data(), Ops.size());
         }
       } while (IsModified);
+
+      // If node has a single use which is CLAMP_R600, folds it
+      if (Result->hasOneUse() && Result->isMachineOpcode()) {
+        SDNode *PotentialClamp = *Result->use_begin();
+        if (PotentialClamp->isMachineOpcode() &&
+            PotentialClamp->getMachineOpcode() == AMDGPU::CLAMP_R600) {
+          unsigned ClampIdx =
+            TII->getOperandIdx(Result->getMachineOpcode(), R600Operands::CLAMP);
+          std::vector<SDValue> Ops;
+          unsigned NumOp = Result->getNumOperands();
+          for (unsigned i = 0; i < NumOp; ++i) {
+            Ops.push_back(Result->getOperand(i));
+          }
+          Ops[ClampIdx - 1] = CurDAG->getTargetConstant(1, MVT::i32);
+          Result = CurDAG->SelectNodeTo(PotentialClamp,
+              Result->getMachineOpcode(), PotentialClamp->getVTList(),
+              Ops.data(), NumOp);
+        }
+      }
     }
   }
 
@@ -303,6 +325,17 @@ bool AMDGPUDAGToDAGISel::FoldOperands(unsigned Opcode,
     TII->getOperandIdx(Opcode, R600Operands::SRC1_SEL),
     TII->getOperandIdx(Opcode, R600Operands::SRC2_SEL)
   };
+  int NegIdx[] = {
+    TII->getOperandIdx(Opcode, R600Operands::SRC0_NEG),
+    TII->getOperandIdx(Opcode, R600Operands::SRC1_NEG),
+    TII->getOperandIdx(Opcode, R600Operands::SRC2_NEG)
+  };
+  int AbsIdx[] = {
+    TII->getOperandIdx(Opcode, R600Operands::SRC0_ABS),
+    TII->getOperandIdx(Opcode, R600Operands::SRC1_ABS),
+    -1
+  };
+
   for (unsigned i = 0; i < 3; i++) {
     if (OperandIdx[i] < 0)
       return false;
@@ -318,6 +351,18 @@ bool AMDGPUDAGToDAGISel::FoldOperands(unsigned Opcode,
       }
       }
       break;
+    case ISD::FNEG:
+      if (NegIdx[i] < 0)
+        break;
+      Ops[OperandIdx[i] - 1] = Operand.getOperand(0);
+      Ops[NegIdx[i] - 1] = CurDAG->getTargetConstant(1, MVT::i32);
+      return true;
+    case ISD::FABS:
+      if (AbsIdx[i] < 0)
+        break;
+      Ops[OperandIdx[i] - 1] = Operand.getOperand(0);
+      Ops[AbsIdx[i] - 1] = CurDAG->getTargetConstant(1, MVT::i32);
+      return true;
     case ISD::BITCAST:
       Ops[OperandIdx[i] - 1] = Operand.getOperand(0);
       return true;
index 0ec1c376dfa38d312760e40b1dfade25a65555ec..591aa52676a4162c7bbced045653b70b7c0e7126 100644 (file)
@@ -1,7 +1,6 @@
 ;RUN: llc < %s -march=r600 -mcpu=redwood | FileCheck %s
 
-; CHECK: MOV T{{[0-9]+\.[XYZW], -T[0-9]+\.[XYZW]}}
-; CHECK: ADD T{{[0-9]+\.[XYZW], T[0-9]+\.[XYZW], T[0-9]+\.[XYZW]}}
+; CHECK: ADD T{{[0-9]+\.[XYZW], T[0-9]+\.[XYZW], -T[0-9]+\.[XYZW]}}
 
 define void @test() {
    %r0 = call float @llvm.R600.load.input(i32 0)