1010 TypedDict ,
1111 TypeVar ,
1212 assert_never ,
13+ assert_type ,
1314 final ,
1415 overload ,
1516)
1617
17- from duron ._core .config import config
1818from duron ._core .context import Context
1919from duron ._core .ops import (
2020 Barrier ,
2828from duron ._core .stream import ObserverStream , Stream , StreamWriter
2929from duron ._loop import EventLoop , create_loop
3030from duron .codec import Codec , JSONValue
31- from duron .log import derive_id , is_entry , random_id
31+ from duron .log import derive_id , is_entry , random_id , set_metadata
32+ from duron .tracing import Tracer , use_tracer
3233from duron .typing import Unspecified , inspect_function
3334
3435if TYPE_CHECKING :
@@ -86,8 +87,9 @@ def __init__(
8687 def invoke (
8788 fn : Fn [_P , _T_co ],
8889 log : LogStorage ,
90+ tracer : Tracer | None ,
8991 ) -> contextlib .AbstractAsyncContextManager [Invoke [_P , _T_co ]]:
90- return _InvokeGuard (Invoke (fn , log ))
92+ return _InvokeGuard (Invoke (fn , log ), tracer )
9193
9294 async def start (self , * args : _P .args , ** kwargs : _P .kwargs ) -> None :
9395 def get_init () -> InitParams :
@@ -106,7 +108,6 @@ def get_init() -> InitParams:
106108 self ._log ,
107109 codec ,
108110 watchers = self ._watchers ,
109- debug = config .debug ,
110111 )
111112 await self ._run .resume ()
112113
@@ -122,7 +123,6 @@ def cb() -> InitParams:
122123 self ._log ,
123124 self ._fn .codec ,
124125 watchers = self ._watchers ,
125- debug = config .debug ,
126126 )
127127 await self ._run .resume ()
128128
@@ -217,12 +217,14 @@ def get_run(self) -> _InvokeRun:
217217
218218@final
219219class _InvokeGuard (Generic [_P , _T_co ]):
220- __slots__ = ("_job" ,)
220+ __slots__ = ("_job" , "_tracer" )
221221
222- def __init__ (self , job : Invoke [_P , _T_co ]) -> None :
222+ def __init__ (self , job : Invoke [_P , _T_co ], tracer : Tracer | None ) -> None :
223223 self ._job = job
224+ self ._tracer = use_tracer (tracer )
224225
225226 async def __aenter__ (self ) -> Invoke [_P , _T_co ]:
227+ self ._tracer .__enter__ ()
226228 return self ._job
227229
228230 async def __aexit__ (
@@ -232,6 +234,7 @@ async def __aexit__(
232234 traceback : TracebackType | None ,
233235 ) -> None :
234236 await self ._job .close ()
237+ self ._tracer .__exit__ (exc_type , exc_value , traceback )
235238
236239
237240class InitParams (TypedDict ):
@@ -286,7 +289,7 @@ async def _invoke_prelude(
286289class _InvokeRun :
287290 __slots__ = (
288291 "_codec" ,
289- "_debug " ,
292+ "_lease " ,
290293 "_log" ,
291294 "_loop" ,
292295 "_now" ,
@@ -297,6 +300,7 @@ class _InvokeRun:
297300 "_streams" ,
298301 "_task" ,
299302 "_tasks" ,
303+ "_tracer" ,
300304 "_watchers" ,
301305 )
302306
@@ -310,15 +314,13 @@ def __init__(
310314 tuple [Callable [[dict [str , JSONValue ]], bool ], StreamObserver [object ]]
311315 ]
312316 | None = None ,
313- debug : bool = False ,
314317 ) -> None :
315318 self ._loop = create_loop (asyncio .get_running_loop ())
316- if debug :
317- self ._loop .set_debug (True )
318319 self ._task = self ._loop .create_task (task )
319320 self ._log = log
320321 self ._codec = codec
321- self ._running : bytes | None = None
322+ self ._running : bool = False
323+ self ._lease : bytes | None = None
322324 self ._pending_msg : list [Entry ] = []
323325 self ._pending_task : dict [
324326 str ,
@@ -336,11 +338,12 @@ def __init__(
336338 ],
337339 ] = {}
338340 self ._watchers = watchers or []
339- self ._debug : dict [str , JSONValue ] | None = (
340- {"run.id" : random_id ()} if debug else None
341- )
341+ self ._tracer : Tracer | None = Tracer .current ()
342342
343343 async def close (self ) -> None :
344+ if self ._lease :
345+ await self ._log .release_lease (self ._lease )
346+ self ._lease = None
344347 for task , _ in self ._tasks .values ():
345348 _ = task .cancel ()
346349 with contextlib .suppress (asyncio .CancelledError ):
@@ -359,6 +362,7 @@ def now(self) -> int:
359362 return self ._now
360363
361364 async def resume (self ) -> None :
365+ self ._lease = await self ._log .acquire_lease ()
362366 recvd_msgs : set [str ] = set ()
363367 async for o , entry in self ._log .stream (None , live = False ):
364368 ts = entry ["ts" ]
@@ -378,21 +382,17 @@ async def run(self) -> object:
378382 if self ._task .done ():
379383 return self ._task .result ()
380384
381- self ._running = await self ._log .acquire_lease ()
382- try :
383- for msg in self ._pending_msg :
384- await self .enqueue_log (msg )
385- self ._pending_msg .clear ()
386- for key , (task_fn , return_type ) in self ._pending_task .items ():
387- self ._tasks [key ] = (asyncio .create_task (task_fn ()), return_type )
388- self ._pending_task .clear ()
389-
390- while waitset := await self ._step ():
391- await waitset .block (self .now ())
392- return self ._task .result ()
393- finally :
394- await self ._log .release_lease (self ._running )
395- self ._running = None
385+ self ._running = True
386+ for msg in self ._pending_msg :
387+ await self .enqueue_log (msg )
388+ self ._pending_msg .clear ()
389+ for key , (task_fn , return_type ) in self ._pending_task .items ():
390+ self ._tasks [key ] = (asyncio .create_task (task_fn ()), return_type )
391+ self ._pending_task .clear ()
392+
393+ while waitset := await self ._step ():
394+ await waitset .block (self .now ())
395+ return self ._task .result ()
396396
397397 async def _step (self ) -> WaitSet | None :
398398 while True :
@@ -495,24 +495,23 @@ async def handle_message(
495495 id_ = e ["id" ]
496496 self ._loop .post_completion (id_ , result = offset )
497497 self ._pending_ops .discard (id_ )
498+ elif e ["type" ] == "trace" :
499+ pass
500+ else :
501+ assert_type (e ["type" ], Literal ["promise/create" ])
498502
499503 async def enqueue_log (self , entry : Entry , * , flush : bool = False ) -> None :
500504 if not self ._running :
501505 self ._pending_msg .append (entry )
506+ elif self ._lease is None :
507+ # closed
508+ return
502509 else :
503- if self ._debug :
504- if "debug" in entry :
505- entry ["debug" ] = {
506- ** entry ["debug" ],
507- ** self ._debug ,
508- }
509- else :
510- entry ["debug" ] = self ._debug
511- else :
512- _ = entry .pop ("debug" , None )
513- offset = await self ._log .append (self ._running , entry )
510+ if self ._tracer :
511+ self ._tracer .attach_metadata (entry )
512+ offset = await self ._log .append (self ._lease , entry )
514513 if flush :
515- await self ._log .flush (self ._running )
514+ await self ._log .flush (self ._lease )
516515 await self .handle_message (offset , entry )
517516
518517 async def enqueue_op (self , id_ : str , fut : OpFuture [object ]) -> None :
@@ -524,14 +523,18 @@ async def enqueue_op(self, id_: str, fut: OpFuture[object]) -> None:
524523 "id" : id_ ,
525524 "type" : "promise/create" ,
526525 }
527- if op .metadata :
528- promise_create_entry ["metadata" ] = op .metadata
529- if self ._debug :
530- promise_create_entry ["debug" ] = {
526+
527+ set_metadata (
528+ promise_create_entry ,
529+ op .metadata ,
530+ {
531531 "fn.name" : str (
532532 getattr (op .callable , "__qualname__" , op .callable )
533533 ),
534534 }
535+ if self ._tracer
536+ else None ,
537+ )
535538 await self .enqueue_log (promise_create_entry )
536539
537540 async def cb () -> None :
@@ -590,8 +593,7 @@ def done(f: OpFuture[object]) -> None:
590593 "id" : stream_id ,
591594 "type" : "stream/create" ,
592595 }
593- if op .metadata :
594- stream_create_entry ["metadata" ] = op .metadata
596+ set_metadata (stream_create_entry , op .metadata )
595597 await self .enqueue_log (stream_create_entry )
596598
597599 case StreamEmit ():
@@ -634,8 +636,7 @@ def done(f: OpFuture[object]) -> None:
634636 "id" : id_ ,
635637 "type" : "promise/create" ,
636638 }
637- if op .metadata :
638- promise_create_entry ["metadata" ] = op .metadata
639+ set_metadata (promise_create_entry , op .metadata )
639640 self ._tasks [id_ ] = (asyncio .Future (), op .return_type )
640641 await self .enqueue_log (promise_create_entry )
641642 case _:
0 commit comments