Skip to content

Commit 991a488

Browse files
committed
Add @wiring.injected to auto-injected methods
This @wiring.injected decorator does not require wiring.wire() to patch a method. This implementation also adds support for positional-only arguments.
1 parent 8ad5e7b commit 991a488

File tree

4 files changed

+120
-40
lines changed

4 files changed

+120
-40
lines changed

README.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,12 @@ A simple Python dependency injection framework.
1010
from pif import wiring, providers
1111

1212

13-
def my_function(a: str = providers.Singleton[str](lambda: "hello wolrd")):
13+
@wiring.injected
14+
def my_function(a: str = providers.Singleton[str](lambda: "hello world")):
1415
return a
1516

1617

1718
if __name__ == '__main__':
18-
assert isinstance(my_function(), providers.Singleton)
19-
20-
wiring.wire([__name__])
21-
2219
assert "hello world" == my_function()
2320
```
2421

pif/wiring.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,57 @@
1111
import functools
1212
import importlib
1313
import inspect
14+
import itertools
1415
import types
15-
from typing import Callable
16+
from typing import Any, Callable
1617

1718
from pif import providers
1819

1920

20-
def patch_args_decorator[T: Callable](func: T, patched_kwargs: dict[str, providers.Provider]) -> T:
21+
def patch_args(
22+
signature: inspect.Signature,
23+
args: tuple[Any, ...],
24+
kwargs: dict[str, Any],
25+
) -> tuple[tuple[Any, ...], dict[str, Any]]:
2126
"""
22-
Get a decorated copy of `func` with patched arguments.
27+
Patch the args and kwargs at runtime using the signature for reference.
2328
24-
TODO(scottzach1) - add support for positional kwargs.
29+
:param signature: to lookup method parameters.
30+
:param args: provided at runtime
31+
:param kwargs: provided at runtime.
32+
:return: injected args and kwargs to pass to func.
33+
"""
34+
for i, (name, value) in enumerate(signature.parameters.items()):
35+
if isinstance(value.default, providers.Provider) and i >= len(args):
36+
if value.kind == inspect.Parameter.POSITIONAL_ONLY:
37+
args = (
38+
*args,
39+
*(p.default for p in itertools.islice(signature.parameters.values(), len(args), i)),
40+
value.default(),
41+
)
42+
if (
43+
value.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
44+
and name not in kwargs
45+
):
46+
kwargs[name] = value.default()
47+
return args, kwargs
48+
49+
50+
def injected[T: Callable](func: T) -> T:
51+
"""
52+
Get a decorated copy of `func` with patched arguments.
2553
2654
:param func: to decorate.
27-
:param patched_kwargs: the kwargs to patch.
2855
:return: the decorated function.
2956
"""
57+
signature = inspect.signature(func)
58+
59+
if not any(p for p in signature.parameters.values() if isinstance(p.default, providers.Provider)):
60+
return func
3061

3162
@functools.wraps(func)
3263
def wrapper(*args, **kwargs):
33-
for keyword in patched_kwargs:
34-
if keyword not in kwargs:
35-
kwargs[keyword] = patched_kwargs[keyword]()
64+
args, kwargs = patch_args(signature, args, kwargs)
3665

3766
return func(*args, **kwargs)
3867

@@ -54,22 +83,13 @@ def patch_method[T: Callable | types.FunctionType](func: T) -> T:
5483
"""
5584
Return a "patched" version of the method provided.
5685
57-
If no values required patching, the provided function will be returned unchanged..
86+
If no values required patching, the provided function will be returned unchanged.
5887
5988
:param func: to patch default values.
6089
:return: a "patched" version of the method provided.
6190
"""
62-
patched_args = {}
63-
64-
for name, value in inspect.signature(func).parameters.items():
65-
if value.kind == inspect.Parameter.POSITIONAL_ONLY:
66-
continue # TODO(scottzach1) Add support for non keyword arguments.
67-
68-
if isinstance(value.default, providers.Provider):
69-
patched_args[name] = value.default
70-
71-
if patched_args:
72-
return patch_args_decorator(func, patched_args)
91+
if any(1 for param in inspect.signature(func).parameters.values() if isinstance(param.default, providers.Provider)):
92+
return injected(func)
7393

7494
return func
7595

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ select = [
2525
"SIM", # flake8-simplify
2626
"I", # isort
2727
]
28-
28+
ignore = ["B008"]
2929
fixable = ["ALL"]
3030

3131
[build-system]

tests/test_wiring.py

Lines changed: 77 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33

44
from pif import providers, wiring
55

6-
provider = providers.Singleton[str](lambda: "hello")
76

7+
def provide(s: str) -> providers.Singleton[str]:
8+
return providers.Singleton[str](lambda: f"{s}_injected")
89

9-
def my_func(a: str = provider):
10+
11+
def my_func(a: str = provide("a")):
1012
"""
1113
Our dummy method to test wiring for the module.
1214
"""
@@ -19,21 +21,21 @@ def test_patch_kwarg():
1921
"""
2022
sig_before = inspect.signature(my_func)
2123
doc_before = my_func.__doc__
22-
assert provider == my_func()
24+
assert isinstance(my_func(), providers.Singleton)
2325
assert not wiring.is_patched(my_func)
2426

2527
wiring.wire([__name__])
2628
sig_wired = inspect.signature(my_func)
2729
doc_wired = my_func.__doc__
28-
assert my_func() == "hello"
30+
assert my_func() == "a_injected"
2931
assert sig_before == sig_wired
3032
assert doc_before == doc_wired
3133
assert wiring.is_patched(my_func)
3234

3335
wiring.unwire([__name__])
3436
sig_unwired = inspect.signature(my_func)
3537
doc_unwired = my_func.__doc__
36-
assert provider == my_func()
38+
assert isinstance(my_func(), providers.Singleton)
3739
assert not wiring.is_patched(my_func)
3840
assert sig_before == sig_unwired
3941
assert doc_before == doc_unwired
@@ -44,17 +46,78 @@ def test_patch_lazy():
4446
Test that our wiring implementation lazily evaluates providers.
4547
"""
4648
mock = MagicMock()
47-
48-
def func(v=None):
49-
return v
50-
51-
assert not mock.call_count
52-
patched = wiring.patch_args_decorator(func, {"v": mock})
5349
assert not mock.call_count
5450

55-
assert func() is None
51+
@wiring.injected
52+
def func(v=providers.Singleton[MagicMock](lambda: mock)):
53+
return v()
54+
5655
assert not mock.call_count
57-
assert isinstance(patched(), MagicMock)
56+
assert isinstance(func(), MagicMock)
5857
assert mock.call_count == 1
59-
assert isinstance(patched(), MagicMock)
58+
assert isinstance(func(), MagicMock)
6059
assert mock.call_count == 2
60+
61+
62+
def test_patch_positional_only():
63+
"""
64+
Test patching for POSITIONAL_ONLY arguments.
65+
"""
66+
67+
@wiring.injected
68+
def p1(a, b=provide("b"), c="c_default", /):
69+
return a, b, c
70+
71+
assert p1(None) == (None, "b_injected", "c_default")
72+
assert p1(None, None) == (None, None, "c_default")
73+
74+
@wiring.injected
75+
def p2(a, b=None, c=provide("c"), /):
76+
return a, b, c
77+
78+
assert p2(None) == (None, None, "c_injected")
79+
assert p2(None, None) == (None, None, "c_injected")
80+
assert p2(None, None, "c_override") == (None, None, "c_override")
81+
82+
83+
def test_patch_positional():
84+
"""
85+
Test patching for POSITIONAL_OR_KEYWORD arguments.
86+
"""
87+
88+
@wiring.injected
89+
def p1(a, b=provide("b"), c="c_default"):
90+
return a, b, c
91+
92+
assert p1(None) == (None, "b_injected", "c_default")
93+
assert p1(None, None) == (None, None, "c_default")
94+
assert p1(a=None) == (None, "b_injected", "c_default")
95+
assert p1(a=None, b=None) == (None, None, "c_default")
96+
97+
@wiring.injected
98+
def p2(a, b=None, c=provide("c")):
99+
return a, b, c
100+
101+
assert p2(None) == (None, None, "c_injected")
102+
assert p2(None, None) == (None, None, "c_injected")
103+
assert p2(None, None, "c_override") == (None, None, "c_override")
104+
assert p2(a=None) == (None, None, "c_injected")
105+
assert p2(a=None, b=None) == (None, None, "c_injected")
106+
assert p2(a=None, b=None, c="c_override") == (None, None, "c_override")
107+
108+
109+
def test_patch_positional_or_keyword():
110+
"""
111+
Test patching for VAR_POSITIONAL argument.
112+
"""
113+
114+
@wiring.injected
115+
def p1(a, b=provide("b"), *c, d="d_default", e=provide("e")):
116+
return a, b, *c, d, e
117+
118+
assert p1("a") == ("a", "b_injected", "d_default", "e_injected")
119+
assert p1("a", "b") == ("a", "b", "d_default", "e_injected")
120+
assert p1("a", "b", "c1") == ("a", "b", "c1", "d_default", "e_injected")
121+
assert p1("a", "b", "c1", "c2") == ("a", "b", "c1", "c2", "d_default", "e_injected")
122+
assert p1("a", "b", "c1", "c2", d="d") == ("a", "b", "c1", "c2", "d", "e_injected")
123+
assert p1("a", "b", "c1", "c2", d="d", e="e") == ("a", "b", "c1", "c2", "d", "e")

0 commit comments

Comments
 (0)