Skip to content

Commit fc1b0b5

Browse files
committed
Propagate musttail attribute to LLVM backend for guaranteed tail calls
This commit implements proper tail call optimization for the `become` expression by propagating LLVM's musttail attribute, which guarantees tail call optimization rather than leaving it as an optimization hint. Changes: - Add `set_tail_call` method to BuilderMethods trait - Add FFI wrapper LLVMRustSetTailCallKind to access LLVM's setTailCallKind API - Implement tail call handling in LLVM backend using musttail - Implement TailCall terminator codegen in rustc_codegen_ssa - Make GCC backend fail explicitly on tail calls (not yet supported) - Add codegen tests to verify musttail is properly emitted - Add runtime tests for deep recursion and mutual recursion The musttail attribute is critical for the correctness of the `become` expression as it guarantees the tail call will be optimized, preventing stack overflow in recursive scenarios.
1 parent 014bd82 commit fc1b0b5

File tree

11 files changed

+385
-6
lines changed

11 files changed

+385
-6
lines changed

compiler/rustc_codegen_gcc/src/builder.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,6 +1751,13 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
17511751
// FIXME(bjorn3): implement
17521752
}
17531753

1754+
fn set_tail_call(&mut self, _call_inst: RValue<'gcc>) {
1755+
// Explicitly fail when this method is called
1756+
bug!(
1757+
"Guaranteed tail calls with the 'become' keyword are not implemented in the GCC backend"
1758+
);
1759+
}
1760+
17541761
fn set_span(&mut self, _span: Span) {}
17551762

17561763
fn from_immediate(&mut self, val: Self::Value) -> Self::Value {

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,6 +1371,11 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
13711371
let cold_inline = llvm::AttributeKind::Cold.create_attr(self.llcx);
13721372
attributes::apply_to_callsite(llret, llvm::AttributePlace::Function, &[cold_inline]);
13731373
}
1374+
1375+
fn set_tail_call(&mut self, call_inst: &'ll Value) {
1376+
// Use musttail for guaranteed tail call optimization required by 'become'
1377+
llvm::LLVMRustSetTailCallKind(call_inst, llvm::TailCallKind::MustTail);
1378+
}
13741379
}
13751380

