Skip to content

Commit 0260086

Browse files
dhrcopybara-github
authored andcommitted
Enable pytype support for (nested) calls to auto_config-decorated functions.
PiperOrigin-RevId: 606427622
1 parent 1ee1124 commit 0260086

2 files changed

Lines changed: 150 additions & 48 deletions

File tree

fiddle/_src/experimental/auto_config.py

Lines changed: 115 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
import linecache
3131
import textwrap
3232
import types
33-
from typing import Any, Callable, Optional, Type, cast
33+
import typing
34+
from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar, cast, overload
3435

3536
from fiddle._src import arg_factory
3637
from fiddle._src import building
@@ -49,54 +50,78 @@
4950
_ATTR_SAVE_TEMP_VAR_ID = '_attr_save_temp'
5051
_CLOSURE_WRAPPER_ID = '__auto_config_closure_wrapper__'
5152
_EMPTY_ARGUMENTS = ast.arguments(
52-
posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[])
53+
posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]
54+
)
5355
_BUILTINS = frozenset([
54-
builtin for builtin in builtins.__dict__.values()
56+
builtin
57+
for builtin in builtins.__dict__.values()
5558
if inspect.isroutine(builtin) or inspect.isclass(builtin)
5659
])
5760

5861

62+
_GenericCallable = TypeVar('_GenericCallable', bound=Callable[..., Any])
63+
T = TypeVar('T')
64+
65+
5966
@dataclasses.dataclass(frozen=True)
60-
class AutoConfig:
67+
class AutoConfig(Generic[T]):
6168
"""A function wrapper for auto_config'd functions.
6269
6370
In order to support auto_config'ing @classmethod's, we need to customize the
6471
descriptor protocol for the auto_config'd function. This simple wrapper type
65-
is designed to look like a simple `functool.wraps` wrapper, but implements
66-
custom behavior for bound methods.
72+
is designed to look like a `functool.wraps` wrapper, but implements custom
73+
behavior for bound methods.
6774
"""
68-
func: Callable[..., Any]
69-
buildable_func: Callable[..., config.Buildable]
75+
76+
func: T
77+
buildable_func: Callable[..., Any]
7078
always_inline: bool
7179

7280
@property
7381
def nowrap(self):
7482
return True # Tells Flax not to decorate this object, for classmethods.
7583

7684
def __post_init__(self):
77-
# Must copy-over to correctly implement "functools.wraps"-like
78-
# functionality.
79-
for name in ('__module__', '__name__', '__qualname__', '__doc__',
80-
'__annotations__'):
85+
# These attributes must be copied over to in order to correctly implement
86+
# "functools.wraps"-like functionality.
87+
for name in (
88+
'__module__',
89+
'__name__',
90+
'__qualname__',
91+
'__doc__',
92+
'__annotations__',
93+
):
8194
try:
8295
value = getattr(self.func, name)
8396
except AttributeError:
8497
pass
8598
else:
8699
object.__setattr__(self, name, value)
87100

88-
def __call__(self, *args, **kwargs) -> Any:
89-
return self.func(*args, **kwargs)
101+
if typing.TYPE_CHECKING:
102+
__module__: str
103+
__name__: str
104+
__qualname__: str
105+
__doc__: str
106+
__annotations__: Dict[str, Any]
107+
# The following informs type checkers that the call method has the same
108+
# signature/annotations as `func`.
109+
__call__: T
110+
else:
111+
# Actual implementation of __call__, which forwards parameters to `func`.
112+
def __call__(self, *args, **kwargs) -> Any:
113+
return self.func(*args, **kwargs)
90114

91-
def as_buildable(self, *args, **kwargs) -> config.Buildable:
115+
def as_buildable(self, *args, **kwargs) -> Any:
92116
return self.buildable_func(*args, **kwargs)
93117

94118
def __get__(self, obj, objtype=None):
95119
# pytype: disable=attribute-error
96120
return AutoConfig(
97121
func=self.func.__get__(obj, objtype),
98122
buildable_func=self.buildable_func.__get__(obj, objtype),
99-
always_inline=self.always_inline)
123+
always_inline=self.always_inline,
124+
)
100125
# pytype: enable=attribute-error
101126

102127
@property
@@ -116,11 +141,13 @@ class UnsupportedLanguageConstructError(SyntaxError):
116141
class _AutoConfigNodeTransformer(ast.NodeTransformer):
117142
"""A NodeTransformer that adds the auto-config call handler into an AST."""
118143

