Replace vfmaddxx213 instructions with their 231-type equivalents in accumulator
authorLang Hames <lhames@gmail.com>
Thu, 23 Jan 2014 20:23:36 +0000 (20:23 +0000)
committerLang Hames <lhames@gmail.com>
Thu, 23 Jan 2014 20:23:36 +0000 (20:23 +0000)
loops. Writing back to the accumulator (231-type) allows the coalescer to
eliminate an extra copy.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@199933 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp
lib/Target/X86/X86ISelLowering.h
lib/Target/X86/X86InstrFMA.td
test/CodeGen/X86/fma.ll

index ede17c1750a35597b4c5e80f831df0b26b2e988e..3e641cdba9145d6d18e662d0aa9c9c41123905f8 100644 (file)
@@ -15963,6 +15963,81 @@ X86TargetLowering::emitEHSjLjLongJmp(MachineInstr *MI,
   return MBB;
 }
 
+// Replace 213-type (isel default) FMA3 instructions with 231-type for
+// accumulator loops. Writing back to the accumulator allows the coalescer
+// to remove extra copies in the loop.   
+MachineBasicBlock *
+X86TargetLowering::emitFMA3Instr(MachineInstr *MI,
+                                 MachineBasicBlock *MBB) const {
+  MachineOperand &AddendOp = MI->getOperand(3);
+
+  // Bail out early if the addend isn't a register - we can't switch these.
+  if (!AddendOp.isReg())
+    return MBB;
+
+  MachineFunction &MF = *MBB->getParent();
+  MachineRegisterInfo &MRI = MF.getRegInfo();
+
+  // Check whether the addend is defined by a PHI:
+  assert(MRI.hasOneDef(AddendOp.getReg()) && "Multiple defs in SSA?");
+  MachineInstr &AddendDef = *MRI.def_begin(AddendOp.getReg());
+  if (!AddendDef.isPHI())
+    return MBB;
+
+  // Look for the following pattern:
+  // loop:
+  //   %addend = phi [%entry, 0], [%loop, %result]
+  //   ...
+  //   %result<tied1> = FMA213 %m2<tied0>, %m1, %addend
+
+  // Replace with:
+  //   loop:
+  //   %addend = phi [%entry, 0], [%loop, %result]
+  //   ...
+  //   %result<tied1> = FMA231 %addend<tied0>, %m1, %m2
+
+  for (unsigned i = 1, e = AddendDef.getNumOperands(); i < e; i += 2) {
+    assert(AddendDef.getOperand(i).isReg());
+    MachineOperand PHISrcOp = AddendDef.getOperand(i);
+    MachineInstr &PHISrcInst = *MRI.def_begin(PHISrcOp.getReg());
+    if (&PHISrcInst == MI) {
+      // Found a matching instruction.
+      unsigned NewFMAOpc = 0;
+      switch (MI->getOpcode()) {
+        case X86::VFMADDPDr213r: NewFMAOpc = X86::VFMADDPDr231r; break;
+        case X86::VFMADDPSr213r: NewFMAOpc = X86::VFMADDPSr231r; break;
+        case X86::VFMADDSDr213r: NewFMAOpc = X86::VFMADDSDr231r; break;
+        case X86::VFMADDSSr213r: NewFMAOpc = X86::VFMADDSSr231r; break;
+        case X86::VFMSUBPDr213r: NewFMAOpc = X86::VFMSUBPDr231r; break;
+        case X86::VFMSUBPSr213r: NewFMAOpc = X86::VFMSUBPSr231r; break;
+        case X86::VFMSUBSDr213r: NewFMAOpc = X86::VFMSUBSDr231r; break;
+        case X86::VFMSUBSSr213r: NewFMAOpc = X86::VFMSUBSSr231r; break;
+        case X86::VFNMADDPDr213r: NewFMAOpc = X86::VFNMADDPDr231r; break;
+        case X86::VFNMADDPSr213r: NewFMAOpc = X86::VFNMADDPSr231r; break;
+        case X86::VFNMADDSDr213r: NewFMAOpc = X86::VFNMADDSDr231r; break;
+        case X86::VFNMADDSSr213r: NewFMAOpc = X86::VFNMADDSSr231r; break;
+        case X86::VFNMSUBPDr213r: NewFMAOpc = X86::VFNMSUBPDr231r; break;
+        case X86::VFNMSUBPSr213r: NewFMAOpc = X86::VFNMSUBPSr231r; break;
+        case X86::VFNMSUBSDr213r: NewFMAOpc = X86::VFNMSUBSDr231r; break;
+        case X86::VFNMSUBSSr213r: NewFMAOpc = X86::VFNMSUBSSr231r; break;
+        default: llvm_unreachable("Unrecognized FMA variant.");
+      }
+
+      const TargetInstrInfo &TII = *MF.getTarget().getInstrInfo();
+      MachineInstrBuilder MIB =
+        BuildMI(MF, MI->getDebugLoc(), TII.get(NewFMAOpc))
+        .addOperand(MI->getOperand(0))
+        .addOperand(MI->getOperand(3))
+        .addOperand(MI->getOperand(2))
+        .addOperand(MI->getOperand(1));
+      MBB->insert(MachineBasicBlock::iterator(MI), MIB);
+      MI->eraseFromParent();
+    }
+  }
+
+  return MBB;
+}
+
 MachineBasicBlock *
 X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr *MI,
                                                MachineBasicBlock *BB) const {
@@ -16194,6 +16269,32 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr *MI,
   case TargetOpcode::STACKMAP:
   case TargetOpcode::PATCHPOINT:
     return emitPatchPoint(MI, BB);
+
+  case X86::VFMADDPDr213r:
+  case X86::VFMADDPSr213r:
+  case X86::VFMADDSDr213r:
+  case X86::VFMADDSSr213r:
+  case X86::VFMSUBPDr213r:
+  case X86::VFMSUBPSr213r:
+  case X86::VFMSUBSDr213r:
+  case X86::VFMSUBSSr213r:
+  case X86::VFNMADDPDr213r:
+  case X86::VFNMADDPSr213r:
+  case X86::VFNMADDSDr213r:
+  case X86::VFNMADDSSr213r:
+  case X86::VFNMSUBPDr213r:
+  case X86::VFNMSUBPSr213r:
+  case X86::VFNMSUBSDr213r:
+  case X86::VFNMSUBSSr213r:
+  case X86::VFMADDPDr213rY:
+  case X86::VFMADDPSr213rY:
+  case X86::VFMSUBPDr213rY:
+  case X86::VFMSUBPSr213rY:
+  case X86::VFNMADDPDr213rY:
+  case X86::VFNMADDPSr213rY:
+  case X86::VFNMSUBPDr213rY:
+  case X86::VFNMSUBPSr213rY:
+    return emitFMA3Instr(MI, BB);
   }
 }
 
