Skip to content

Commit 5335cfd

Browse files
committed
wip stream writer context manager
1 parent 70766f1 commit 5335cfd

1 file changed

Lines changed: 25 additions & 44 deletions

File tree

src/duron/_core/invoke.py

Lines changed: 25 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@
4141

4242
if TYPE_CHECKING:
4343
from asyncio.exceptions import CancelledError
44-
from collections.abc import Callable, Coroutine
45-
from contextvars import Token
46-
from types import TracebackType
44+
from collections.abc import AsyncGenerator, Callable, Coroutine
4745

4846
from duron._core.ops import Op, StreamObserver
4947
from duron._decorator.durable import DurableFn
@@ -63,16 +61,17 @@
6361
from duron.typing import FunctionType
6462

6563

66-
_T_co = TypeVar("_T_co", covariant=True)
6764
_T = TypeVar("_T")
6865
_P = ParamSpec("_P")
66+
_OutT = TypeVar("_OutT", covariant=True) # noqa: PLC0105
6967

7068
_CURRENT_VERSION: Final = 0
7169

7270

73-
def invoke(
74-
fn: DurableFn[_P, _T_co], log: LogStorage, /, *, tracer: Tracer | None = None
75-
) -> contextlib.AbstractAsyncContextManager[DurableRun[_P, _T_co]]:
71+
@contextlib.asynccontextmanager
72+
async def invoke(
73+
fn: DurableFn[_P, _OutT], log: LogStorage, /, *, tracer: Tracer | None = None
74+
) -> AsyncGenerator[DurableRun[_P, _OutT], None]:
7675
"""Create an invocation context for this durable function.
7776
7877
Args:
@@ -82,18 +81,24 @@ def invoke(
8281
8382
Returns:
8483
Async context manager for Invoke instance
85-
"""
86-
return _InvokeGuard(DurableRun(fn, log), tracer)
84+
""" # noqa: DOC402, DOC202
85+
token = current_tracer.set(tracer)
86+
run = DurableRun(fn, log)
87+
try:
88+
yield run
89+
finally:
90+
await run.close()
91+
current_tracer.reset(token)
8792

8893

8994
@final
90-
class DurableRun(Generic[_P, _T_co]):
95+
class DurableRun(Generic[_P, _OutT]):
9196
__slots__ = ("_fn", "_log", "_run", "_watchers")
9297

93-
def __init__(self, fn: DurableFn[_P, _T_co], log: LogStorage) -> None:
98+
def __init__(self, fn: DurableFn[_P, _OutT], log: LogStorage) -> None:
9499
self._fn = fn
95100
self._log = log
96-
self._run: _InvokeRun | None = None
101+
self._run: _DurableRun | None = None
97102
self._watchers: list[tuple[dict[str, str], StreamObserver]] = []
98103

99104
async def start(self, *args: _P.args, **kwargs: _P.kwargs) -> None:
@@ -111,20 +116,20 @@ async def prelude() -> InitParams: # noqa: RUF029
111116
type_info = inspect_function(self._fn.fn)
112117
p = _invoke_prelude(self._fn, type_info, prelude)
113118
loop = await create_loop()
114-
self._run = _InvokeRun(loop, p, self._log, codec, watchers=self._watchers)
119+
self._run = _DurableRun(loop, p, self._log, codec, watchers=self._watchers)
115120
await self._run.resume()
116121

117122
async def resume(self) -> None:
118123
"""Resume a previously started invocation."""
119124
type_info = inspect_function(self._fn.fn)
120125
prelude = _invoke_prelude(self._fn, type_info, _resume_init)
121126
loop = await create_loop()
122-
self._run = _InvokeRun(
127+
self._run = _DurableRun(
123128
loop, prelude, self._log, self._fn.codec, watchers=self._watchers
124129
)
125130
await self._run.resume()
126131

127-
async def wait(self) -> _T_co:
132+
async def wait(self) -> _OutT:
128133
"""Wait for the durable function invocation to complete \
129134
and return its result.
130135
@@ -137,7 +142,7 @@ async def wait(self) -> _T_co:
137142
if self._run is None:
138143
msg = "Job not started"
139144
raise RuntimeError(msg)
140-
return cast("_T_co", await self._run.run())
145+
return cast("_OutT", await self._run.run())
141146

142147
async def close(self) -> None:
143148
if self._run:
@@ -214,37 +219,13 @@ def open_stream(
214219
return stream
215220
return _StreamWriter(self, name)
216221

217-
def get_run(self) -> _InvokeRun:
222+
def get_run(self) -> _DurableRun:
218223
if self._run is None:
219224
msg = "Job not started"
220225
raise RuntimeError(msg)
221226
return self._run
222227

223228

224-
@final
225-
class _InvokeGuard(Generic[_P, _T_co]):
226-
__slots__ = ("_job", "_token", "_tracer")
227-
228-
def __init__(self, job: DurableRun[_P, _T_co], tracer: Tracer | None) -> None:
229-
self._job = job
230-
self._tracer = tracer
231-
self._token: Token[Tracer | None] | None = None
232-
233-
async def __aenter__(self) -> DurableRun[_P, _T_co]:
234-
self._token = current_tracer.set(self._tracer)
235-
return self._job
236-
237-
async def __aexit__(
238-
self,
239-
exc_type: type[BaseException] | None,
240-
exc_value: BaseException | None,
241-
traceback: TracebackType | None,
242-
) -> None:
243-
await self._job.close()
244-
if self._token:
245-
current_tracer.reset(self._token)
246-
247-
248229
class InitParams(TypedDict):
249230
version: int
250231
args: list[JSONValue]
@@ -253,10 +234,10 @@ class InitParams(TypedDict):
253234

254235

255236
async def _invoke_prelude(
256-
job_fn: DurableFn[..., _T_co],
237+
job_fn: DurableFn[..., _OutT],
257238
type_info: FunctionType,
258239
init: Callable[[], Coroutine[Any, Any, InitParams]],
259-
) -> _T_co:
240+
) -> _OutT:
260241
loop = asyncio.get_running_loop()
261242
assert isinstance(loop, EventLoop) # noqa: S101
262243

@@ -316,7 +297,7 @@ async def _invoke_prelude(
316297

317298

318299
@final
319-
class _InvokeRun:
300+
class _DurableRun:
320301
__slots__ = (
321302
"_codec",
322303
"_lease",

0 commit comments

Comments
 (0)