@@ -3,13 +3,13 @@ use std::ptr;
33use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , AutoDiffItem , DiffActivity , DiffMode } ;
44use rustc_codegen_ssa:: ModuleCodegen ;
55use rustc_codegen_ssa:: common:: TypeKind ;
6- use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods ;
6+ use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods } ;
77use rustc_errors:: FatalError ;
88use rustc_middle:: bug;
99use tracing:: { debug, trace} ;
1010
1111use crate :: back:: write:: llvm_err;
12- use crate :: builder:: { SBuilder , UNNAMED } ;
12+ use crate :: builder:: { Builder , OperandRef , PlaceRef , UNNAMED } ;
1313use crate :: context:: SimpleCx ;
1414use crate :: declare:: declare_simple_fn;
1515use crate :: errors:: { AutoDiffWithoutEnable , LlvmError } ;
@@ -18,7 +18,7 @@ use crate::llvm::{Metadata, True};
1818use crate :: value:: Value ;
1919use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
2020
21- fn get_params ( fnc : & Value ) -> Vec < & Value > {
21+ fn _get_params ( fnc : & Value ) -> Vec < & Value > {
2222 let param_num = llvm:: LLVMCountParams ( fnc) as usize ;
2323 let mut fnc_args: Vec < & Value > = vec ! [ ] ;
2424 fnc_args. reserve ( param_num) ;
@@ -48,9 +48,9 @@ fn has_sret(fnc: &Value) -> bool {
4848// need to match those.
4949// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
5050// using iterators and peek()?
51- fn match_args_from_caller_to_enzyme < ' ll > (
51+ fn match_args_from_caller_to_enzyme < ' ll , ' tcx > (
5252 cx : & SimpleCx < ' ll > ,
53- builder : & SBuilder < ' ll , ' ll > ,
53+ builder : & mut Builder < ' _ , ' ll , ' tcx > ,
5454 width : u32 ,
5555 args : & mut Vec < & ' ll llvm:: Value > ,
5656 inputs : & [ DiffActivity ] ,
@@ -288,11 +288,14 @@ fn compute_enzyme_fn_ty<'ll>(
288288/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
289289// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
290290// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
291- fn generate_enzyme_call < ' ll > (
291+ pub ( crate ) fn generate_enzyme_call < ' ll , ' tcx > (
292+ builder : & mut Builder < ' _ , ' ll , ' tcx > ,
292293 cx : & SimpleCx < ' ll > ,
293294 fn_to_diff : & ' ll Value ,
294295 outer_fn : & ' ll Value ,
296+ fn_args : & [ OperandRef < ' tcx , & ' ll Value > ] ,
295297 attrs : AutoDiffAttrs ,
298+ dest : PlaceRef < ' tcx , & ' ll Value > ,
296299) {
297300 // We have to pick the name depending on whether we want forward or reverse mode autodiff.
298301 let mut ad_name: String = match attrs. mode {
@@ -365,14 +368,6 @@ fn generate_enzyme_call<'ll>(
365368 let enzyme_marker_attr = llvm:: CreateAttrString ( cx. llcx , "enzyme_marker" ) ;
366369 attributes:: apply_to_llfn ( outer_fn, Function , & [ enzyme_marker_attr] ) ;
367370
368- // first, remove all calls from fnc
369- let entry = llvm:: LLVMGetFirstBasicBlock ( outer_fn) ;
370- let br = llvm:: LLVMRustGetTerminator ( entry) ;
371- llvm:: LLVMRustEraseInstFromParent ( br) ;
372-
373- let last_inst = llvm:: LLVMRustGetLastInstruction ( entry) . unwrap ( ) ;
374- let mut builder = SBuilder :: build ( cx, entry) ;
375-
376371 let num_args = llvm:: LLVMCountParams ( & fn_to_diff) ;
377372 let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
378373 args. push ( fn_to_diff) ;
@@ -388,40 +383,20 @@ fn generate_enzyme_call<'ll>(
388383 }
389384
390385 let has_sret = has_sret ( outer_fn) ;
391- let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn ) ;
386+ let outer_args: Vec < & llvm:: Value > = fn_args . iter ( ) . map ( |op| op . immediate ( ) ) . collect ( ) ;
392387 match_args_from_caller_to_enzyme (
393388 & cx,
394- & builder,
389+ builder,
395390 attrs. width ,
396391 & mut args,
397392 & attrs. input_activity ,
398393 & outer_args,
399394 has_sret,
400395 ) ;
401396
402- let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
403-
404- // This part is a bit iffy. LLVM requires that a call to an inlineable function has some
405- // metadata attached to it, but we just created this code oota. Given that the
406- // differentiated function already has partly confusing metadata, and given that this
407- // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
408- // dummy code which we inserted at a higher level.
409- // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have,
410- // and how to best improve it for enzyme core and rust-enzyme.
411- let md_ty = cx. get_md_kind_id ( "dbg" ) ;
412- if llvm:: LLVMRustHasMetadata ( last_inst, md_ty) {
413- let md = llvm:: LLVMRustDIGetInstMetadata ( last_inst)
414- . expect ( "failed to get instruction metadata" ) ;
415- let md_todiff = cx. get_metadata_value ( md) ;
416- llvm:: LLVMSetMetadata ( call, md_ty, md_todiff) ;
417- } else {
418- // We don't panic, since depending on whether we are in debug or release mode, we might
419- // have no debug info to copy, which would then be ok.
420- trace ! ( "no dbg info" ) ;
421- }
397+ let call = builder. call ( enzyme_ty, None , None , ad_fn, & args, None , None ) ;
422398
423- // Now that we copied the metadata, get rid of dummy code.
424- llvm:: LLVMRustEraseInstUntilInclusive ( entry, last_inst) ;
399+ builder. store_to_place ( call, dest. val ) ;
425400
426401 if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
427402 if has_sret {
@@ -444,10 +419,10 @@ fn generate_enzyme_call<'ll>(
444419 llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
445420 }
446421 builder. ret_void ( ) ;
447- } else {
448- builder. ret ( call) ;
449422 }
450423
424+ builder. store_to_place ( call, dest. val ) ;
425+
451426 // Let's crash in case that we messed something up above and generated invalid IR.
452427 llvm:: LLVMRustVerifyFunction (
453428 outer_fn,
@@ -461,6 +436,7 @@ pub(crate) fn differentiate<'ll>(
461436 cgcx : & CodegenContext < LlvmCodegenBackend > ,
462437 diff_items : Vec < AutoDiffItem > ,
463438) -> Result < ( ) , FatalError > {
439+ // TODO(Sa4dUs): delete all this logic
464440 for item in & diff_items {
465441 trace ! ( "{}" , item) ;
466442 }
@@ -480,7 +456,7 @@ pub(crate) fn differentiate<'ll>(
480456 for item in diff_items. iter ( ) {
481457 let name = item. source . clone ( ) ;
482458 let fn_def: Option < & llvm:: Value > = cx. get_function ( & name) ;
483- let Some ( fn_def ) = fn_def else {
459+ let Some ( _fn_def ) = fn_def else {
484460 return Err ( llvm_err (
485461 diag_handler. handle ( ) ,
486462 LlvmError :: PrepareAutoDiff {
@@ -492,7 +468,7 @@ pub(crate) fn differentiate<'ll>(
492468 } ;
493469 debug ! ( ?item. target) ;
494470 let fn_target: Option < & llvm:: Value > = cx. get_function ( & item. target ) ;
495- let Some ( fn_target ) = fn_target else {
471+ let Some ( _fn_target ) = fn_target else {
496472 return Err ( llvm_err (
497473 diag_handler. handle ( ) ,
498474 LlvmError :: PrepareAutoDiff {
@@ -503,7 +479,7 @@ pub(crate) fn differentiate<'ll>(
503479 ) ) ;
504480 } ;
505481
506- generate_enzyme_call ( & cx, fn_def, fn_target, item. attrs . clone ( ) ) ;
482+ // generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
507483 }
508484
509485 // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
0 commit comments