1212import importlib
1313import inspect
1414import types
15- from typing import Any , Callable
15+ from typing import Callable
1616
1717from pif import providers
1818
1919
20- def patch_args_decorator [T : Callable ](func : T , patched_kwargs : dict [str , Any ]) -> T :
20+ def patch_args_decorator [T : Callable ](func : T , patched_kwargs : dict [str , providers . Provider ]) -> T :
2121 """
2222 Get a decorated copy of `func` with patched arguments.
2323
@@ -30,8 +30,11 @@ def patch_args_decorator[T: Callable](func: T, patched_kwargs: dict[str, Any]) -
3030
3131 @functools .wraps (func )
3232 def wrapper (* args , ** kwargs ):
33- patched_kwargs .update (kwargs )
34- return func (* args , ** patched_kwargs )
33+ for keyword in patched_kwargs :
34+ if keyword not in kwargs :
35+ kwargs [keyword ] = patched_kwargs [keyword ]()
36+
37+ return func (* args , ** kwargs )
3538
3639 wrapper ._patched_func = func
3740 return wrapper
@@ -63,7 +66,7 @@ def patch_method[T: Callable | types.FunctionType](func: T) -> T:
6366 continue # TODO(scottzach1) Add support for non keyword arguments.
6467
6568 if isinstance (value .default , providers .Provider ):
66- patched_args [name ] = value .default ()
69+ patched_args [name ] = value .default
6770
6871 if patched_args :
6972 return patch_args_decorator (func , patched_args )
0 commit comments