Skip to content

Commit d50ac0d

Browse files
committed
feat: stream trace
1 parent cafa1bc commit d50ac0d

7 files changed

Lines changed: 181 additions & 60 deletions

File tree

src/duron/_core/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import contextvars
45
from contextlib import contextmanager
56
from contextvars import ContextVar
67
from random import Random
@@ -125,6 +126,7 @@ async def run(
125126
args=args,
126127
kwargs=kwargs,
127128
return_type=return_type,
129+
context=contextvars.copy_context(),
128130
metadata=self._get_metadata(metadata),
129131
labels=self._get_labels(None),
130132
),

src/duron/_core/invoke.py

Lines changed: 84 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import contextlib
5+
import sys
56
import time
67
from typing import TYPE_CHECKING, Final, Generic, Literal, cast
78
from typing_extensions import (
@@ -38,6 +39,7 @@
3839
from duron.typing import Unspecified, inspect_function
3940

4041
if TYPE_CHECKING:
42+
import contextvars
4143
from collections.abc import Callable, Coroutine
4244
from contextvars import Token
4345
from types import TracebackType
@@ -60,6 +62,7 @@
6062
StreamCreateEntry,
6163
StreamEmitEntry,
6264
)
65+
from duron.tracing._tracer import EntrySpan
6366
from duron.typing import FunctionType, TypeHint
6467

6568

@@ -330,7 +333,11 @@ def __init__(
330333
self._pending_msg: list[Entry] = []
331334
self._pending_task: dict[
332335
str,
333-
tuple[Callable[[], Coroutine[Any, Any, None]], TypeHint[Any]],
336+
tuple[
337+
Callable[[], Coroutine[Any, Any, None]],
338+
contextvars.Context,
339+
TypeHint[Any],
340+
],
334341
] = {}
335342
self._pending_ops: set[str] = set()
336343
self._now = 0
@@ -341,6 +348,7 @@ def __init__(
341348
list[StreamObserver[object]],
342349
TypeHint[Any],
343350
dict[str, str] | None,
351+
EntrySpan | None,
344352
],
345353
] = {}
346354
self._watchers = watchers or []
@@ -405,8 +413,11 @@ async def run(self) -> object:
405413
for msg in self._pending_msg:
406414
await self.enqueue_log(msg)
407415
self._pending_msg.clear()
408-
for key, (task_fn, return_type) in self._pending_task.items():
409-
self._tasks[key] = (asyncio.create_task(task_fn()), return_type)
416+
for key, (task_fn, context, return_type) in self._pending_task.items():
417+
self._tasks[key] = (
418+
_create_task_context(task_fn(), context),
419+
return_type,
420+
)
410421
self._pending_task.clear()
411422

