#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/CallSite.h"
+#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfo.h"
}
}
+/// Returns a musttail call instruction if one immediately precedes the given
+/// return instruction with an optional bitcast instruction between them.
+static CallInst *getPrecedingMustTailCall(ReturnInst *RI) {
+ Instruction *Prev = RI->getPrevNode();
+ if (!Prev)
+ return nullptr;
+
+ if (Value *RV = RI->getReturnValue()) {
+ if (RV != Prev)
+ return nullptr;
+
+ // Look through the optional bitcast.
+ if (auto *BI = dyn_cast<BitCastInst>(Prev)) {
+ RV = BI->getOperand(0);
+ Prev = BI->getPrevNode();
+ if (!Prev || RV != Prev)
+ return nullptr;
+ }
+ }
+
+ if (auto *CI = dyn_cast<CallInst>(Prev)) {
+ if (CI->isMustTailCall())
+ return CI;
+ }
+ return nullptr;
+}
+
/// InlineFunction - This function inlines the called function into the basic
/// block of the caller. This returns false if it is not possible to inline
/// this call. The program is still in a well defined state if this occurs
// If the call to the callee is not a tail call, we must clear the 'tail'
// flags on any calls that we inline.
- bool MustClearTailCallFlags =
- !(isa<CallInst>(TheCall) && cast<CallInst>(TheCall)->isTailCall());
+ CallInst::TailCallKind CallSiteTailKind = CallInst::TCK_None;
+ if (CallInst *CI = dyn_cast<CallInst>(TheCall))
+ CallSiteTailKind = CI->getTailCallKind();
+ bool MustClearTailCallFlags = false;
// If the call to the callee cannot throw, set the 'nounwind' flag on any
// calls that we inline.
}
}
+ bool InlinedMustTailCalls = false;
+ if (InlinedFunctionInfo.ContainsCalls) {
+ for (Function::iterator BB = FirstNewBlock, E = Caller->end(); BB != E;
+ ++BB) {
+ for (Instruction &I : *BB) {
+ CallInst *CI = dyn_cast<CallInst>(&I);
+ if (!CI)
+ continue;
+
+ // We need to reduce the strength of any inlined tail calls. For
+ // musttail, we have to avoid introducing potential unbounded stack
+ // growth. For example, if functions 'f' and 'g' are mutually recursive
+ // with musttail, we can inline 'g' into 'f' so long as we preserve
+ // musttail on the cloned call to 'f'. If either the inlined call site
+ // or the cloned call site is *not* musttail, the program already has
+ // one frame of stack growth, so it's safe to remove musttail. Here is
+ // a table of example transformations:
+ //
+ // f -> musttail g -> musttail f ==> f -> musttail f
+ // f -> musttail g -> tail f ==> f -> tail f
+ // f -> g -> musttail f ==> f -> f
+ // f -> g -> tail f ==> f -> f
+ CallInst::TailCallKind ChildTCK = CI->getTailCallKind();
+ ChildTCK = std::min(CallSiteTailKind, ChildTCK);
+ CI->setTailCallKind(ChildTCK);
+ InlinedMustTailCalls |= CI->isMustTailCall();
+
+ // Calls inlined through a 'nounwind' call site should be marked
+ // 'nounwind'.
+ if (MarkNoUnwind)
+ CI->setDoesNotThrow();
+ }
+ }
+ }
+
// Leave lifetime markers for the static alloca's, scoping them to the
// function we just inlined.
if (InsertLifetime && !IFI.StaticAllocas.empty()) {
}
builder.CreateLifetimeStart(AI, AllocaSize);
- for (unsigned ri = 0, re = Returns.size(); ri != re; ++ri) {
- IRBuilder<> builder(Returns[ri]);
- builder.CreateLifetimeEnd(AI, AllocaSize);
- }
+ for (ReturnInst *RI : Returns)
+ IRBuilder<>(RI).CreateLifetimeEnd(AI, AllocaSize);
}
}
// Insert a call to llvm.stackrestore before any return instructions in the
// inlined function.
- for (unsigned i = 0, e = Returns.size(); i != e; ++i) {
- IRBuilder<>(Returns[i]).CreateCall(StackRestore, SavedPtr);
- }
- }
-
- // If we are inlining tail call instruction through a call site that isn't
- // marked 'tail', we must remove the tail marker for any calls in the inlined
- // code. Also, calls inlined through a 'nounwind' call site should be marked
- // 'nounwind'.
- if (InlinedFunctionInfo.ContainsCalls &&
- (MustClearTailCallFlags || MarkNoUnwind)) {
- for (Function::iterator BB = FirstNewBlock, E = Caller->end();
- BB != E; ++BB)
- for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I)
- if (CallInst *CI = dyn_cast<CallInst>(I)) {
- if (MustClearTailCallFlags)
- CI->setTailCall(false);
- if (MarkNoUnwind)
- CI->setDoesNotThrow();
- }
+ for (ReturnInst *RI : Returns)
+ IRBuilder<>(RI).CreateCall(StackRestore, SavedPtr);
}
// If we are inlining for an invoke instruction, we must make sure to rewrite
if (InvokeInst *II = dyn_cast<InvokeInst>(TheCall))
HandleInlinedInvoke(II, FirstNewBlock, InlinedFunctionInfo);
+ // Handle any inlined musttail call sites. In order for a new call site to be
+ // musttail, the source of the clone and the inlined call site must have been
+ // musttail. Therefore it's safe to return without merging control into the
+ // phi below.
+ if (InlinedMustTailCalls) {
+ // Check if we need to bitcast the result of any musttail calls.
+ Type *NewRetTy = Caller->getReturnType();
+ bool NeedBitCast = !TheCall->use_empty() && TheCall->getType() != NewRetTy;
+
+ // Handle the returns preceded by musttail calls separately.
+ SmallVector<ReturnInst *, 8> NormalReturns;
+ for (ReturnInst *RI : Returns) {
+ CallInst *ReturnedMustTail = getPrecedingMustTailCall(RI);
+ if (!ReturnedMustTail) {
+ NormalReturns.push_back(RI);
+ continue;
+ }
+ if (!NeedBitCast)
+ continue;
+
+ // Delete the old return and any preceding bitcast.
+ BasicBlock *CurBB = RI->getParent();
+ auto *OldCast = dyn_cast_or_null<BitCastInst>(RI->getReturnValue());
+ RI->eraseFromParent();
+ if (OldCast)
+ OldCast->eraseFromParent();
+
+ // Insert a new bitcast and return with the right type.
+ IRBuilder<> Builder(CurBB);
+ Builder.CreateRet(Builder.CreateBitCast(ReturnedMustTail, NewRetTy));
+ }
+
+ // Leave behind the normal returns so we can merge control flow.
+ std::swap(Returns, NormalReturns);
+ }
+
// If we cloned in _exactly one_ basic block, and if that block ends in a
// return instruction, we splice the body of the inlined callee directly into
// the calling basic block.
// Since we are now done with the Call/Invoke, we can delete it.
TheCall->eraseFromParent();
+ // If we inlined any musttail calls and the original return is now
+ // unreachable, delete it. It can only contain a bitcast and ret.
+ if (InlinedMustTailCalls && pred_begin(AfterCallBB) == pred_end(AfterCallBB))
+ AfterCallBB->eraseFromParent();
+
// We should always be able to fold the entry block of the function into the
// single predecessor of the block...
assert(cast<BranchInst>(Br)->isUnconditional() && "splitBasicBlock broken!");
-; RUN: opt < %s -inline -S | not grep tail
+; RUN: opt < %s -inline -S | FileCheck %s
-declare void @bar(i32*)
+; We have to apply the less restrictive TailCallKind of the call site being
+; inlined and any call sites cloned into the caller.
-define internal void @foo(i32* %P) {
- tail call void @bar( i32* %P )
- ret void
+; No tail marker after inlining, since test_capture_c captures an alloca.
+; CHECK: define void @test_capture_a(
+; CHECK-NOT: tail
+; CHECK: call void @test_capture_c(
+
+declare void @test_capture_c(i32*)
+define internal void @test_capture_b(i32* %P) {
+ tail call void @test_capture_c(i32* %P)
+ ret void
+}
+define void @test_capture_a() {
+ %A = alloca i32 ; captured by test_capture_b
+ call void @test_capture_b(i32* %A)
+ ret void
+}
+
+; No musttail marker after inlining, since the prototypes don't match.
+; CHECK: define void @test_proto_mismatch_a(
+; CHECK-NOT: musttail
+; CHECK: call void @test_proto_mismatch_c(
+
+declare void @test_proto_mismatch_c(i32*)
+define internal void @test_proto_mismatch_b(i32* %p) {
+ musttail call void @test_proto_mismatch_c(i32* %p)
+ ret void
+}
+define void @test_proto_mismatch_a() {
+ call void @test_proto_mismatch_b(i32* null)
+ ret void
}
-define void @caller() {
- %A = alloca i32 ; <i32*> [#uses=1]
- call void @foo( i32* %A )
- ret void
+; After inlining through a musttail call site, we need to keep musttail markers
+; to prevent unbounded stack growth.
+; CHECK: define void @test_musttail_basic_a(
+; CHECK: musttail call void @test_musttail_basic_c(
+
+declare void @test_musttail_basic_c(i32* %p)
+define internal void @test_musttail_basic_b(i32* %p) {
+ musttail call void @test_musttail_basic_c(i32* %p)
+ ret void
+}
+define void @test_musttail_basic_a(i32* %p) {
+ musttail call void @test_musttail_basic_b(i32* %p)
+ ret void
+}
+
+; We can't merge the returns.
+; CHECK: define void @test_multiret_a(
+; CHECK: musttail call void @test_multiret_c(
+; CHECK-NEXT: ret void
+; CHECK: musttail call void @test_multiret_d(
+; CHECK-NEXT: ret void
+
+declare void @test_multiret_c(i1 zeroext %b)
+declare void @test_multiret_d(i1 zeroext %b)
+define internal void @test_multiret_b(i1 zeroext %b) {
+ br i1 %b, label %c, label %d
+c:
+ musttail call void @test_multiret_c(i1 zeroext %b)
+ ret void
+d:
+ musttail call void @test_multiret_d(i1 zeroext %b)
+ ret void
+}
+define void @test_multiret_a(i1 zeroext %b) {
+ musttail call void @test_multiret_b(i1 zeroext %b)
+ ret void
}
+; We have to avoid bitcast chains.
+; CHECK: define i32* @test_retptr_a(
+; CHECK: musttail call i8* @test_retptr_c(
+; CHECK-NEXT: bitcast i8* {{.*}} to i32*
+; CHECK-NEXT: ret i32*
+
+declare i8* @test_retptr_c()
+define internal i16* @test_retptr_b() {
+ %rv = musttail call i8* @test_retptr_c()
+ %v = bitcast i8* %rv to i16*
+ ret i16* %v
+}
+define i32* @test_retptr_a() {
+ %rv = musttail call i16* @test_retptr_b()
+ %v = bitcast i16* %rv to i32*
+ ret i32* %v
+}
+
+; Combine the last two cases: multiple returns with pointer bitcasts.
+; CHECK: define i32* @test_multiptrret_a(
+; CHECK: musttail call i8* @test_multiptrret_c(
+; CHECK-NEXT: bitcast i8* {{.*}} to i32*
+; CHECK-NEXT: ret i32*
+; CHECK: musttail call i8* @test_multiptrret_d(
+; CHECK-NEXT: bitcast i8* {{.*}} to i32*
+; CHECK-NEXT: ret i32*
+
+declare i8* @test_multiptrret_c(i1 zeroext %b)
+declare i8* @test_multiptrret_d(i1 zeroext %b)
+define internal i16* @test_multiptrret_b(i1 zeroext %b) {
+ br i1 %b, label %c, label %d
+c:
+ %c_rv = musttail call i8* @test_multiptrret_c(i1 zeroext %b)
+ %c_v = bitcast i8* %c_rv to i16*
+ ret i16* %c_v
+d:
+ %d_rv = musttail call i8* @test_multiptrret_d(i1 zeroext %b)
+ %d_v = bitcast i8* %d_rv to i16*
+ ret i16* %d_v
+}
+define i32* @test_multiptrret_a(i1 zeroext %b) {
+ %rv = musttail call i16* @test_multiptrret_b(i1 zeroext %b)
+ %v = bitcast i16* %rv to i32*
+ ret i32* %v
+}
+
+; Inline a musttail call site which contains a normal return and a musttail call.
+; CHECK: define i32 @test_mixedret_a(
+; CHECK: br i1 %b
+; CHECK: musttail call i32 @test_mixedret_c(
+; CHECK-NEXT: ret i32
+; CHECK: call i32 @test_mixedret_d(i1 zeroext %b)
+; CHECK: add i32 1,
+; CHECK-NOT: br
+; CHECK: ret i32
+
+declare i32 @test_mixedret_c(i1 zeroext %b)
+declare i32 @test_mixedret_d(i1 zeroext %b)
+define internal i32 @test_mixedret_b(i1 zeroext %b) {
+ br i1 %b, label %c, label %d
+c:
+ %c_rv = musttail call i32 @test_mixedret_c(i1 zeroext %b)
+ ret i32 %c_rv
+d:
+ %d_rv = call i32 @test_mixedret_d(i1 zeroext %b)
+ %d_rv1 = add i32 1, %d_rv
+ ret i32 %d_rv1
+}
+define i32 @test_mixedret_a(i1 zeroext %b) {
+ %rv = musttail call i32 @test_mixedret_b(i1 zeroext %b)
+ ret i32 %rv
+}