|
| 1 | +# _ _ _ _ |
| 2 | +# ___ ___ ___ | |_| |_ ______ _ ___| |__ / | |
| 3 | +# / __|/ __/ _ \| __| __|_ / _` |/ __| '_ \| | |
| 4 | +# \__ \ (_| (_) | |_| |_ / / (_| | (__| | | | | |
| 5 | +# |___/\___\___/ \__|\__/___\__,_|\___|_| |_|_| |
| 6 | +# |
| 7 | +# Zac Scott (github.com/scottzach1) |
| 8 | +# |
| 9 | +# https://github.com/scottzach1/python-injector-framework |
| 10 | + |
| 11 | +import functools |
| 12 | +import importlib |
| 13 | +import inspect |
| 14 | +import types |
| 15 | +from typing import Any, Callable |
| 16 | + |
| 17 | +from pif import providers |
| 18 | + |
| 19 | + |
| 20 | +def patch_args_decorator[T: Callable](func: T, patched_kwargs: dict[str, Any]) -> T: |
| 21 | + """ |
| 22 | + Get a decorated copy of `func` with patched arguments. |
| 23 | +
|
| 24 | + TODO(scottzach1) - add support for positional kwargs. |
| 25 | +
|
| 26 | + :param func: to decorate. |
| 27 | + :param patched_kwargs: the kwargs to patch. |
| 28 | + :return: the decorated function. |
| 29 | + """ |
| 30 | + |
| 31 | + @functools.wraps(func) |
| 32 | + def wrapper(*args, **kwargs): |
| 33 | + patched_kwargs.update(kwargs) |
| 34 | + return func(*args, **patched_kwargs) |
| 35 | + |
| 36 | + wrapper._patched_func = func |
| 37 | + return wrapper |
| 38 | + |
| 39 | + |
| 40 | +def is_patched(func: Callable | types.FunctionType) -> bool: |
| 41 | + """ |
| 42 | + Checks if a function has been "patched" by the `patch_args_decorator` |
| 43 | +
|
| 44 | + :param func: the function to check. |
| 45 | + :return: True if patched, False otherwise. |
| 46 | + """ |
| 47 | + return hasattr(func, "_patched_func") |
| 48 | + |
| 49 | + |
| 50 | +def patch_method[T: Callable | types.FunctionType](func: T) -> T: |
| 51 | + """ |
| 52 | + Return a "patched" version of the method provided. |
| 53 | +
|
| 54 | + If no values required patching, the provided function will be returned unchanged.. |
| 55 | +
|
| 56 | + :param func: to patch default values. |
| 57 | + :return: a "patched" version of the method provided. |
| 58 | + """ |
| 59 | + patched_args = {} |
| 60 | + |
| 61 | + for name, value in inspect.signature(func).parameters.items(): |
| 62 | + if value.kind == inspect.Parameter.POSITIONAL_ONLY: |
| 63 | + continue # TODO(scottzach1) Add support for non keyword arguments. |
| 64 | + |
| 65 | + if isinstance(value.default, providers.Provider): |
| 66 | + patched_args[name] = value.default() |
| 67 | + |
| 68 | + if patched_args: |
| 69 | + return patch_args_decorator(func, patched_args) |
| 70 | + |
| 71 | + return func |
| 72 | + |
| 73 | + |
| 74 | +def unpatch_method[T: Callable | types.FunctionType](func: T) -> T: |
| 75 | + """ |
| 76 | + Get an "unpatched" copy of a method. |
| 77 | +
|
| 78 | + If the value was not patched, the provided function will be returned unchanged. |
| 79 | +
|
| 80 | + :param func: the function to unpatch. |
| 81 | + :return: the unpatched provided function. |
| 82 | + """ |
| 83 | + return getattr(func, "_patched_func", func) |
| 84 | + |
| 85 | + |
| 86 | +def wire(modules: list[types.ModuleType | str]) -> None: |
| 87 | + """ |
| 88 | + Patch all methods in the module containing `Provide` default arguments. |
| 89 | +
|
| 90 | + :param modules: list of modules to wire. |
| 91 | + """ |
| 92 | + for module in modules: |
| 93 | + if isinstance(module, str): |
| 94 | + module = importlib.import_module(module) |
| 95 | + |
| 96 | + for name, obj in inspect.getmembers(module): |
| 97 | + if inspect.isfunction(obj): |
| 98 | + if obj is not (patched := patch_method(obj)): |
| 99 | + setattr(module, name, patched) |
| 100 | + elif inspect.isclass(obj): |
| 101 | + for method_name, method in inspect.getmembers(obj, inspect.isfunction): |
| 102 | + if method is not (patched := patch_method(method)): |
| 103 | + setattr(obj, method_name, patched) |
| 104 | + |
| 105 | + |
| 106 | +def unwire(modules: list[types.ModuleType]) -> None: |
| 107 | + """ |
| 108 | + Unpatch all methods in the module containing `Provide` default arguments. |
| 109 | +
|
| 110 | + :param modules: list of modules to wire. |
| 111 | + """ |
| 112 | + for module in modules: |
| 113 | + if isinstance(module, str): |
| 114 | + module = importlib.import_module(module) |
| 115 | + |
| 116 | + for name, obj in inspect.getmembers(module): |
| 117 | + if inspect.isfunction(obj): |
| 118 | + if obj is not (unpatched := unpatch_method(obj)): |
| 119 | + setattr(module, name, unpatched) |
| 120 | + elif inspect.isclass(obj): |
| 121 | + for method_name, method in inspect.getmembers(obj, inspect.isfunction): |
| 122 | + if method is not (unpatched := unpatch_method(method)): |
| 123 | + setattr(obj, method_name, unpatched) |
0 commit comments