Skip to content

Commit 01e3109

Browse files
committed
feat: metadata context manager
1 parent 4867090 commit 01e3109

10 files changed

Lines changed: 94 additions & 74 deletions

File tree

examples/agent.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ async def agent_fn(
111111
for tool_call in result.choices[0].message.tool_calls:
112112
await output.send(("call", tool_call.model_dump_json()))
113113
tasks.append(
114-
asyncio.create_task(ctx.run(call_tool, None, tool_call))
114+
asyncio.create_task(ctx.run(call_tool, tool_call))
115115
)
116116
for id_, tool_result in await asyncio.gather(*tasks):
117117
await output.send(("tool", tool_result))
@@ -233,7 +233,6 @@ async def _completion(
233233

234234
return await ctx.run(
235235
_completion,
236-
None,
237236
messages,
238237
)
239238

examples/hello_world.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ async def generate_lucky_number() -> int:
2828
@duron.fn
2929
async def greeting_flow(ctx: duron.Context, name: str) -> str:
3030
message, lucky_number = await asyncio.gather(
31-
ctx.run(work, None, name), ctx.run(generate_lucky_number)
31+
ctx.run(work, name), ctx.run(generate_lucky_number)
3232
)
3333
return f"{message} Your lucky number is {lucky_number}."
3434

src/duron/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from ._core.config import set_config as set_config
22
from ._core.context import Context as Context
3-
from ._core.options import RunOptions as RunOptions
43
from ._core.signal import Signal as Signal
54
from ._core.signal import SignalInterrupt as SignalInterrupt
65
from ._core.signal import SignalWriter as SignalWriter

src/duron/_core/config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
from dataclasses import dataclass
45
from typing import TYPE_CHECKING
56

@@ -12,9 +13,13 @@
1213
@dataclass(slots=True)
1314
class _Config:
1415
codec: Codec
16+
debug: bool
1517

1618

17-
config = _Config(codec=DefaultCodec())
19+
config = _Config(
20+
codec=DefaultCodec(),
21+
debug=os.getenv("DURON_DEBUG", "0").lower() in {"1", "true"},
22+
)
1823

1924

2025
def set_config(

src/duron/_core/context.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import binascii
5+
from contextlib import contextmanager
56
from contextvars import ContextVar
67
from random import Random
78
from typing import TYPE_CHECKING, cast
@@ -21,11 +22,10 @@
2122
from duron.typing import inspect_function
2223

2324
if TYPE_CHECKING:
24-
from collections.abc import Callable, Coroutine
25+
from collections.abc import Callable, Coroutine, Generator
2526
from contextvars import Token
2627
from types import TracebackType
2728

28-
from duron._core.options import RunOptions
2929
from duron._core.signal import Signal, SignalWriter
3030
from duron._core.stream import Stream, StreamWriter
3131
from duron._decorator.fn import Fn
@@ -38,6 +38,9 @@
3838
_P = ParamSpec("_P")
3939

4040
_context: ContextVar[Context | None] = ContextVar("duron_context", default=None)
41+
_metadata: ContextVar[dict[str, JSONValue] | None] = ContextVar(
42+
"duron_metadata", default=None
43+
)
4144

4245

4346
@final
@@ -76,7 +79,6 @@ def current() -> Context:
7679
async def run(
7780
self,
7881
fn: Callable[_P, Coroutine[Any, Any, _T]] | Op[_P, _T],
79-
options: RunOptions | None = ...,
8082
/,
8183
*args: _P.args,
8284
**kwargs: _P.kwargs,
@@ -85,7 +87,6 @@ async def run(
8587
async def run(
8688
self,
8789
fn: Callable[_P, _T] | CheckpointOp[_P, _T, Any],
88-
options: RunOptions | None = ...,
8990
/,
9091
*args: _P.args,
9192
**kwargs: _P.kwargs,
@@ -95,7 +96,6 @@ async def run(
9596
fn: Callable[_P, Coroutine[Any, Any, _T] | _T]
9697
| Op[_P, _T]
9798
| CheckpointOp[_P, _T, Any],
98-
options: RunOptions | None = None,
9999
/,
100100
*args: _P.args,
101101
**kwargs: _P.kwargs,
@@ -105,7 +105,7 @@ async def run(
105105
raise RuntimeError(msg)
106106

107107
if isinstance(fn, CheckpointOp):
108-
async with self.run_stream(fn, options, *args, **kwargs) as stream:
108+
async with self.run_stream(fn, *args, **kwargs) as stream:
109109
await stream.discard()
110110
return await stream
111111

@@ -123,20 +123,18 @@ async def run(
123123
args=args,
124124
kwargs=kwargs,
125125
return_type=return_type,
126-
metadata=_merge(options.metadata if options else None, metadata),
126+
metadata=self._get_metadata(metadata),
127127
),
128128
)
129129
return cast("_T", await op)
130130

131131
def run_stream(
132132
self,
133133
fn: CheckpointOp[_P, _T, _S],
134-
options: RunOptions | None = None,
135134
/,
136135
*args: _P.args,
137136
**kwargs: _P.kwargs,
138137
) -> AsyncContextManager[Stream[_S, _T]]:
139-
_ = options
140138
if asyncio.get_running_loop() is not self._loop:
141139
msg = "Context time can only be used in the context loop"
142140
raise RuntimeError(msg)
@@ -155,7 +153,6 @@ async def create_stream(
155153
dtype: TypeHint[_T],
156154
*,
157155
external: bool = False,
158-
metadata: dict[str, JSONValue] | None = None,
159156
) -> tuple[Stream[_T, None], StreamWriter[_T]]:
160157
if asyncio.get_running_loop() is not self._loop:
161158
msg = "Context time can only be used in the context loop"
@@ -164,32 +161,28 @@ async def create_stream(
164161
self._loop,
165162
dtype,
166163
external=external,
167-
metadata=metadata,
164+
metadata=self._get_metadata(None),
168165
)
169166

170167
async def create_signal(
171168
self,
172169
dtype: TypeHint[_T],
173-
*,
174-
metadata: dict[str, JSONValue] | None = None,
175170
) -> tuple[Signal[_T], SignalWriter[_T]]:
176171
if asyncio.get_running_loop() is not self._loop:
177172
msg = "Context time can only be used in the context loop"
178173
raise RuntimeError(msg)
179-
return await create_signal(self._loop, dtype, metadata=metadata)
174+
return await create_signal(self._loop, dtype, metadata=self._get_metadata(None))
180175

181176
async def create_promise(
182177
self,
183178
dtype: type[_T],
184-
*,
185-
metadata: dict[str, JSONValue] | None = None,
186179
) -> tuple[str, asyncio.Future[_T]]:
187180
if asyncio.get_running_loop() is not self._loop:
188181
msg = "Context time can only be used in the context loop"
189182
raise RuntimeError(msg)
190183
fut = create_op(
191184
self._loop,
192-
ExternalPromiseCreate(metadata=metadata, return_type=dtype),
185+
ExternalPromiseCreate(metadata=self._get_metadata(None), return_type=dtype),
193186
)
194187
return (
195188
binascii.b2a_base64(fut.id, newline=False).decode(),
@@ -220,12 +213,30 @@ def random(self) -> Random:
220213
raise RuntimeError(msg)
221214
return Random(self._loop.generate_op_id()) # noqa: S311
222215

216+
@contextmanager
217+
def metadata(self, metadata: dict[str, JSONValue]) -> Generator[None, None, None]:
218+
if asyncio.get_running_loop() is not self._loop:
219+
msg = "Context time can only be used in the context loop"
220+
raise RuntimeError(msg)
221+
if not metadata:
222+
yield
223+
return
224+
225+
current = _metadata.get()
226+
merged = {**current, **metadata} if current is not None else metadata
227+
token = _metadata.set(merged)
228+
try:
229+
yield
230+
finally:
231+
_metadata.reset(token)
223232

224-
def _merge(
225-
d1: dict[str, JSONValue] | None, d2: dict[str, JSONValue] | None
226-
) -> dict[str, JSONValue] | None:
227-
if d1 is None:
228-
return d2
229-
if d2 is None:
230-
return d1
231-
return {**d1, **d2}
233+
@staticmethod
234+
def _get_metadata(
235+
merge: dict[str, JSONValue] | None,
236+
) -> dict[str, JSONValue] | None:
237+
current = _metadata.get()
238+
if merge is None:
239+
return current
240+
if current is None:
241+
return merge
242+
return {**current, **merge}

0 commit comments

Comments
 (0)