rdar://12753946
authorShuxin Yang <shuxin.llvm@gmail.com>
Fri, 14 Dec 2012 18:46:06 +0000 (18:46 +0000)
committerShuxin Yang <shuxin.llvm@gmail.com>
Fri, 14 Dec 2012 18:46:06 +0000 (18:46 +0000)
Implement rule : "x * (select cond 1.0, 0.0) -> select cond x, 0.0"

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@170226 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
test/Transforms/InstCombine/fast-math.ll

index b95da85a7f33c983f498f8e01eec3c09133acc2c..964297a5eab9f2c1d94faf796fb5bc84222e9f22 100644 (file)
@@ -341,6 +341,38 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
     }
   }
 
+  // X * cond ? 1.0 : 0.0 => cond ? X : 0.0
+  if (I.hasNoNaNs() && I.hasNoSignedZeros()) {
+    Value *V0 = I.getOperand(0);
+    Value *V1 = I.getOperand(1);
+    Value *Cond, *SLHS, *SRHS;
+    bool Match = false;
+
+    if (match(V0, m_Select(m_Value(Cond), m_Value(SLHS), m_Value(SRHS)))) {
+      Match = true;
+    } else if (match(V1, m_Select(m_Value(Cond), m_Value(SLHS), 
+                     m_Value(SRHS)))) {
+      Match = true;
+      std::swap(V0, V1);
+    }
+
+    if (Match) {
+      ConstantFP *C0 = dyn_cast<ConstantFP>(SLHS);
+      ConstantFP *C1 = dyn_cast<ConstantFP>(SRHS);
+
+      if (C0 && C1 &&
+          ((C0->isZero() && C1->isExactlyValue(1.0)) ||
+           (C1->isZero() && C0->isExactlyValue(1.0)))) {
+        Value *T;
+        if (C0->isZero())
+          T = Builder->CreateSelect(Cond, SLHS, V1);
+        else
+          T = Builder->CreateSelect(Cond, V1, SRHS);
+        return ReplaceInstUsesWith(I, T);
+      }
+    }
+  }
+
   return Changed ? &I : 0;
 }
 
index b6a15677bb71a0faf796505f14bc59a629af078e..0b87cd95d9cca4a26e897bc90a4c69228f59c3ed 100644 (file)
@@ -3,19 +3,17 @@
 ; testing-case "float fold(float a) { return 1.2f * a * 2.3f; }"
 ; 1.2f and 2.3f is supposed to be fold.
 define float @fold(float %a) {
-fold:
   %mul = fmul fast float %a, 0x3FF3333340000000
   %mul1 = fmul fast float %mul, 0x4002666660000000
   ret float %mul1
-; CHECK: fold
+; CHECK: @fold
 ; CHECK: fmul float %a, 0x4006147AE0000000
 }
 
 ; Same testing-case as the one used in fold() except that the operators have
 ; fixed FP mode.
 define float @notfold(float %a) {
-notfold:
-; CHECK: notfold
+; CHECK: @notfold
 ; CHECK: %mul = fmul fast float %a, 0x3FF3333340000000
   %mul = fmul fast float %a, 0x3FF3333340000000
   %mul1 = fmul float %mul, 0x4002666660000000
@@ -23,10 +21,40 @@ notfold:
 }
 
 define float @fold2(float %a) {
-fold2:
-; CHECK: fold2
+; CHECK: @fold2
 ; CHECK: fmul float %a, 0x4006147AE0000000
   %mul = fmul float %a, 0x3FF3333340000000
   %mul1 = fmul fast float %mul, 0x4002666660000000
   ret float %mul1
 }
+
+; rdar://12753946:  x * cond ? 1.0 : 0.0 => cond ? x : 0.0
+define double @select1(i32 %cond, double %x, double %y) {
+  %tobool = icmp ne i32 %cond, 0
+  %cond1 = select i1 %tobool, double 1.000000e+00, double 0.000000e+00
+  %mul = fmul nnan nsz double %cond1, %x
+  %add = fadd double %mul, %y
+  ret double %add
+; CHECK: @select1
+; CHECK: select i1 %tobool, double %x, double 0.000000e+00
+}
+
+define double @select2(i32 %cond, double %x, double %y) {
+  %tobool = icmp ne i32 %cond, 0
+  %cond1 = select i1 %tobool, double 0.000000e+00, double 1.000000e+00
+  %mul = fmul nnan nsz double %cond1, %x
+  %add = fadd double %mul, %y
+  ret double %add
+; CHECK: @select2
+; CHECK: select i1 %tobool, double 0.000000e+00, double %x
+}
+
+define double @select3(i32 %cond, double %x, double %y) {
+  %tobool = icmp ne i32 %cond, 0
+  %cond1 = select i1 %tobool, double 0.000000e+00, double 2.000000e+00
+  %mul = fmul nnan nsz double %cond1, %x
+  %add = fadd double %mul, %y
+  ret double %add
+; CHECK: @select3
+; CHECK: fmul nnan nsz double %cond1, %x
+}