412423
while waitset := await self._step():
@@ -458,7 +469,7 @@ async def handle_message(
458469

459470
return_type: TypeHint[Any] = Unspecified
460471
if pending_info is not None:
461-
_, return_type = pending_info
472+
_, _, return_type = pending_info
462473
elif task_info is not None:
463474
_, return_type = task_info
464475
else:
@@ -500,7 +511,7 @@ async def handle_message(
500511
id_, exception=ValueError("Stream not found")
501512
)
502513
else:
503-
obs, tv, _ = self._streams[e["stream_id"]]
514+
obs, tv, _, _ = self._streams[e["stream_id"]]
504515
for ob in obs:
505516
ob.on_next(
506517
offset,
@@ -515,7 +526,7 @@ async def handle_message(
515526
id_, exception=ValueError("Stream not found")
516527
)
517528
else:
518-
obs, _, _ = self._streams[e["stream_id"]]
529+
obs, _, _, _ = self._streams[e["stream_id"]]
519530
self._loop.post_completion(id_, result=None)
520531
for ob in obs:
521532
if "error" in e:
@@ -556,17 +567,10 @@ async def enqueue_op(self, id_: str, fut: OpFuture[object]) -> None:
556567
"type": "promise.create",
557568
}
558569

559-
set_metadata(
560-
promise_create_entry,
561-
op.metadata,
562-
)
563-
set_labels(
564-
promise_create_entry,
565-
op.labels,
566-
)
570+
set_metadata(promise_create_entry, op.metadata)
571+
set_labels(promise_create_entry, op.labels)
567572
if self._tracer:
568-
span = fut.context.run(
569-
self._tracer.new_entry_span,
573+
entry_span = self._tracer.new_entry_span(
570574
promise_create_entry,
571575
{
572576
"name": cast(
@@ -576,7 +580,7 @@ async def enqueue_op(self, id_: str, fut: OpFuture[object]) -> None:
576580
},
577581
)
578582
else:
579-
span = NULL_SPAN
583+
entry_span = None
580584
await self.enqueue_log(promise_create_entry)
581585

582586
async def cb() -> None:
@@ -587,7 +591,7 @@ async def cb() -> None:
587591
"type": "promise.complete",
588592
"promise_id": id_,
589593
}
590-
with span as span_:
594+
with entry_span.new_span() if entry_span else NULL_SPAN:
591595
try:
592596
result = op.callable(*op.args, **op.kwargs)
593597
if asyncio.iscoroutine(result):
@@ -598,8 +602,8 @@ async def cb() -> None:
598602
except asyncio.CancelledError as e:
599603
entry["error"] = _encode_error(e)
600604

601-
if self._tracer:
602-
self._tracer.end_entry_span(entry, id_, span_)
605+
if entry_span:
606+
entry_span.end(entry)
603607
await self.enqueue_log(entry)
604608

605609
def done(f: OpFuture[object]) -> None:
@@ -616,9 +620,12 @@ def done(f: OpFuture[object]) -> None:
616620
fut.add_done_callback(done)
617621
sid = id_
618622
if self._running:
619-
self._tasks[sid] = (asyncio.create_task(cb()), op.return_type)
623+
self._tasks[sid] = (
624+
_create_task_context(cb(), op.context),
625+
op.return_type,
626+
)
620627
else:
621-
self._pending_task[sid] = (cb, op.return_type)
628+
self._pending_task[sid] = (cb, op.context, op.return_type)
622629

623630
case StreamCreate():
624631
stream_id = id_
@@ -631,27 +638,46 @@ def done(f: OpFuture[object]) -> None:
631638
if _match_labels(op.labels or {}, matcher):
632639
ob.append(watcher)
633640

634-
self._streams[stream_id] = (ob, op.dtype, op.labels)
635-
636641
stream_create_entry: StreamCreateEntry = {
637642
"ts": self.now(),
638643
"id": stream_id,
639644
"type": "stream.create",
640645
}
646+
if self._tracer:
647+
entry_span = self._tracer.new_entry_span(
648+
stream_create_entry,
649+
{},
650+
)
651+
else:
652+
entry_span = None
653+
654+
self._streams[stream_id] = (ob, op.dtype, op.labels, entry_span)
655+
641656
set_metadata(stream_create_entry, op.metadata)
642657
set_labels(stream_create_entry, op.labels)
643658
await self.enqueue_log(stream_create_entry)
644659

645660
case StreamEmit():
661+
_, _, _, entry_span = self._streams[op.stream_id]
646662
stream_emit_entry: StreamEmitEntry = {
647663
"ts": self.now(),
648664
"id": id_,
649665
"stream_id": op.stream_id,
650666
"type": "stream.emit",
651667
"value": self._codec.encode_json(op.value),
652668
}
669+
if entry_span:
670+
entry_span.attach(
671+
stream_emit_entry,
672+
{
673+
"type": "event",
674+
"ts": self.now(),
675+
"kind": "stream",
676+
},
677+
)
653678
await self.enqueue_log(stream_emit_entry)
654679
case StreamClose():
680+
_, _, _, entry_span = self._streams[op.stream_id]
655681
if op.exception:
656682
stream_close_entry_err: StreamCompleteEntry = {
657683
"ts": self.now(),
@@ -660,6 +686,8 @@ def done(f: OpFuture[object]) -> None:
660686
"type": "stream.complete",
661687
"error": _encode_error(op.exception),
662688
}
689+
if entry_span:
690+
entry_span.end(stream_close_entry_err)
663691
await self.enqueue_log(stream_close_entry_err)
664692
else:
665693
stream_close_entry: StreamCompleteEntry = {
@@ -668,6 +696,8 @@ def done(f: OpFuture[object]) -> None:
668696
"stream_id": op.stream_id,
669697
"type": "stream.complete",
670698
}
699+
if entry_span:
700+
entry_span.end(stream_close_entry)
671701
await self.enqueue_log(stream_close_entry)
672702
case Barrier():
673703
barrier_entry: BarrierEntry = {
@@ -722,7 +752,7 @@ async def send_stream(
722752
) -> int:
723753
cnt = 0
724754
ts = self.now()
725-
for stream_id, (_, _, lb) in self._streams.items():
755+
for stream_id, (_, _, lb, entry_span) in self._streams.items():
726756
if _match_labels(lb or {}, matcher):
727757
entry: StreamEmitEntry = {
728758
"ts": ts,
@@ -731,6 +761,15 @@ async def send_stream(
731761
"stream_id": stream_id,
732762
"value": self._codec.encode_json(value),
733763
}
764+
if entry_span:
765+
entry_span.attach(
766+
entry,
767+
{
768+
"type": "event",
769+
"ts": ts,
770+
"kind": "stream",
771+
},
772+
)
734773
await self.enqueue_log(entry)
735774
cnt += 1
736775
return cnt
@@ -744,12 +783,12 @@ async def close_stream(
744783
ts = self.now()
745784
# Collect matching stream IDs first to avoid modifying dict during iteration
746785
matching_streams = [
747-
stream_id
748-
for stream_id, (_, _, lb) in self._streams.items()
786+
(stream_id, span)
787+
for stream_id, (_, _, lb, span) in self._streams.items()
749788
if _match_labels(lb or {}, matcher)
750789
]
751790

752-
for stream_id in matching_streams:
791+
for stream_id, entry_span in matching_streams:
753792
if error:
754793
entry: StreamCompleteEntry = {
755794
"ts": ts,
@@ -765,6 +804,8 @@ async def close_stream(
765804
"type": "stream.complete",
766805
"stream_id": stream_id,
767806
}
807+
if entry_span:
808+
entry_span.end(entry)
768809
await self.enqueue_log(entry)
769810
cnt += 1
770811
return cnt
@@ -818,3 +859,18 @@ async def close(self, error: BaseException | None = None) -> None:
818859
== 0
819860
):
820861
await asyncio.sleep(0.1)
862+
863+
864+
if sys.version_info >= (3, 11):
865+
866+
def _create_task_context(
867+
coro: Coroutine[Any, Any, _T], context: contextvars.Context
868+
) -> asyncio.Task[_T]:
869+
return asyncio.create_task(coro, context=context)
870+
871+
else:
872+
873+
def _create_task_context(
874+
coro: Coroutine[Any, Any, _T], context: contextvars.Context
875+
) -> asyncio.Task[_T]:
876+
return context.run(asyncio.create_task, coro)

src/duron/_core/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
if TYPE_CHECKING:
99
from collections.abc import Callable, Coroutine
10+
from contextvars import Context
1011

1112
from duron._loop import EventLoop, OpFuture
1213
from duron.codec import JSONValue
@@ -22,6 +23,7 @@ class FnCall:
2223
args: tuple[object, ...]
2324
kwargs: dict[str, object]
2425
return_type: TypeHint[Any]
26+
context: Context
2527
metadata: dict[str, JSONValue] | None = None
2628
labels: dict[str, str] | None = None
2729

src/duron/_core/stream.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import contextlib
5+
import contextvars
56
from abc import ABC, abstractmethod
67
from asyncio.exceptions import CancelledError
78
from collections import deque
@@ -402,6 +403,7 @@ async def __aenter__(self) -> _StreamRun[_U, _T]:
402403
args=(sink,),
403404
kwargs={},
404405
return_type=Unspecified,
406+
context=contextvars.copy_context(),
405407
),
406408
)
407409
return self._stream

src/duron/_loop.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@
4646
class OpFuture(asyncio.Future[_T], Generic[_T]):
4747
__slots__: tuple[str, ...] = ("id", "params")
4848

49-
id: str
50-
params: object
51-
context: Context
52-
5349
def __init__(
5450
self,
5551
id_: str,
@@ -59,7 +55,6 @@ def __init__(
5955
super().__init__(loop=loop)
6056
self.id = id_
6157
self.params = params
62-
self.context = contextvars.copy_context()
6358

6459

6560
@dataclass(slots=True)

0 commit comments

Comments
 (0)