Skip to content

Commit 18bc644

Browse files
committed
feat: add span names for op
1 parent e4a5527 commit 18bc644

14 files changed

Lines changed: 282 additions & 218 deletions

File tree

examples/agent.py

Lines changed: 53 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from duron import Defer, Signal, SignalInterrupt, Stream, StreamWriter
2424
from duron.codec import Codec
2525
from duron.contrib.storage import FileLogStorage
26-
from duron.tracing import Tracer
26+
from duron.tracing import Tracer, span
2727

2828
if TYPE_CHECKING:
2929
from duron.codec import JSONValue
@@ -65,6 +65,7 @@ async def agent_fn(
6565
},
6666
]
6767
async with input_ as inp:
68+
i = 0
6869
while True:
6970
msgs: list[str] = [msgs async for _, msgs in inp.next_nowait(ctx)]
7071
if not msgs:
@@ -76,58 +77,60 @@ async def agent_fn(
7677
"content": "\n".join(msgs),
7778
})
7879
await output.send(("user", "\n".join(msgs)))
79-
while True:
80-
try:
81-
async with signal:
82-
result = await completion(
83-
ctx,
84-
messages=history,
85-
)
86-
if result.choices[0].message.content:
87-
await output.send((
88-
"assistant",
89-
result.choices[0].message.content,
90-
))
91-
history.append({
92-
"role": "assistant",
93-
"content": result.choices[0].message.content,
94-
"tool_calls": [
95-
{
96-
"id": toolcall.id,
97-
"type": "function",
98-
"function": {
99-
"name": toolcall.function.name,
100-
"arguments": toolcall.function.arguments,
101-
},
102-
}
103-
for toolcall in result.choices[0].message.tool_calls
104-
or []
105-
if toolcall.type == "function"
106-
],
107-
})
108-
if not result.choices[0].message.tool_calls:
109-
break
110-
111-
tasks: list[asyncio.Task[tuple[str, str]]] = []
112-
for tool_call in result.choices[0].message.tool_calls:
113-
await output.send(("call", tool_call.model_dump_json()))
114-
tasks.append(
115-
asyncio.create_task(ctx.run(call_tool, tool_call))
80+
with span(f"Round #{i}"):
81+
i += 1
82+
while True:
83+
try:
84+
async with signal:
85+
result = await completion(
86+
ctx,
87+
messages=history,
11688
)
117-
for id_, tool_result in await asyncio.gather(*tasks):
118-
await output.send(("tool", tool_result))
89+
if result.choices[0].message.content:
90+
await output.send((
91+
"assistant",
92+
result.choices[0].message.content,
93+
))
11994
history.append({
120-
"role": "tool",
121-
"tool_call_id": id_,
122-
"content": tool_result,
95+
"role": "assistant",
96+
"content": result.choices[0].message.content,
97+
"tool_calls": [
98+
{
99+
"id": toolcall.id,
100+
"type": "function",
101+
"function": {
102+
"name": toolcall.function.name,
103+
"arguments": toolcall.function.arguments,
104+
},
105+
}
106+
for toolcall in result.choices[0].message.tool_calls
107+
or []
108+
if toolcall.type == "function"
109+
],
123110
})
124-
except SignalInterrupt:
125-
await output.send(("assistant", "[Interrupted]"))
126-
history.append({
127-
"role": "assistant",
128-
"content": "[Interrupted]",
129-
})
130-
break
111+
if not result.choices[0].message.tool_calls:
112+
break
113+
114+
tasks: list[asyncio.Task[tuple[str, str]]] = []
115+
for tool_call in result.choices[0].message.tool_calls:
116+
await output.send(("call", tool_call.model_dump_json()))
117+
tasks.append(
118+
asyncio.create_task(ctx.run(call_tool, tool_call))
119+
)
120+
for id_, tool_result in await asyncio.gather(*tasks):
121+
await output.send(("tool", tool_result))
122+
history.append({
123+
"role": "tool",
124+
"tool_call_id": id_,
125+
"content": tool_result,
126+
})
127+
except SignalInterrupt:
128+
await output.send(("assistant", "[Interrupted]"))
129+
history.append({
130+
"role": "assistant",
131+
"content": "[Interrupted]",
132+
})
133+
break
131134

132135

133136
@duron.op

examples/hello_world.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import random
66
import sys
7+
from collections.abc import AsyncGenerator
78
from pathlib import Path
89

910
import duron
@@ -30,11 +31,22 @@ async def generate_lucky_number() -> int:
3031
return random.randint(1, 100)
3132

3233

34+
@duron.op(checkpoint=True, initial=lambda: 0, reducer=lambda a, _b: a + 10)
35+
async def count_up(count: int, target: int) -> AsyncGenerator[None, int]:
36+
print("⚡ Counting...")
37+
await asyncio.sleep(0.5)
38+
while count < target:
39+
count = yield
40+
print(f"⚡ Current count: {count}")
41+
await asyncio.sleep(0.05)
42+
43+
3344
@duron.fn
3445
async def greeting_flow(ctx: duron.Context, name: str) -> str:
3546
message, lucky_number = await asyncio.gather(
3647
ctx.run(work, name), ctx.run(generate_lucky_number)
3748
)
49+
_ = await ctx.run(count_up, lucky_number)
3850
return f"{message} Your lucky number is {lucky_number}."
3951

4052

src/duron/_core/context.py

Lines changed: 52 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import contextvars
55
from contextlib import contextmanager
66
from contextvars import ContextVar
7+
from dataclasses import dataclass
78
from random import Random
89
from typing import TYPE_CHECKING, cast
910
from typing_extensions import (
@@ -19,10 +20,11 @@
1920
from duron._core.signal import create_signal
2021
from duron._core.stream import create_stream, run_stream
2122
from duron._decorator.op import CheckpointOp, Op
23+
from duron._util.linked_dict import LinkedDict
2224
from duron.typing import inspect_function
2325

2426
if TYPE_CHECKING:
25-
from collections.abc import Callable, Coroutine, Generator
27+
from collections.abc import Callable, Coroutine, Generator, Mapping
2628
from contextvars import Token
2729
from types import TracebackType
2830

@@ -38,10 +40,18 @@
3840
_P = ParamSpec("_P")
3941

4042
_context: ContextVar[Context | None] = ContextVar("duron.context", default=None)
41-
_metadata: ContextVar[dict[str, JSONValue] | None] = ContextVar(
42-
"duron.metadata", default=None
43+
44+
45+
@final
46+
@dataclass(slots=True)
47+
class Annotation:
48+
metadata: LinkedDict[str, JSONValue]
49+
labels: LinkedDict[str, str]
50+
51+
52+
_annotation: ContextVar[Annotation | None] = ContextVar(
53+
"duron.context.annotation", default=None
4354
)
44-
_labels: ContextVar[dict[str, str] | None] = ContextVar("duron.labels", default=None)
4555

4656

4757
@final
@@ -119,10 +129,12 @@ async def run(
119129
return_type = inspect_function(fn).return_type
120130
metadata = None
121131

132+
callable_ = fn.fn if isinstance(fn, Op) else fn
122133
op = create_op(
123134
self._loop,
124135
FnCall(
125-
callable=fn.fn if isinstance(fn, Op) else fn,
136+
callable=callable_,
137+
name=cast("str", getattr(callable_, "__name__", repr(callable_))),
126138
args=args,
127139
kwargs=kwargs,
128140
return_type=return_type,
@@ -156,7 +168,9 @@ def run_stream(
156168
async def create_stream(
157169
self,
158170
dtype: TypeHint[_T],
171+
/,
159172
*,
173+
name: str | None = None,
160174
external: bool = False,
161175
) -> tuple[Stream[_T, None], StreamWriter[_T]]:
162176
if asyncio.get_running_loop() is not self._loop:
@@ -167,12 +181,11 @@ async def create_stream(
167181
dtype,
168182
external=external,
169183
metadata=self._get_metadata(None),
170-
labels=self._get_labels(None),
184+
labels=self._get_labels({"name": name} if name else None),
171185
)
172186

173187
async def create_signal(
174-
self,
175-
dtype: TypeHint[_T],
188+
self, dtype: TypeHint[_T], /, *, name: str | None = None
176189
) -> tuple[Signal[_T], SignalWriter[_T]]:
177190
if asyncio.get_running_loop() is not self._loop:
178191
msg = "Context time can only be used in the context loop"
@@ -181,7 +194,7 @@ async def create_signal(
181194
self._loop,
182195
dtype,
183196
metadata=self._get_metadata(None),
184-
labels=self._get_labels(None),
197+
labels=self._get_labels({"name": name} if name else None),
185198
)
186199

187200
async def create_promise(
@@ -229,57 +242,55 @@ def random(self) -> Random:
229242
return Random(self._loop.generate_op_id()) # noqa: S311
230243

231244
@contextmanager
232-
def metadata(self, metadata: dict[str, JSONValue]) -> Generator[None, None, None]:
233-
if asyncio.get_running_loop() is not self._loop:
234-
msg = "Context time can only be used in the context loop"
235-
raise RuntimeError(msg)
236-
if not metadata:
237-
yield
238-
return
239-
240-
current = _metadata.get()
241-
merged = {**current, **metadata} if current is not None else metadata
242-
token = _metadata.set(merged)
243-
try:
244-
yield
245-
finally:
246-
_metadata.reset(token)
247-
248-
@contextmanager
249-
def labels(self, labels: dict[str, str]) -> Generator[None, None, None]:
245+
def annotate(
246+
self,
247+
*,
248+
labels: Mapping[str, str] | None = None,
249+
metadata: Mapping[str, JSONValue] | None = None,
250+
) -> Generator[None, None, None]:
250251
if asyncio.get_running_loop() is not self._loop:
251252
msg = "Context labels can only be used in the context loop"
252253
raise RuntimeError(msg)
253-
if not labels:
254+
if not labels and not metadata:
254255
yield
255256
return
256257

257-
current = _labels.get()
258-
merged = {**current, **labels} if current is not None else labels
259-
token = _labels.set(merged)
258+
current = _annotation.get()
259+
token = _annotation.set(
260+
Annotation(
261+
metadata=current.metadata.extend(metadata)
262+
if current
263+
else LinkedDict(metadata),
264+
labels=current.labels.extend(labels) if current else LinkedDict(labels),
265+
)
266+
)
260267
try:
261268
yield
262269
finally:
263-
_labels.reset(token)
270+
_annotation.reset(token)
264271

265272
@staticmethod
266273
def _get_metadata(
267274
merge: dict[str, JSONValue] | None,
268275
) -> dict[str, JSONValue] | None:
269-
current = _metadata.get()
270-
if merge is None:
271-
return current
276+
anno = _annotation.get()
277+
current = anno.metadata if anno else None
272278
if current is None:
273279
return merge
274-
return {**current, **merge}
280+
if merge:
281+
return current.extend(merge).materialize()
282+
return current.materialize()
275283

276284
@staticmethod
277285
def _get_labels(
278286
merge: dict[str, str] | None,
279287
) -> dict[str, str] | None:
280-
current = _labels.get()
281-
if merge is None:
282-
return current
288+
anno = _annotation.get()
289+
current = anno.labels if anno else None
290+
if merge:
291+
if current is None:
292+
return merge
293+
return current.extend(merge).materialize()
283294
if current is None:
284-
return merge
285-
return {**current, **merge}
295+
return None
296+
return current.materialize()

0 commit comments

Comments
 (0)