Skip to content

Commit 8ad5e7b

Browse files
committed
Add lazy evaluation of patched args
1 parent 406c76c commit 8ad5e7b

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

pif/wiring.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
import importlib
1313
import inspect
1414
import types
15-
from typing import Any, Callable
15+
from typing import Callable
1616

1717
from pif import providers
1818

1919

20-
def patch_args_decorator[T: Callable](func: T, patched_kwargs: dict[str, Any]) -> T:
20+
def patch_args_decorator[T: Callable](func: T, patched_kwargs: dict[str, providers.Provider]) -> T:
2121
"""
2222
Get a decorated copy of `func` with patched arguments.
2323
@@ -30,8 +30,11 @@ def patch_args_decorator[T: Callable](func: T, patched_kwargs: dict[str, Any]) -
3030

3131
@functools.wraps(func)
3232
def wrapper(*args, **kwargs):
33-
patched_kwargs.update(kwargs)
34-
return func(*args, **patched_kwargs)
33+
for keyword in patched_kwargs:
34+
if keyword not in kwargs:
35+
kwargs[keyword] = patched_kwargs[keyword]()
36+
37+
return func(*args, **kwargs)
3538

3639
wrapper._patched_func = func
3740
return wrapper
@@ -63,7 +66,7 @@ def patch_method[T: Callable | types.FunctionType](func: T) -> T:
6366
continue # TODO(scottzach1) Add support for non keyword arguments.
6467

6568
if isinstance(value.default, providers.Provider):
66-
patched_args[name] = value.default()
69+
patched_args[name] = value.default
6770

6871
if patched_args:
6972
return patch_args_decorator(func, patched_args)

tests/test_wiring.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
from unittest.mock import MagicMock
23

34
from pif import providers, wiring
45

@@ -36,3 +37,24 @@ def test_patch_kwarg():
3637
assert not wiring.is_patched(my_func)
3738
assert sig_before == sig_unwired
3839
assert doc_before == doc_unwired
40+
41+
42+
def test_patch_lazy():
43+
"""
44+
Test that our wiring implementation lazily evaluates providers.
45+
"""
46+
mock = MagicMock()
47+
48+
def func(v=None):
49+
return v
50+
51+
assert not mock.call_count
52+
patched = wiring.patch_args_decorator(func, {"v": mock})
53+
assert not mock.call_count
54+
55+
assert func() is None
56+
assert not mock.call_count
57+
assert isinstance(patched(), MagicMock)
58+
assert mock.call_count == 1
59+
assert isinstance(patched(), MagicMock)
60+
assert mock.call_count == 2

0 commit comments

Comments
 (0)