Skip to content

Implement explicit tail calls in the LLVM backend #138555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions compiler/rustc_codegen_gcc/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 +1751,13 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
// FIXME(bjorn3): implement
}

fn set_tail_call(&mut self, _call_inst: RValue<'gcc>) {
// Explicitly fail when this method is called
bug!(
"Guaranteed tail calls with the 'become' keyword are not implemented in the GCC backend"
);
}

fn set_span(&mut self, _span: Span) {}

fn from_immediate(&mut self, val: Self::Value) -> Self::Value {
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,11 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
let cold_inline = llvm::AttributeKind::Cold.create_attr(self.llcx);
attributes::apply_to_callsite(llret, llvm::AttributePlace::Function, &[cold_inline]);
}

fn set_tail_call(&mut self, call_inst: &'ll Value) {
// Use musttail for guaranteed tail call optimization required by 'become'
llvm::LLVMRustSetTailCallKind(call_inst, llvm::TailCallKind::MustTail);
}
}

impl<'ll> StaticBuilderMethods for Builder<'_, 'll, '_> {
Expand Down
12 changes: 12 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ pub(crate) enum ModuleFlagMergeBehavior {

// Consts for the LLVM CallConv type, pre-cast to usize.

/// LLVM TailCallKind for musttail support
#[derive(Copy, Clone, PartialEq, Debug)]
#[repr(C)]
#[allow(dead_code)]
pub(crate) enum TailCallKind {
None = 0,
Tail = 1,
MustTail = 2,
NoTail = 3,
}

/// LLVM CallingConv::ID. Should we wrap this?
///
/// See <https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/CallingConv.h>
Expand Down Expand Up @@ -1181,6 +1192,7 @@ unsafe extern "C" {
pub(crate) safe fn LLVMIsGlobalConstant(GlobalVar: &Value) -> Bool;
pub(crate) safe fn LLVMSetGlobalConstant(GlobalVar: &Value, IsConstant: Bool);
pub(crate) safe fn LLVMSetTailCall(CallInst: &Value, IsTailCall: Bool);
pub(crate) safe fn LLVMRustSetTailCallKind(CallInst: &Value, Kind: TailCallKind);

// Operations on attributes
pub(crate) fn LLVMCreateStringAttribute(
Expand Down
100 changes: 94 additions & 6 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,97 @@ impl<'a, 'tcx> TerminatorCodegenHelper<'tcx> {

/// Codegen implementations for some terminator variants.
impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
fn codegen_tail_call_terminator(
&mut self,
bx: &mut Bx,
func: &mir::Operand<'tcx>,
args: &[Spanned<mir::Operand<'tcx>>],
fn_span: Span,
) {
// We don't need source_info as we already have fn_span for diagnostics
let func = self.codegen_operand(bx, func);
let fn_ty = func.layout.ty;

// Create the callee. This is a fn ptr or zero-sized and hence a kind of scalar.
let (fn_ptr, fn_abi, instance) = match *fn_ty.kind() {
ty::FnDef(def_id, substs) => {
let instance = ty::Instance::expect_resolve(
bx.tcx(),
bx.typing_env(),
def_id,
substs,
fn_span,
);
let fn_ptr = bx.get_fn_addr(instance);
let fn_abi = bx.fn_abi_of_instance(instance, ty::List::empty());
(fn_ptr, fn_abi, Some(instance))
}
ty::FnPtr(..) => {
let sig = fn_ty.fn_sig(bx.tcx());
let extra_args = bx.tcx().mk_type_list(&[]);
let fn_ptr = func.immediate();
let fn_abi = bx.fn_abi_of_fn_ptr(sig, extra_args);
(fn_ptr, fn_abi, None)
}
_ => bug!("{} is not callable", func.layout.ty),
};

let mut llargs = Vec::with_capacity(args.len());
let mut lifetime_ends_after_call = Vec::new();

// Process arguments
for arg in args {
let op = self.codegen_operand(bx, &arg.node);
let arg_idx = llargs.len();

if arg_idx < fn_abi.args.len() {
self.codegen_argument(
bx,
op,
&mut llargs,
&fn_abi.args[arg_idx],
&mut lifetime_ends_after_call,
);
} else {
// This can happen in case of C-variadic functions
let is_immediate = match op.val {
Immediate(_) => true,
_ => false,
};

if is_immediate {
llargs.push(op.immediate());
} else {
let temp = PlaceRef::alloca(bx, op.layout);
op.val.store(bx, temp);
llargs.push(bx.load(
bx.backend_type(op.layout),
temp.val.llval,
temp.val.align,
));
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is missing an sret argument when PassMode::Indirect is used. Also please deduplicate this code with regular calls where possible to prevent divergence.


// Call the function
let fn_ty = bx.fn_decl_backend_type(fn_abi);
let fn_attrs = if let Some(instance) = instance
&& bx.tcx().def_kind(instance.def_id()).has_codegen_attrs()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would we ever have a tail call for something that doesn't have codegen attrs?

{
Some(bx.tcx().codegen_fn_attrs(instance.def_id()))
} else {
None
};

// Perform the actual function call
let llret = bx.call(fn_ty, fn_attrs, Some(fn_abi), fn_ptr, &llargs, None, instance);

// Mark as tail call - this is the critical part
bx.set_tail_call(llret);
Comment on lines +427 to +431
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we use normal function codegen and just set_tail_call at the end?

Can we at least share some of the argument processing code and instance resolution code?


// Return the result - musttail requires ret immediately after the call
bx.ret(llret);
}
/// Generates code for a `Resume` terminator.
fn codegen_resume_terminator(&mut self, helper: TerminatorCodegenHelper<'tcx>, bx: &mut Bx) {
if let Some(funclet) = helper.funclet(self) {
Expand Down Expand Up @@ -1390,12 +1481,9 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
fn_span,
mergeable_succ(),
),
mir::TerminatorKind::TailCall { .. } => {
// FIXME(explicit_tail_calls): implement tail calls in ssa backend
span_bug!(
terminator.source_info.span,
"`TailCall` terminator is not yet supported by `rustc_codegen_ssa`"
)
mir::TerminatorKind::TailCall { ref func, ref args, fn_span } => {
self.codegen_tail_call_terminator(bx, func, args, fn_span);
MergingSucc::False
}
mir::TerminatorKind::CoroutineDrop | mir::TerminatorKind::Yield { .. } => {
bug!("coroutine ops in codegen")
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_codegen_ssa/src/traits/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,10 @@ pub trait BuilderMethods<'a, 'tcx>:
funclet: Option<&Self::Funclet>,
instance: Option<Instance<'tcx>>,
) -> Self::Value;

/// Mark a call instruction as a tail call (guaranteed tail call optimization)
/// Used for implementing the `become` expression
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All those references to being used to implement become feel a bit redundant to me.

fn set_tail_call(&mut self, call_inst: Self::Value);
fn zext(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value;

fn apply_attrs_to_cleanup_callsite(&mut self, llret: Self::Value);
Expand Down
26 changes: 26 additions & 0 deletions compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@
//
//===----------------------------------------------------------------------===

// Define TailCallKind enum values to match LLVM's
enum LLVMRustTailCallKind {
LLVMRustTailCallKindNone = 0,
LLVMRustTailCallKindTail = 1,
LLVMRustTailCallKindMustTail = 2,
LLVMRustTailCallKindNoTail = 3
};

using namespace llvm;
using namespace llvm::sys;
using namespace llvm::object;
Expand Down Expand Up @@ -1949,3 +1957,21 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) {
MD.NoHWAddress = true;
GV.setSanitizerMetadata(MD);
}

extern "C" void LLVMRustSetTailCallKind(LLVMValueRef Call, LLVMRustTailCallKind Kind) {
CallInst *CI = unwrap<CallInst>(Call);
switch (Kind) {
case LLVMRustTailCallKindNone:
CI->setTailCallKind(CallInst::TCK_None);
break;
case LLVMRustTailCallKindTail:
CI->setTailCallKind(CallInst::TCK_Tail);
break;
case LLVMRustTailCallKindMustTail:
CI->setTailCallKind(CallInst::TCK_MustTail);
break;
case LLVMRustTailCallKindNoTail:
CI->setTailCallKind(CallInst::TCK_NoTail);
break;
}
}
68 changes: 68 additions & 0 deletions tests/codegen/tail-call-become.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
//@ compile-flags: -C opt-level=0 -Cpanic=abort -C no-prepopulate-passes
//@ needs-llvm-components: x86

#![feature(explicit_tail_calls)]
#![crate_type = "lib"]

// CHECK-LABEL: define {{.*}}@with_tail(
#[no_mangle]
#[inline(never)]
pub fn with_tail(n: u32) -> u32 {
// CHECK: tail call {{.*}}@with_tail(
if n == 0 { 0 } else { become with_tail(n - 1) }
}

// CHECK-LABEL: define {{.*}}@no_tail(
#[no_mangle]
#[inline(never)]
pub fn no_tail(n: u32) -> u32 {
// CHECK-NOT: tail call
// CHECK: call {{.*}}@no_tail(
if n == 0 { 0 } else { no_tail(n - 1) }
}

// CHECK-LABEL: define {{.*}}@even_with_tail(
#[no_mangle]
#[inline(never)]
pub fn even_with_tail(n: u32) -> bool {
// CHECK: tail call {{.*}}@odd_with_tail(
match n {
0 => true,
_ => become odd_with_tail(n - 1),
}
}

// CHECK-LABEL: define {{.*}}@odd_with_tail(
#[no_mangle]
#[inline(never)]
pub fn odd_with_tail(n: u32) -> bool {
// CHECK: tail call {{.*}}@even_with_tail(
match n {
0 => false,
_ => become even_with_tail(n - 1),
}
}

// CHECK-LABEL: define {{.*}}@even_no_tail(
#[no_mangle]
#[inline(never)]
pub fn even_no_tail(n: u32) -> bool {
// CHECK-NOT: tail call
// CHECK: call {{.*}}@odd_no_tail(
match n {
0 => true,
_ => odd_no_tail(n - 1),
}
}

// CHECK-LABEL: define {{.*}}@odd_no_tail(
#[no_mangle]
#[inline(never)]
pub fn odd_no_tail(n: u32) -> bool {
// CHECK-NOT: tail call
// CHECK: call {{.*}}@even_no_tail(
match n {
0 => false,
_ => even_no_tail(n - 1),
}
}
33 changes: 33 additions & 0 deletions tests/codegen/tail-call-musttail.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//@ compile-flags: -C opt-level=0 -Cpanic=abort -C no-prepopulate-passes
//@ needs-unwind

#![crate_type = "lib"]
#![feature(explicit_tail_calls)]

// Ensure that explicit tail calls use musttail in LLVM

// CHECK-LABEL: define {{.*}}@simple_tail_call(
#[no_mangle]
#[inline(never)]
pub fn simple_tail_call(n: i32) -> i32 {
// CHECK: musttail call {{.*}}@simple_tail_call(
// CHECK-NEXT: ret i32
if n <= 0 {
0
} else {
become simple_tail_call(n - 1)
}
}

// CHECK-LABEL: define {{.*}}@tail_call_with_args(
#[no_mangle]
#[inline(never)]
pub fn tail_call_with_args(a: i32, b: i32, c: i32) -> i32 {
// CHECK: musttail call {{.*}}@tail_call_with_args(
// CHECK-NEXT: ret i32
if a == 0 {
b + c
} else {
become tail_call_with_args(a - 1, b + 1, c)
}
}
35 changes: 35 additions & 0 deletions tests/ui/explicit-tail-calls/llvm-ir-tail-call.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//@ compile-flags: -O
//@ run-pass
#![expect(incomplete_features)]
#![feature(explicit_tail_calls)]

// A deep recursive function that uses explicit tail calls
// This will cause stack overflow without tail call optimization
fn deep_recursion(n: u32) -> u32 {
match n {
0 => 0,
_ => become deep_recursion(n - 1)
}
}

// A deep recursive function without explicit tail calls
// This will overflow the stack for large values
fn deep_recursion_no_tail(n: u32) -> u32 {
match n {
0 => 0,
_ => deep_recursion_no_tail(n - 1)
}
}

fn main() {
// Verify correctness for small values
assert_eq!(deep_recursion(10), 0);
assert_eq!(deep_recursion_no_tail(10), 0);

// This will succeed only if tail call optimization is working
// It would overflow the stack otherwise
println!("Starting deep recursion with 50,000 calls");
let result = deep_recursion(50_000);
assert_eq!(result, 0);
println!("Successfully completed 50,000 recursive calls with tail call optimization");
}
Loading
Loading