119-
def __init__(self,
120-
source: str,
121-
filename: str,
122-
line_number: int,
123-
allow_control_flow=False):
144+
def __init__(
145+
self,
146+
source: str,
147+
filename: str,
148+
line_number: int,
149+
allow_control_flow=False,
150+
):
124151
"""Initializes the auto config node transformer instance.
125152
126153
Args:
@@ -191,7 +218,8 @@ def _validate_decorator_ordering(self, node: ast.FunctionDef):
191218
raise AssertionError(
192219
f'@{decorator} placed above @auto_config on function {node.name} '
193220
f'at {self._filename}:{self._line_number}. Reorder decorators so '
194-
f'that @auto_config is placed above @{decorator}.')
221+
f'that @auto_config is placed above @{decorator}.'
222+
)
195223

196224
# pylint: disable=invalid-name
197225
def visit_Call(self, node: ast.Call):
@@ -432,7 +460,8 @@ def fn(...): # Or some expression involving a lambda.
432460
*closure_var_definitions,
433461
*module.body,
434462
],
435-
decorator_list=[])
463+
decorator_list=[],
464+
)
436465
],
437466
type_ignores=[],
438467
)
@@ -443,7 +472,8 @@ def fn(...): # Or some expression involving a lambda.
443472
def _find_function_code(code: types.CodeType, fn_name: str):
444473
"""Finds the code object within `code` corresponding to `fn_name`."""
445474
code = [
446-
const for const in code.co_consts
475+
const
476+
for const in code.co_consts
447477
if inspect.iscode(const) and const.co_name == fn_name
448478
]
449479
assert len(code) == 1, f"Couldn't find function code for {fn_name!r}."
@@ -553,7 +583,7 @@ def _make_partial(partial_cls, buildable_or_callable, *args, **kwargs):
553583
return partial_cls(buildable_or_callable, *args, **kwargs)
554584

555585

556-
def exempt(fn_or_cls: Callable[..., Any]) -> Callable[..., Any]:
586+
def exempt(fn_or_cls: _GenericCallable) -> _GenericCallable:
557587
"""Wrap a callable so that it's exempted from auto_config.
558588
559589
This can be used either as a decorator to exempt a function, or used inside
@@ -591,8 +621,36 @@ class ConfigTypes:
591621
arg_factory_cls: Type[partial.ArgFactory] = partial.ArgFactory
592622

593623

624+
@overload
625+
def auto_config(
626+
fn: _GenericCallable,
627+
*,
628+
experimental_allow_dataclass_attribute_access: bool = False,
629+
experimental_allow_control_flow: bool = False,
630+
experimental_always_inline: Optional[bool] = None,
631+
experimental_exemption_policy: Optional[auto_config_policy.Policy] = None,
632+
experimental_config_types: ConfigTypes = ConfigTypes(),
633+
experimental_result_must_contain_buildable: bool = True,
634+
) -> AutoConfig[_GenericCallable]:
635+
...
636+
637+
638+
@overload
639+
def auto_config(
640+
fn: None = None,
641+
*,
642+
experimental_allow_dataclass_attribute_access: bool = False,
643+
experimental_allow_control_flow: bool = False,
644+
experimental_always_inline: Optional[bool] = None,
645+
experimental_exemption_policy: Optional[auto_config_policy.Policy] = None,
646+
experimental_config_types: ConfigTypes = ConfigTypes(),
647+
experimental_result_must_contain_buildable: bool = True,
648+
) -> Callable[[_GenericCallable], AutoConfig[_GenericCallable]]:
649+
...
650+
651+
594652
def auto_config(
595-
fn=None,
653+
fn: Optional[_GenericCallable] = None,
596654
*,
597655
experimental_allow_dataclass_attribute_access=False,
598656
experimental_allow_control_flow: bool = False,
@@ -778,9 +836,11 @@ def auto_config_attr_save_handler(obj, attr, value, allow_dataclass=True):
778836

779837
def make_auto_config(fn):
780838
if not isinstance(fn, (types.FunctionType, classmethod, staticmethod)):
781-
raise ValueError('`auto_config` is only compatible with functions, '
782-
f'`@classmethod`s, and `@staticmethod`s. Got {fn!r} '
783-
f'with type {type(fn)!r}.')
839+
raise ValueError(
840+
'`auto_config` is only compatible with functions, '
841+
f'`@classmethod`s, and `@staticmethod`s. Got {fn!r} '
842+
f'with type {type(fn)!r}.'
843+
)
784844

785845
if isinstance(fn, (classmethod, staticmethod)):
786846
method_type = type(fn)
@@ -799,7 +859,8 @@ def make_auto_config(fn):
799859
source=source,
800860
filename=filename,
801861
line_number=line_number,
802-
allow_control_flow=experimental_allow_control_flow)
862+
allow_control_flow=experimental_allow_control_flow,
863+
)
803864

