Prevent hoisting fmul from THEN/ELSE to IF if there is fmsub/fmadd opportunity.
authorChad Rosier <mcrosier@codeaurora.org>
Mon, 23 Feb 2015 19:15:16 +0000 (19:15 +0000)
committerChad Rosier <mcrosier@codeaurora.org>
Mon, 23 Feb 2015 19:15:16 +0000 (19:15 +0000)
This patch adds the isProfitableToHoist API.  For AArch64, we want to prevent a
fmul from being hoisted in cases where it is more profitable to form a
fmsub/fmadd.

Phabricator Review: http://reviews.llvm.org/D7299
Patch by Lawrence Hu <lawrence@codeaurora.org>

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@230241 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/TargetTransformInfo.h
include/llvm/Analysis/TargetTransformInfoImpl.h
include/llvm/CodeGen/BasicTTIImpl.h
include/llvm/Target/TargetLowering.h
lib/Analysis/TargetTransformInfo.cpp
lib/Target/AArch64/AArch64ISelLowering.cpp
lib/Target/AArch64/AArch64ISelLowering.h
lib/Transforms/Utils/SimplifyCFG.cpp
test/Transforms/SimplifyCFG/AArch64/lit.local.cfg [new file with mode: 0644]
test/Transforms/SimplifyCFG/AArch64/prefer-fma.ll [new file with mode: 0644]

index 26ceac189a130d90fa4e602479f8c2736ccf61d1..49981416604f12d410445d08aff3ec76c16ea160 100644 (file)
@@ -313,6 +313,10 @@ public:
   /// by referencing its sub-register AX.
   bool isTruncateFree(Type *Ty1, Type *Ty2) const;
 
+  /// \brief Return true if it is profitable to hoist instruction in the
+  /// then/else to before if.
+  bool isProfitableToHoist(Instruction *I) const;
+
   /// \brief Return true if this type is legal.
   bool isTypeLegal(Type *Ty) const;
 
@@ -521,6 +525,7 @@ public:
                                    int64_t BaseOffset, bool HasBaseReg,
                                    int64_t Scale) = 0;
   virtual bool isTruncateFree(Type *Ty1, Type *Ty2) = 0;
+  virtual bool isProfitableToHoist(Instruction *I) = 0;
   virtual bool isTypeLegal(Type *Ty) = 0;
   virtual unsigned getJumpBufAlignment() = 0;
   virtual unsigned getJumpBufSize() = 0;
@@ -633,6 +638,9 @@ public:
   bool isTruncateFree(Type *Ty1, Type *Ty2) override {
     return Impl.isTruncateFree(Ty1, Ty2);
   }
+  bool isProfitableToHoist(Instruction *I) override {
+    return Impl.isProfitableToHoist(I);
+  }
   bool isTypeLegal(Type *Ty) override { return Impl.isTypeLegal(Ty); }
   unsigned getJumpBufAlignment() override { return Impl.getJumpBufAlignment(); }
   unsigned getJumpBufSize() override { return Impl.getJumpBufSize(); }
index 0254880b4e6fc071c7f0fbf3df105e92750846ac..3e02c0ce3ca92a26e24ca2050626ad97ba8f6f84 100644 (file)
@@ -225,6 +225,8 @@ public:
 
   bool isTruncateFree(Type *Ty1, Type *Ty2) { return false; }
 
+  bool isProfitableToHoist(Instruction *I) { return true; }
+
   bool isTypeLegal(Type *Ty) { return false; }
 
   unsigned getJumpBufAlignment() { return 0; }
index 25a74b331dea74d2b346905e4fc26e6a810b5c70..ff85b064bc969b4d78090eddb9f70f4da6d8d087 100644 (file)
@@ -145,6 +145,10 @@ public:
     return getTLI()->isTruncateFree(Ty1, Ty2);
   }
 
+  bool isProfitableToHoist(Instruction *I) {
+    return getTLI()->isProfitableToHoist(I);
+  }
+
   bool isTypeLegal(Type *Ty) {
     EVT VT = getTLI()->getValueType(Ty);
     return getTLI()->isTypeLegal(VT);
index d320bf1c30a5341ad7274357886fcfe2d4011b63..cd499ba5cb0f9855042b462c914df7ee94cfee72 100644 (file)
@@ -1456,6 +1456,8 @@ public:
     return false;
   }
 
+  virtual bool isProfitableToHoist(Instruction *I) const { return true; }
+
   /// Return true if any actual instruction that defines a value of type Ty1
   /// implicitly zero-extends the value to Ty2 in the result register.
   ///
