Apply the InstCombine fptrunc sqrt optimization to llvm.sqrt
authorHal Finkel <hfinkel@anl.gov>
Sat, 16 Nov 2013 21:29:08 +0000 (21:29 +0000)
committerHal Finkel <hfinkel@anl.gov>
Sat, 16 Nov 2013 21:29:08 +0000 (21:29 +0000)
InstCombine, in visitFPTrunc, applies the following optimization to sqrt calls:

  (fptrunc (sqrt (fpext x))) -> (sqrtf x)

but does not apply the same optimization to llvm.sqrt. This is a problem
because, to enable vectorization, Clang generates llvm.sqrt instead of sqrt in
fast-math mode, and because this optimization is being applied to sqrt and not
applied to llvm.sqrt, sometimes the fast-math code is slower.

This change makes InstCombine apply this optimization to llvm.sqrt as well.

This fixes the specific problem in PR17758, although the same underlying issue
(optimizations applied to libcalls are not applied to intrinsics) exists for
other optimizations in SimplifyLibCalls.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@194935 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/InstCombine/InstCombineCasts.cpp
test/Transforms/InstCombine/double-float-shrink-1.ll

index a1aedd4e8f1810c20ebc2532a440f18ed66cbac7..72377dc0adcaa9b698527cbb34e5f94b66bd927b 100644 (file)
@@ -1262,9 +1262,14 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) {
   }
 
   // Fold (fptrunc (sqrt (fpext x))) -> (sqrtf x)
+  // Note that we restrict this transformation based on
+  // TLI->has(LibFunc::sqrtf), even for the sqrt intrinsic, because
+  // TLI->has(LibFunc::sqrtf) is sufficient to guarantee that the
+  // single-precision intrinsic can be expanded in the backend.
   CallInst *Call = dyn_cast<CallInst>(CI.getOperand(0));
   if (Call && Call->getCalledFunction() && TLI->has(LibFunc::sqrtf) &&
-      Call->getCalledFunction()->getName() == TLI->getName(LibFunc::sqrt) &&
+      (Call->getCalledFunction()->getName() == TLI->getName(LibFunc::sqrt) ||
+       Call->getCalledFunction()->getIntrinsicID() == Intrinsic::sqrt) &&
       Call->getNumArgOperands() == 1 &&
       Call->hasOneUse()) {
     CastInst *Arg = dyn_cast<CastInst>(Call->getArgOperand(0));
@@ -1275,11 +1280,11 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) {
         Arg->getOperand(0)->getType()->isFloatTy()) {
       Function *Callee = Call->getCalledFunction();
       Module *M = CI.getParent()->getParent()->getParent();
-      Constant *SqrtfFunc = M->getOrInsertFunction("sqrtf",
-                                                   Callee->getAttributes(),
-                                                   Builder->getFloatTy(),
-                                                   Builder->getFloatTy(),
-                                                   NULL);
+      Constant *SqrtfFunc = (Callee->getIntrinsicID() == Intrinsic::sqrt) ?
+        Intrinsic::getDeclaration(M, Intrinsic::sqrt, Builder->getFloatTy()) :
+        M->getOrInsertFunction("sqrtf", Callee->getAttributes(),
+                               Builder->getFloatTy(), Builder->getFloatTy(),
+                               NULL);
       CallInst *ret = CallInst::Create(SqrtfFunc, Arg->getOperand(0),
                                        "sqrtfcall");
       ret->setAttributes(Callee->getAttributes());
index e5448ee00765475b4544dc2c468c8474c1456e41..5cacb591e00645f3ff04937043a57c92018487fc 100644 (file)
@@ -263,6 +263,7 @@ define double @sin_test2(float %f) nounwind readnone {
    ret double %call
 ; CHECK: call double @sin(double %conv)
 }
+
 define float @sqrt_test(float %f) nounwind readnone {
 ; CHECK: sqrt_test
    %conv = fpext float %f to double
@@ -272,6 +273,15 @@ define float @sqrt_test(float %f) nounwind readnone {
 ; CHECK: call float @sqrtf(float %f)
 }
 
+define float @sqrt_int_test(float %f) nounwind readnone {
+; CHECK: sqrt_int_test
+   %conv = fpext float %f to double
+   %call = call double @llvm.sqrt.f64(double %conv)
+   %conv1 = fptrunc double %call to float
+   ret float %conv1
+; CHECK: call float @llvm.sqrt.f32(float %f)
+}
+
 define double @sqrt_test2(float %f) nounwind readnone {
 ; CHECK: sqrt_test2
    %conv = fpext float %f to double
@@ -331,3 +341,6 @@ declare double @acos(double) nounwind readnone
 declare double @acosh(double) nounwind readnone
 declare double @asin(double) nounwind readnone
 declare double @asinh(double) nounwind readnone
+
+declare double @llvm.sqrt.f64(double) nounwind readnone
+