22
33import asyncio
44import contextlib
5+ import sys
56import time
67from typing import TYPE_CHECKING , Final , Generic , Literal , cast
78from typing_extensions import (
3839from duron .typing import Unspecified , inspect_function
3940
4041if TYPE_CHECKING :
42+ import contextvars
4143 from collections .abc import Callable , Coroutine
4244 from contextvars import Token
4345 from types import TracebackType
6062 StreamCreateEntry ,
6163 StreamEmitEntry ,
6264 )
65+ from duron .tracing ._tracer import EntrySpan
6366 from duron .typing import FunctionType , TypeHint
6467
6568
@@ -330,7 +333,11 @@ def __init__(
330333 self ._pending_msg : list [Entry ] = []
331334 self ._pending_task : dict [
332335 str ,
333- tuple [Callable [[], Coroutine [Any , Any , None ]], TypeHint [Any ]],
336+ tuple [
337+ Callable [[], Coroutine [Any , Any , None ]],
338+ contextvars .Context ,
339+ TypeHint [Any ],
340+ ],
334341 ] = {}
335342 self ._pending_ops : set [str ] = set ()
336343 self ._now = 0
@@ -341,6 +348,7 @@ def __init__(
341348 list [StreamObserver [object ]],
342349 TypeHint [Any ],
343350 dict [str , str ] | None ,
351+ EntrySpan | None ,
344352 ],
345353 ] = {}
346354 self ._watchers = watchers or []
@@ -405,8 +413,11 @@ async def run(self) -> object:
405413 for msg in self ._pending_msg :
406414 await self .enqueue_log (msg )
407415 self ._pending_msg .clear ()
408- for key , (task_fn , return_type ) in self ._pending_task .items ():
409- self ._tasks [key ] = (asyncio .create_task (task_fn ()), return_type )
416+ for key , (task_fn , context , return_type ) in self ._pending_task .items ():
417+ self ._tasks [key ] = (
418+ _create_task_context (task_fn (), context ),
419+ return_type ,
420+ )
410421 self ._pending_task .clear ()
411422
412423 while waitset := await self ._step ():
@@ -458,7 +469,7 @@ async def handle_message(
458469
459470 return_type : TypeHint [Any ] = Unspecified
460471 if pending_info is not None :
461- _ , return_type = pending_info
472+ _ , _ , return_type = pending_info
462473 elif task_info is not None :
463474 _ , return_type = task_info
464475 else :
@@ -500,7 +511,7 @@ async def handle_message(
500511 id_ , exception = ValueError ("Stream not found" )
501512 )
502513 else :
503- obs , tv , _ = self ._streams [e ["stream_id" ]]
514+ obs , tv , _ , _ = self ._streams [e ["stream_id" ]]
504515 for ob in obs :
505516 ob .on_next (
506517 offset ,
@@ -515,7 +526,7 @@ async def handle_message(
515526 id_ , exception = ValueError ("Stream not found" )
516527 )
517528 else :
518- obs , _ , _ = self ._streams [e ["stream_id" ]]
529+ obs , _ , _ , _ = self ._streams [e ["stream_id" ]]
519530 self ._loop .post_completion (id_ , result = None )
520531 for ob in obs :
521532 if "error" in e :
@@ -556,17 +567,10 @@ async def enqueue_op(self, id_: str, fut: OpFuture[object]) -> None:
556567 "type" : "promise.create" ,
557568 }
558569
559- set_metadata (
560- promise_create_entry ,
561- op .metadata ,
562- )
563- set_labels (
564- promise_create_entry ,
565- op .labels ,
566- )
570+ set_metadata (promise_create_entry , op .metadata )
571+ set_labels (promise_create_entry , op .labels )
567572 if self ._tracer :
568- span = fut .context .run (
569- self ._tracer .new_entry_span ,
573+ entry_span = self ._tracer .new_entry_span (
570574 promise_create_entry ,
571575 {
572576 "name" : cast (
@@ -576,7 +580,7 @@ async def enqueue_op(self, id_: str, fut: OpFuture[object]) -> None:
576580 },
577581 )
578582 else :
579- span = NULL_SPAN
583+ entry_span = None
580584 await self .enqueue_log (promise_create_entry )
581585
582586 async def cb () -> None :
@@ -587,7 +591,7 @@ async def cb() -> None:
587591 "type" : "promise.complete" ,
588592 "promise_id" : id_ ,
589593 }
590- with span as span_ :
594+ with entry_span . new_span () if entry_span else NULL_SPAN :
591595 try :
592596 result = op .callable (* op .args , ** op .kwargs )
593597 if asyncio .iscoroutine (result ):
@@ -598,8 +602,8 @@ async def cb() -> None:
598602 except asyncio .CancelledError as e :
599603 entry ["error" ] = _encode_error (e )
600604
601- if self . _tracer :
602- self . _tracer . end_entry_span (entry , id_ , span_ )
605+ if entry_span :
606+ entry_span . end (entry )
603607 await self .enqueue_log (entry )
604608
605609 def done (f : OpFuture [object ]) -> None :
@@ -616,9 +620,12 @@ def done(f: OpFuture[object]) -> None:
616620 fut .add_done_callback (done )
617621 sid = id_
618622 if self ._running :
619- self ._tasks [sid ] = (asyncio .create_task (cb ()), op .return_type )
623+ self ._tasks [sid ] = (
624+ _create_task_context (cb (), op .context ),
625+ op .return_type ,
626+ )
620627 else :
621- self ._pending_task [sid ] = (cb , op .return_type )
628+ self ._pending_task [sid ] = (cb , op .context , op . return_type )
622629
623630 case StreamCreate ():
624631 stream_id = id_
@@ -631,27 +638,46 @@ def done(f: OpFuture[object]) -> None:
631638 if _match_labels (op .labels or {}, matcher ):
632639 ob .append (watcher )
633640
634- self ._streams [stream_id ] = (ob , op .dtype , op .labels )
635-
636641 stream_create_entry : StreamCreateEntry = {
637642 "ts" : self .now (),
638643 "id" : stream_id ,
639644 "type" : "stream.create" ,
640645 }
646+ if self ._tracer :
647+ entry_span = self ._tracer .new_entry_span (
648+ stream_create_entry ,
649+ {},
650+ )
651+ else :
652+ entry_span = None
653+
654+ self ._streams [stream_id ] = (ob , op .dtype , op .labels , entry_span )
655+
641656 set_metadata (stream_create_entry , op .metadata )
642657 set_labels (stream_create_entry , op .labels )
643658 await self .enqueue_log (stream_create_entry )
644659
645660 case StreamEmit ():
661+ _ , _ , _ , entry_span = self ._streams [op .stream_id ]
646662 stream_emit_entry : StreamEmitEntry = {
647663 "ts" : self .now (),
648664 "id" : id_ ,
649665 "stream_id" : op .stream_id ,
650666 "type" : "stream.emit" ,
651667 "value" : self ._codec .encode_json (op .value ),
652668 }
669+ if entry_span :
670+ entry_span .attach (
671+ stream_emit_entry ,
672+ {
673+ "type" : "event" ,
674+ "ts" : self .now (),
675+ "kind" : "stream" ,
676+ },
677+ )
653678 await self .enqueue_log (stream_emit_entry )
654679 case StreamClose ():
680+ _ , _ , _ , entry_span = self ._streams [op .stream_id ]
655681 if op .exception :
656682 stream_close_entry_err : StreamCompleteEntry = {
657683 "ts" : self .now (),
@@ -660,6 +686,8 @@ def done(f: OpFuture[object]) -> None:
660686 "type" : "stream.complete" ,
661687 "error" : _encode_error (op .exception ),
662688 }
689+ if entry_span :
690+ entry_span .end (stream_close_entry_err )
663691 await self .enqueue_log (stream_close_entry_err )
664692 else :
665693 stream_close_entry : StreamCompleteEntry = {
@@ -668,6 +696,8 @@ def done(f: OpFuture[object]) -> None:
668696 "stream_id" : op .stream_id ,
669697 "type" : "stream.complete" ,
670698 }
699+ if entry_span :
700+ entry_span .end (stream_close_entry )
671701 await self .enqueue_log (stream_close_entry )
672702 case Barrier ():
673703 barrier_entry : BarrierEntry = {
@@ -722,7 +752,7 @@ async def send_stream(
722752 ) -> int :
723753 cnt = 0
724754 ts = self .now ()
725- for stream_id , (_ , _ , lb ) in self ._streams .items ():
755+ for stream_id , (_ , _ , lb , entry_span ) in self ._streams .items ():
726756 if _match_labels (lb or {}, matcher ):
727757 entry : StreamEmitEntry = {
728758 "ts" : ts ,
@@ -731,6 +761,15 @@ async def send_stream(
731761 "stream_id" : stream_id ,
732762 "value" : self ._codec .encode_json (value ),
733763 }
764+ if entry_span :
765+ entry_span .attach (
766+ entry ,
767+ {
768+ "type" : "event" ,
769+ "ts" : ts ,
770+ "kind" : "stream" ,
771+ },
772+ )
734773 await self .enqueue_log (entry )
735774 cnt += 1
736775 return cnt
@@ -744,12 +783,12 @@ async def close_stream(
744783 ts = self .now ()
745784 # Collect matching stream IDs first to avoid modifying dict during iteration
746785 matching_streams = [
747- stream_id
748- for stream_id , (_ , _ , lb ) in self ._streams .items ()
786+ ( stream_id , span )
787+ for stream_id , (_ , _ , lb , span ) in self ._streams .items ()
749788 if _match_labels (lb or {}, matcher )
750789 ]
751790
752- for stream_id in matching_streams :
791+ for stream_id , entry_span in matching_streams :
753792 if error :
754793 entry : StreamCompleteEntry = {
755794 "ts" : ts ,
@@ -765,6 +804,8 @@ async def close_stream(
765804 "type" : "stream.complete" ,
766805 "stream_id" : stream_id ,
767806 }
807+ if entry_span :
808+ entry_span .end (entry )
768809 await self .enqueue_log (entry )
769810 cnt += 1
770811 return cnt
@@ -818,3 +859,18 @@ async def close(self, error: BaseException | None = None) -> None:
818859 == 0
819860 ):
820861 await asyncio .sleep (0.1 )
862+
863+
864+ if sys .version_info >= (3 , 11 ):
865+
866+ def _create_task_context (
867+ coro : Coroutine [Any , Any , _T ], context : contextvars .Context
868+ ) -> asyncio .Task [_T ]:
869+ return asyncio .create_task (coro , context = context )
870+
871+ else :
872+
873+ def _create_task_context (
874+ coro : Coroutine [Any , Any , _T ], context : contextvars .Context
875+ ) -> asyncio .Task [_T ]:
876+ return context .run (asyncio .create_task , coro )
0 commit comments