diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 71705ecb4d0f5..eab43c10ffdcd 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -1,7 +1,6 @@ use std::ptr; use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode}; -use rustc_codegen_ssa::ModuleCodegen; use rustc_codegen_ssa::back::write::ModuleConfig; use rustc_codegen_ssa::traits::BaseTypeCodegenMethods as _; use rustc_errors::FatalError; @@ -277,7 +276,7 @@ fn generate_enzyme_call<'ll>( } pub(crate) fn differentiate<'ll>( - module: &'ll ModuleCodegen, + module_llvm: &'ll ModuleLlvm, cgcx: &CodegenContext, diff_items: Vec, _config: &ModuleConfig, @@ -287,8 +286,7 @@ pub(crate) fn differentiate<'ll>( } let diag_handler = cgcx.create_dcx(); - - let cx = SimpleCx::new(module.module_llvm.llmod(), module.module_llvm.llcx, cgcx.pointer_size); + let cx = SimpleCx { llmod: module_llvm.llmod(), llcx: module_llvm.llcx }; // First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag? if !diff_items.is_empty() diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index 4c5a78ca74fe4..55870e386035a 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -90,9 +90,9 @@ impl Diagnostic<'_, G> for ParseTargetMachineConfig<'_> { } } -#[derive(Diagnostic)] -#[diag(codegen_llvm_autodiff_without_lto)] -pub(crate) struct AutoDiffWithoutLTO; +// #[derive(Diagnostic)] +// #[diag(codegen_llvm_autodiff_without_lto)] +// pub(crate) struct AutoDiffWithoutLTO; #[derive(Diagnostic)] #[diag(codegen_llvm_autodiff_without_enable)] diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index f622646a5d9d0..5eda88ed04f16 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -27,8 +27,7 @@ use std::mem::ManuallyDrop; use back::owned_target_machine::OwnedTargetMachine; use back::write::{create_informational_target_machine, create_target_machine}; -use context::SimpleCx; -use errors::{AutoDiffWithoutLTO, ParseTargetMachineConfig}; +use errors::ParseTargetMachineConfig; pub(crate) use llvm_util::target_features_cfg; use rustc_ast::expand::allocator::AllocatorKind; use rustc_ast::expand::autodiff_attrs::AutoDiffItem; @@ -45,7 +44,7 @@ use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::ty::TyCtxt; use rustc_middle::util::Providers; use rustc_session::Session; -use rustc_session::config::{Lto, OptLevel, OutputFilenames, PrintKind, PrintRequest}; +use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; use rustc_span::Symbol; mod back { @@ -238,11 +237,32 @@ impl WriteBackendMethods for LlvmCodegenBackend { diff_fncs: Vec, config: &ModuleConfig, ) -> Result<(), FatalError> { - if cgcx.lto != Lto::Fat { - let dcx = cgcx.create_dcx(); - return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO)); - } - builder::autodiff::differentiate(module, cgcx, diff_fncs, config) + //if cgcx.lto != Lto::Fat { + // let dcx = cgcx.create_dcx(); + // return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO)); + //} + let module_llvm = &module.module_llvm; + builder::autodiff::differentiate(module_llvm, cgcx, diff_fncs, config) + } + fn autodiff_thin( + cgcx: &CodegenContext, + thin_module: &ThinModule, + diff_fncs: Vec, + config: &ModuleConfig, + ) -> Result<(), FatalError> { + let dcx = cgcx.create_dcx(); + let dcx = dcx.handle(); + + let module_name = &thin_module.shared.module_names[thin_module.idx]; + + // Right now the implementation we've got only works over serialized + // modules, so we create a fresh new LLVM context and parse the module + // into that context. One day, however, we may do this for upstream + // crates but for locally codegened modules we may be able to reuse + // that LLVM Context and Module. + let module_llvm = ModuleLlvm::parse(cgcx, module_name, thin_module.data(), dcx)?; + + builder::autodiff::differentiate(&module_llvm, cgcx, diff_fncs, config) } } diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index 9fd984b6419ee..97fe21ad4d552 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -94,10 +94,11 @@ impl LtoModuleCodegen { match &self { LtoModuleCodegen::Fat(module) => { B::autodiff(cgcx, &module, diff_fncs, config)?; + }, + LtoModuleCodegen::Thin(module) => { + B::autodiff_thin(cgcx, module, diff_fncs, config)?; } - _ => panic!("autodiff called with non-fat LTO module"), } - Ok(self) } } diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index 9cc737d194ce7..bb8380bb3d55d 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -42,7 +42,7 @@ use tracing::debug; use super::link::{self, ensure_removed}; use super::lto::{self, SerializedModule}; use super::symbol_export::symbol_name_for_instance_in_crate; -use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir}; +use crate::errors::ErrorCreatingRemarkDir; use crate::traits::*; use crate::{ CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind, @@ -418,15 +418,18 @@ fn generate_lto_work( vec![(WorkItem::LTO(module), 0)] } else { if !autodiff.is_empty() { - let dcx = cgcx.create_dcx(); - dcx.handle().emit_fatal(AutodiffWithoutLto {}); + //let dcx = cgcx.create_dcx(); + //dcx.handle().emit_fatal(AutodiffWithoutLto {}); } + let config = cgcx.config(ModuleKind::Regular); assert!(needs_fat_lto.is_empty()); let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules) .unwrap_or_else(|e| e.raise()); lto_modules .into_iter() .map(|module| { + let module = + unsafe { module.autodiff(cgcx, autodiff.clone(), config).unwrap_or_else(|e| e.raise()) }; let cost = module.cost(); (WorkItem::LTO(module), cost) }) @@ -1466,6 +1469,7 @@ fn start_executing_work( if needs_fat_lto.is_empty() && needs_thin_lto.is_empty() && lto_import_only_modules.is_empty() + && autodiff_items.is_empty() { // Nothing more to do! break; @@ -1479,13 +1483,14 @@ fn start_executing_work( assert!(!started_lto); started_lto = true; + let autodiff_items = mem::take(&mut autodiff_items); let needs_fat_lto = mem::take(&mut needs_fat_lto); let needs_thin_lto = mem::take(&mut needs_thin_lto); let import_only_modules = mem::take(&mut lto_import_only_modules); for (work, cost) in generate_lto_work( &cgcx, - autodiff_items.clone(), + autodiff_items, needs_fat_lto, needs_thin_lto, import_only_modules, diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index c77efdd172876..9fa07f352efc5 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -68,6 +68,12 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { diff_fncs: Vec, config: &ModuleConfig, ) -> Result<(), FatalError>; + fn autodiff_thin( + cgcx: &CodegenContext, + thin: &ThinModule, + diff_fncs: Vec, + config: &ModuleConfig, + ) -> Result<(), FatalError>; } pub trait ThinBufferMethods: Send + Sync {