Skip to content

add autodiff inline #139308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion compiler/rustc_codegen_llvm/src/attributes.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//! Set and unset common attributes on LLVM values.
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
use rustc_codegen_ssa::traits::*;
use rustc_hir::def_id::DefId;
Expand Down Expand Up @@ -28,6 +27,22 @@ pub(crate) fn apply_to_callsite(callsite: &Value, idx: AttributePlace, attrs: &[
}
}

pub(crate) fn has_attr(llfn: &Value, idx: AttributePlace, attr: AttributeKind) -> bool {
llvm::HasAttributeAtIndex(llfn, idx, attr)
}

pub(crate) fn has_string_attr(llfn: &Value, name: &str) -> bool {
llvm::HasStringAttribute(llfn, name)
}

pub(crate) fn remove_from_llfn(llfn: &Value, place: AttributePlace, kind: AttributeKind) {
llvm::RemoveRustEnumAttributeAtIndex(llfn, place, kind);
}

pub(crate) fn remove_string_attr_from_llfn(llfn: &Value, name: &str) {
llvm::RemoveStringAttrFromFn(llfn, name);
}

/// Get LLVM attribute for the provided inline heuristic.
#[inline]
fn inline_attr<'ll>(cx: &CodegenCx<'ll, '_>, inline: InlineAttr) -> Option<&'ll Attribute> {
Expand Down
28 changes: 27 additions & 1 deletion compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ use crate::back::write::{
use crate::errors::{
DynamicLinkingWithLTO, LlvmError, LtoBitcodeFromRlib, LtoDisallowed, LtoDylib, LtoProcMacro,
};
use crate::llvm::AttributePlace::Function;
use crate::llvm::{self, build_string};
use crate::{LlvmCodegenBackend, ModuleLlvm};
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes};

/// We keep track of the computed LTO cache keys from the previous
/// session to determine which CGUs we can reuse.
Expand Down Expand Up @@ -666,6 +667,31 @@ pub(crate) fn run_pass_manager(
}

if cfg!(llvm_enzyme) && enable_ad && !thin {
let cx =
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);

for function in cx.get_functions() {
let enzyme_marker = "enzyme_marker";
if attributes::has_string_attr(function, enzyme_marker) {
// Sanity check: Ensure 'noinline' is present before replacing it.
assert!(
!attributes::has_attr(function, Function, llvm::AttributeKind::NoInline),
"Expected __enzyme function to have 'noinline' before adding 'alwaysinline'"
);

attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline);
attributes::remove_string_attr_from_llfn(function, enzyme_marker);

assert!(
!attributes::has_string_attr(function, enzyme_marker),
"Expected function to not have 'enzyme_marker'"
);

let always_inline = llvm::AttributeKind::AlwaysInline.create_attr(cx.llcx);
attributes::apply_to_llfn(function, Function, &[always_inline]);
}
}

let opt_stage = llvm::OptStage::FatLTO;
let stage = write::AutodiffStage::PostAD;
if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,11 @@ fn generate_enzyme_call<'ll>(
let attr = llvm::AttributeKind::NoInline.create_attr(cx.llcx);
attributes::apply_to_llfn(ad_fn, Function, &[attr]);

// We add a made-up attribute just such that we can recognize it after AD to update
// (no)-inline attributes. We'll then also remove this attribute.
let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker");
attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]);

// first, remove all calls from fnc
let entry = llvm::LLVMGetFirstBasicBlock(outer_fn);
let br = llvm::LLVMRustGetTerminator(entry);
Expand Down
10 changes: 10 additions & 0 deletions compiler/rustc_codegen_llvm/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,16 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
llvm::LLVMMDStringInContext2(self.llcx(), name.as_ptr() as *const c_char, name.len())
})
}

pub(crate) fn get_functions(&self) -> Vec<&'ll Value> {
let mut functions = vec![];
let mut func = unsafe { llvm::LLVMGetFirstFunction(self.llmod()) };
while let Some(f) = func {
functions.push(f);
func = unsafe { llvm::LLVMGetNextFunction(f) }
}
functions
}
}

impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {
Expand Down
13 changes: 13 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ unsafe extern "C" {
pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
pub(crate) fn LLVMRustHasFnAttribute(
F: &Value,
Name: *const c_char,
NameLen: libc::size_t,
) -> bool;
pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char, NameLen: libc::size_t);
pub(crate) fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>;
pub(crate) fn LLVMGetNextFunction(Fn: &Value) -> Option<&Value>;
pub(crate) fn LLVMRustRemoveEnumAttributeAtIndex(
Fn: &Value,
index: c_uint,
kind: AttributeKind,
);
}

unsafe extern "C" {
Expand Down
26 changes: 26 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,32 @@ pub(crate) fn AddFunctionAttributes<'ll>(
}
}

pub(crate) fn HasAttributeAtIndex<'ll>(
llfn: &'ll Value,
idx: AttributePlace,
kind: AttributeKind,
) -> bool {
unsafe { LLVMRustHasAttributeAtIndex(llfn, idx.as_uint(), kind) }
}

pub(crate) fn HasStringAttribute<'ll>(llfn: &'ll Value, name: &str) -> bool {
unsafe { LLVMRustHasFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
}

pub(crate) fn RemoveStringAttrFromFn<'ll>(llfn: &'ll Value, name: &str) {
unsafe { LLVMRustRemoveFnAttribute(llfn, name.as_c_char_ptr(), name.len()) }
}

pub(crate) fn RemoveRustEnumAttributeAtIndex(
llfn: &Value,
place: AttributePlace,
kind: AttributeKind,
) {
unsafe {
LLVMRustRemoveEnumAttributeAtIndex(llfn, place.as_uint(), kind);
}
}

pub(crate) fn AddCallSiteAttributes<'ll>(
callsite: &'ll Value,
idx: AttributePlace,
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_codegen_llvm/src/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
(**self).borrow().llcx
}

pub(crate) fn llmod(&self) -> &'ll llvm::Module {
(**self).borrow().llmod
}

pub(crate) fn isize_ty(&self) -> &'ll Type {
(**self).borrow().isize_ty
}
Expand Down
21 changes: 21 additions & 0 deletions compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,27 @@ extern "C" LLVMMetadataRef LLVMRustDIGetInstMetadata(LLVMValueRef x) {
return nullptr;
}

extern "C" void
LLVMRustRemoveEnumAttributeAtIndex(LLVMValueRef F, size_t index,
LLVMRustAttributeKind RustAttr) {
LLVMRemoveEnumAttributeAtIndex(F, index, fromRust(RustAttr));
}

extern "C" bool LLVMRustHasFnAttribute(LLVMValueRef F, const char *Name,
size_t NameLen) {
if (auto *Fn = dyn_cast<Function>(unwrap<Value>(F))) {
return Fn->hasFnAttribute(StringRef(Name, NameLen));
}
return false;
}

extern "C" void LLVMRustRemoveFnAttribute(LLVMValueRef Fn, const char *Name,
size_t NameLen) {
if (auto *F = dyn_cast<Function>(unwrap<Value>(Fn))) {
F->removeFnAttr(StringRef(Name, NameLen));
}
}

extern "C" void LLVMRustGlobalAddMetadata(LLVMValueRef Global, unsigned Kind,
LLVMMetadataRef MD) {
unwrap<GlobalObject>(Global)->addMetadata(Kind, *unwrap<MDNode>(MD));
Expand Down
23 changes: 23 additions & 0 deletions tests/codegen/autodiff/inline.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat -Zautodiff=NoPostopt
//@ no-prefer-dynamic
//@ needs-enzyme

#![feature(autodiff)]

use std::autodiff::autodiff;

#[autodiff(d_square, Reverse, Duplicated, Active)]
fn square(x: &f64) -> f64 {
x * x
}

// CHECK: ; inline::d_square
// CHECK-NEXT: ; Function Attrs: alwaysinline
// CHECK-NOT: noinline
// CHECK-NEXT: define internal fastcc void @_ZN6inline8d_square17h021c74e92c259cdeE
fn main() {
let x = std::hint::black_box(3.0);
let mut dx1 = std::hint::black_box(1.0);
let _ = d_square(&x, &mut dx1, 1.0);
assert_eq!(dx1, 6.0);
}
Loading