804865
# Parse the AST, and modify it by intercepting all `Call`s with the
805866
# `auto_config_call_handler`. Finally, ensure line numbers and code
@@ -882,7 +943,8 @@ def as_buildable(*args, **kwargs):
882943
fn = method_type(fn)
883944
as_buildable = method_type(as_buildable)
884945
return AutoConfig(
885-
fn, as_buildable, always_inline=experimental_always_inline)
946+
fn, as_buildable, always_inline=experimental_always_inline
947+
)
886948

887949
# Decorator with empty parenthesis.
888950
if fn is None:
@@ -951,7 +1013,6 @@ def main():
9511013
experimental_always_inline = True
9521014

9531015
def make_unconfig(fn) -> AutoConfig:
954-
9551016
@functools.wraps(fn)
9561017
def python_implementation(*args, **kwargs):
9571018
previous = building._state.in_build # pytype: disable=module-attr # pylint: disable=protected-access
@@ -965,7 +1026,8 @@ def python_implementation(*args, **kwargs):
9651026
return AutoConfig(
9661027
func=python_implementation,
9671028
buildable_func=fn,
968-
always_inline=experimental_always_inline)
1029+
always_inline=experimental_always_inline,
1030+
)
9691031

9701032
# We use this pattern to support using the decorator with and without
9711033
# parenthesis.
@@ -1023,20 +1085,27 @@ def make_experiment():
10231085
doesn't correspond to an ``auto_config``'d function.
10241086
"""
10251087
if not isinstance(buildable, config.Config):
1026-
raise ValueError('Cannot `inline` non-Config buildables; '
1027-
f'{type(buildable)} is not compatible.')
1088+
raise ValueError(
1089+
'Cannot `inline` non-Config buildables; '
1090+
f'{type(buildable)} is not compatible.'
1091+
)
10281092
if not is_auto_config(buildable.__fn_or_cls__):
1029-
raise ValueError('Cannot `inline` a non-auto_config function; '
1030-
f'`{buildable.__fn_or_cls__}` is not compatible.')
1093+
raise ValueError(
1094+
'Cannot `inline` a non-auto_config function; '
1095+
f'`{buildable.__fn_or_cls__}` is not compatible.'
1096+
)
10311097
# Evaluate the `as_buildable` interpretation.
10321098
auto_config_fn = cast(AutoConfig, buildable.__fn_or_cls__)
10331099
tmp_config = auto_config_fn.as_buildable(**buildable.__arguments__)
10341100
if not isinstance(tmp_config, config.Buildable):
1035-
raise ValueError('You cannot currently inline functions that do not return '
1036-
'`fdl.Buildable`s.')
1101+
raise ValueError(
1102+
'You cannot currently inline functions that do not return '
1103+
'`fdl.Buildable`s.'
1104+
)
10371105

10381106
mutate_buildable.move_buildable_internals(
1039-
source=tmp_config, destination=buildable)
1107+
source=tmp_config, destination=buildable
1108+
)
10401109

10411110

10421111
def _getsource(fn: Any) -> str:
@@ -1056,11 +1125,12 @@ def _is_lambda(fn: Any) -> bool:
10561125
return False
10571126
if not (hasattr(fn, '__name__') and hasattr(fn, '__code__')):
10581127
return False
1059-
return ((fn.__name__ == '<lambda>') or (fn.__code__.co_name == '<lambda>'))
1128+
return (fn.__name__ == '<lambda>') or (fn.__code__.co_name == '<lambda>')
10601129

10611130

10621131
class _LambdaFinder(cst.CSTVisitor):
10631132
"""CST Visitor that searches for the source code for a given lambda func."""
1133+
10641134
METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,)
10651135

10661136
def __init__(self, lambda_fn):
@@ -1095,15 +1165,17 @@ def _getsource_for_lambda(fn: Callable[..., Any]) -> str:
10951165
elif not lambda_finder.candidates:
10961166
raise ValueError(
10971167
'Fiddle auto_config was unable to find the source code for '
1098-
f'{fn}: could not find lambda on line {lambda_finder.lineno}.')
1168+
f'{fn}: could not find lambda on line {lambda_finder.lineno}.'
1169+
)
10991170
else:
11001171
# TODO(b/258671226): If desired, we could narrow down which lambda is
11011172
# used based on the signature (or even fancier things like the checking
11021173
# fn.__code__.co_names).
11031174
raise ValueError(
11041175
'Fiddle auto_config was unable to find the source code for '
11051176
f'{fn}: multiple lambdas found on line {lambda_finder.lineno}; '
1106-
'try moving each lambda to its own line.')
1177+
'try moving each lambda to its own line.'
1178+
)
11071179

11081180

11091181
def with_buildable_func(

0 commit comments

Comments
 (0)