44import contextvars
55from contextlib import contextmanager
66from contextvars import ContextVar
7+ from dataclasses import dataclass
78from random import Random
89from typing import TYPE_CHECKING , cast
910from typing_extensions import (
1920from duron ._core .signal import create_signal
2021from duron ._core .stream import create_stream , run_stream
2122from duron ._decorator .op import CheckpointOp , Op
23+ from duron ._util .linked_dict import LinkedDict
2224from duron .typing import inspect_function
2325
2426if TYPE_CHECKING :
25- from collections .abc import Callable , Coroutine , Generator
27+ from collections .abc import Callable , Coroutine , Generator , Mapping
2628 from contextvars import Token
2729 from types import TracebackType
2830
3840 _P = ParamSpec ("_P" )
3941
4042_context : ContextVar [Context | None ] = ContextVar ("duron.context" , default = None )
41- _metadata : ContextVar [dict [str , JSONValue ] | None ] = ContextVar (
42- "duron.metadata" , default = None
43+
44+
45+ @final
46+ @dataclass (slots = True )
47+ class Annotation :
48+ metadata : LinkedDict [str , JSONValue ]
49+ labels : LinkedDict [str , str ]
50+
51+
52+ _annotation : ContextVar [Annotation | None ] = ContextVar (
53+ "duron.context.annotation" , default = None
4354)
44- _labels : ContextVar [dict [str , str ] | None ] = ContextVar ("duron.labels" , default = None )
4555
4656
4757@final
@@ -119,10 +129,12 @@ async def run(
119129 return_type = inspect_function (fn ).return_type
120130 metadata = None
121131
132+ callable_ = fn .fn if isinstance (fn , Op ) else fn
122133 op = create_op (
123134 self ._loop ,
124135 FnCall (
125- callable = fn .fn if isinstance (fn , Op ) else fn ,
136+ callable = callable_ ,
137+ name = cast ("str" , getattr (callable_ , "__name__" , repr (callable_ ))),
126138 args = args ,
127139 kwargs = kwargs ,
128140 return_type = return_type ,
@@ -156,7 +168,9 @@ def run_stream(
156168 async def create_stream (
157169 self ,
158170 dtype : TypeHint [_T ],
171+ / ,
159172 * ,
173+ name : str | None = None ,
160174 external : bool = False ,
161175 ) -> tuple [Stream [_T , None ], StreamWriter [_T ]]:
162176 if asyncio .get_running_loop () is not self ._loop :
@@ -167,12 +181,11 @@ async def create_stream(
167181 dtype ,
168182 external = external ,
169183 metadata = self ._get_metadata (None ),
170- labels = self ._get_labels (None ),
184+ labels = self ._get_labels ({ "name" : name } if name else None ),
171185 )
172186
173187 async def create_signal (
174- self ,
175- dtype : TypeHint [_T ],
188+ self , dtype : TypeHint [_T ], / , * , name : str | None = None
176189 ) -> tuple [Signal [_T ], SignalWriter [_T ]]:
177190 if asyncio .get_running_loop () is not self ._loop :
178191 msg = "Context time can only be used in the context loop"
@@ -181,7 +194,7 @@ async def create_signal(
181194 self ._loop ,
182195 dtype ,
183196 metadata = self ._get_metadata (None ),
184- labels = self ._get_labels (None ),
197+ labels = self ._get_labels ({ "name" : name } if name else None ),
185198 )
186199
187200 async def create_promise (
@@ -229,57 +242,55 @@ def random(self) -> Random:
229242 return Random (self ._loop .generate_op_id ()) # noqa: S311
230243
231244 @contextmanager
232- def metadata (self , metadata : dict [str , JSONValue ]) -> Generator [None , None , None ]:
233- if asyncio .get_running_loop () is not self ._loop :
234- msg = "Context time can only be used in the context loop"
235- raise RuntimeError (msg )
236- if not metadata :
237- yield
238- return
239-
240- current = _metadata .get ()
241- merged = {** current , ** metadata } if current is not None else metadata
242- token = _metadata .set (merged )
243- try :
244- yield
245- finally :
246- _metadata .reset (token )
247-
248- @contextmanager
249- def labels (self , labels : dict [str , str ]) -> Generator [None , None , None ]:
245+ def annotate (
246+ self ,
247+ * ,
248+ labels : Mapping [str , str ] | None = None ,
249+ metadata : Mapping [str , JSONValue ] | None = None ,
250+ ) -> Generator [None , None , None ]:
250251 if asyncio .get_running_loop () is not self ._loop :
251252 msg = "Context labels can only be used in the context loop"
252253 raise RuntimeError (msg )
253- if not labels :
254+ if not labels and not metadata :
254255 yield
255256 return
256257
257- current = _labels .get ()
258- merged = {** current , ** labels } if current is not None else labels
259- token = _labels .set (merged )
258+ current = _annotation .get ()
259+ token = _annotation .set (
260+ Annotation (
261+ metadata = current .metadata .extend (metadata )
262+ if current
263+ else LinkedDict (metadata ),
264+ labels = current .labels .extend (labels ) if current else LinkedDict (labels ),
265+ )
266+ )
260267 try :
261268 yield
262269 finally :
263- _labels .reset (token )
270+ _annotation .reset (token )
264271
265272 @staticmethod
266273 def _get_metadata (
267274 merge : dict [str , JSONValue ] | None ,
268275 ) -> dict [str , JSONValue ] | None :
269- current = _metadata .get ()
270- if merge is None :
271- return current
276+ anno = _annotation .get ()
277+ current = anno .metadata if anno else None
272278 if current is None :
273279 return merge
274- return {** current , ** merge }
280+ if merge :
281+ return current .extend (merge ).materialize ()
282+ return current .materialize ()
275283
276284 @staticmethod
277285 def _get_labels (
278286 merge : dict [str , str ] | None ,
279287 ) -> dict [str , str ] | None :
280- current = _labels .get ()
281- if merge is None :
282- return current
288+ anno = _annotation .get ()
289+ current = anno .labels if anno else None
290+ if merge :
291+ if current is None :
292+ return merge
293+ return current .extend (merge ).materialize ()
283294 if current is None :
284- return merge
285- return { ** current , ** merge }
295+ return None
296+ return current . materialize ()
0 commit comments