@@ -14,7 +14,7 @@ use crate::context::SimpleCx;
1414use crate :: declare:: declare_simple_fn;
1515use crate :: errors:: { AutoDiffWithoutEnable , LlvmError } ;
1616use crate :: llvm:: AttributePlace :: Function ;
17- use crate :: llvm:: { Metadata , True } ;
17+ use crate :: llvm:: { Metadata , True , Type } ;
1818use crate :: value:: Value ;
1919use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
2020
@@ -29,7 +29,7 @@ fn _get_params(fnc: &Value) -> Vec<&Value> {
2929 fnc_args
3030}
3131
32- fn has_sret ( fnc : & Value ) -> bool {
32+ fn _has_sret ( fnc : & Value ) -> bool {
3333 let num_args = llvm:: LLVMCountParams ( fnc) as usize ;
3434 if num_args == 0 {
3535 false
@@ -55,7 +55,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
5555 args : & mut Vec < & ' ll llvm:: Value > ,
5656 inputs : & [ DiffActivity ] ,
5757 outer_args : & [ & ' ll llvm:: Value ] ,
58- has_sret : bool ,
5958) {
6059 debug ! ( "matching autodiff arguments" ) ;
6160 // We now handle the issue that Rust level arguments not always match the llvm-ir level
@@ -67,20 +66,12 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
6766 let mut outer_pos: usize = 0 ;
6867 let mut activity_pos = 0 ;
6968
70- if has_sret {
71- // Then the first outer arg is the sret pointer. Enzyme doesn't know about sret, so the
72- // inner function will still return something. We increase our outer_pos by one,
73- // and once we're done with all other args we will take the return of the inner call and
74- // update the sret pointer with it
75- outer_pos = 1 ;
76- }
77-
78- let enzyme_const = cx. create_metadata ( b"enzyme_const" ) ;
79- let enzyme_out = cx. create_metadata ( b"enzyme_out" ) ;
80- let enzyme_dup = cx. create_metadata ( b"enzyme_dup" ) ;
81- let enzyme_dupv = cx. create_metadata ( b"enzyme_dupv" ) ;
82- let enzyme_dupnoneed = cx. create_metadata ( b"enzyme_dupnoneed" ) ;
83- let enzyme_dupnoneedv = cx. create_metadata ( b"enzyme_dupnoneedv" ) ;
69+ let enzyme_const = cx. create_metadata ( "enzyme_const" . to_string ( ) ) . unwrap ( ) ;
70+ let enzyme_out = cx. create_metadata ( "enzyme_out" . to_string ( ) ) . unwrap ( ) ;
71+ let enzyme_dup = cx. create_metadata ( "enzyme_dup" . to_string ( ) ) . unwrap ( ) ;
72+ let enzyme_dupv = cx. create_metadata ( "enzyme_dupv" . to_string ( ) ) . unwrap ( ) ;
73+ let enzyme_dupnoneed = cx. create_metadata ( "enzyme_dupnoneed" . to_string ( ) ) . unwrap ( ) ;
74+ let enzyme_dupnoneedv = cx. create_metadata ( "enzyme_dupnoneedv" . to_string ( ) ) . unwrap ( ) ;
8475
8576 while activity_pos < inputs. len ( ) {
8677 let diff_activity = inputs[ activity_pos as usize ] ;
@@ -193,92 +184,6 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
193184 }
194185}
195186
196- // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input
197- // arguments. We do however need to declare them with their correct return type.
198- // We already figured the correct return type out in our frontend, when generating the outer_fn,
199- // so we can now just go ahead and use that. This is not always trivial, e.g. because sret.
200- // Beyond sret, this article describes our challenges nicely:
201- // <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
202- // I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
203- fn compute_enzyme_fn_ty < ' ll > (
204- cx : & SimpleCx < ' ll > ,
205- attrs : & AutoDiffAttrs ,
206- fn_to_diff : & ' ll Value ,
207- outer_fn : & ' ll Value ,
208- ) -> & ' ll llvm:: Type {
209- let fn_ty = cx. get_type_of_global ( outer_fn) ;
210- let mut ret_ty = cx. get_return_type ( fn_ty) ;
211-
212- let has_sret = has_sret ( outer_fn) ;
213-
214- if has_sret {
215- // Now we don't just forward the return type, so we have to figure it out based on the
216- // primal return type, in combination with the autodiff settings.
217- let fn_ty = cx. get_type_of_global ( fn_to_diff) ;
218- let inner_ret_ty = cx. get_return_type ( fn_ty) ;
219-
220- let void_ty = unsafe { llvm:: LLVMVoidTypeInContext ( cx. llcx ) } ;
221- if inner_ret_ty == void_ty {
222- // This indicates that even the inner function has an sret.
223- // Right now I only look for an sret in the outer function.
224- // This *probably* needs some extra handling, but I never ran
225- // into such a case. So I'll wait for user reports to have a test case.
226- bug ! ( "sret in inner function" ) ;
227- }
228-
229- if attrs. width == 1 {
230- // Enzyme returns a struct of style:
231- // `{ original_ret(if requested), float, float, ... }`
232- let mut struct_elements = vec ! [ ] ;
233- if attrs. has_primal_ret ( ) {
234- struct_elements. push ( inner_ret_ty) ;
235- }
236- // Next, we push the list of active floats, since they will be lowered to `enzyme_out`,
237- // and therefore part of the return struct.
238- let param_tys = cx. func_params_types ( fn_ty) ;
239- for ( act, param_ty) in attrs. input_activity . iter ( ) . zip ( param_tys) {
240- if matches ! ( act, DiffActivity :: Active ) {
241- // Now find the float type at position i based on the fn_ty,
242- // to know what (f16/f32/f64/...) to add to the struct.
243- struct_elements. push ( param_ty) ;
244- }
245- }
246- ret_ty = cx. type_struct ( & struct_elements, false ) ;
247- } else {
248- // First we check if we also have to deal with the primal return.
249- match attrs. mode {
250- DiffMode :: Forward => match attrs. ret_activity {
251- DiffActivity :: Dual => {
252- let arr_ty =
253- unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 + 1 ) } ;
254- ret_ty = arr_ty;
255- }
256- DiffActivity :: DualOnly => {
257- let arr_ty =
258- unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 ) } ;
259- ret_ty = arr_ty;
260- }
261- DiffActivity :: Const => {
262- todo ! ( "Not sure, do we need to do something here?" ) ;
263- }
264- _ => {
265- bug ! ( "unreachable" ) ;
266- }
267- } ,
268- DiffMode :: Reverse => {
269- todo ! ( "Handle sret for reverse mode" ) ;
270- }
271- _ => {
272- bug ! ( "unreachable" ) ;
273- }
274- }
275- }
276- }
277-
278- // LLVM can figure out the input types on it's own, so we take a shortcut here.
279- unsafe { llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) }
280- }
281-
282187/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
283188/// function with expected naming and calling conventions[^1] which will be
284189/// discovered by the enzyme LLVM pass and its body populated with the differentiated
@@ -292,7 +197,8 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
292197 builder : & mut Builder < ' _ , ' ll , ' tcx > ,
293198 cx : & SimpleCx < ' ll > ,
294199 fn_to_diff : & ' ll Value ,
295- outer_fn : & ' ll Value ,
200+ outer_name : & str ,
201+ ret_ty : & ' ll Type ,
296202 fn_args : & [ OperandRef < ' tcx , & ' ll Value > ] ,
297203 attrs : AutoDiffAttrs ,
298204 dest : PlaceRef < ' tcx , & ' ll Value > ,
@@ -305,11 +211,9 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
305211 }
306212 . to_string ( ) ;
307213
308- // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
214+ // add outer_name to ad_name to make it unique, in case users apply autodiff to multiple
309215 // functions. Unwrap will only panic, if LLVM gave us an invalid string.
310- let name = llvm:: get_value_name ( outer_fn) ;
311- let outer_fn_name = std:: str:: from_utf8 ( & name) . unwrap ( ) ;
312- ad_name. push_str ( outer_fn_name) ;
216+ ad_name. push_str ( outer_name) ;
313217
314218 // Let us assume the user wrote the following function square:
315219 //
@@ -320,13 +224,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
320224 // ret double %0
321225 // }
322226 // ```
323- //
324- // The user now applies autodiff to the function square, in which case fn_to_diff will be `square`.
325- // Our macro generates the following placeholder code (slightly simplified):
326- //
327- // ```llvm
328227 // define double @dsquare(double %x) {
329- // ; placeholder code
330228 // return 0.0;
331229 // }
332230 // ```
@@ -343,92 +241,54 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
343241 // ret double %0
344242 // }
345243 // ```
346- unsafe {
347- let enzyme_ty = compute_enzyme_fn_ty ( cx, & attrs, fn_to_diff, outer_fn) ;
348-
349- // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
350- // think a bit more about what should go here.
351- let cc = llvm:: LLVMGetFunctionCallConv ( outer_fn) ;
352- let ad_fn = declare_simple_fn (
353- cx,
354- & ad_name,
355- llvm:: CallConv :: try_from ( cc) . expect ( "invalid callconv" ) ,
356- llvm:: UnnamedAddr :: No ,
357- llvm:: Visibility :: Default ,
358- enzyme_ty,
359- ) ;
360-
361- // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
362- // do it's work.
363- let attr = llvm:: AttributeKind :: NoInline . create_attr ( cx. llcx ) ;
364- attributes:: apply_to_llfn ( ad_fn, Function , & [ attr] ) ;
365-
366- // We add a made-up attribute just such that we can recognize it after AD to update
367- // (no)-inline attributes. We'll then also remove this attribute.
368- let enzyme_marker_attr = llvm:: CreateAttrString ( cx. llcx , "enzyme_marker" ) ;
369- attributes:: apply_to_llfn ( outer_fn, Function , & [ enzyme_marker_attr] ) ;
370-
371- let num_args = llvm:: LLVMCountParams ( & fn_to_diff) ;
372- let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
373- args. push ( fn_to_diff) ;
374-
375- let enzyme_primal_ret = cx. create_metadata ( b"enzyme_primal_return" ) ;
376- if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
377- args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
378- }
379- if attrs. width > 1 {
380- let enzyme_width = cx. create_metadata ( b"enzyme_width" ) ;
381- args. push ( cx. get_metadata_value ( enzyme_width) ) ;
382- args. push ( cx. get_const_int ( cx. type_i64 ( ) , attrs. width as u64 ) ) ;
383- }
384-
385- let has_sret = has_sret ( outer_fn) ;
386- let outer_args: Vec < & llvm:: Value > = fn_args. iter ( ) . map ( |op| op. immediate ( ) ) . collect ( ) ;
387- match_args_from_caller_to_enzyme (
388- & cx,
389- builder,
390- attrs. width ,
391- & mut args,
392- & attrs. input_activity ,
393- & outer_args,
394- has_sret,
395- ) ;
396-
397- let call = builder. call ( enzyme_ty, None , None , ad_fn, & args, None , None ) ;
398-
399- builder. store_to_place ( call, dest. val ) ;
244+ let enzyme_ty = unsafe { llvm:: LLVMFunctionType ( ret_ty, ptr:: null ( ) , 0 , True ) } ;
245+
246+ // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
247+ // think a bit more about what should go here.
248+ // FIXME(Sa4dUs): have to find a way to get the cc, using `FastCallConv` for now
249+ let cc = 8 ;
250+ let ad_fn = declare_simple_fn (
251+ cx,
252+ & ad_name,
253+ llvm:: CallConv :: try_from ( cc) . expect ( "invalid callconv" ) ,
254+ llvm:: UnnamedAddr :: No ,
255+ llvm:: Visibility :: Default ,
256+ enzyme_ty,
257+ ) ;
258+
259+ // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to
260+ // do it's work.
261+ let attr = llvm:: AttributeKind :: NoInline . create_attr ( cx. llcx ) ;
262+ attributes:: apply_to_llfn ( ad_fn, Function , & [ attr] ) ;
263+
264+ let num_args = llvm:: LLVMCountParams ( & fn_to_diff) ;
265+ let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
266+ args. push ( fn_to_diff) ;
267+
268+ let enzyme_primal_ret = cx. create_metadata ( "enzyme_primal_return" . to_string ( ) ) . unwrap ( ) ;
269+ if matches ! ( attrs. ret_activity, DiffActivity :: Dual | DiffActivity :: Active ) {
270+ args. push ( cx. get_metadata_value ( enzyme_primal_ret) ) ;
271+ }
272+ if attrs. width > 1 {
273+ let enzyme_width = cx. create_metadata ( "enzyme_width" . to_string ( ) ) . unwrap ( ) ;
274+ args. push ( cx. get_metadata_value ( enzyme_width) ) ;
275+ args. push ( cx. get_const_int ( cx. type_i64 ( ) , attrs. width as u64 ) ) ;
276+ }
400277
401- if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
402- if has_sret {
403- // This is what we already have in our outer_fn (shortened):
404- // define void @_foo(ptr <..> sret([32 x i8]) initializes((0, 32)) %0, <...>) {
405- // %7 = call [4 x double] (...) @__enzyme_fwddiff_foo(ptr @square, metadata !"enzyme_width", i64 4, <...>)
406- // <Here we are, we want to add the following two lines>
407- // store [4 x double] %7, ptr %0, align 8
408- // ret void
409- // }
278+ let outer_args: Vec < & llvm:: Value > = fn_args. iter ( ) . map ( |op| op. immediate ( ) ) . collect ( ) ;
410279
411- // now store the result of the enzyme call into the sret pointer.
412- let sret_ptr = outer_args[ 0 ] ;
413- let call_ty = cx. val_ty ( call) ;
414- if attrs. width == 1 {
415- assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Struct ) ;
416- } else {
417- assert_eq ! ( cx. type_kind( call_ty) , TypeKind :: Array ) ;
418- }
419- llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
420- }
421- builder. ret_void ( ) ;
422- }
280+ match_args_from_caller_to_enzyme (
281+ & cx,
282+ builder,
283+ attrs. width ,
284+ & mut args,
285+ & attrs. input_activity ,
286+ & outer_args,
287+ ) ;
423288
424- builder. store_to_place ( call, dest . val ) ;
289+ let call = builder. call ( enzyme_ty , None , None , ad_fn , & args , None , None ) ;
425290
426- // Let's crash in case that we messed something up above and generated invalid IR.
427- llvm:: LLVMRustVerifyFunction (
428- outer_fn,
429- llvm:: LLVMRustVerifierFailureAction :: LLVMAbortProcessAction ,
430- ) ;
431- }
291+ builder. store_to_place ( call, dest. val ) ;
432292}
433293
434294pub ( crate ) fn differentiate < ' ll > (
0 commit comments