Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions paderbox/utils/functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import inspect
from typing import Callable
import wrapt


def partial_decorator(
Expand Down Expand Up @@ -155,15 +156,30 @@ def partial_decorator(
... print(a, args, kwargs)
>>> buzz(keyword1=42, keyword2=123)(1, 2, 3, 4, 5)
1 (2, 3, 4, 5) {'keyword1': 42, 'keyword2': 123}
"""
if fn is None:
return functools.partial(
partial_decorator,
chain=chain,
requires_partial_call=requires_partial_call
)

signature = inspect.signature(fn)
Instance methods and class methods
>>> class A:
... @partial_decorator
... @staticmethod
... def s(a, b):
... print(a, b)
...
... @partial_decorator
... @classmethod
... def c(cls, a, b):
... print(a, b)
...
... @partial_decorator
... def a(self, a, b):
... print(a, b)
>>> A().a(b=4)(1)
1 4
>>> A.c(b=4)(1)
1 4
>>> A.s(b=4)(1)
1 4
"""
signature = None

@functools.wraps(fn)
def partial_wrapper(
Expand Down Expand Up @@ -220,4 +236,18 @@ def partial_wrapper(
)
return fn(*args, **kwargs)

return partial_wrapper
# It is hard to get all types of functions/methods correct. The `wrapt`
# package has wrappers that handle all cases correctly, so use that here
@wrapt.decorator
def partial_decorator_(wrapped_, instance, args, kwargs):
nonlocal signature, fn
fn = wrapped_
if signature is None:
# Only get the signature once
signature = inspect.signature(fn)
return partial_wrapper(*args, **kwargs)

if fn is None:
return partial_decorator_
else:
return partial_decorator_(fn)