1414)
1515
1616from duron .event_loop import create_loop
17+ from duron .log .codec import DefaultCodec
1718from duron .log .entry import is_entry
1819from duron .ops import FnCall
1920
2021if TYPE_CHECKING :
2122 from collections .abc import AsyncGenerator , Callable
2223
2324 from duron .event_loop import WaitSet
25+ from duron .log .codec import BaseCodec
2426 from duron .log .entry import Entry , ErrorInfo , UnknownEntry
2527 from duron .log .storage import BaseLogStorage , Lease , Offset
2628 from duron .mark import DurableFn
3032
3133
3234class TaskRunner :
33- def __init__ (self ) :
34- pass
35+ def __init__ (self , codec : BaseCodec | None = None ) -> None :
36+ self . _codec : BaseCodec = codec or DefaultCodec ()
3537
3638 async def run (
3739 self ,
3840 task_id : bytes ,
3941 task_co : DurableFn [[], Coroutine [Any , Any , _T ]],
4042 log : BaseLogStorage ,
4143 ) -> _T :
42- return await _Task [_T ](task_id , task_co (), log ).run ()
44+ return await _Task [_T ](task_id , task_co (), log , self . _codec ).run ()
4345
4446
4547@final
4648class _Task (Generic [_T ]):
4749 def __init__ (
48- self , id : bytes , task_co : Coroutine [Any , Any , _T ], log : BaseLogStorage
50+ self ,
51+ id : bytes ,
52+ task_co : Coroutine [Any , Any , _T ],
53+ log : BaseLogStorage ,
54+ codec : BaseCodec ,
4955 ) -> None :
5056 self ._id = id
5157 self ._loop = create_loop (id )
5258 self ._task = self ._loop .create_task (task_co )
5359 self ._log = log
60+ self ._codec = codec
5461 self ._running : Lease | None = None
5562 self ._pending_msg : list [Entry ] = []
5663 self ._pending_task : dict [str , Callable [[], Coroutine [Any , Any , object ]]] = {}
@@ -110,7 +117,7 @@ async def run(self) -> _T:
110117 "id" : _encode_id (self ._id , True ),
111118 "ts" : _encode_timestamp (self .now ()),
112119 "promise_id" : _encode_id (self ._id , False ),
113- "result" : res ,
120+ "result" : self . _codec . encode_json ( res ) ,
114121 },
115122 )
116123 return res
@@ -121,7 +128,7 @@ async def run(self) -> _T:
121128 "id" : _encode_id (self ._id , True ),
122129 "ts" : _encode_timestamp (self .now ()),
123130 "promise_id" : _encode_id (self ._id , False ),
124- "error" : exception_to_error_info ( e ),
131+ "error" : _encode_error ( e , self . _codec ),
125132 },
126133 )
127134 raise
@@ -157,13 +164,13 @@ async def handle_message(self, e: Entry) -> None:
157164 if "error" in e :
158165 self ._loop .post_completion_threadsafe (
159166 id ,
160- exception = error_info_to_exception (e ["error" ]),
167+ exception = _decode_error (e ["error" ], self . _codec ),
161168 )
162169 self ._pending_ops .discard (id )
163170 elif "result" in e :
164171 self ._loop .post_completion_threadsafe (
165172 id ,
166- result = e ["result" ],
173+ result = self . _codec . decode_json ( e ["result" ]) ,
167174 )
168175 self ._pending_ops .discard (id )
169176 else :
@@ -200,17 +207,17 @@ async def cb() -> None:
200207 "id" : _encode_id (id , True ),
201208 "type" : "promise/complete" ,
202209 "promise_id" : _encode_id (id , False ),
203- "result" : result ,
210+ "result" : self . _codec . encode_json ( result ) ,
204211 }
205212 )
206213 except BaseException as e :
207214 await self .enqueue_log (
208215 {
216+ "type" : "promise/complete" ,
209217 "ts" : _encode_timestamp (self .now ()),
210218 "id" : _encode_id (id , True ),
211- "type" : "promise/complete" ,
212219 "promise_id" : _encode_id (id , False ),
213- "error" : exception_to_error_info ( e ),
220+ "error" : _encode_error ( e , self . _codec ),
214221 }
215222 )
216223
@@ -223,20 +230,20 @@ async def cb() -> None:
223230 raise NotImplementedError (f"Unsupported op: { op !r} " )
224231
225232
226- def _encode_id (id : bytes , end : bool ) -> str :
227- if end :
233+ def _encode_id (id : bytes , is_end : bool ) -> str :
234+ if is_end :
228235 return base64 .b64encode (id ).decode () + "-"
229236 else :
230237 return base64 .b64encode (id ).decode () + "+"
231238
232239
233- def _decode_id (s : str ) -> tuple [bytes , bool ]:
234- if s .endswith ("-" ):
235- return base64 .b64decode (s [:- 1 ]), True
236- elif s .endswith ("+" ):
237- return base64 .b64decode (s [:- 1 ]), False
240+ def _decode_id (encoded : str ) -> tuple [bytes , bool ]:
241+ if encoded .endswith ("-" ):
242+ return base64 .b64decode (encoded [:- 1 ]), True
243+ elif encoded .endswith ("+" ):
244+ return base64 .b64decode (encoded [:- 1 ]), False
238245 else :
239- raise ValueError (f"Invalid encoded id: { s !r} " )
246+ raise ValueError (f"Invalid encoded id: { encoded !r} " )
240247
241248
242249def _encode_timestamp (ts_ns : int ) -> int :
@@ -247,12 +254,20 @@ def _decode_timestamp(ts: int) -> int:
247254 return ts * 1_000
248255
249256
250- def exception_to_error_info (e : BaseException ) -> ErrorInfo :
257+ def _encode_error (error : BaseException , codec : BaseCodec ) -> ErrorInfo :
258+ """Convert exception to ErrorInfo dict."""
251259 return {
252- "code" : 1 ,
253- "message" : str (e ),
260+ "code" : - 1 ,
261+ "message" : str (error ),
262+ "state" : codec .encode_state (error ),
254263 }
255264
256265
257- def error_info_to_exception (e : ErrorInfo ) -> Exception :
258- return Exception (f"[{ e ['code' ]} ] { e ['message' ]} " )
266+ def _decode_error (error_info : ErrorInfo , codec : BaseCodec ) -> BaseException :
267+ """Convert ErrorInfo dict to exception."""
268+ try :
269+ if "state" not in error_info :
270+ return Exception (f"[{ error_info ['code' ]} ] { error_info ['message' ]} " )
271+ return cast ("BaseException" , codec .decode_state (error_info ["state" ]))
272+ except Exception :
273+ return Exception (f"[{ error_info ['code' ]} ] { error_info ['message' ]} " )
0 commit comments