22
33import asyncio
44import binascii
5+ from contextlib import contextmanager
56from contextvars import ContextVar
67from random import Random
78from typing import TYPE_CHECKING , cast
2122from duron .typing import inspect_function
2223
2324if TYPE_CHECKING :
24- from collections .abc import Callable , Coroutine
25+ from collections .abc import Callable , Coroutine , Generator
2526 from contextvars import Token
2627 from types import TracebackType
2728
28- from duron ._core .options import RunOptions
2929 from duron ._core .signal import Signal , SignalWriter
3030 from duron ._core .stream import Stream , StreamWriter
3131 from duron ._decorator .fn import Fn
3838 _P = ParamSpec ("_P" )
3939
4040_context : ContextVar [Context | None ] = ContextVar ("duron_context" , default = None )
41+ _metadata : ContextVar [dict [str , JSONValue ] | None ] = ContextVar (
42+ "duron_metadata" , default = None
43+ )
4144
4245
4346@final
@@ -76,7 +79,6 @@ def current() -> Context:
7679 async def run (
7780 self ,
7881 fn : Callable [_P , Coroutine [Any , Any , _T ]] | Op [_P , _T ],
79- options : RunOptions | None = ...,
8082 / ,
8183 * args : _P .args ,
8284 ** kwargs : _P .kwargs ,
@@ -85,7 +87,6 @@ async def run(
8587 async def run (
8688 self ,
8789 fn : Callable [_P , _T ] | CheckpointOp [_P , _T , Any ],
88- options : RunOptions | None = ...,
8990 / ,
9091 * args : _P .args ,
9192 ** kwargs : _P .kwargs ,
@@ -95,7 +96,6 @@ async def run(
9596 fn : Callable [_P , Coroutine [Any , Any , _T ] | _T ]
9697 | Op [_P , _T ]
9798 | CheckpointOp [_P , _T , Any ],
98- options : RunOptions | None = None ,
9999 / ,
100100 * args : _P .args ,
101101 ** kwargs : _P .kwargs ,
@@ -105,7 +105,7 @@ async def run(
105105 raise RuntimeError (msg )
106106
107107 if isinstance (fn , CheckpointOp ):
108- async with self .run_stream (fn , options , * args , ** kwargs ) as stream :
108+ async with self .run_stream (fn , * args , ** kwargs ) as stream :
109109 await stream .discard ()
110110 return await stream
111111
@@ -123,20 +123,18 @@ async def run(
123123 args = args ,
124124 kwargs = kwargs ,
125125 return_type = return_type ,
126- metadata = _merge ( options . metadata if options else None , metadata ),
126+ metadata = self . _get_metadata ( metadata ),
127127 ),
128128 )
129129 return cast ("_T" , await op )
130130
131131 def run_stream (
132132 self ,
133133 fn : CheckpointOp [_P , _T , _S ],
134- options : RunOptions | None = None ,
135134 / ,
136135 * args : _P .args ,
137136 ** kwargs : _P .kwargs ,
138137 ) -> AsyncContextManager [Stream [_S , _T ]]:
139- _ = options
140138 if asyncio .get_running_loop () is not self ._loop :
141139 msg = "Context time can only be used in the context loop"
142140 raise RuntimeError (msg )
@@ -155,7 +153,6 @@ async def create_stream(
155153 dtype : TypeHint [_T ],
156154 * ,
157155 external : bool = False ,
158- metadata : dict [str , JSONValue ] | None = None ,
159156 ) -> tuple [Stream [_T , None ], StreamWriter [_T ]]:
160157 if asyncio .get_running_loop () is not self ._loop :
161158 msg = "Context time can only be used in the context loop"
@@ -164,32 +161,28 @@ async def create_stream(
164161 self ._loop ,
165162 dtype ,
166163 external = external ,
167- metadata = metadata ,
164+ metadata = self . _get_metadata ( None ) ,
168165 )
169166
170167 async def create_signal (
171168 self ,
172169 dtype : TypeHint [_T ],
173- * ,
174- metadata : dict [str , JSONValue ] | None = None ,
175170 ) -> tuple [Signal [_T ], SignalWriter [_T ]]:
176171 if asyncio .get_running_loop () is not self ._loop :
177172 msg = "Context time can only be used in the context loop"
178173 raise RuntimeError (msg )
179- return await create_signal (self ._loop , dtype , metadata = metadata )
174+ return await create_signal (self ._loop , dtype , metadata = self . _get_metadata ( None ) )
180175
181176 async def create_promise (
182177 self ,
183178 dtype : type [_T ],
184- * ,
185- metadata : dict [str , JSONValue ] | None = None ,
186179 ) -> tuple [str , asyncio .Future [_T ]]:
187180 if asyncio .get_running_loop () is not self ._loop :
188181 msg = "Context time can only be used in the context loop"
189182 raise RuntimeError (msg )
190183 fut = create_op (
191184 self ._loop ,
192- ExternalPromiseCreate (metadata = metadata , return_type = dtype ),
185+ ExternalPromiseCreate (metadata = self . _get_metadata ( None ) , return_type = dtype ),
193186 )
194187 return (
195188 binascii .b2a_base64 (fut .id , newline = False ).decode (),
@@ -220,12 +213,30 @@ def random(self) -> Random:
220213 raise RuntimeError (msg )
221214 return Random (self ._loop .generate_op_id ()) # noqa: S311
222215
216+ @contextmanager
217+ def metadata (self , metadata : dict [str , JSONValue ]) -> Generator [None , None , None ]:
218+ if asyncio .get_running_loop () is not self ._loop :
219+ msg = "Context time can only be used in the context loop"
220+ raise RuntimeError (msg )
221+ if not metadata :
222+ yield
223+ return
224+
225+ current = _metadata .get ()
226+ merged = {** current , ** metadata } if current is not None else metadata
227+ token = _metadata .set (merged )
228+ try :
229+ yield
230+ finally :
231+ _metadata .reset (token )
223232
224- def _merge (
225- d1 : dict [str , JSONValue ] | None , d2 : dict [str , JSONValue ] | None
226- ) -> dict [str , JSONValue ] | None :
227- if d1 is None :
228- return d2
229- if d2 is None :
230- return d1
231- return {** d1 , ** d2 }
233+ @staticmethod
234+ def _get_metadata (
235+ merge : dict [str , JSONValue ] | None ,
236+ ) -> dict [str , JSONValue ] | None :
237+ current = _metadata .get ()
238+ if merge is None :
239+ return current
240+ if current is None :
241+ return merge
242+ return {** current , ** merge }
0 commit comments