index 9b32d121010835888c8750849a11134e8abbf68d..d985c98875cd563a8bf68e74dafbf870e3709bef 100644 (file)
@@ -972,6 +972,9 @@ namespace llvm {
     MachineBasicBlock *emitEHSjLjLongJmp(MachineInstr *MI,
                                          MachineBasicBlock *MBB) const;
 
+    MachineBasicBlock *emitFMA3Instr(MachineInstr *MI,
+                                     MachineBasicBlock *MBB) const;
+
     /// Emit nodes that will be selected as "test Op0,Op0", or something
     /// equivalent, for use with the given x86 condition code.
     SDValue EmitTest(SDValue Op0, unsigned X86CC, SelectionDAG &DAG) const;
index b2cc8209bf966cf2043eef3c3ebb5e137aeedfc1..206f7b600f3c6474fc09cbc03b061acfaf34c695 100644 (file)
@@ -20,7 +20,7 @@ multiclass fma3p_rm<bits<8> opc, string OpcodeStr,
                     PatFrag MemFrag128, PatFrag MemFrag256,
                     ValueType OpVT128, ValueType OpVT256,
                     SDPatternOperator Op = null_frag> {
-  let isCommutable = 1 in
+  let isCommutable = 1, usesCustomInserter = 1 in
   def r     : FMA3<opc, MRMSrcReg, (outs VR128:$dst),
                    (ins VR128:$src1, VR128:$src2, VR128:$src3),
                    !strconcat(OpcodeStr,
@@ -36,7 +36,7 @@ multiclass fma3p_rm<bits<8> opc, string OpcodeStr,
                    [(set VR128:$dst, (OpVT128 (Op VR128:$src2, VR128:$src1,
                                                (MemFrag128 addr:$src3))))]>;
 
-  let isCommutable = 1 in
+  let isCommutable = 1, usesCustomInserter = 1 in
   def rY    : FMA3<opc, MRMSrcReg, (outs VR256:$dst),
                    (ins VR256:$src1, VR256:$src2, VR256:$src3),
                    !strconcat(OpcodeStr,
@@ -118,7 +118,7 @@ let Constraints = "$src1 = $dst" in {
 multiclass fma3s_rm<bits<8> opc, string OpcodeStr, X86MemOperand x86memop,
                     RegisterClass RC, ValueType OpVT, PatFrag mem_frag,
                     SDPatternOperator OpNode = null_frag> {
-  let isCommutable = 1 in
+  let isCommutable = 1, usesCustomInserter = 1 in
   def r     : FMA3<opc, MRMSrcReg, (outs RC:$dst),
                    (ins RC:$src1, RC:$src2, RC:$src3),
                    !strconcat(OpcodeStr,
index 917eac0ca32dec91fe4ae5ef5517021bc061db55..0cdf3cdf574c536953390b407bf464d7efe16c75 100644 (file)
@@ -42,6 +42,21 @@ entry:
   ret float %call
 }
 
+; Test FMA3 variant selection
+; CHECK: fma3_select231:
+; CHECK: vfmadd231ss
+define float @fma3_select231(float %x, float %y, i32 %N) #0 {
+entry:
+  br label %while.body
+while.body:                                       ; preds = %while.body, %while.body
+  %acc.01 = phi float [ 0.000000e+00, %entry ], [ %acc, %while.body ]
+  %acc = tail call float @llvm.fma.f32(float %x, float %y, float %acc.01) nounwind readnone
+  %b = fcmp ueq float %acc, 0.0
+  br i1 %b, label %while.body, label %while.end
+while.end:                                        ; preds = %while.body, %entry
+  ret float %acc
+}
+
 declare float @llvm.fma.f32(float, float, float) nounwind readnone
 declare double @llvm.fma.f64(double, double, double) nounwind readnone
 declare x86_fp80 @llvm.fma.f80(x86_fp80, x86_fp80, x86_fp80) nounwind readnone