Skip to content

Commit 8697dc4

Browse files
authored
Unrolled build for #149271
Rollup merge of #149271 - sgasho:enzyme-dlopen, r=bjorn3 feat: dlopen Enzyme related issue: #145899 related pr: #146623 This PR is a continuation of #146623 I refactored some code for #146623 and added the functions shown in #144197 r? ````@bjorn3```` cc: ````@ZuseZ4```` Zulip link: https://rust-lang.zulipchat.com/#narrow/channel/182449-t-compiler.2Fhelp/topic/libload.20.2F.20dlopen.20Enzyme.2Fautodiff/near/553647912
2 parents cec7008 + 58aeab5 commit 8697dc4

File tree

12 files changed

+527
-243
lines changed

12 files changed

+527
-243
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3613,6 +3613,7 @@ dependencies = [
36133613
"gimli 0.31.1",
36143614
"itertools",
36153615
"libc",
3616+
"libloading 0.9.0",
36163617
"measureme",
36173618
"object 0.37.3",
36183619
"rustc-demangle",

compiler/rustc_codegen_llvm/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ bitflags = "2.4.1"
1414
gimli = "0.31"
1515
itertools = "0.12"
1616
libc = "0.2"
17+
libloading = { version = "0.9.0", optional = true }
1718
measureme = "12.0.1"
1819
object = { version = "0.37.0", default-features = false, features = ["std", "read"] }
1920
rustc-demangle = "0.1.21"
@@ -46,7 +47,7 @@ tracing = "0.1"
4647
[features]
4748
# tidy-alphabetical-start
4849
check_only = ["rustc_llvm/check_only"]
49-
llvm_enzyme = []
50+
llvm_enzyme = ["dep:libloading"]
5051
llvm_offload = []
5152
# tidy-alphabetical-end
5253

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -528,31 +528,34 @@ fn thin_lto(
528528
}
529529
}
530530

531-
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
531+
#[cfg(feature = "llvm_enzyme")]
532+
pub(crate) fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
533+
let mut enzyme = llvm::EnzymeWrapper::get_instance();
534+
532535
for val in ad {
533536
// We intentionally don't use a wildcard, to not forget handling anything new.
534537
match val {
535538
config::AutoDiff::PrintPerf => {
536-
llvm::set_print_perf(true);
539+
enzyme.set_print_perf(true);
537540
}
538541
config::AutoDiff::PrintAA => {
539-
llvm::set_print_activity(true);
542+
enzyme.set_print_activity(true);
540543
}
541544
config::AutoDiff::PrintTA => {
542-
llvm::set_print_type(true);
545+
enzyme.set_print_type(true);
543546
}
544547
config::AutoDiff::PrintTAFn(fun) => {
545-
llvm::set_print_type(true); // Enable general type printing
546-
llvm::set_print_type_fun(&fun); // Set specific function to analyze
548+
enzyme.set_print_type(true); // Enable general type printing
549+
enzyme.set_print_type_fun(&fun); // Set specific function to analyze
547550
}
548551
config::AutoDiff::Inline => {
549-
llvm::set_inline(true);
552+
enzyme.set_inline(true);
550553
}
551554
config::AutoDiff::LooseTypes => {
552-
llvm::set_loose_types(true);
555+
enzyme.set_loose_types(true);
553556
}
554557
config::AutoDiff::PrintSteps => {
555-
llvm::set_print(true);
558+
enzyme.set_print(true);
556559
}
557560
// We handle this in the PassWrapper.cpp
558561
config::AutoDiff::PrintPasses => {}
@@ -571,9 +574,9 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
571574
}
572575
}
573576
// This helps with handling enums for now.
574-
llvm::set_strict_aliasing(false);
577+
enzyme.set_strict_aliasing(false);
575578
// FIXME(ZuseZ4): Test this, since it was added a long time ago.
576-
llvm::set_rust_rules(true);
579+
enzyme.set_rust_rules(true);
577580
}
578581

579582
pub(crate) fn run_pass_manager(
@@ -607,10 +610,6 @@ pub(crate) fn run_pass_manager(
607610
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }
608611
};
609612

610-
if enable_ad {
611-
enable_autodiff_settings(&config.autodiff);
612-
}
613-
614613
unsafe {
615614
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage);
616615
}

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,13 @@ pub(crate) unsafe fn llvm_optimize(
730730

731731
let llvm_plugins = config.llvm_plugins.join(",");
732732

733+
let enzyme_fn = if consider_ad {
734+
let wrapper = llvm::EnzymeWrapper::get_instance();
735+
wrapper.registerEnzymeAndPassPipeline
736+
} else {
737+
std::ptr::null()
738+
};
739+
733740
let result = unsafe {
734741
llvm::LLVMRustOptimize(
735742
module.module_llvm.llmod(),
@@ -749,7 +756,7 @@ pub(crate) unsafe fn llvm_optimize(
749756
vectorize_loop,
750757
config.no_builtins,
751758
config.emit_lifetime_markers,
752-
run_enzyme,
759+
enzyme_fn,
753760
print_before_enzyme,
754761
print_after_enzyme,
755762
print_passes,

compiler/rustc_codegen_llvm/src/lib.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,17 @@ impl CodegenBackend for LlvmCodegenBackend {
240240

241241
fn init(&self, sess: &Session) {
242242
llvm_util::init(sess); // Make sure llvm is inited
243+
244+
#[cfg(feature = "llvm_enzyme")]
245+
{
246+
use rustc_session::config::AutoDiff;
247+
248+
use crate::back::lto::enable_autodiff_settings;
249+
if sess.opts.unstable_opts.autodiff.contains(&AutoDiff::Enable) {
250+
drop(llvm::EnzymeWrapper::get_or_init(&sess.opts.sysroot));
251+
enable_autodiff_settings(&sess.opts.unstable_opts.autodiff);
252+
}
253+
}
243254
}
244255

245256
fn provide(&self, providers: &mut Providers) {

0 commit comments

Comments
 (0)