66use std:: fmt:: { self , Display , Formatter } ;
77use std:: str:: FromStr ;
88
9- use crate :: expand:: typetree:: TypeTree ;
9+ use rustc_span:: { Symbol , sym} ;
10+
1011use crate :: expand:: { Decodable , Encodable , HashStable_Generic } ;
1112use crate :: { Ty , TyKind } ;
1213
@@ -31,6 +32,12 @@ pub enum DiffMode {
3132 Reverse ,
3233}
3334
35+ impl DiffMode {
36+ pub fn all_modes ( ) -> & ' static [ Symbol ] {
37+ & [ sym:: Source , sym:: Forward , sym:: Reverse ]
38+ }
39+ }
40+
3441/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
3542/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
3643/// we add to the previous shadow value. To not surprise users, we picked different names.
@@ -76,43 +83,20 @@ impl DiffActivity {
7683 use DiffActivity :: * ;
7784 matches ! ( self , |Dual | DualOnly | Dualv | DualvOnly | Const )
7885 }
79- }
80- /// We generate one of these structs for each `#[autodiff(...)]` attribute.
81- #[ derive( Clone , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
82- pub struct AutoDiffItem {
83- /// The name of the function getting differentiated
84- pub source : String ,
85- /// The name of the function being generated
86- pub target : String ,
87- pub attrs : AutoDiffAttrs ,
88- pub inputs : Vec < TypeTree > ,
89- pub output : TypeTree ,
90- }
9186
92- #[ derive( Clone , Eq , PartialEq , Encodable , Decodable , Debug , HashStable_Generic ) ]
93- pub struct AutoDiffAttrs {
94- /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
95- /// e.g. in the [JAX
96- /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
97- pub mode : DiffMode ,
98- /// A user-provided, batching width. If not given, we will default to 1 (no batching).
99- /// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
100- /// - Calling the function 50 times with a batch size of 2
101- /// - Calling the function 25 times with a batch size of 4,
102- /// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
103- /// cache locality, better re-usal of primal values, and other optimizations.
104- /// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
105- /// times, so this massively increases code size. As such, values like 1024 are unlikely to
106- /// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
107- /// experiments for now and focus on documenting the implications of a large width.
108- pub width : u32 ,
109- pub ret_activity : DiffActivity ,
110- pub input_activity : Vec < DiffActivity > ,
111- }
112-
113- impl AutoDiffAttrs {
114- pub fn has_primal_ret ( & self ) -> bool {
115- matches ! ( self . ret_activity, DiffActivity :: Active | DiffActivity :: Dual )
87+ pub fn all_activities ( ) -> & ' static [ Symbol ] {
88+ & [
89+ sym:: None ,
90+ sym:: Active ,
91+ sym:: ActiveOnly ,
92+ sym:: Const ,
93+ sym:: Dual ,
94+ sym:: Dualv ,
95+ sym:: DualOnly ,
96+ sym:: DualvOnly ,
97+ sym:: Duplicated ,
98+ sym:: DuplicatedOnly ,
99+ ]
116100 }
117101}
118102
@@ -241,59 +225,3 @@ impl FromStr for DiffActivity {
241225 }
242226 }
243227}
244-
245- impl AutoDiffAttrs {
246- pub fn has_ret_activity ( & self ) -> bool {
247- self . ret_activity != DiffActivity :: None
248- }
249- pub fn has_active_only_ret ( & self ) -> bool {
250- self . ret_activity == DiffActivity :: ActiveOnly
251- }
252-
253- pub const fn error ( ) -> Self {
254- AutoDiffAttrs {
255- mode : DiffMode :: Error ,
256- width : 0 ,
257- ret_activity : DiffActivity :: None ,
258- input_activity : Vec :: new ( ) ,
259- }
260- }
261- pub fn source ( ) -> Self {
262- AutoDiffAttrs {
263- mode : DiffMode :: Source ,
264- width : 0 ,
265- ret_activity : DiffActivity :: None ,
266- input_activity : Vec :: new ( ) ,
267- }
268- }
269-
270- pub fn is_active ( & self ) -> bool {
271- self . mode != DiffMode :: Error
272- }
273-
274- pub fn is_source ( & self ) -> bool {
275- self . mode == DiffMode :: Source
276- }
277- pub fn apply_autodiff ( & self ) -> bool {
278- !matches ! ( self . mode, DiffMode :: Error | DiffMode :: Source )
279- }
280-
281- pub fn into_item (
282- self ,
283- source : String ,
284- target : String ,
285- inputs : Vec < TypeTree > ,
286- output : TypeTree ,
287- ) -> AutoDiffItem {
288- AutoDiffItem { source, target, inputs, output, attrs : self }
289- }
290- }
291-
292- impl fmt:: Display for AutoDiffItem {
293- fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
294- write ! ( f, "Differentiating {} -> {}" , self . source, self . target) ?;
295- write ! ( f, " with attributes: {:?}" , self . attrs) ?;
296- write ! ( f, " with inputs: {:?}" , self . inputs) ?;
297- write ! ( f, " with output: {:?}" , self . output)
298- }
299- }
0 commit comments