index b5440e2a2c3711d56a91c987c0175ba2ed742360..7ff29b028aec5ed85410a41058f3b2ad09e509e2 100644 (file)
@@ -123,6 +123,10 @@ bool TargetTransformInfo::isTruncateFree(Type *Ty1, Type *Ty2) const {
   return TTIImpl->isTruncateFree(Ty1, Ty2);
 }
 
+bool TargetTransformInfo::isProfitableToHoist(Instruction *I) const {
+  return TTIImpl->isProfitableToHoist(I);
+}
+
 bool TargetTransformInfo::isTypeLegal(Type *Ty) const {
   return TTIImpl->isTypeLegal(Ty);
 }
index 332c8796c0a9e00adf51f83a183d3423ff04bcac..fb31d7d33763c46fb35920fdca773ce7db980567 100644 (file)
@@ -6533,6 +6533,34 @@ bool AArch64TargetLowering::isTruncateFree(EVT VT1, EVT VT2) const {
   return NumBits1 > NumBits2;
 }
 
+/// Check if it is profitable to hoist instruction in then/else to if.
+/// Not profitable if I and it's user can form a FMA instruction
+/// because we prefer FMSUB/FMADD.
+bool AArch64TargetLowering::isProfitableToHoist(Instruction *I) const {
+  if (I->getOpcode() != Instruction::FMul)
+    return true;
+
+  if (I->getNumUses() != 1)
+    return true;
+
+  Instruction *User = I->user_back();
+
+  if (User &&
+      !(User->getOpcode() == Instruction::FSub ||
+        User->getOpcode() == Instruction::FAdd))
+    return true;
+
+  const TargetOptions &Options = getTargetMachine().Options;
+  EVT VT = getValueType(User->getOperand(0)->getType());
+
+  if (isFMAFasterThanFMulAndFAdd(VT) &&
+      isOperationLegalOrCustom(ISD::FMA, VT) &&
+      (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath))
+    return false;
+
+  return true;
+}
+
 // All 32-bit GPR operations implicitly zero the high-half of the corresponding
 // 64-bit GPR.
 bool AArch64TargetLowering::isZExtFree(Type *Ty1, Type *Ty2) const {
index 6cbc425e71ff43a7e52b6fa5d371330091dc0784..db15538e43be6959fb868188ad3c957a9cbea989 100644 (file)
@@ -18,6 +18,7 @@
 #include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/IR/CallingConv.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/Target/TargetLowering.h"
 
 namespace llvm {
@@ -286,6 +287,8 @@ public:
   bool isTruncateFree(Type *Ty1, Type *Ty2) const override;
   bool isTruncateFree(EVT VT1, EVT VT2) const override;
 
+  bool isProfitableToHoist(Instruction *I) const override;
+
   bool isZExtFree(Type *Ty1, Type *Ty2) const override;
   bool isZExtFree(EVT VT1, EVT VT2) const override;
   bool isZExtFree(SDValue Val, EVT VT2) const override;
index 9cbd05f897bff4295d7e5bd48d8a15429d7b256e..3248a83636c4a1abb1eca9924fc4f48549637342 100644 (file)
@@ -1053,7 +1053,8 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I);
 /// HoistThenElseCodeToIf - Given a conditional branch that goes to BB1 and
 /// BB2, hoist any common code in the two blocks up into the branch block.  The
 /// caller of this function guarantees that BI's block dominates BB1 and BB2.
-static bool HoistThenElseCodeToIf(BranchInst *BI, const DataLayout *DL) {
+static bool HoistThenElseCodeToIf(BranchInst *BI, const DataLayout *DL,
+                                  const TargetTransformInfo &TTI) {
   // This does very trivial matching, with limited scanning, to find identical
   // instructions in the two blocks.  In particular, we don't want to get into
   // O(M*N) situations here where M and N are the sizes of BB1 and BB2.  As
@@ -1088,6 +1089,9 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, const DataLayout *DL) {
     if (isa<TerminatorInst>(I1))
       goto HoistTerminator;
 
+    if (!TTI.isProfitableToHoist(I1) || !TTI.isProfitableToHoist(I2))
+      return Changed;
+
     // For a normal instruction, we just move one to right before the branch,
     // then replace all uses of the other with the first.  Finally, we remove
     // the now redundant second instruction.
@@ -4442,7 +4446,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) {
   // can hoist it up to the branching block.
   if (BI->getSuccessor(0)->getSinglePredecessor()) {
     if (BI->getSuccessor(1)->getSinglePredecessor()) {
-      if (HoistThenElseCodeToIf(BI, DL))
+      if (HoistThenElseCodeToIf(BI, DL, TTI))
         return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true;
     } else {
       // If Successor #1 has multiple preds, we may be able to conditionally
diff --git a/test/Transforms/SimplifyCFG/AArch64/lit.local.cfg b/test/Transforms/SimplifyCFG/AArch64/lit.local.cfg
new file mode 100644 (file)
index 0000000..6642d28
--- /dev/null
@@ -0,0 +1,5 @@
+config.suffixes = ['.ll']
+
+targets = set(config.root.targets_to_build.split())
+if not 'AArch64' in targets:
+    config.unsupported = True
diff --git a/test/Transforms/SimplifyCFG/AArch64/prefer-fma.ll b/test/Transforms/SimplifyCFG/AArch64/prefer-fma.ll
new file mode 100644 (file)
index 0000000..076cb58
--- /dev/null
@@ -0,0 +1,72 @@
+; RUN: opt < %s -mtriple=aarch64-linux-gnu -simplifycfg -enable-unsafe-fp-math -S >%t
+; RUN: FileCheck %s < %t
+; ModuleID = 't.cc'
+
+; Function Attrs: nounwind
+define double @_Z3fooRdS_S_S_(double* dereferenceable(8) %x, double* dereferenceable(8) %y, double* dereferenceable(8) %a) #0 {
+entry:
+  %0 = load double* %y, align 8
+  %cmp = fcmp oeq double %0, 0.000000e+00
+  %1 = load double* %x, align 8
+  br i1 %cmp, label %if.then, label %if.else
+
+; fadd (const, (fmul x, y))
+if.then:                                          ; preds = %entry
+; CHECK-LABEL: if.then:
+; CHECK:   %3 = fmul fast double %1, %2
+; CHECK-NEXT:   %mul = fadd fast double 1.000000e+00, %3
+  %2 = load double* %a, align 8
+  %3 = fmul fast double %1, %2
+  %mul = fadd fast double 1.000000e+00, %3
+  store double %mul, double* %y, align 8
+  br label %if.end
+
+; fsub ((fmul x, y), z)
+if.else:                                          ; preds = %entry
+; CHECK-LABEL: if.else:
+; CHECK:   %mul1 = fmul fast double %1, %2
+; CHECK-NEXT:   %sub1 = fsub fast double %mul1, %0
+  %4 = load double* %a, align 8
+  %mul1 = fmul fast double %1, %4
+  %sub1 = fsub fast double %mul1, %0
+  store double %sub1, double* %y, align 8
+  br label %if.end
+
+if.end:                                           ; preds = %if.else, %if.then
+  %5 = load double* %y, align 8
+  %cmp2 = fcmp oeq double %5, 2.000000e+00
+  %6 = load double* %x, align 8
+  br i1 %cmp2, label %if.then2, label %if.else2
+
+; fsub (x, (fmul y, z))
+if.then2:                                         ; preds = %entry
+; CHECK-LABEL: if.then2:
+; CHECK:   %7 = fmul fast double %5, 3.000000e+00
+; CHECK-NEXT:   %mul2 = fsub fast double %6, %7
+  %7 = load double* %a, align 8
+  %8 = fmul fast double %6, 3.0000000e+00
+  %mul2 = fsub fast double %7, %8
+  store double %mul2, double* %y, align 8
+  br label %if.end2
+
+; fsub (fneg((fmul x, y)), const)
+if.else2:                                         ; preds = %entry
+; CHECK-LABEL: if.else2:
+; CHECK:   %mul3 = fmul fast double %5, 3.000000e+00
+; CHECK-NEXT:   %neg = fsub fast double 0.000000e+00, %mul3
+; CHECK-NEXT:   %sub2 = fsub fast double %neg, 3.000000e+00
+  %mul3 = fmul fast double %6, 3.0000000e+00
+  %neg = fsub fast double 0.0000000e+00, %mul3
+  %sub2 = fsub fast double %neg, 3.0000000e+00
+  store double %sub2, double* %y, align 8
+  br label %if.end2
+
+if.end2:                                           ; preds = %if.else, %if.then
+  %9 = load double* %x, align 8
+  %10 = load double* %y, align 8
+  %add = fadd fast double %9, %10
+  %11 = load double* %a, align 8
+  %add2 = fadd fast double %add, %11
+  ret double %add2
+}
+