44import contextvars
55from contextlib import contextmanager
66from contextvars import ContextVar
7- from dataclasses import dataclass
87from random import Random
98from typing import TYPE_CHECKING , cast
109from typing_extensions import (
1615 overload ,
1716)
1817
19- from duron ._core .ops import Barrier , ExternalPromiseCreate , FnCall , create_op
18+ from duron ._core .ops import (
19+ Barrier ,
20+ ExternalPromiseCreate ,
21+ FnCall ,
22+ OpAnnotations ,
23+ create_op ,
24+ )
2025from duron ._core .signal import create_signal
2126from duron ._core .stream import create_stream , run_stream
22- from duron ._decorator .op import CheckpointOp , Op
23- from duron ._util .linked_dict import LinkedDict
27+ from duron ._decorator .effect import CheckpointFn , EffectFn
2428from duron .typing import inspect_function
2529
2630if TYPE_CHECKING :
3034
3135 from duron ._core .signal import Signal , SignalWriter
3236 from duron ._core .stream import Stream , StreamWriter
33- from duron ._decorator .fn import Fn
37+ from duron ._decorator .durable import DurableFn
3438 from duron ._loop import EventLoop
3539 from duron .codec import JSONValue
3640 from duron .typing import TypeHint
4044 _P = ParamSpec ("_P" )
4145
4246_context : ContextVar [Context | None ] = ContextVar ("duron.context" , default = None )
43-
44-
45- @final
46- @dataclass (slots = True )
47- class Annotation :
48- metadata : LinkedDict [str , JSONValue ]
49- labels : LinkedDict [str , str ]
50-
51-
52- _annotation : ContextVar [Annotation | None ] = ContextVar (
47+ _annotation : ContextVar [OpAnnotations | None ] = ContextVar (
5348 "duron.context.annotation" , default = None
5449)
5550
@@ -58,7 +53,7 @@ class Annotation:
5853class Context :
5954 __slots__ = ("_fn" , "_loop" , "_token" )
6055
61- def __init__ (self , task : Fn [..., object ], loop : EventLoop ) -> None :
56+ def __init__ (self , task : DurableFn [..., object ], loop : EventLoop ) -> None :
6257 self ._loop : EventLoop = loop
6358 self ._fn = task
6459 self ._token : Token [Context | None ] | None = None
@@ -89,24 +84,24 @@ def current() -> Context:
8984 @overload
9085 async def run (
9186 self ,
92- fn : Callable [_P , Coroutine [Any , Any , _T ]] | Op [_P , _T ],
87+ fn : Callable [_P , Coroutine [Any , Any , _T ]] | EffectFn [_P , _T ],
9388 / ,
9489 * args : _P .args ,
9590 ** kwargs : _P .kwargs ,
9691 ) -> _T : ...
9792 @overload
9893 async def run (
9994 self ,
100- fn : Callable [_P , _T ] | CheckpointOp [_P , _T , Any ],
95+ fn : Callable [_P , _T ] | CheckpointFn [_P , _T , Any ],
10196 / ,
10297 * args : _P .args ,
10398 ** kwargs : _P .kwargs ,
10499 ) -> _T : ...
105100 async def run (
106101 self ,
107102 fn : Callable [_P , Coroutine [Any , Any , _T ] | _T ]
108- | Op [_P , _T ]
109- | CheckpointOp [_P , _T , Any ],
103+ | EffectFn [_P , _T ]
104+ | CheckpointFn [_P , _T , Any ],
110105 / ,
111106 * args : _P .args ,
112107 ** kwargs : _P .kwargs ,
@@ -115,39 +110,38 @@ async def run(
115110 msg = "Context time can only be used in the context loop"
116111 raise RuntimeError (msg )
117112
118- if isinstance (fn , CheckpointOp ):
113+ if isinstance (fn , CheckpointFn ):
119114 async with self .run_stream (
120- cast ("CheckpointOp [_P, _T, Any]" , fn ), * args , ** kwargs
115+ cast ("CheckpointFn [_P, _T, Any]" , fn ), * args , ** kwargs
121116 ) as stream :
122117 await stream .discard ()
123118 return await stream
124119
125- if isinstance (fn , Op ):
120+ if isinstance (fn , EffectFn ):
126121 return_type = fn .return_type
127- metadata = fn .metadata
128122 else :
129123 return_type = inspect_function (fn ).return_type
130- metadata = None
131124
132- callable_ = fn .fn if isinstance (fn , Op ) else fn
125+ callable_ = fn .fn if isinstance (fn , EffectFn ) else fn
133126 op = create_op (
134127 self ._loop ,
135128 FnCall (
136129 callable = callable_ ,
137- name = cast ("str" , getattr (callable_ , "__name__" , repr (callable_ ))),
138130 args = args ,
139131 kwargs = kwargs ,
140132 return_type = return_type ,
141133 context = contextvars .copy_context (),
142- metadata = self ._get_metadata (metadata ),
143- labels = self ._get_labels (None ),
134+ annotations = OpAnnotations .extend (
135+ _annotation .get (),
136+ name = cast ("str" , getattr (callable_ , "__name__" , repr (callable_ ))),
137+ ),
144138 ),
145139 )
146140 return cast ("_T" , await op )
147141
148142 def run_stream (
149143 self ,
150- fn : CheckpointOp [_P , _T , _S ],
144+ fn : CheckpointFn [_P , _T , _S ],
151145 / ,
152146 * args : _P .args ,
153147 ** kwargs : _P .kwargs ,
@@ -180,8 +174,11 @@ async def create_stream(
180174 self ._loop ,
181175 dtype ,
182176 external = external ,
183- metadata = self ._get_metadata (None ),
184- labels = self ._get_labels ({"name" : name } if name else None ),
177+ annotations = OpAnnotations .extend (
178+ _annotation .get (),
179+ name = name ,
180+ labels = {"name" : name } if name else None ,
181+ ),
185182 )
186183
187184 async def create_signal (
@@ -193,23 +190,27 @@ async def create_signal(
193190 return await create_signal (
194191 self ._loop ,
195192 dtype ,
196- metadata = self ._get_metadata (None ),
197- labels = self ._get_labels ({"name" : name } if name else None ),
193+ annotations = OpAnnotations .extend (
194+ _annotation .get (),
195+ labels = {"name" : name } if name else None ,
196+ ),
198197 )
199198
200199 async def create_promise (
201- self ,
202- dtype : type [_T ],
200+ self , dtype : type [_T ], / , * , name : str | None = None
203201 ) -> tuple [str , asyncio .Future [_T ]]:
204202 if asyncio .get_running_loop () is not self ._loop :
205203 msg = "Context time can only be used in the context loop"
206204 raise RuntimeError (msg )
207205 fut = create_op (
208206 self ._loop ,
209207 ExternalPromiseCreate (
210- metadata = self ._get_metadata (None ),
211208 return_type = dtype ,
212- labels = self ._get_labels (None ),
209+ annotations = OpAnnotations .extend (
210+ _annotation .get (),
211+ name = name ,
212+ labels = {"name" : name } if name else None ,
213+ ),
213214 ),
214215 )
215216 return (
@@ -257,40 +258,9 @@ def annotate(
257258
258259 current = _annotation .get ()
259260 token = _annotation .set (
260- Annotation (
261- metadata = current .metadata .extend (metadata )
262- if current
263- else LinkedDict (metadata ),
264- labels = current .labels .extend (labels ) if current else LinkedDict (labels ),
265- )
261+ OpAnnotations .extend (current , metadata = metadata , labels = labels )
266262 )
267263 try :
268264 yield
269265 finally :
270266 _annotation .reset (token )
271-
272- @staticmethod
273- def _get_metadata (
274- merge : dict [str , JSONValue ] | None ,
275- ) -> dict [str , JSONValue ] | None :
276- anno = _annotation .get ()
277- current = anno .metadata if anno else None
278- if current is None :
279- return merge
280- if merge :
281- return current .extend (merge ).materialize ()
282- return current .materialize ()
283-
284- @staticmethod
285- def _get_labels (
286- merge : dict [str , str ] | None ,
287- ) -> dict [str , str ] | None :
288- anno = _annotation .get ()
289- current = anno .labels if anno else None
290- if merge :
291- if current is None :
292- return merge
293- return current .extend (merge ).materialize ()
294- if current is None :
295- return None
296- return current .materialize ()
0 commit comments