1111 Any ,
1212 Generic ,
1313 ParamSpec ,
14- TypedDict ,
1514 TypeVar ,
1615 cast ,
1716 final ,
1817)
1918
20- from duron .codec import Codec
19+ from typing_extensions import TypedDict
20+
21+ from duron .codec import Codec , JSONValue
2122from duron .context import Context
2223from duron .event_loop import EventLoop , create_loop
2324from duron .log import is_entry
3334 from duron .log import (
3435 Entry ,
3536 ErrorInfo ,
36- JSONValue ,
3737 LogStorage ,
3838 PromiseCompleteEntry ,
3939 )
@@ -73,14 +73,16 @@ def __init__(
7373 self ._run : _TaskRun | None = None
7474
7575 async def start (self , * args : _P .args , ** kwargs : _P .kwargs ) -> None :
76+ def get_init () -> TaskInitParams :
77+ return {
78+ "version" : _CURRENT_VERSION ,
79+ "args" : [codec .encode_json (arg ) for arg in args ],
80+ "kwargs" : {k : codec .encode_json (v ) for k , v in kwargs .items ()},
81+ }
82+
7683 codec = self ._task_fn .codec
77- init : TaskInitParams = {
78- "version" : _CURRENT_VERSION ,
79- "args" : [codec .encode_json (arg ) for arg in args ],
80- "kwargs" : {k : codec .encode_json (v ) for k , v in kwargs .items ()},
81- }
8284 type_info = codec .inspect_function (self ._task_fn .fn )
83- task_prelude = _task_prelude (self ._task_fn , type_info , lambda : init )
85+ task_prelude = _task_prelude (self ._task_fn , type_info , get_init )
8486 self ._run = _TaskRun (
8587 TaskRun (task = task_prelude , return_type = type_info .return_type ),
8688 self ._log ,
@@ -162,7 +164,7 @@ def __init__(
162164 self ._pending_ops : set [bytes ] = set ()
163165 self ._now = 0
164166 self ._offset : object | None = None
165- self ._tasks : dict [str , tuple [asyncio .Task [None ], type | None ]] = {}
167+ self ._tasks : dict [str , tuple [asyncio .Future [None ], type | None ]] = {}
166168
167169 def now (self ) -> int :
168170 if self ._running :
@@ -256,25 +258,34 @@ async def handle_message(self, e: Entry) -> None:
256258 pending_info = self ._pending_task .pop (e ["promise_id" ], None )
257259 task_info = self ._tasks .pop (e ["promise_id" ], None )
258260
259- # Get return type from either pending or running task
261+ id = _decode_id (e ["promise_id" ])
262+ if id not in self ._pending_ops :
263+ return
264+
260265 return_type = None
261266 if pending_info is not None :
262267 _ , return_type = pending_info
263268 elif task_info is not None :
264269 _ , return_type = task_info
270+ else :
271+ print (e )
272+ raise AssertionError ("unreachable" )
265273
266- id = _decode_id (e ["promise_id" ])
267274 if "error" in e :
268275 self ._loop .post_completion (
269276 id ,
270277 exception = _decode_error (e ["error" ]),
271278 )
272279 self ._pending_ops .discard (id )
273280 elif "result" in e :
274- self ._loop .post_completion (
275- id ,
276- result = self ._codec .decode_json (e ["result" ], return_type ),
277- )
281+ try :
282+ result = self ._codec .decode_json (e ["result" ], return_type )
283+ self ._loop .post_completion (id , result = result )
284+ except BaseException as exc :
285+ self ._loop .post_completion (
286+ id ,
287+ exception = exc ,
288+ )
278289 self ._pending_ops .discard (id )
279290 else :
280291 raise ValueError (f"Invalid promise/complete entry: { e !r} " )
@@ -339,6 +350,8 @@ def done(f: OpFuture) -> None:
339350 "type" : "promise/create" ,
340351 })
341352
353+ fut_host : asyncio .Future [None ] = asyncio .Future ()
354+
342355 async def cb () -> None :
343356 entry : PromiseCompleteEntry = {
344357 "ts" : _encode_timestamp (self .now ()),
@@ -353,9 +366,15 @@ async def cb() -> None:
353366 entry ["error" ] = _encode_error (e )
354367 finally :
355368 await self .enqueue_log (entry , True )
369+ fut_host .set_result (None )
356370
357371 _ = self ._loop .create_task (cb ())
358372
373+ self ._tasks [_encode_id (id )] = (
374+ fut_host ,
375+ op .return_type ,
376+ )
377+
359378 case _:
360379 raise NotImplementedError (f"Unsupported op: { op !r} " )
361380
0 commit comments