13761381
impl<'ll> StaticBuilderMethods for Builder<'_, 'll, '_> {

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,17 @@ pub(crate) enum ModuleFlagMergeBehavior {
9797

9898
// Consts for the LLVM CallConv type, pre-cast to usize.
9999

100+
/// LLVM TailCallKind for musttail support
101+
#[derive(Copy, Clone, PartialEq, Debug)]
102+
#[repr(C)]
103+
#[allow(dead_code)]
104+
pub(crate) enum TailCallKind {
105+
None = 0,
106+
Tail = 1,
107+
MustTail = 2,
108+
NoTail = 3,
109+
}
110+
100111
/// LLVM CallingConv::ID. Should we wrap this?
101112
///
102113
/// See <https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/CallingConv.h>
@@ -1181,6 +1192,7 @@ unsafe extern "C" {
11811192
pub(crate) safe fn LLVMIsGlobalConstant(GlobalVar: &Value) -> Bool;
11821193
pub(crate) safe fn LLVMSetGlobalConstant(GlobalVar: &Value, IsConstant: Bool);
11831194
pub(crate) safe fn LLVMSetTailCall(CallInst: &Value, IsTailCall: Bool);
1195+
pub(crate) safe fn LLVMRustSetTailCallKind(CallInst: &Value, Kind: TailCallKind);
11841196

11851197
// Operations on attributes
11861198
pub(crate) fn LLVMCreateStringAttribute(

compiler/rustc_codegen_ssa/src/mir/block.rs

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,97 @@ impl<'a, 'tcx> TerminatorCodegenHelper<'tcx> {
342342

343343
/// Codegen implementations for some terminator variants.
344344
impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
345+
fn codegen_tail_call_terminator(
346+
&mut self,
347+
bx: &mut Bx,
348+
func: &mir::Operand<'tcx>,
349+
args: &[Spanned<mir::Operand<'tcx>>],
350+
fn_span: Span,
351+
) {
352+
// We don't need source_info as we already have fn_span for diagnostics
353+
let func = self.codegen_operand(bx, func);
354+
let fn_ty = func.layout.ty;
355+
356+
// Create the callee. This is a fn ptr or zero-sized and hence a kind of scalar.
357+
let (fn_ptr, fn_abi, instance) = match *fn_ty.kind() {
358+
ty::FnDef(def_id, substs) => {
359+
let instance = ty::Instance::expect_resolve(
360+
bx.tcx(),
361+
bx.typing_env(),
362+
def_id,
363+
substs,
364+
fn_span,
365+
);
366+
let fn_ptr = bx.get_fn_addr(instance);
367+
let fn_abi = bx.fn_abi_of_instance(instance, ty::List::empty());
368+
(fn_ptr, fn_abi, Some(instance))
369+
}
370+
ty::FnPtr(..) => {
371+
let sig = fn_ty.fn_sig(bx.tcx());
372+
let extra_args = bx.tcx().mk_type_list(&[]);
373+
let fn_ptr = func.immediate();
374+
let fn_abi = bx.fn_abi_of_fn_ptr(sig, extra_args);
375+
(fn_ptr, fn_abi, None)
376+
}
377+
_ => bug!("{} is not callable", func.layout.ty),
378+
};
379+
380+
let mut llargs = Vec::with_capacity(args.len());
381+
let mut lifetime_ends_after_call = Vec::new();
382+
383+
// Process arguments
384+
for arg in args {
385+
let op = self.codegen_operand(bx, &arg.node);
386+
let arg_idx = llargs.len();
387+
388+
if arg_idx < fn_abi.args.len() {
389+
self.codegen_argument(
390+
bx,
391+
op,
392+
&mut llargs,
393+
&fn_abi.args[arg_idx],
394+
&mut lifetime_ends_after_call,
395+
);
396+
} else {
397+
// This can happen in case of C-variadic functions
398+
let is_immediate = match op.val {
399+
Immediate(_) => true,
400+
_ => false,
401+
};
402+
403+
if is_immediate {
404+
llargs.push(op.immediate());
405+
} else {
406+
let temp = PlaceRef::alloca(bx, op.layout);
407+
op.val.store(bx, temp);
408+
llargs.push(bx.load(
409+
bx.backend_type(op.layout),
410+
temp.val.llval,
411+
temp.val.align,
412+
));
413+
}
414+
}
415+
}
416+
417+
// Call the function
418+
let fn_ty = bx.fn_decl_backend_type(fn_abi);
419+
let fn_attrs = if let Some(instance) = instance
420+
&& bx.tcx().def_kind(instance.def_id()).has_codegen_attrs()
421+
{
422+
Some(bx.tcx().codegen_fn_attrs(instance.def_id()))
423+
} else {
424+
None
425+
};
426+
427+
// Perform the actual function call
428+
let llret = bx.call(fn_ty, fn_attrs, Some(fn_abi), fn_ptr, &llargs, None, instance);
429+
430+
// Mark as tail call - this is the critical part
431+
bx.set_tail_call(llret);
432+
433+
// Return the result - musttail requires ret immediately after the call
434+
bx.ret(llret);
435+
}
345436
/// Generates code for a `Resume` terminator.
346437
fn codegen_resume_terminator(&mut self, helper: TerminatorCodegenHelper<'tcx>, bx: &mut Bx) {
347438
if let Some(funclet) = helper.funclet(self) {
@@ -1390,12 +1481,9 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
13901481
fn_span,
13911482
mergeable_succ(),
13921483
),
1393-
mir::TerminatorKind::TailCall { .. } => {
1394-
// FIXME(explicit_tail_calls): implement tail calls in ssa backend
1395-
span_bug!(
1396-
terminator.source_info.span,
1397-
"`TailCall` terminator is not yet supported by `rustc_codegen_ssa`"
1398-
)
1484+
mir::TerminatorKind::TailCall { ref func, ref args, fn_span } => {
1485+
self.codegen_tail_call_terminator(bx, func, args, fn_span);
1486+
MergingSucc::False
13991487
}
14001488
mir::TerminatorKind::CoroutineDrop | mir::TerminatorKind::Yield { .. } => {
14011489
bug!("coroutine ops in codegen")

compiler/rustc_codegen_ssa/src/traits/builder.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,10 @@ pub trait BuilderMethods<'a, 'tcx>:
595595
funclet: Option<&Self::Funclet>,
596596
instance: Option<Instance<'tcx>>,
597597
) -> Self::Value;
598+
599+
/// Mark a call instruction as a tail call (guaranteed tail call optimization)
600+
/// Used for implementing the `become` expression
601+
fn set_tail_call(&mut self, call_inst: Self::Value);
598602
fn zext(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value;
599603

600604
fn apply_attrs_to_cleanup_callsite(&mut self, llret: Self::Value);

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@
5151
//
5252
//===----------------------------------------------------------------------===
5353

54+
// Define TailCallKind enum values to match LLVM's
55+
enum LLVMRustTailCallKind {
56+
LLVMRustTailCallKindNone = 0,
57+
LLVMRustTailCallKindTail = 1,
58+
LLVMRustTailCallKindMustTail = 2,
59+
LLVMRustTailCallKindNoTail = 3
60+
};
61+
5462
using namespace llvm;
5563
using namespace llvm::sys;
5664
using namespace llvm::object;
@@ -1949,3 +1957,21 @@ extern "C" void LLVMRustSetNoSanitizeHWAddress(LLVMValueRef Global) {
19491957
MD.NoHWAddress = true;
19501958
GV.setSanitizerMetadata(MD);
19511959
}
1960+
1961+
extern "C" void LLVMRustSetTailCallKind(LLVMValueRef Call, LLVMRustTailCallKind Kind) {
1962+
CallInst *CI = unwrap<CallInst>(Call);
1963+
switch (Kind) {
1964+
case LLVMRustTailCallKindNone:
1965+
CI->setTailCallKind(CallInst::TCK_None);
1966+
break;
1967+
case LLVMRustTailCallKindTail:
1968+
CI->setTailCallKind(CallInst::TCK_Tail);
1969+
break;
1970+
case LLVMRustTailCallKindMustTail:
1971+
CI->setTailCallKind(CallInst::TCK_MustTail);
1972+
break;
1973+
case LLVMRustTailCallKindNoTail:
1974+
CI->setTailCallKind(CallInst::TCK_NoTail);
1975+
break;
1976+
}
1977+
}

tests/codegen/tail-call-become.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
//@ compile-flags: -C opt-level=0 -Cpanic=abort -C no-prepopulate-passes
2+
//@ needs-llvm-components: x86
3+
4+
#![feature(explicit_tail_calls)]
5+
#![crate_type = "lib"]
6+
7+
// CHECK-LABEL: define {{.*}}@with_tail(
8+
#[no_mangle]
9+
#[inline(never)]
10+
pub fn with_tail(n: u32) -> u32 {
11+
// CHECK: tail call {{.*}}@with_tail(
12+
if n == 0 { 0 } else { become with_tail(n - 1) }
13+
}
14+
15+
// CHECK-LABEL: define {{.*}}@no_tail(
16+
#[no_mangle]
17+
#[inline(never)]
18+
pub fn no_tail(n: u32) -> u32 {
19+
// CHECK-NOT: tail call
20+
// CHECK: call {{.*}}@no_tail(
21+
if n == 0 { 0 } else { no_tail(n - 1) }
22+
}
23+
24+
// CHECK-LABEL: define {{.*}}@even_with_tail(
25+
#[no_mangle]
26+
#[inline(never)]
27+
pub fn even_with_tail(n: u32) -> bool {
28+
// CHECK: tail call {{.*}}@odd_with_tail(
29+
match n {
30+
0 => true,
31+
_ => become odd_with_tail(n - 1),
32+
}
33+
}
34+
35+
// CHECK-LABEL: define {{.*}}@odd_with_tail(
36+
#[no_mangle]
37+
#[inline(never)]
38+
pub fn odd_with_tail(n: u32) -> bool {
39+
// CHECK: tail call {{.*}}@even_with_tail(
40+
match n {
41+
0 => false,
42+
_ => become even_with_tail(n - 1),
43+
}
44+
}
45+
46+
// CHECK-LABEL: define {{.*}}@even_no_tail(
47+
#[no_mangle]
48+
#[inline(never)]
49+
pub fn even_no_tail(n: u32) -> bool {
50+
// CHECK-NOT: tail call
51+
// CHECK: call {{.*}}@odd_no_tail(
52+
match n {
53+
0 => true,
54+
_ => odd_no_tail(n - 1),
55+
}
56+
}
57+
58+
// CHECK-LABEL: define {{.*}}@odd_no_tail(
59+
#[no_mangle]
60+
#[inline(never)]
61+
pub fn odd_no_tail(n: u32) -> bool {
62+
// CHECK-NOT: tail call
63+
// CHECK: call {{.*}}@even_no_tail(
64+
match n {
65+
0 => false,
66+
_ => even_no_tail(n - 1),
67+
}
68+
}

tests/codegen/tail-call-musttail.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//@ compile-flags: -C opt-level=0 -Cpanic=abort -C no-prepopulate-passes
2+
//@ needs-unwind
3+
4+
#![crate_type = "lib"]
5+
#![feature(explicit_tail_calls)]
6+
7+
// Ensure that explicit tail calls use musttail in LLVM
8+
9+
// CHECK-LABEL: define {{.*}}@simple_tail_call(
10+
#[no_mangle]
11+
#[inline(never)]
12+
pub fn simple_tail_call(n: i32) -> i32 {
13+
// CHECK: musttail call {{.*}}@simple_tail_call(
14+
// CHECK-NEXT: ret i32
15+
if n <= 0 {
16+
0
17+
} else {
18+
become simple_tail_call(n - 1)
19+
}
20+
}
21+
22+
// CHECK-LABEL: define {{.*}}@tail_call_with_args(
23+
#[no_mangle]
24+
#[inline(never)]
25+
pub fn tail_call_with_args(a: i32, b: i32, c: i32) -> i32 {
26+
// CHECK: musttail call {{.*}}@tail_call_with_args(
27+
// CHECK-NEXT: ret i32
28+
if a == 0 {
29+
b + c
30+
} else {
31+
become tail_call_with_args(a - 1, b + 1, c)
32+
}
33+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//@ compile-flags: -O
2+
//@ run-pass
3+
#![expect(incomplete_features)]
4+
#![feature(explicit_tail_calls)]
5+
6+
// A deep recursive function that uses explicit tail calls
7+
// This will cause stack overflow without tail call optimization
8+
fn deep_recursion(n: u32) -> u32 {
9+
match n {
10+
0 => 0,
11+
_ => become deep_recursion(n - 1)
12+
}
13+
}
14+
15+
// A deep recursive function without explicit tail calls
16+
// This will overflow the stack for large values
17+
fn deep_recursion_no_tail(n: u32) -> u32 {
18+
match n {
19+
0 => 0,
20+
_ => deep_recursion_no_tail(n - 1)
21+
}
22+
}
23+
24+
fn main() {
25+
// Verify correctness for small values
26+
assert_eq!(deep_recursion(10), 0);
27+
assert_eq!(deep_recursion_no_tail(10), 0);
28+
29+
// This will succeed only if tail call optimization is working
30+
// It would overflow the stack otherwise
31+
println!("Starting deep recursion with 50,000 calls");
32+
let result = deep_recursion(50_000);
33+
assert_eq!(result, 0);
34+
println!("Successfully completed 50,000 recursive calls with tail call optimization");
35+
}

0 commit comments

Comments
 (0)