Skip to content

Commit 004b97f

Browse files
committed
feat: refactor type detection in decorator
1 parent e9a18bb commit 004b97f

8 files changed

Lines changed: 101 additions & 82 deletions

File tree

src/duron/_core/context.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from duron._core.signal import create_signal
1919
from duron._core.stream import create_stream, run_stream
2020
from duron._decorator.op import CheckpointOp, Op
21+
from duron.typing import inspect_function
2122

2223
if TYPE_CHECKING:
2324
from collections.abc import Callable, Coroutine
@@ -41,11 +42,11 @@
4142

4243
@final
4344
class Context:
44-
__slots__ = ("_loop", "_task", "_token")
45+
__slots__ = ("_fn", "_loop", "_token")
4546

4647
def __init__(self, task: Fn[..., object], loop: EventLoop) -> None:
4748
self._loop: EventLoop = loop
48-
self._task = task
49+
self._fn = task
4950
self._token: Token[Context | None] | None = None
5051

5152
def __enter__(self) -> Context:
@@ -109,12 +110,9 @@ async def run(
109110
return await stream
110111

111112
if isinstance(fn, Op):
112-
if fn.return_type:
113-
return_type = fn.return_type
114-
else:
115-
return_type = self._task.codec.inspect_function(fn.fn).return_type
113+
return_type = fn.return_type
116114
else:
117-
return_type = self._task.codec.inspect_function(fn).return_type
115+
return_type = inspect_function(fn).return_type
118116

119117
metadata = options.metadata if options else None
120118
op = create_op(

src/duron/_core/invoke.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from duron._loop import EventLoop, create_loop
3232
from duron.codec import Codec, JSONValue
3333
from duron.log import is_entry
34-
from duron.typing import unspecified
34+
from duron.typing import inspect_function, unspecified
3535

3636
if TYPE_CHECKING:
3737
from collections.abc import Callable, Coroutine
@@ -44,7 +44,7 @@
4444
from duron._core.stream import Stream
4545
from duron._decorator.fn import Fn
4646
from duron._loop import OpFuture, WaitSet
47-
from duron.codec import Codec, FunctionType
47+
from duron.codec import Codec
4848
from duron.log import (
4949
BarrierEntry,
5050
Entry,
@@ -56,7 +56,7 @@
5656
StreamCreateEntry,
5757
StreamEmitEntry,
5858
)
59-
from duron.typing import TypeHint
59+
from duron.typing import FunctionType, TypeHint
6060

6161

6262
_T_co = TypeVar("_T_co", covariant=True)
@@ -100,7 +100,7 @@ def get_init() -> InitParams:
100100
}
101101

102102
codec = self._fn.codec
103-
type_info = codec.inspect_function(self._fn.fn)
103+
type_info = inspect_function(self._fn.fn)
104104
prelude = _invoke_prelude(self._fn, type_info, get_init)
105105
self._run = _InvokeRun(
106106
prelude,
@@ -115,7 +115,7 @@ def cb() -> InitParams:
115115
msg = "not started"
116116
raise RuntimeError(msg)
117117

118-
type_info = self._fn.codec.inspect_function(self._fn.fn)
118+
type_info = inspect_function(self._fn.fn)
119119
prelude = _invoke_prelude(self._fn, type_info, cb)
120120
self._run = _InvokeRun(
121121
prelude,

src/duron/_decorator/op.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from collections.abc import AsyncGenerator
34
from dataclasses import dataclass
45
from typing import (
56
TYPE_CHECKING,
@@ -9,13 +10,15 @@
910
Literal,
1011
ParamSpec,
1112
TypeVar,
13+
get_args,
14+
get_origin,
1215
overload,
1316
)
1417

15-
from duron.typing import unspecified
18+
from duron.typing import inspect_function, unspecified
1619

1720
if TYPE_CHECKING:
18-
from collections.abc import AsyncGenerator, Callable, Coroutine
21+
from collections.abc import Callable, Coroutine
1922

2023
from duron.typing import TypeHint
2124

@@ -32,7 +35,7 @@ class CheckpointOp(Generic[_P, _S, _T]):
3235
initial: Callable[[], _S]
3336
reducer: Callable[[_S, _T], _S]
3437
action_type: TypeHint[_T]
35-
state_type: TypeHint[_S]
38+
return_type: TypeHint[_S]
3639

3740
def __call__(
3841
self,
@@ -108,20 +111,33 @@ def op(
108111
def decorate_ckpt(
109112
fn: Callable[Concatenate[_S, _P], AsyncGenerator[_T, _S]],
110113
) -> CheckpointOp[_P, _S, _T]:
114+
action_type_local = action_type
115+
return_type_local = return_type
116+
if not action_type_local or not return_type_local:
117+
ret = inspect_function(fn).return_type
118+
if get_origin(ret) is AsyncGenerator:
119+
yield_, send = get_args(ret)
120+
if not action_type_local:
121+
action_type_local = yield_
122+
if not return_type_local:
123+
return_type_local = send
111124
return CheckpointOp(
112125
fn=fn,
113126
initial=initial,
114127
reducer=reducer,
115-
action_type=action_type,
116-
state_type=return_type,
128+
action_type=action_type_local,
129+
return_type=return_type_local,
117130
)
118131

119132
return decorate_ckpt
120133

121134
def decorate(
122135
fn: Callable[_P, Coroutine[Any, Any, _T_co]],
123136
) -> Op[_P, _T_co]:
124-
return Op(fn=fn, return_type=return_type)
137+
return Op(
138+
fn=fn,
139+
return_type=return_type or inspect_function(fn).return_type,
140+
)
125141

126142
if fn is not None:
127143
return decorate(fn)

src/duron/codec.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
11
from __future__ import annotations
22

3-
import inspect
43
from abc import ABC, abstractmethod
5-
from dataclasses import dataclass
64
from typing import TYPE_CHECKING, TypeGuard, cast, final
75
from typing_extensions import TypeAliasType, override
86

9-
from duron.typing import unspecified
10-
117
if TYPE_CHECKING:
12-
from collections.abc import Callable
138
from typing_extensions import Any
149

1510
from duron.typing import TypeHint
@@ -37,13 +32,6 @@ def is_json_value(x: object) -> TypeGuard[JSONValue]:
3732
return False
3833

3934

40-
@dataclass(slots=True)
41-
class FunctionType:
42-
return_type: TypeHint[Any]
43-
parameters: list[str]
44-
parameter_types: dict[str, TypeHint[Any]]
45-
46-
4735
class Codec(ABC):
4836
__slots__: tuple[str, ...] = ()
4937

@@ -58,40 +46,6 @@ def decode_json(
5846
/,
5947
) -> object: ...
6048

61-
def inspect_function( # noqa: PLR6301
62-
self,
63-
fn: Callable[..., object],
64-
) -> FunctionType:
65-
try:
66-
sig = inspect.signature(fn, eval_str=True)
67-
except NameError:
68-
sig = inspect.signature(fn)
69-
return_type = (
70-
sig.return_annotation
71-
if sig.return_annotation != inspect.Parameter.empty
72-
else unspecified
73-
)
74-
75-
parameter_names: list[str] = []
76-
parameter_types: dict[str, TypeHint[Any]] = {}
77-
for k, p in sig.parameters.items():
78-
if p.kind in {
79-
inspect.Parameter.VAR_POSITIONAL,
80-
inspect.Parameter.VAR_KEYWORD,
81-
}:
82-
continue
83-
84-
parameter_names.append(k)
85-
parameter_types[p.name] = (
86-
p.annotation if p.annotation != inspect.Parameter.empty else unspecified
87-
)
88-
89-
return FunctionType(
90-
return_type=return_type,
91-
parameters=parameter_names,
92-
parameter_types=parameter_types,
93-
)
94-
9549

9650
@final
9751
class DefaultCodec(Codec):

src/duron/typing/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .hint import TypeHint as TypeHint
2+
from .hint import unspecified as unspecified
3+
from .inspect import FunctionType as FunctionType
4+
from .inspect import inspect_function as inspect_function
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from types import UnionType
24
from typing_extensions import TypeAliasType, TypeVar
35

@@ -23,5 +25,3 @@ def __bool__(self) -> bool:
2325
from typing_extensions import TypeForm
2426

2527
TypeHint = TypeAliasType("TypeHint", TypeForm[_T] | _Unspecified, type_params=(_T,))
26-
27-
__all__ = ["TypeHint", "unspecified"]

src/duron/typing/inspect.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from __future__ import annotations
2+
3+
import inspect
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING, Any
6+
7+
from duron.typing.hint import unspecified
8+
9+
if TYPE_CHECKING:
10+
from collections.abc import Callable
11+
12+
from duron.typing.hint import TypeHint
13+
14+
15+
@dataclass(slots=True)
16+
class FunctionType:
17+
return_type: TypeHint[Any]
18+
parameters: list[str]
19+
parameter_types: dict[str, TypeHint[Any]]
20+
21+
22+
def inspect_function(
23+
fn: Callable[..., object],
24+
) -> FunctionType:
25+
try:
26+
sig = inspect.signature(fn, eval_str=True)
27+
except NameError:
28+
sig = inspect.signature(fn)
29+
return_type = (
30+
sig.return_annotation
31+
if sig.return_annotation != inspect.Parameter.empty
32+
else unspecified
33+
)
34+
35+
parameter_names: list[str] = []
36+
parameter_types: dict[str, TypeHint[Any]] = {}
37+
for k, p in sig.parameters.items():
38+
if p.kind in {
39+
inspect.Parameter.VAR_POSITIONAL,
40+
inspect.Parameter.VAR_KEYWORD,
41+
}:
42+
continue
43+
44+
parameter_names.append(k)
45+
parameter_types[p.name] = (
46+
p.annotation if p.annotation != inspect.Parameter.empty else unspecified
47+
)
48+
49+
return FunctionType(
50+
return_type=return_type,
51+
parameters=parameter_names,
52+
parameter_types=parameter_types,
53+
)
Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,15 @@
11
from __future__ import annotations
22

33
from collections.abc import AsyncGenerator
4-
from typing import TYPE_CHECKING
54

6-
from duron.codec import DefaultCodec, FunctionType
7-
8-
if TYPE_CHECKING:
9-
from duron.codec import Codec
10-
11-
codec: Codec = DefaultCodec()
5+
from duron.typing import FunctionType, inspect_function
126

137

148
def test_no_parameters() -> None:
159
def simple_func() -> int:
1610
return 42
1711

18-
result = codec.inspect_function(simple_func)
12+
result = inspect_function(simple_func)
1913

2014
assert isinstance(result, FunctionType)
2115
assert result.parameters == []
@@ -27,7 +21,7 @@ def test_no_parameters_no_return_type() -> None:
2721
def simple_func() -> None:
2822
pass
2923

30-
result = codec.inspect_function(simple_func)
24+
result = inspect_function(simple_func)
3125

3226
assert isinstance(result, FunctionType)
3327
assert result.parameters == []
@@ -39,7 +33,7 @@ def test_type_annotated_parameters() -> None:
3933
def typed_func(_x: int, _y: str, _z: float) -> bool:
4034
return True
4135

42-
result = codec.inspect_function(typed_func)
36+
result = inspect_function(typed_func)
4337

4438
assert result.parameters == ["_x", "_y", "_z"]
4539
assert result.parameter_types == {"_x": int, "_y": str, "_z": float}
@@ -50,7 +44,7 @@ def test_return_type_annotation() -> None:
5044
def func_with_return() -> dict[str, int]:
5145
return {}
5246

53-
result = codec.inspect_function(func_with_return)
47+
result = inspect_function(func_with_return)
5448

5549
assert result.parameters == []
5650
assert result.parameter_types == {}
@@ -61,7 +55,7 @@ def test_complex_type_annotations() -> None:
6155
def complex_func(_x: int | str, _y: list[dict[str, int]]) -> tuple[int, ...]:
6256
return (1, 2, 3)
6357

64-
result = codec.inspect_function(complex_func)
58+
result = inspect_function(complex_func)
6559

6660
assert result.parameters == ["_x", "_y"]
6761
assert result.parameter_types == {"_x": int | str, "_y": list[dict[str, int]]}
@@ -72,7 +66,7 @@ def test_async_function() -> None:
7266
async def async_func(_a: int, _b: str) -> bool: # noqa: RUF029
7367
return True
7468

75-
result = codec.inspect_function(async_func)
69+
result = inspect_function(async_func)
7670

7771
assert result.parameters == ["_a", "_b"]
7872
assert result.parameter_types == {"_a": int, "_b": str}
@@ -83,7 +77,7 @@ def test_varargs_and_kwargs() -> None:
8377
def func_with_varargs(_a: int, *_args: str, **_kwargs: bool) -> None:
8478
pass
8579

86-
result = codec.inspect_function(func_with_varargs)
80+
result = inspect_function(func_with_varargs)
8781

8882
assert result.parameters == ["_a"]
8983
assert result.parameter_types == {"_a": int}
@@ -94,7 +88,7 @@ def test_positional_only_and_keyword_only() -> None:
9488
def func_with_special_args(_a: int, /, _b: str, *, _c: float) -> bool:
9589
return True
9690

97-
result = codec.inspect_function(func_with_special_args)
91+
result = inspect_function(func_with_special_args)
9892

9993
assert result.parameters == ["_a", "_b", "_c"]
10094
assert result.parameter_types == {"_a": int, "_b": str, "_c": float}
@@ -105,7 +99,7 @@ def test_iterator() -> None:
10599
async def generator() -> AsyncGenerator[int]: # noqa: RUF029
106100
yield 1
107101

108-
result = codec.inspect_function(generator)
102+
result = inspect_function(generator)
109103

110104
assert result.parameters == []
111105
assert result.parameter_types == {}

0 commit comments

Comments
 (0)