3030import linecache
3131import textwrap
3232import types
33- from typing import Any , Callable , Optional , Type , cast
33+ import typing
34+ from typing import Any , Callable , Dict , Generic , Optional , Type , TypeVar , cast , overload
3435
3536from fiddle ._src import arg_factory
3637from fiddle ._src import building
4950_ATTR_SAVE_TEMP_VAR_ID = '_attr_save_temp'
5051_CLOSURE_WRAPPER_ID = '__auto_config_closure_wrapper__'
5152_EMPTY_ARGUMENTS = ast .arguments (
52- posonlyargs = [], args = [], kwonlyargs = [], kw_defaults = [], defaults = [])
53+ posonlyargs = [], args = [], kwonlyargs = [], kw_defaults = [], defaults = []
54+ )
5355_BUILTINS = frozenset ([
54- builtin for builtin in builtins .__dict__ .values ()
56+ builtin
57+ for builtin in builtins .__dict__ .values ()
5558 if inspect .isroutine (builtin ) or inspect .isclass (builtin )
5659])
5760
5861
62+ _GenericCallable = TypeVar ('_GenericCallable' , bound = Callable [..., Any ])
63+ T = TypeVar ('T' )
64+
65+
5966@dataclasses .dataclass (frozen = True )
60- class AutoConfig :
67+ class AutoConfig ( Generic [ T ]) :
6168 """A function wrapper for auto_config'd functions.
6269
6370 In order to support auto_config'ing @classmethod's, we need to customize the
6471 descriptor protocol for the auto_config'd function. This simple wrapper type
65- is designed to look like a simple `functool.wraps` wrapper, but implements
66- custom behavior for bound methods.
72+ is designed to look like a `functool.wraps` wrapper, but implements custom
73+ behavior for bound methods.
6774 """
68- func : Callable [..., Any ]
69- buildable_func : Callable [..., config .Buildable ]
75+
76+ func : T
77+ buildable_func : Callable [..., Any ]
7078 always_inline : bool
7179
7280 @property
7381 def nowrap (self ):
7482 return True # Tells Flax not to decorate this object, for classmethods.
7583
7684 def __post_init__ (self ):
77- # Must copy-over to correctly implement "functools.wraps"-like
78- # functionality.
79- for name in ('__module__' , '__name__' , '__qualname__' , '__doc__' ,
80- '__annotations__' ):
85+ # These attributes must be copied over to in order to correctly implement
86+ # "functools.wraps"-like functionality.
87+ for name in (
88+ '__module__' ,
89+ '__name__' ,
90+ '__qualname__' ,
91+ '__doc__' ,
92+ '__annotations__' ,
93+ ):
8194 try :
8295 value = getattr (self .func , name )
8396 except AttributeError :
8497 pass
8598 else :
8699 object .__setattr__ (self , name , value )
87100
88- def __call__ (self , * args , ** kwargs ) -> Any :
89- return self .func (* args , ** kwargs )
101+ if typing .TYPE_CHECKING :
102+ __module__ : str
103+ __name__ : str
104+ __qualname__ : str
105+ __doc__ : str
106+ __annotations__ : Dict [str , Any ]
107+ # The following informs type checkers that the call method has the same
108+ # signature/annotations as `func`.
109+ __call__ : T
110+ else :
111+ # Actual implementation of __call__, which forwards parameters to `func`.
112+ def __call__ (self , * args , ** kwargs ) -> Any :
113+ return self .func (* args , ** kwargs )
90114
91- def as_buildable (self , * args , ** kwargs ) -> config . Buildable :
115+ def as_buildable (self , * args , ** kwargs ) -> Any :
92116 return self .buildable_func (* args , ** kwargs )
93117
94118 def __get__ (self , obj , objtype = None ):
95119 # pytype: disable=attribute-error
96120 return AutoConfig (
97121 func = self .func .__get__ (obj , objtype ),
98122 buildable_func = self .buildable_func .__get__ (obj , objtype ),
99- always_inline = self .always_inline )
123+ always_inline = self .always_inline ,
124+ )
100125 # pytype: enable=attribute-error
101126
102127 @property
@@ -116,11 +141,13 @@ class UnsupportedLanguageConstructError(SyntaxError):
116141class _AutoConfigNodeTransformer (ast .NodeTransformer ):
117142 """A NodeTransformer that adds the auto-config call handler into an AST."""
118143
119- def __init__ (self ,
120- source : str ,
121- filename : str ,
122- line_number : int ,
123- allow_control_flow = False ):
144+ def __init__ (
145+ self ,
146+ source : str ,
147+ filename : str ,
148+ line_number : int ,
149+ allow_control_flow = False ,
150+ ):
124151 """Initializes the auto config node transformer instance.
125152
126153 Args:
@@ -191,7 +218,8 @@ def _validate_decorator_ordering(self, node: ast.FunctionDef):
191218 raise AssertionError (
192219 f'@{ decorator } placed above @auto_config on function { node .name } '
193220 f'at { self ._filename } :{ self ._line_number } . Reorder decorators so '
194- f'that @auto_config is placed above @{ decorator } .' )
221+ f'that @auto_config is placed above @{ decorator } .'
222+ )
195223
196224 # pylint: disable=invalid-name
197225 def visit_Call (self , node : ast .Call ):
@@ -432,7 +460,8 @@ def fn(...): # Or some expression involving a lambda.
432460 * closure_var_definitions ,
433461 * module .body ,
434462 ],
435- decorator_list = [])
463+ decorator_list = [],
464+ )
436465 ],
437466 type_ignores = [],
438467 )
@@ -443,7 +472,8 @@ def fn(...): # Or some expression involving a lambda.
443472def _find_function_code (code : types .CodeType , fn_name : str ):
444473 """Finds the code object within `code` corresponding to `fn_name`."""
445474 code = [
446- const for const in code .co_consts
475+ const
476+ for const in code .co_consts
447477 if inspect .iscode (const ) and const .co_name == fn_name
448478 ]
449479 assert len (code ) == 1 , f"Couldn't find function code for { fn_name !r} ."
@@ -553,7 +583,7 @@ def _make_partial(partial_cls, buildable_or_callable, *args, **kwargs):
553583 return partial_cls (buildable_or_callable , * args , ** kwargs )
554584
555585
556- def exempt (fn_or_cls : Callable [..., Any ] ) -> Callable [..., Any ] :
586+ def exempt (fn_or_cls : _GenericCallable ) -> _GenericCallable :
557587 """Wrap a callable so that it's exempted from auto_config.
558588
559589 This can be used either as a decorator to exempt a function, or used inside
@@ -591,8 +621,36 @@ class ConfigTypes:
591621 arg_factory_cls : Type [partial .ArgFactory ] = partial .ArgFactory
592622
593623
624+ @overload
625+ def auto_config (
626+ fn : _GenericCallable ,
627+ * ,
628+ experimental_allow_dataclass_attribute_access : bool = False ,
629+ experimental_allow_control_flow : bool = False ,
630+ experimental_always_inline : Optional [bool ] = None ,
631+ experimental_exemption_policy : Optional [auto_config_policy .Policy ] = None ,
632+ experimental_config_types : ConfigTypes = ConfigTypes (),
633+ experimental_result_must_contain_buildable : bool = True ,
634+ ) -> AutoConfig [_GenericCallable ]:
635+ ...
636+
637+
638+ @overload
639+ def auto_config (
640+ fn : None = None ,
641+ * ,
642+ experimental_allow_dataclass_attribute_access : bool = False ,
643+ experimental_allow_control_flow : bool = False ,
644+ experimental_always_inline : Optional [bool ] = None ,
645+ experimental_exemption_policy : Optional [auto_config_policy .Policy ] = None ,
646+ experimental_config_types : ConfigTypes = ConfigTypes (),
647+ experimental_result_must_contain_buildable : bool = True ,
648+ ) -> Callable [[_GenericCallable ], AutoConfig [_GenericCallable ]]:
649+ ...
650+
651+
594652def auto_config (
595- fn = None ,
653+ fn : Optional [ _GenericCallable ] = None ,
596654 * ,
597655 experimental_allow_dataclass_attribute_access = False ,
598656 experimental_allow_control_flow : bool = False ,
@@ -778,9 +836,11 @@ def auto_config_attr_save_handler(obj, attr, value, allow_dataclass=True):
778836
779837 def make_auto_config (fn ):
780838 if not isinstance (fn , (types .FunctionType , classmethod , staticmethod )):
781- raise ValueError ('`auto_config` is only compatible with functions, '
782- f'`@classmethod`s, and `@staticmethod`s. Got { fn !r} '
783- f'with type { type (fn )!r} .' )
839+ raise ValueError (
840+ '`auto_config` is only compatible with functions, '
841+ f'`@classmethod`s, and `@staticmethod`s. Got { fn !r} '
842+ f'with type { type (fn )!r} .'
843+ )
784844
785845 if isinstance (fn , (classmethod , staticmethod )):
786846 method_type = type (fn )
@@ -799,7 +859,8 @@ def make_auto_config(fn):
799859 source = source ,
800860 filename = filename ,
801861 line_number = line_number ,
802- allow_control_flow = experimental_allow_control_flow )
862+ allow_control_flow = experimental_allow_control_flow ,
863+ )
803864
804865 # Parse the AST, and modify it by intercepting all `Call`s with the
805866 # `auto_config_call_handler`. Finally, ensure line numbers and code
@@ -882,7 +943,8 @@ def as_buildable(*args, **kwargs):
882943 fn = method_type (fn )
883944 as_buildable = method_type (as_buildable )
884945 return AutoConfig (
885- fn , as_buildable , always_inline = experimental_always_inline )
946+ fn , as_buildable , always_inline = experimental_always_inline
947+ )
886948
887949 # Decorator with empty parenthesis.
888950 if fn is None :
@@ -951,7 +1013,6 @@ def main():
9511013 experimental_always_inline = True
9521014
9531015 def make_unconfig (fn ) -> AutoConfig :
954-
9551016 @functools .wraps (fn )
9561017 def python_implementation (* args , ** kwargs ):
9571018 previous = building ._state .in_build # pytype: disable=module-attr # pylint: disable=protected-access
@@ -965,7 +1026,8 @@ def python_implementation(*args, **kwargs):
9651026 return AutoConfig (
9661027 func = python_implementation ,
9671028 buildable_func = fn ,
968- always_inline = experimental_always_inline )
1029+ always_inline = experimental_always_inline ,
1030+ )
9691031
9701032 # We use this pattern to support using the decorator with and without
9711033 # parenthesis.
@@ -1023,20 +1085,27 @@ def make_experiment():
10231085 doesn't correspond to an ``auto_config``'d function.
10241086 """
10251087 if not isinstance (buildable , config .Config ):
1026- raise ValueError ('Cannot `inline` non-Config buildables; '
1027- f'{ type (buildable )} is not compatible.' )
1088+ raise ValueError (
1089+ 'Cannot `inline` non-Config buildables; '
1090+ f'{ type (buildable )} is not compatible.'
1091+ )
10281092 if not is_auto_config (buildable .__fn_or_cls__ ):
1029- raise ValueError ('Cannot `inline` a non-auto_config function; '
1030- f'`{ buildable .__fn_or_cls__ } ` is not compatible.' )
1093+ raise ValueError (
1094+ 'Cannot `inline` a non-auto_config function; '
1095+ f'`{ buildable .__fn_or_cls__ } ` is not compatible.'
1096+ )
10311097 # Evaluate the `as_buildable` interpretation.
10321098 auto_config_fn = cast (AutoConfig , buildable .__fn_or_cls__ )
10331099 tmp_config = auto_config_fn .as_buildable (** buildable .__arguments__ )
10341100 if not isinstance (tmp_config , config .Buildable ):
1035- raise ValueError ('You cannot currently inline functions that do not return '
1036- '`fdl.Buildable`s.' )
1101+ raise ValueError (
1102+ 'You cannot currently inline functions that do not return '
1103+ '`fdl.Buildable`s.'
1104+ )
10371105
10381106 mutate_buildable .move_buildable_internals (
1039- source = tmp_config , destination = buildable )
1107+ source = tmp_config , destination = buildable
1108+ )
10401109
10411110
10421111def _getsource (fn : Any ) -> str :
@@ -1056,11 +1125,12 @@ def _is_lambda(fn: Any) -> bool:
10561125 return False
10571126 if not (hasattr (fn , '__name__' ) and hasattr (fn , '__code__' )):
10581127 return False
1059- return (( fn .__name__ == '<lambda>' ) or (fn .__code__ .co_name == '<lambda>' ) )
1128+ return (fn .__name__ == '<lambda>' ) or (fn .__code__ .co_name == '<lambda>' )
10601129
10611130
10621131class _LambdaFinder (cst .CSTVisitor ):
10631132 """CST Visitor that searches for the source code for a given lambda func."""
1133+
10641134 METADATA_DEPENDENCIES = (cst .metadata .PositionProvider ,)
10651135
10661136 def __init__ (self , lambda_fn ):
@@ -1095,15 +1165,17 @@ def _getsource_for_lambda(fn: Callable[..., Any]) -> str:
10951165 elif not lambda_finder .candidates :
10961166 raise ValueError (
10971167 'Fiddle auto_config was unable to find the source code for '
1098- f'{ fn } : could not find lambda on line { lambda_finder .lineno } .' )
1168+ f'{ fn } : could not find lambda on line { lambda_finder .lineno } .'
1169+ )
10991170 else :
11001171 # TODO(b/258671226): If desired, we could narrow down which lambda is
11011172 # used based on the signature (or even fancier things like the checking
11021173 # fn.__code__.co_names).
11031174 raise ValueError (
11041175 'Fiddle auto_config was unable to find the source code for '
11051176 f'{ fn } : multiple lambdas found on line { lambda_finder .lineno } ; '
1106- 'try moving each lambda to its own line.' )
1177+ 'try moving each lambda to its own line.'
1178+ )
11071179
11081180
11091181def with_buildable_func (
0 commit comments