4141
4242if TYPE_CHECKING :
4343 from asyncio .exceptions import CancelledError
44- from collections .abc import Callable , Coroutine
45- from contextvars import Token
46- from types import TracebackType
44+ from collections .abc import AsyncGenerator , Callable , Coroutine
4745
4846 from duron ._core .ops import Op , StreamObserver
4947 from duron ._decorator .durable import DurableFn
6361 from duron .typing import FunctionType
6462
6563
66- _T_co = TypeVar ("_T_co" , covariant = True )
6764_T = TypeVar ("_T" )
6865_P = ParamSpec ("_P" )
66+ _OutT = TypeVar ("_OutT" , covariant = True ) # noqa: PLC0105
6967
7068_CURRENT_VERSION : Final = 0
7169
7270
73- def invoke (
74- fn : DurableFn [_P , _T_co ], log : LogStorage , / , * , tracer : Tracer | None = None
75- ) -> contextlib .AbstractAsyncContextManager [DurableRun [_P , _T_co ]]:
71+ @contextlib .asynccontextmanager
72+ async def invoke (
73+ fn : DurableFn [_P , _OutT ], log : LogStorage , / , * , tracer : Tracer | None = None
74+ ) -> AsyncGenerator [DurableRun [_P , _OutT ], None ]:
7675 """Create an invocation context for this durable function.
7776
7877 Args:
@@ -82,18 +81,24 @@ def invoke(
8281
8382 Returns:
8483 Async context manager for Invoke instance
85- """
86- return _InvokeGuard (DurableRun (fn , log ), tracer )
84+ """ # noqa: DOC402, DOC202
85+ token = current_tracer .set (tracer )
86+ run = DurableRun (fn , log )
87+ try :
88+ yield run
89+ finally :
90+ await run .close ()
91+ current_tracer .reset (token )
8792
8893
8994@final
90- class DurableRun (Generic [_P , _T_co ]):
95+ class DurableRun (Generic [_P , _OutT ]):
9196 __slots__ = ("_fn" , "_log" , "_run" , "_watchers" )
9297
93- def __init__ (self , fn : DurableFn [_P , _T_co ], log : LogStorage ) -> None :
98+ def __init__ (self , fn : DurableFn [_P , _OutT ], log : LogStorage ) -> None :
9499 self ._fn = fn
95100 self ._log = log
96- self ._run : _InvokeRun | None = None
101+ self ._run : _DurableRun | None = None
97102 self ._watchers : list [tuple [dict [str , str ], StreamObserver ]] = []
98103
99104 async def start (self , * args : _P .args , ** kwargs : _P .kwargs ) -> None :
@@ -111,20 +116,20 @@ async def prelude() -> InitParams: # noqa: RUF029
111116 type_info = inspect_function (self ._fn .fn )
112117 p = _invoke_prelude (self ._fn , type_info , prelude )
113118 loop = await create_loop ()
114- self ._run = _InvokeRun (loop , p , self ._log , codec , watchers = self ._watchers )
119+ self ._run = _DurableRun (loop , p , self ._log , codec , watchers = self ._watchers )
115120 await self ._run .resume ()
116121
117122 async def resume (self ) -> None :
118123 """Resume a previously started invocation."""
119124 type_info = inspect_function (self ._fn .fn )
120125 prelude = _invoke_prelude (self ._fn , type_info , _resume_init )
121126 loop = await create_loop ()
122- self ._run = _InvokeRun (
127+ self ._run = _DurableRun (
123128 loop , prelude , self ._log , self ._fn .codec , watchers = self ._watchers
124129 )
125130 await self ._run .resume ()
126131
127- async def wait (self ) -> _T_co :
132+ async def wait (self ) -> _OutT :
128133 """Wait for the durable function invocation to complete \
129134 and return its result.
130135
@@ -137,7 +142,7 @@ async def wait(self) -> _T_co:
137142 if self ._run is None :
138143 msg = "Job not started"
139144 raise RuntimeError (msg )
140- return cast ("_T_co " , await self ._run .run ())
145+ return cast ("_OutT " , await self ._run .run ())
141146
142147 async def close (self ) -> None :
143148 if self ._run :
@@ -214,37 +219,13 @@ def open_stream(
214219 return stream
215220 return _StreamWriter (self , name )
216221
217- def get_run (self ) -> _InvokeRun :
222+ def get_run (self ) -> _DurableRun :
218223 if self ._run is None :
219224 msg = "Job not started"
220225 raise RuntimeError (msg )
221226 return self ._run
222227
223228
224- @final
225- class _InvokeGuard (Generic [_P , _T_co ]):
226- __slots__ = ("_job" , "_token" , "_tracer" )
227-
228- def __init__ (self , job : DurableRun [_P , _T_co ], tracer : Tracer | None ) -> None :
229- self ._job = job
230- self ._tracer = tracer
231- self ._token : Token [Tracer | None ] | None = None
232-
233- async def __aenter__ (self ) -> DurableRun [_P , _T_co ]:
234- self ._token = current_tracer .set (self ._tracer )
235- return self ._job
236-
237- async def __aexit__ (
238- self ,
239- exc_type : type [BaseException ] | None ,
240- exc_value : BaseException | None ,
241- traceback : TracebackType | None ,
242- ) -> None :
243- await self ._job .close ()
244- if self ._token :
245- current_tracer .reset (self ._token )
246-
247-
248229class InitParams (TypedDict ):
249230 version : int
250231 args : list [JSONValue ]
@@ -253,10 +234,10 @@ class InitParams(TypedDict):
253234
254235
255236async def _invoke_prelude (
256- job_fn : DurableFn [..., _T_co ],
237+ job_fn : DurableFn [..., _OutT ],
257238 type_info : FunctionType ,
258239 init : Callable [[], Coroutine [Any , Any , InitParams ]],
259- ) -> _T_co :
240+ ) -> _OutT :
260241 loop = asyncio .get_running_loop ()
261242 assert isinstance (loop , EventLoop ) # noqa: S101
262243
@@ -316,7 +297,7 @@ async def _invoke_prelude(
316297
317298
318299@final
319- class _InvokeRun :
300+ class _DurableRun :
320301 __slots__ = (
321302 "_codec" ,
322303 "_lease" ,
0 commit comments